Skip to content

Commit d21c9d4

Browse files
committed
feat(types): add many new types and tighten nullability/precision checks
Add wide-ranging support for additional scalar and parameterized Substrait types, improve parameter handling, and make nullability/precision checks stricter and more correct. Signed-off-by: MBWhite <[email protected]>
1 parent fa12088 commit d21c9d4

File tree

4 files changed

+379
-102
lines changed

4 files changed

+379
-102
lines changed

src/substrait/builders/type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,14 @@ def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type:
221221
)
222222
)
223223

224+
def timestamp(nullable=True) -> stt.Type:
225+
return stt.Type(
226+
timestamp=stt.Type.Timestamp(
227+
nullability=stt.Type.NULLABILITY_NULLABLE
228+
if nullable
229+
else stt.Type.NULLABILITY_REQUIRED,
230+
)
231+
)
224232

225233
def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type:
226234
return stt.Type(

src/substrait/derivation_expression.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,116 @@ def _evaluate(x, values: dict):
6565
return Type(fp64=Type.FP64(nullability=nullability))
6666
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
6767
return Type(bool=Type.Boolean(nullability=nullability))
68+
elif isinstance(scalar_type, SubstraitTypeParser.StringContext):
69+
return Type(string=Type.String(nullability=nullability))
70+
elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext):
71+
return Type(timestamp=Type.Timestamp(nullability=nullability))
72+
elif isinstance(scalar_type, SubstraitTypeParser.DateContext):
73+
return Type(date=Type.Date(nullability=nullability))
74+
elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext):
75+
return Type(interval_year=Type.IntervalYear(nullability=nullability))
76+
elif isinstance(scalar_type, SubstraitTypeParser.UuidContext):
77+
return Type(uuid=Type.UUID(nullability=nullability))
78+
elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext):
79+
return Type(binary=Type.Binary(nullability=nullability))
80+
elif isinstance(scalar_type, SubstraitTypeParser.TimeContext):
81+
return Type(time=Type.Time(nullability=nullability))
82+
elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext):
83+
return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability))
6884
else:
6985
raise Exception(f"Unknown scalar type {type(scalar_type)}")
7086
elif parametrized_type:
87+
nullability = (
88+
Type.NULLABILITY_NULLABLE
89+
if parametrized_type.isnull
90+
else Type.NULLABILITY_REQUIRED
91+
)
7192
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
7293
precision = _evaluate(parametrized_type.precision, values)
7394
scale = _evaluate(parametrized_type.scale, values)
74-
nullability = (
75-
Type.NULLABILITY_NULLABLE
76-
if parametrized_type.isnull
77-
else Type.NULLABILITY_REQUIRED
78-
)
7995
return Type(
8096
decimal=Type.Decimal(
8197
precision=precision, scale=scale, nullability=nullability
8298
)
8399
)
100+
elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext):
101+
length = _evaluate(parametrized_type.length, values)
102+
return Type(
103+
varchar=Type.VarChar(
104+
length=length,
105+
nullability=nullability,
106+
)
107+
)
108+
elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext):
109+
length = _evaluate(parametrized_type.length, values)
110+
return Type(
111+
fixed_char=Type.FixedChar(
112+
length=length,
113+
nullability=nullability,
114+
)
115+
)
116+
elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext):
117+
length = _evaluate(parametrized_type.length, values)
118+
return Type(
119+
fixed_binary=Type.FixedBinary(
120+
length=length,
121+
nullability=nullability,
122+
)
123+
)
124+
elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampContext):
125+
precision = _evaluate(parametrized_type.precision, values)
126+
return Type(
127+
precision_timestamp=Type.PrecisionTimestamp(
128+
precision=precision,
129+
nullability=nullability,
130+
)
131+
)
132+
elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext):
133+
precision = _evaluate(parametrized_type.precision, values)
134+
return Type(
135+
precision_timestamp_tz=Type.PrecisionTimestampTZ(
136+
precision=precision,
137+
nullability=nullability,
138+
)
139+
)
140+
elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext):
141+
return Type(
142+
interval_year=Type.IntervalYear(
143+
nullability=nullability,
144+
)
145+
)
146+
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
147+
types = list(map(lambda x: _evaluate(x,values),parametrized_type.expr()))
148+
return Type(
149+
struct=Type.Struct(
150+
types=types,
151+
nullability=nullability,
152+
)
153+
)
154+
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
155+
type = _evaluate(parametrized_type.expr(),values)
156+
return Type(
157+
list=Type.List(
158+
type=type,
159+
nullability=nullability,
160+
)
161+
)
162+
163+
elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
164+
return Type(
165+
map=Type.Map(
166+
key=_evaluate(parametrized_type.key,values),
167+
value=_evaluate(parametrized_type.value,values),
168+
nullability=nullability,
169+
)
170+
)
171+
elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext):
172+
# it gives me a parser error i may have to update the parser
173+
# string `evaluate("NSTRUCT<longitude: i32, latitude: i32>")` from the docs https://substrait.io/types/type_classes/
174+
# line 1:17 extraneous input ':'
175+
raise NotImplementedError("Named structure type not implemented yet")
176+
# elif isinstance(parametrized_type, SubstraitTypeParser.UserDefinedContext):
177+
84178
raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
85179
elif any_type:
86180
any_var = any_type.AnyVar()

