Skip to content

Commit 9a13e6e

Browse files
committed
readd extension_uris to extended expression + add test
1 parent 66a11f1 commit 9a13e6e

File tree

2 files changed

+149
-1
lines changed

2 files changed

+149
-1
lines changed

src/substrait/builders/extended_expression.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,12 @@ def resolve(
524524

525525
bound_else = resolve_expression(_else, base_schema, registry)
526526

527+
extension_uris = merge_extension_uris(
528+
*[b[0].extension_uris for b in bound_ifs],
529+
*[b[1].extension_uris for b in bound_ifs],
530+
bound_else.extension_uris,
531+
)
532+
527533
extension_urns = merge_extension_urns(
528534
*[b[0].extension_urns for b in bound_ifs],
529535
*[b[1].extension_urns for b in bound_ifs],
@@ -575,6 +581,7 @@ def resolve(
575581
)
576582
],
577583
base_schema=base_schema,
584+
extension_uris=extension_uris,
578585
extension_urns=extension_urns,
579586
extensions=extensions,
580587
)

tests/builders/extended_expression/test_if_then.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
22
import substrait.gen.proto.type_pb2 as stt
33
import substrait.gen.proto.extended_expression_pb2 as stee
4-
from substrait.builders.extended_expression import if_then, literal
4+
import substrait.gen.proto.extensions.extensions_pb2 as ste
5+
from substrait.builders.extended_expression import if_then, literal, scalar_function, column
6+
from substrait.extension_registry import ExtensionRegistry
57

68

79
struct = stt.Type.Struct(
@@ -75,3 +77,142 @@ def test_if_else():
7577
)
7678

7779
assert actual == expected
80+
81+
def test_if_then_with_extension():
82+
"""Test if_then with scalar function to verify both URI and URN are present."""
83+
import yaml
84+
85+
# Use a minimal registry with only the comparison extension
86+
registry = ExtensionRegistry(load_default_extensions=False)
87+
content = """%YAML 1.2
88+
---
89+
urn: extension:io.substrait:functions_comparison
90+
scalar_functions:
91+
- name: "gt"
92+
description: ""
93+
impls:
94+
- args:
95+
- value: fp32
96+
- value: fp32
97+
return: boolean
98+
"""
99+
registry.register_extension_dict(
100+
yaml.safe_load(content),
101+
uri="https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml"
102+
)
103+
104+
# Create if_then: if order_total > 100 then "expensive" else "cheap"
105+
actual = if_then(
106+
ifs=[
107+
(
108+
scalar_function(
109+
"extension:io.substrait:functions_comparison",
110+
"gt",
111+
expressions=[
112+
column("order_total"),
113+
literal(100.0, type=stt.Type(
114+
fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_REQUIRED)
115+
)),
116+
],
117+
),
118+
literal(
119+
"expensive",
120+
type=stt.Type(
121+
string=stt.Type.String(nullability=stt.Type.NULLABILITY_REQUIRED)
122+
),
123+
),
124+
)
125+
],
126+
_else=literal(
127+
"cheap",
128+
type=stt.Type(
129+
string=stt.Type.String(nullability=stt.Type.NULLABILITY_REQUIRED)
130+
),
131+
),
132+
)(named_struct, registry)
133+
134+
expected = stee.ExtendedExpression(
135+
extension_uris=[
136+
ste.SimpleExtensionURI(
137+
extension_uri_anchor=1,
138+
uri="https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml"
139+
)
140+
],
141+
extension_urns=[
142+
ste.SimpleExtensionURN(
143+
extension_urn_anchor=1,
144+
urn="extension:io.substrait:functions_comparison"
145+
)
146+
],
147+
extensions=[
148+
ste.SimpleExtensionDeclaration(
149+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
150+
extension_uri_reference=1,
151+
extension_urn_reference=1,
152+
function_anchor=1,
153+
name="gt:fp32_fp32"
154+
)
155+
)
156+
],
157+
referred_expr=[
158+
stee.ExpressionReference(
159+
expression=stalg.Expression(
160+
if_then=stalg.Expression.IfThen(
161+
**{
162+
"ifs": [
163+
stalg.Expression.IfThen.IfClause(
164+
**{
165+
"if": stalg.Expression(
166+
scalar_function=stalg.Expression.ScalarFunction(
167+
function_reference=1,
168+
output_type=stt.Type(
169+
bool=stt.Type.Boolean(
170+
nullability=stt.Type.NULLABILITY_NULLABLE
171+
)
172+
),
173+
arguments=[
174+
stalg.FunctionArgument(
175+
value=stalg.Expression(
176+
selection=stalg.Expression.FieldReference(
177+
direct_reference=stalg.Expression.ReferenceSegment(
178+
struct_field=stalg.Expression.ReferenceSegment.StructField(
179+
field=2
180+
)
181+
),
182+
root_reference=stalg.Expression.FieldReference.RootReference()
183+
)
184+
)
185+
),
186+
stalg.FunctionArgument(
187+
value=stalg.Expression(
188+
literal=stalg.Expression.Literal(
189+
fp32=100.0
190+
)
191+
)
192+
),
193+
]
194+
)
195+
),
196+
"then": stalg.Expression(
197+
literal=stalg.Expression.Literal(
198+
string="expensive"
199+
)
200+
),
201+
}
202+
)
203+
],
204+
"else": stalg.Expression(
205+
literal=stalg.Expression.Literal(
206+
string="cheap"
207+
)
208+
),
209+
}
210+
)
211+
),
212+
output_names=["IfThen(gt(order_total,Literal(100.0)),Literal(expensive),Literal(cheap))"],
213+
)
214+
],
215+
base_schema=named_struct,
216+
)
217+
218+
assert actual == expected

0 commit comments

Comments
 (0)