Skip to content

Commit ee0cb3f

Browse files
authored
fix: correct if_then and switch_expression type inference calls (#128)
Fixed type inference method calls in if_then and switch_expression builders.
1 parent fa12088 commit ee0cb3f

File tree

2 files changed

+164
-11
lines changed

2 files changed

+164
-11
lines changed

src/substrait/type_inference.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
22
import substrait.gen.proto.extended_expression_pb2 as stee
3-
import substrait.gen.proto.type_pb2 as stt
43
import substrait.gen.proto.plan_pb2 as stp
4+
import substrait.gen.proto.type_pb2 as stt
55

66

77
def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type:
@@ -127,7 +127,7 @@ def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type:
127127
raise Exception(f"Unknown literal_type {literal_type}")
128128

129129

130-
def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type:
130+
def infer_nested_type(nested: stalg.Expression.Nested, parent_schema) -> stt.Type:
131131
nested_type = nested.WhichOneof("nested_type")
132132

133133
nullability = (
@@ -139,22 +139,27 @@ def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type:
139139
if nested_type == "struct":
140140
return stt.Type(
141141
struct=stt.Type.Struct(
142-
types=[infer_expression_type(f) for f in nested.struct.fields],
142+
types=[
143+
infer_expression_type(f, parent_schema)
144+
for f in nested.struct.fields
145+
],
143146
nullability=nullability,
144147
)
145148
)
146149
elif nested_type == "list":
147150
return stt.Type(
148151
list=stt.Type.List(
149-
type=infer_expression_type(nested.list.values[0]),
152+
type=infer_expression_type(nested.list.values[0], parent_schema),
150153
nullability=nullability,
151154
)
152155
)
153156
elif nested_type == "map":
154157
return stt.Type(
155158
map=stt.Type.Map(
156-
key=infer_expression_type(nested.map.key_values[0].key),
157-
value=infer_expression_type(nested.map.key_values[0].value),
159+
key=infer_expression_type(nested.map.key_values[0].key, parent_schema),
160+
value=infer_expression_type(
161+
nested.map.key_values[0].value, parent_schema
162+
),
158163
nullability=nullability,
159164
)
160165
)
@@ -191,17 +196,19 @@ def infer_expression_type(
191196
elif rex_type == "window_function":
192197
return expression.window_function.output_type
193198
elif rex_type == "if_then":
194-
return infer_expression_type(expression.if_then.ifs[0].then)
199+
return infer_expression_type(expression.if_then.ifs[0].then, parent_schema)
195200
elif rex_type == "switch_expression":
196-
return infer_expression_type(expression.switch_expression.ifs[0].then)
201+
return infer_expression_type(
202+
expression.switch_expression.ifs[0].then, parent_schema
203+
)
197204
elif rex_type == "cast":
198205
return expression.cast.type
199206
elif rex_type == "singular_or_list" or rex_type == "multi_or_list":
200207
return stt.Type(
201208
bool=stt.Type.Boolean(nullability=stt.Type.Nullability.NULLABILITY_NULLABLE)
202209
)
203210
elif rex_type == "nested":
204-
return infer_nested_type(expression.nested)
211+
return infer_nested_type(expression.nested, parent_schema)
205212
elif rex_type == "subquery":
206213
subquery_type = expression.subquery.WhichOneof("subquery_type")
207214

tests/test_type_inference.py

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
22
import substrait.gen.proto.type_pb2 as stt
3-
from substrait.type_inference import infer_rel_schema
4-
3+
from substrait.type_inference import (
4+
infer_expression_type,
5+
infer_nested_type,
6+
infer_rel_schema,
7+
)
58

69
struct = stt.Type.Struct(
710
types=[
@@ -312,3 +315,146 @@ def test_inference_join_left_mark():
312315
)
313316

314317
assert infer_rel_schema(rel) == expected
318+
319+
320+
def test_infer_expression_type_literal():
321+
"""Test infer_expression_type with a literal expression."""
322+
expr = stalg.Expression(literal=stalg.Expression.Literal(i64=42, nullable=False))
323+
324+
result = infer_expression_type(expr, struct)
325+
326+
expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED))
327+
assert result == expected
328+
329+
330+
def test_infer_expression_type_selection():
331+
"""Test infer_expression_type with a field selection expression."""
332+
expr = stalg.Expression(
333+
selection=stalg.Expression.FieldReference(
334+
root_reference=stalg.Expression.FieldReference.RootReference(),
335+
direct_reference=stalg.Expression.ReferenceSegment(
336+
struct_field=stalg.Expression.ReferenceSegment.StructField(field=0),
337+
),
338+
)
339+
)
340+
341+
result = infer_expression_type(expr, struct)
342+
343+
# Should return the type of field 0 from the struct (i64)
344+
expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED))
345+
assert result == expected
346+
347+
348+
def test_infer_expression_type_window_function():
349+
"""Test infer_expression_type with a window function expression."""
350+
expr = stalg.Expression(
351+
window_function=stalg.Expression.WindowFunction(
352+
function_reference=0,
353+
output_type=stt.Type(
354+
i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)
355+
),
356+
)
357+
)
358+
359+
result = infer_expression_type(expr, struct)
360+
361+
expected = stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE))
362+
assert result == expected
363+
364+
365+
def test_infer_nested_type_struct():
366+
"""Test infer_nested_type with a struct nested expression."""
367+
expr = stalg.Expression(
368+
nested=stalg.Expression.Nested(
369+
struct=stalg.Expression.Nested.Struct(
370+
fields=[
371+
stalg.Expression(
372+
literal=stalg.Expression.Literal(i32=1, nullable=False)
373+
),
374+
stalg.Expression(
375+
literal=stalg.Expression.Literal(string="test", nullable=True)
376+
),
377+
]
378+
),
379+
nullable=False,
380+
)
381+
)
382+
383+
result = infer_nested_type(expr.nested, struct)
384+
385+
expected = stt.Type(
386+
struct=stt.Type.Struct(
387+
types=[
388+
stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)),
389+
stt.Type(
390+
string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)
391+
),
392+
],
393+
nullability=stt.Type.NULLABILITY_REQUIRED,
394+
)
395+
)
396+
assert result == expected
397+
398+
399+
def test_infer_nested_type_list():
400+
"""Test infer_nested_type with a list nested expression."""
401+
expr = stalg.Expression(
402+
nested=stalg.Expression.Nested(
403+
list=stalg.Expression.Nested.List(
404+
values=[
405+
stalg.Expression(
406+
literal=stalg.Expression.Literal(fp32=3.14, nullable=False)
407+
),
408+
]
409+
),
410+
nullable=False,
411+
)
412+
)
413+
414+
result = infer_nested_type(expr.nested, struct)
415+
416+
expected = stt.Type(
417+
list=stt.Type.List(
418+
type=stt.Type(
419+
fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_REQUIRED)
420+
),
421+
nullability=stt.Type.NULLABILITY_REQUIRED,
422+
)
423+
)
424+
assert result == expected
425+
426+
427+
def test_infer_nested_type_map():
428+
"""Test infer_nested_type with a map nested expression."""
429+
expr = stalg.Expression(
430+
nested=stalg.Expression.Nested(
431+
map=stalg.Expression.Nested.Map(
432+
key_values=[
433+
stalg.Expression.Nested.Map.KeyValue(
434+
key=stalg.Expression(
435+
literal=stalg.Expression.Literal(
436+
string="key", nullable=False
437+
)
438+
),
439+
value=stalg.Expression(
440+
literal=stalg.Expression.Literal(i32=42, nullable=False)
441+
),
442+
),
443+
]
444+
),
445+
nullable=False,
446+
)
447+
)
448+
449+
result = infer_nested_type(expr.nested, struct)
450+
451+
expected = stt.Type(
452+
map=stt.Type.Map(
453+
key=stt.Type(
454+
string=stt.Type.String(nullability=stt.Type.NULLABILITY_REQUIRED)
455+
),
456+
value=stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)),
457+
nullability=stt.Type.NULLABILITY_REQUIRED,
458+
)
459+
)
460+
assert result == expected

0 commit comments

Comments
 (0)