Skip to content

Commit 975952d

Browse files
committed
fix remaining extended expressions missing uris + fmt
also added tests
1 parent 9a13e6e commit 975952d

17 files changed

+662
-93
lines changed

examples/builder_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ def expression_only_example():
318318
print("Complex salary calculation expression:")
319319
# Create a simple plan to wrap the expression
320320
dummy_schema = named_struct(
321-
names=["base_salary"], struct=struct(types=[fp64(nullable=False)], nullable=False)
321+
names=["base_salary"],
322+
struct=struct(types=[fp64(nullable=False)], nullable=False),
322323
)
323324
dummy_table = read_named_table("dummy", dummy_schema)
324325
dummy_plan = project(dummy_table, expressions=[complex_expr])

src/substrait/builders/extended_expression.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def resolve(
250250
ste.SimpleExtensionDeclaration(
251251
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
252252
extension_urn_reference=registry.lookup_urn(urn),
253-
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
253+
extension_uri_reference=registry.lookup_uri_anchor(uri)
254+
if uri
255+
else 0,
254256
function_anchor=func[0].anchor,
255257
name=str(func[0]),
256258
)
@@ -346,7 +348,9 @@ def resolve(
346348
ste.SimpleExtensionDeclaration(
347349
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
348350
extension_urn_reference=registry.lookup_urn(urn),
349-
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
351+
extension_uri_reference=registry.lookup_uri_anchor(uri)
352+
if uri
353+
else 0,
350354
function_anchor=func[0].anchor,
351355
name=str(func[0]),
352356
)
@@ -444,14 +448,15 @@ def resolve(
444448
ste.SimpleExtensionDeclaration(
445449
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
446450
extension_urn_reference=registry.lookup_urn(urn),
447-
extension_uri_reference=registry.lookup_uri_anchor(uri) if uri else 0,
451+
extension_uri_reference=registry.lookup_uri_anchor(uri)
452+
if uri
453+
else 0,
448454
function_anchor=func[0].anchor,
449455
name=str(func[0]),
450456
)
451457
)
452458
]
453459

454-
455460
extension_urns = merge_extension_urns(
456461
func_extension_urns,
457462
*[b.extension_urns for b in bound_expressions],
@@ -609,6 +614,12 @@ def resolve(
609614
]
610615
bound_else = resolve_expression(_else, base_schema, registry)
611616

617+
extension_uris = merge_extension_uris(
618+
bound_match.extension_uris,
619+
*[b.extension_uris for _, b in bound_ifs],
620+
bound_else.extension_uris,
621+
)
622+
612623
extension_urns = merge_extension_urns(
613624
bound_match.extension_urns,
614625
*[b.extension_urns for _, b in bound_ifs],
@@ -644,6 +655,7 @@ def resolve(
644655
],
645656
base_schema=base_schema,
646657
extension_urns=extension_urns,
658+
extension_uris=extension_uris,
647659
extensions=extensions,
648660
)
649661

@@ -661,6 +673,10 @@ def resolve(
661673
bound_value = resolve_expression(value, base_schema, registry)
662674
bound_options = [resolve_expression(o, base_schema, registry) for o in options]
663675

676+
extension_uris = merge_extension_uris(
677+
bound_value.extension_uris, *[b.extension_uris for b in bound_options]
678+
)
679+
664680
extension_urns = merge_extension_urns(
665681
bound_value.extension_urns, *[b.extension_urns for b in bound_options]
666682
)
@@ -687,6 +703,7 @@ def resolve(
687703
],
688704
base_schema=base_schema,
689705
extension_urns=extension_urns,
706+
extension_uris=extension_uris,
690707
extensions=extensions,
691708
)
692709

@@ -707,12 +724,17 @@ def resolve(
707724
[resolve_expression(e, base_schema, registry) for e in o] for o in options
708725
]
709726

727+
extension_uris = merge_extension_uris(
728+
*[b.extension_uris for b in bound_value],
729+
*[e.extension_uris for b in bound_options for e in b],
730+
)
731+
710732
extension_urns = merge_extension_urns(
711733
*[b.extension_urns for b in bound_value],
712734
*[e.extension_urns for b in bound_options for e in b],
713735
)
714736

715-
extensions = merge_extension_urns(
737+
extensions = merge_extension_declarations(
716738
*[b.extensions for b in bound_value],
717739
*[e.extensions for b in bound_options for e in b],
718740
)
@@ -738,6 +760,7 @@ def resolve(
738760
],
739761
base_schema=base_schema,
740762
extension_urns=extension_urns,
763+
extension_uris=extension_uris,
741764
extensions=extensions,
742765
)
743766

@@ -767,6 +790,7 @@ def resolve(
767790
],
768791
base_schema=base_schema,
769792
extension_urns=bound_input.extension_urns,
793+
extension_uris=bound_input.extension_uris,
770794
extensions=bound_input.extensions,
771795
)
772796

src/substrait/builders/plan.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
resolve_expression,
1919
)
2020
from substrait.type_inference import infer_plan_schema
21-
from substrait.utils import merge_extension_declarations, merge_extension_urns, merge_extension_uris
21+
from substrait.utils import (
22+
merge_extension_declarations,
23+
merge_extension_urns,
24+
merge_extension_uris,
25+
)
2226

2327
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
2428