src/substrait/extension_registry.py

Lines changed: 116 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,24 @@ def normalize_substrait_type_names(typ: str) -> str:
6868
raise Exception(f"Unrecognized substrait type {typ}")
6969

7070

71-
def violates_integer_option(actual: int, option, parameters: dict):
71+
def violates_integer_option(actual: int, option, parameters: dict, subset=False):
72+
option_numeric = None
7273
if isinstance(option, SubstraitTypeParser.NumericLiteralContext):
73-
return actual != int(str(option.Number()))
74+
option_numeric = int(str(option.Number()))
7475
elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext):
7576
parameter_name = str(option.Identifier())
76-
if parameter_name in parameters and parameters[parameter_name] != actual:
77-
return True
78-
else:
77+
78+
if parameter_name not in parameters:
7979
parameters[parameter_name] = actual
80+
option_numeric = parameters[parameter_name]
8081
else:
8182
raise Exception(
8283
f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead"
8384
)
84-
85-
return False
85+
if subset:
86+
return actual < option_numeric
87+
else:
88+
return actual != option_numeric
8689

8790

8891
def types_equal(type1: Type, type2: Type, check_nullability=False):
@@ -112,6 +115,27 @@ def handle_parameter_cover(
112115
return True
113116

114117

118+
def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool:
119+
if not check_nullability:
120+
return True
121+
# The ANTLR context stores a Token called ``isnull`` – it is
122+
# present when the type is declared as nullable.
123+
nullability = (
124+
Type.Nullability.NULLABILITY_NULLABLE
125+
if getattr(parameterized_type, "isnull", None) is not None
126+
else Type.Nullability.NULLABILITY_REQUIRED
127+
)
128+
# if nullability == Type.Nullability.NULLABILITY_NULLABLE:
129+
# return True # is still true even if the covered is required
130+
# The protobuf message stores its own enum – we compare the two.
131+
covered_nullability = getattr(
132+
getattr(covered, kind), # e.g. covered.varchar
133+
"nullability",
134+
None,
135+
)
136+
return nullability == covered_nullability
137+
138+
115139
def covers(
116140
covered: Type,
117141
covering: SubstraitTypeParser.TypeLiteralContext,
@@ -123,7 +147,6 @@ def covers(
123147
return handle_parameter_cover(
124148
covered, parameter_name, parameters, check_nullability
125149
)
126-
127150
covering: SubstraitTypeParser.TypeDefContext = covering.typeDef()
128151

129152
any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType()
@@ -142,31 +165,99 @@ def covers(
142165

143166
parameterized_type = covering.parameterizedType()
144167
if parameterized_type:
145-
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
146-
if covered.WhichOneof("kind") != "decimal":
168+
kind = covered.WhichOneof("kind")
169+
if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext):
170+
if kind != "varchar":
171+
return False
172+
if hasattr(parameterized_type, "length") and violates_integer_option(
173+
covered.varchar.length, parameterized_type.length, parameters
174+
):
147175
return False
148176

149-
nullability = (
150-
Type.NULLABILITY_NULLABLE
151-
if parameterized_type.isnull
152-
else Type.NULLABILITY_REQUIRED
177+
return _check_nullability(
178+
check_nullability, parameterized_type, covered, kind
153179
)
154-
155-
if (
156-
check_nullability
157-
and nullability
158-
!= covered.__getattribute__(covered.WhichOneof("kind")).nullability
180+
if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext):
181+
if kind != "fixed_char":
182+
return False
183+
if hasattr(parameterized_type, "length") and violates_integer_option(
184+
covered.fixed_char.length, parameterized_type.length, parameters
159185
):
160186
return False
187+
return _check_nullability(
188+
check_nullability, parameterized_type, covered, kind
189+
)
161190

191+
if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext):
192+
if kind != "fixed_binary":
193+
return False
194+
if hasattr(parameterized_type, "length") and violates_integer_option(
195+
covered.fixed_binary.length, parameterized_type.length, parameters
196+
):
197+
return False
198+
# return True
199+
return _check_nullability(
200+
check_nullability, parameterized_type, covered, kind
201+
)
202+
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
203+
if kind != "decimal":
204+
return False
205+
if not _check_nullability(
206+
check_nullability, parameterized_type, covered, kind
207+
):
208+
return False
209+
# precision / scale are both optional – a missing value means “no limit”.
210+
covered_scale = getattr(covered.decimal, "scale", 0)
211+
param_scale = getattr(parameterized_type, "scale", 0)
212+
covered_prec = getattr(covered.decimal, "precision", 0)
213+
param_prec = getattr(parameterized_type, "precision", 0)
162214
return not (
163-
violates_integer_option(
164-
covered.decimal.scale, parameterized_type.scale, parameters
165-
)
166-
or violates_integer_option(
167-
covered.decimal.precision, parameterized_type.precision, parameters
168-
)
215+
violates_integer_option(covered_scale, param_scale, parameters)
216+
or violates_integer_option(covered_prec, param_prec, parameters)
169217
)
218+
if isinstance(
219+
parameterized_type, SubstraitTypeParser.PrecisionTimestampContext
220+
):
221+
if kind != "precision_timestamp":
222+
return False
223+
if not _check_nullability(
224+
check_nullability, parameterized_type, covered, kind
225+
):
226+
return False
227+
# return True
228+
covered_prec = getattr(covered.precision_timestamp, "precision", 0)
229+
param_prec = getattr(parameterized_type, "precision", 0)
230+
return not violates_integer_option(covered_prec, param_prec, parameters)
231+
232+
if isinstance(
233+
parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext
234+
):
235+
if kind != "precision_timestamp_tz":
236+
return False
237+
if not _check_nullability(
238+
check_nullability, parameterized_type, covered, kind
239+
):
240+
return False
241+
# return True
242+
covered_prec = getattr(covered.precision_timestamp_tz, "precision", 0)
243+
param_prec = getattr(parameterized_type, "precision", 0)
244+
return not violates_integer_option(covered_prec, param_prec, parameters)
245+
246+
kind_mapping = {
247+
SubstraitTypeParser.ListContext: "list",
248+
SubstraitTypeParser.MapContext: "map",
249+
SubstraitTypeParser.StructContext: "struct",
250+
SubstraitTypeParser.UserDefinedContext: "user_defined",
251+
SubstraitTypeParser.PrecisionIntervalDayContext: "interval_day",
252+
}
253+
254+
for ctx_cls, expected_kind in kind_mapping.items():
255+
if isinstance(parameterized_type, ctx_cls):
256+
if kind != expected_kind:
257+
return False
258+
return _check_nullability(
259+
check_nullability, parameterized_type, covered, kind
260+
)
170261
else:
171262
raise Exception(f"Unhandled type {type(parameterized_type)}")
172263

0 commit comments

Comments
 (0)