Skip to content

Commit 02a65f4

Browse files
authored
feat(types): add comprehensive type support and stricter validation (#130)
Expanded type support with new scalar and parameterized types. Improved parameter validation and stricter nullability/precision checks. Signed-off-by: MBWhite <[email protected]>
1 parent 9bb6bb9 commit 02a65f4

File tree

5 files changed

+561
-148
lines changed

5 files changed

+561
-148
lines changed

src/substrait/builders/type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Iterable
2+
23
import substrait.gen.proto.type_pb2 as stt
34

45

src/substrait/derivation_expression.py

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Optional
2-
from antlr4 import InputStream, CommonTokenStream
2+
3+
from antlr4 import CommonTokenStream, InputStream
4+
35
from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
46
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
5-
from substrait.gen.proto.type_pb2 import Type
7+
from substrait.gen.proto.type_pb2 import NamedStruct, Type
68

79

810
def _evaluate(x, values: dict):
@@ -65,22 +67,128 @@ def _evaluate(x, values: dict):
6567
return Type(fp64=Type.FP64(nullability=nullability))
6668
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
6769
return Type(bool=Type.Boolean(nullability=nullability))
70+
elif isinstance(scalar_type, SubstraitTypeParser.StringContext):
71+
return Type(string=Type.String(nullability=nullability))
72+
elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext):
73+
return Type(timestamp=Type.Timestamp(nullability=nullability))
74+
elif isinstance(scalar_type, SubstraitTypeParser.DateContext):
75+
return Type(date=Type.Date(nullability=nullability))
76+
elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext):
77+
return Type(interval_year=Type.IntervalYear(nullability=nullability))
78+
elif isinstance(scalar_type, SubstraitTypeParser.UuidContext):
79+
return Type(uuid=Type.UUID(nullability=nullability))
80+
elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext):
81+
return Type(binary=Type.Binary(nullability=nullability))
82+
elif isinstance(scalar_type, SubstraitTypeParser.TimeContext):
83+
return Type(time=Type.Time(nullability=nullability))
84+
elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext):
85+
return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability))
6886
else:
6987
raise Exception(f"Unknown scalar type {type(scalar_type)}")
7088
elif parametrized_type:
89+
nullability = (
90+
Type.NULLABILITY_NULLABLE
91+
if parametrized_type.isnull
92+
else Type.NULLABILITY_REQUIRED
93+
)
7194
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
7295
precision = _evaluate(parametrized_type.precision, values)
7396
scale = _evaluate(parametrized_type.scale, values)
74-
nullability = (
75-
Type.NULLABILITY_NULLABLE
76-
if parametrized_type.isnull
77-
else Type.NULLABILITY_REQUIRED
78-
)
7997
return Type(
8098
decimal=Type.Decimal(
8199
precision=precision, scale=scale, nullability=nullability
82100
)
83101
)
102+
elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext):
103+
length = _evaluate(parametrized_type.length, values)
104+
return Type(
105+
varchar=Type.VarChar(
106+
length=length,
107+
nullability=nullability,
108+
)
109+
)
110+
elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext):
111+
length = _evaluate(parametrized_type.length, values)
112+
return Type(
113+
fixed_char=Type.FixedChar(
114+
length=length,
115+
nullability=nullability,
116+
)
117+
)
118+
elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext):
119+
length = _evaluate(parametrized_type.length, values)
120+
return Type(
121+
fixed_binary=Type.FixedBinary(
122+
length=length,
123+
nullability=nullability,
124+
)
125+
)
126+
elif isinstance(
127+
parametrized_type, SubstraitTypeParser.PrecisionTimestampContext
128+
):
129+
precision = _evaluate(parametrized_type.precision, values)
130+
return Type(
131+
precision_timestamp=Type.PrecisionTimestamp(
132+
precision=precision,
133+
nullability=nullability,
134+
)
135+
)
136+
elif isinstance(
137+
parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext
138+
):
139+
precision = _evaluate(parametrized_type.precision, values)
140+
return Type(
141+
precision_timestamp_tz=Type.PrecisionTimestampTZ(
142+
precision=precision,
143+
nullability=nullability,
144+
)
145+
)
146+
elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext):
147+
return Type(
148+
interval_year=Type.IntervalYear(
149+
nullability=nullability,
150+
)
151+
)
152+
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
153+
types = list(
154+
map(lambda x: _evaluate(x, values), parametrized_type.expr())
155+
)
156+
return Type(
157+
struct=Type.Struct(
158+
types=types,
159+
nullability=nullability,
160+
)
161+
)
162+
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
163+
list_type = _evaluate(parametrized_type.expr(), values)
164+
return Type(
165+
list=Type.List(
166+
type=list_type,
167+
nullability=nullability,
168+
)
169+
)
170+
171+
elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
172+
return Type(
173+
map=Type.Map(
174+
key=_evaluate(parametrized_type.key, values),
175+
value=_evaluate(parametrized_type.value, values),
176+
nullability=nullability,
177+
)
178+
)
179+
elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext):
180+
names = list(map(lambda k: k.getText(), parametrized_type.Identifier()))
181+
struct = Type.Struct(
182+
types=list(
183+
map(lambda k: _evaluate(k, values), parametrized_type.expr())
184+
),
185+
nullability=nullability,
186+
)
187+
return NamedStruct(
188+
names=names,
189+
struct=struct,
190+
)
191+
84192
raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
85193
elif any_type:
86194
any_var = any_type.AnyVar()

0 commit comments

Comments
 (0)