Skip to content

Commit d145e28

Browse files
committed
fix remaining extended expressions missing uris + fmt
also added tests
1 parent 125d6b3 commit d145e28

17 files changed

+672
-84
lines changed

examples/builder_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ def expression_only_example():
318318
print("Complex salary calculation expression:")
319319
# Create a simple plan to wrap the expression
320320
dummy_schema = named_struct(
321-
names=["base_salary"], struct=struct(types=[fp64(nullable=False)], nullable=False)
321+
names=["base_salary"],
322+
struct=struct(types=[fp64(nullable=False)], nullable=False),
322323
)
323324
dummy_table = read_named_table("dummy", dummy_schema)
324325
dummy_plan = project(dummy_table, expressions=[complex_expr])

src/substrait/builders/extended_expression.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def resolve(
250250
ste.SimpleExtensionDeclaration(
251251
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
252252
extension_urn_reference=registry.lookup_urn(urn),
253-
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
253+
extension_uri_reference=registry.lookup_uri_anchor(uri)
254+
if uri
255+
else 0,
254256
function_anchor=func[0].anchor,
255257
name=str(func[0]),
256258
)
@@ -346,7 +348,9 @@ def resolve(
346348
ste.SimpleExtensionDeclaration(
347349
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
348350
extension_urn_reference=registry.lookup_urn(urn),
349-
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
351+
extension_uri_reference=registry.lookup_uri_anchor(uri)
352+
if uri
353+
else 0,
350354
function_anchor=func[0].anchor,
351355
name=str(func[0]),
352356
)
@@ -444,14 +448,15 @@ def resolve(
444448
ste.SimpleExtensionDeclaration(
445449
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
446450
extension_urn_reference=registry.lookup_urn(urn),
447-
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
451+
extension_uri_reference=registry.lookup_uri_anchor(uri)
452+
if uri
453+
else 0,
448454
function_anchor=func[0].anchor,
449455
name=str(func[0]),
450456
)
451457
)
452458
]
453459

454-
455460
extension_urns = merge_extension_urns(
456461
func_extension_urns,
457462
*[b.extension_urns for b in bound_expressions],
@@ -609,6 +614,12 @@ def resolve(
609614
]
610615
bound_else = resolve_expression(_else, base_schema, registry)
611616

617+
extension_uris = merge_extension_uris(
618+
bound_match.extension_uris,
619+
*[b.extension_uris for _, b in bound_ifs],
620+
bound_else.extension_uris,
621+
)
622+
612623
extension_urns = merge_extension_urns(
613624
bound_match.extension_urns,
614625
*[b.extension_urns for _, b in bound_ifs],
@@ -644,6 +655,7 @@ def resolve(
644655
],
645656
base_schema=base_schema,
646657
extension_urns=extension_urns,
658+
extension_uris=extension_uris,
647659
extensions=extensions,
648660
)
649661

@@ -661,6 +673,10 @@ def resolve(
661673
bound_value = resolve_expression(value, base_schema, registry)
662674
bound_options = [resolve_expression(o, base_schema, registry) for o in options]
663675

676+
extension_uris = merge_extension_uris(
677+
bound_value.extension_uris, *[b.extension_uris for b in bound_options]
678+
)
679+
664680
extension_urns = merge_extension_urns(
665681
bound_value.extension_urns, *[b.extension_urns for b in bound_options]
666682
)
@@ -687,6 +703,7 @@ def resolve(
687703
],
688704
base_schema=base_schema,
689705
extension_urns=extension_urns,
706+
extension_uris=extension_uris,
690707
extensions=extensions,
691708
)
692709

@@ -707,12 +724,17 @@ def resolve(
707724
[resolve_expression(e, base_schema, registry) for e in o] for o in options
708725
]
709726

727+
extension_uris = merge_extension_uris(
728+
*[b.extension_uris for b in bound_value],
729+
*[e.extension_uris for b in bound_options for e in b],
730+
)
731+
710732
extension_urns = merge_extension_urns(
711733
*[b.extension_urns for b in bound_value],
712734
*[e.extension_urns for b in bound_options for e in b],
713735
)
714736

715-
extensions = merge_extension_urns(
737+
extensions = merge_extension_declarations(
716738
*[b.extensions for b in bound_value],
717739
*[e.extensions for b in bound_options for e in b],
718740
)
@@ -738,6 +760,7 @@ def resolve(
738760
],
739761
base_schema=base_schema,
740762
extension_urns=extension_urns,
763+
extension_uris=extension_uris,
741764
extensions=extensions,
742765
)
743766

@@ -767,6 +790,7 @@ def resolve(
767790
],
768791
base_schema=base_schema,
769792
extension_urns=bound_input.extension_urns,
793+
extension_uris=bound_input.extension_uris,
770794
extensions=bound_input.extensions,
771795
)
772796

src/substrait/builders/plan.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
resolve_expression,
1919
)
2020
from substrait.type_inference import infer_plan_schema
21-
from substrait.utils import merge_extension_declarations, merge_extension_urns, merge_extension_uris
21+
from substrait.utils import (
22+
merge_extension_declarations,
23+
merge_extension_urns,
24+
merge_extension_uris,
25+
)
2226

2327
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
2428