@@ -29,7 +33,7 @@ def _merge_extensions(*objs):
2933
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
3034
3135
During the URI -> URN migration period, we maintain both URI and URN references
32-
for backwards compatibility.
36+
for backwards compatibility.
3337
"""
3438
return {
3539
"extension_uris": merge_extension_uris(*[b.extension_uris for b in objs if b]),

src/substrait/sql/sql_to_substrait.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
"Eq": ("extension:io.substrait:functions_comparison", "equal"),
3737
}
3838

39-
aggregate_function_mapping = {"SUM": ("extension:io.substrait:functions_arithmetic", "sum")}
39+
aggregate_function_mapping = {
40+
"SUM": ("extension:io.substrait:functions_arithmetic", "sum")
41+
}
4042

4143
window_function_mapping = {
4244
"row_number": ("extension:io.substrait:functions_arithmetic", "row_number"),

src/substrait/utils/display.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def _stream_literal_value(
723723
elif literal.HasField("string"):
724724
string_value = f'"{literal.string}"'
725725
stream.write(
726-
f'{indent}{self._color("string", Colors.BLUE)}: {self._color(string_value, Colors.GREEN)}\n'
726+
f"{indent}{self._color('string', Colors.BLUE)}: {self._color(string_value, Colors.GREEN)}\n"
727727
)
728728
elif literal.HasField("date"):
729729
stream.write(

tests/builders/extended_expression/test_aggregate_function.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838

3939

4040
registry = ExtensionRegistry(load_default_extensions=False)
41-
registry.register_extension_dict(yaml.safe_load(content), uri="https://test.example.com/test.yaml")
41+
registry.register_extension_dict(
42+
yaml.safe_load(content), uri="https://test.example.com/test.yaml"
43+
)
4244

4345

4446
def test_aggregate_count():
@@ -57,15 +59,21 @@ def test_aggregate_count():
5759
)(named_struct, registry)
5860

5961
expected = stee.ExtendedExpression(
60-
extension_urns=[ste.SimpleExtensionURN(extension_urn_anchor=1, urn="extension:test:urn")],
61-
extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="https://test.example.com/test.yaml")],
62+
extension_urns=[
63+
ste.SimpleExtensionURN(extension_urn_anchor=1, urn="extension:test:urn")
64+
],
65+
extension_uris=[
66+
ste.SimpleExtensionURI(
67+
extension_uri_anchor=1, uri="https://test.example.com/test.yaml"
68+
)
69+
],
6270
extensions=[
6371
ste.SimpleExtensionDeclaration(
6472
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
6573
extension_urn_reference=1,
6674
extension_uri_reference=1,
6775
function_anchor=1,
68-
name="count:any"
76+
name="count:any",
6977
)
7078
)
7179
],

tests/builders/extended_expression/test_cast.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,100 @@ def test_cast():
4444
)
4545

4646
assert e == expected
47+
48+
49+
def test_cast_with_extension():
50+
import yaml
51+
import substrait.gen.proto.extensions.extensions_pb2 as ste
52+
from substrait.builders.extended_expression import scalar_function
53+
54+
registry_with_ext = ExtensionRegistry(load_default_extensions=False)
55+
content = """%YAML 1.2
56+
---
57+
urn: extension:test:functions
58+
scalar_functions:
59+
- name: "add"
60+
description: ""
61+
impls:
62+
- args:
63+
- value: i8
64+
- value: i8
65+
return: i8
66+
"""
67+
registry_with_ext.register_extension_dict(
68+
yaml.safe_load(content), uri="https://test.example.com/functions.yaml"
69+
)
70+
71+
actual = cast(
72+
input=scalar_function(
73+
"extension:test:functions",
74+
"add",
75+
expressions=[literal(1, i8()), literal(2, i8())],
76+
),
77+
type=i16(),
78+
)(named_struct, registry_with_ext)
79+
80+
expected = stee.ExtendedExpression(
81+
extension_uris=[
82+
ste.SimpleExtensionURI(
83+
extension_uri_anchor=1, uri="https://test.example.com/functions.yaml"
84+
)
85+
],
86+
extension_urns=[
87+
ste.SimpleExtensionURN(
88+
extension_urn_anchor=1, urn="extension:test:functions"
89+
)
90+
],
91+
extensions=[
92+
ste.SimpleExtensionDeclaration(
93+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
94+
extension_uri_reference=1,
95+
extension_urn_reference=1,
96+
function_anchor=1,
97+
name="add:i8_i8",
98+
)
99+
)
100+
],
101+
referred_expr=[
102+
stee.ExpressionReference(
103+
expression=stalg.Expression(
104+
cast=stalg.Expression.Cast(
105+
type=stt.Type(
106+
i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE)
107+
),
108+
input=stalg.Expression(
109+
scalar_function=stalg.Expression.ScalarFunction(
110+
function_reference=1,
111+
output_type=stt.Type(
112+
i8=stt.Type.I8(
113+
nullability=stt.Type.NULLABILITY_NULLABLE
114+
)
115+
),
116+
arguments=[
117+
stalg.FunctionArgument(
118+
value=stalg.Expression(
119+
literal=stalg.Expression.Literal(
120+
i8=1, nullable=True
121+
)
122+
)
123+
),
124+
stalg.FunctionArgument(
125+
value=stalg.Expression(
126+
literal=stalg.Expression.Literal(
127+
i8=2, nullable=True
128+
)
129+
)
130+
),
131+
],
132+
)
133+
),
134+
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
135+
)
136+
),
137+
output_names=["cast"],
138+
)
139+
],
140+
base_schema=named_struct,
141+
)
142+
143+
assert actual == expected

0 commit comments

Comments
 (0)