Skip to content
Merged
1 change: 1 addition & 0 deletions src/substrait/builders/type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable

import substrait.gen.proto.type_pb2 as stt


Expand Down
122 changes: 115 additions & 7 deletions src/substrait/derivation_expression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional
from antlr4 import InputStream, CommonTokenStream

from antlr4 import CommonTokenStream, InputStream

from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.proto.type_pb2 import Type
from substrait.gen.proto.type_pb2 import NamedStruct, Type


def _evaluate(x, values: dict):
Expand Down Expand Up @@ -65,22 +67,128 @@ def _evaluate(x, values: dict):
return Type(fp64=Type.FP64(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
return Type(bool=Type.Boolean(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.StringContext):
return Type(string=Type.String(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext):
return Type(timestamp=Type.Timestamp(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.DateContext):
return Type(date=Type.Date(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext):
return Type(interval_year=Type.IntervalYear(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.UuidContext):
return Type(uuid=Type.UUID(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext):
return Type(binary=Type.Binary(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimeContext):
return Type(time=Type.Time(nullability=nullability))
elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext):
return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability))
else:
raise Exception(f"Unknown scalar type {type(scalar_type)}")
elif parametrized_type:
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
precision = _evaluate(parametrized_type.precision, values)
scale = _evaluate(parametrized_type.scale, values)
nullability = (
Type.NULLABILITY_NULLABLE
if parametrized_type.isnull
else Type.NULLABILITY_REQUIRED
)
return Type(
decimal=Type.Decimal(
precision=precision, scale=scale, nullability=nullability
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
varchar=Type.VarChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_char=Type.FixedChar(
length=length,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext):
length = _evaluate(parametrized_type.length, values)
return Type(
fixed_binary=Type.FixedBinary(
length=length,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp=Type.PrecisionTimestamp(
precision=precision,
nullability=nullability,
)
)
elif isinstance(
parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext
):
precision = _evaluate(parametrized_type.precision, values)
return Type(
precision_timestamp_tz=Type.PrecisionTimestampTZ(
precision=precision,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext):
return Type(
interval_year=Type.IntervalYear(
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
types = list(
map(lambda x: _evaluate(x, values), parametrized_type.expr())
)
return Type(
struct=Type.Struct(
types=types,
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
list_type = _evaluate(parametrized_type.expr(), values)
return Type(
list=Type.List(
type=list_type,
nullability=nullability,
)
)

elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
return Type(
map=Type.Map(
key=_evaluate(parametrized_type.key, values),
value=_evaluate(parametrized_type.value, values),
nullability=nullability,
)
)
elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext):
names = list(map(lambda k: k.getText(), parametrized_type.Identifier()))
struct = Type.Struct(
types=list(
map(lambda k: _evaluate(k, values), parametrized_type.expr())
),
nullability=nullability,
)
return NamedStruct(
names=names,
struct=struct,
)

raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
elif any_type:
any_var = any_type.AnyVar()
Expand Down
Loading