@@ -29,7 +33,7 @@ def _merge_extensions(*objs):
2933
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
3034
3135
During the URI -> URN migration period, we maintain both URI and URN references
32-
for backwards compatibility.
36+
for backwards compatibility.
3337
"""
3438
return {
3539
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs if b]),

src/substrait/sql/sql_to_substrait.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
"Eq": ("extension:io.substrait:functions_comparison", "equal"),
3737
}
3838

39-
aggregate_function_mapping = {"SUM": ("extension:io.substrait:functions_arithmetic", "sum")}
39+
aggregate_function_mapping = {
40+
"SUM": ("extension:io.substrait:functions_arithmetic", "sum")
41+
}
4042

4143
window_function_mapping = {
4244
"row_number": ("extension:io.substrait:functions_arithmetic", "row_number"),

src/substrait/utils/display.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def _stream_literal_value(
723723
elif literal.HasField("string"):
724724
string_value = f'"{literal.string}"'
725725
stream.write(
726-
f'{indent}{self._color("string", Colors.BLUE)}: {self._color(string_value, Colors.GREEN)}\n'
726+
f"{indent}{self._color('string', Colors.BLUE)}: {self._color(string_value, Colors.GREEN)}\n"
727727
)
728728
elif literal.HasField("date"):
729729
stream.write(

tests/builders/extended_expression/test_aggregate_function.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838

3939

4040
registry = ExtensionRegistry(load_default_extensions=False)
41-
registry.register_extension_dict(yaml.safe_load(content), uri="https://test.example.com/test.yaml")
41+
registry.register_extension_dict(
42+
yaml.safe_load(content), uri="https://test.example.com/test.yaml"
43+
)
4244

4345

4446
def test_aggregate_count():
@@ -57,15 +59,21 @@ def test_aggregate_count():
5759
)(named_struct, registry)
5860

5961
expected = stee.ExtendedExpression(
60-
extension_urns=[ste.SimpleExtensionURN(extension_urn_anchor=1, urn="extension:test:urn")],
61-
extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="https://test.example.com/test.yaml")],
62+
extension_urns=[
63+
ste.SimpleExtensionURN(extension_urn_anchor=1, urn="extension:test:urn")
64+
],
65+
extension_uris=[
66+
ste.SimpleExtensionURI(
67+
extension_uri_anchor=1, uri="https://test.example.com/test.yaml"
68+
)
69+
],
6270
extensions=[
6371
ste.SimpleExtensionDeclaration(
6472
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
6573
extension_urn_reference=1,
6674
extension_uri_reference=1,
6775
function_anchor=1,
68-
name="count:any"
76+
name="count:any",
6977
)
7078
)
7179
],

tests/builders/extended_expression/test_cast.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,103 @@ def test_cast():
4444
)
4545

4646
assert e == expected
47+
48+
49+
def test_cast_with_extension():
50+
"""Test cast with scalar function to verify both URI and URN are present."""
51+
import yaml
52+
import substrait.gen.proto.extensions.extensions_pb2 as ste
53+
from substrait.builders.extended_expression import scalar_function
54+
55+
# Register a minimal extension
56+
registry_with_ext = ExtensionRegistry(load_default_extensions=False)
57+
content = """%YAML 1.2
58+
---
59+
urn: extension:test:functions
60+
scalar_functions:
61+
- name: "add"
62+
description: ""
63+
impls:
64+
- args:
65+
- value: i8
66+
- value: i8
67+
return: i8
68+
"""
69+
registry_with_ext.register_extension_dict(
70+
yaml.safe_load(content), uri="https://test.example.com/functions.yaml"
71+
)
72+
73+
# cast on scalar function result
74+
actual = cast(
75+
input=scalar_function(
76+
"extension:test:functions",
77+
"add",
78+
expressions=[literal(1, i8()), literal(2, i8())],
79+
),
80+
type=i16(),
81+
)(named_struct, registry_with_ext)
82+
83+
expected = stee.ExtendedExpression(
84+
extension_uris=[
85+
ste.SimpleExtensionURI(
86+
extension_uri_anchor=1, uri="https://test.example.com/functions.yaml"
87+
)
88+
],
89+
extension_urns=[
90+
ste.SimpleExtensionURN(
91+
extension_urn_anchor=1, urn="extension:test:functions"
92+
)
93+
],
94+
extensions=[
95+
ste.SimpleExtensionDeclaration(
96+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
97+
extension_uri_reference=1,
98+
extension_urn_reference=1,
99+
function_anchor=1,
100+
name="add:i8_i8",
101+
)
102+
)
103+
],
104+
referred_expr=[
105+
stee.ExpressionReference(
106+
expression=stalg.Expression(
107+
cast=stalg.Expression.Cast(
108+
type=stt.Type(
109+
i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE)
110+
),
111+
input=stalg.Expression(
112+
scalar_function=stalg.Expression.ScalarFunction(
113+
function_reference=1,
114+
output_type=stt.Type(
115+
i8=stt.Type.I8(
116+
nullability=stt.Type.NULLABILITY_NULLABLE
117+
)
118+
),
119+
arguments=[
120+
stalg.FunctionArgument(
121+
value=stalg.Expression(
122+
literal=stalg.Expression.Literal(
123+
i8=1, nullable=True
124+
)
125+
)
126+
),
127+
stalg.FunctionArgument(
128+
value=stalg.Expression(
129+
literal=stalg.Expression.Literal(
130+
i8=2, nullable=True
131+
)
132+
)
133+
),
134+
],
135+
)
136+
),
137+
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
138+
)
139+
),
140+
output_names=["cast"],
141+
)
142+
],
143+
base_schema=named_struct,
144+
)
145+
146+
assert actual == expected

0 commit comments

Comments
 (0)