Skip to content

Commit fadc414

Browse files
authored
feat: enable providing advanced_extensions for relations (#105)
Adding the ability to pass in an optional 'advanced extension' to the relation builders. Signed-off-by: MBWhite <[email protected]>
1 parent c909f06 commit fadc414

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed

src/substrait/builders/plan.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
See `examples/builder_example.py` for usage.
66
"""
77

8-
from typing import Iterable, Union, Callable
8+
from typing import Iterable, Optional, Union, Callable
99

1010
import substrait.gen.proto.algebra_pb2 as stalg
11+
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
1112
import substrait.gen.proto.plan_pb2 as stp
1213
import substrait.gen.proto.type_pb2 as stt
1314
import substrait.gen.proto.extended_expression_pb2 as stee
@@ -32,7 +33,9 @@ def _merge_extensions(*objs):
3233

3334

3435
def read_named_table(
35-
names: Union[str, Iterable[str]], named_struct: stt.NamedStruct
36+
names: Union[str, Iterable[str]],
37+
named_struct: stt.NamedStruct,
38+
extension: Optional[AdvancedExtension] = None,
3639
) -> UnboundPlan:
3740
if named_struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE:
3841
raise Exception("NamedStruct must not contain a nullable struct")
@@ -47,6 +50,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
4750
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
4851
base_schema=named_struct,
4952
named_table=stalg.ReadRel.NamedTable(names=_names),
53+
advanced_extension=extension,
5054
)
5155
)
5256

@@ -60,7 +64,9 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
6064

6165

6266
def project(
63-
plan: PlanOrUnbound, expressions: Iterable[ExtendedExpressionOrUnbound]
67+
plan: PlanOrUnbound,
68+
expressions: Iterable[ExtendedExpressionOrUnbound],
69+
extension: Optional[AdvancedExtension] = None,
6470
) -> UnboundPlan:
6571
def resolve(registry: ExtensionRegistry) -> stp.Plan:
6672
_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
@@ -86,6 +92,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
8692
expressions=[
8793
e.expression for ee in bound_expressions for e in ee.referred_expr
8894
],
95+
advanced_extension=extension,
8996
)
9097
)
9198

@@ -97,7 +104,11 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
97104
return resolve
98105

99106

100-
def filter(plan: PlanOrUnbound, expression: ExtendedExpressionOrUnbound) -> UnboundPlan:
107+
def filter(
108+
plan: PlanOrUnbound,
109+
expression: ExtendedExpressionOrUnbound,
110+
extension: Optional[AdvancedExtension] = None,
111+
) -> UnboundPlan:
101112
def resolve(registry: ExtensionRegistry) -> stp.Plan:
102113
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
103114
ns = infer_plan_schema(bound_plan)
@@ -109,6 +120,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
109120
filter=stalg.FilterRel(
110121
input=bound_plan.relations[-1].root.input,
111122
condition=bound_expression.referred_expr[0].expression,
123+
advanced_extension=extension,
112124
)
113125
)
114126

@@ -130,6 +142,7 @@ def sort(
130142
tuple[ExtendedExpressionOrUnbound, stalg.SortField.SortDirection.ValueType],
131143
]
132144
],
145+
extension: Optional[AdvancedExtension] = None,
133146
) -> UnboundPlan:
134147
def resolve(registry: ExtensionRegistry) -> stp.Plan:
135148
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
@@ -155,7 +168,8 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
155168
)
156169
for e in bound_expressions
157170
],
158-
)
171+
advanced_extension=extension,
172+
),
159173
)
160174

161175
return stp.Plan(
@@ -193,6 +207,7 @@ def fetch(
193207
plan: PlanOrUnbound,
194208
offset: ExtendedExpressionOrUnbound,
195209
count: ExtendedExpressionOrUnbound,
210+
extension: Optional[AdvancedExtension] = None,
196211
) -> UnboundPlan:
197212
def resolve(registry: ExtensionRegistry) -> stp.Plan:
198213
bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
@@ -208,6 +223,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
208223
if bound_offset
209224
else None,
210225
count_expr=bound_count.referred_expr[0].expression,
226+
advanced_extension=extension,
211227
)
212228
)
213229

@@ -230,6 +246,7 @@ def join(
230246
right: PlanOrUnbound,
231247
expression: ExtendedExpressionOrUnbound,
232248
type: stalg.JoinRel.JoinType,
249+
extension: Optional[AdvancedExtension] = None,
233250
) -> UnboundPlan:
234251
def resolve(registry: ExtensionRegistry) -> stp.Plan:
235252
bound_left = left if isinstance(left, stp.Plan) else left(registry)
@@ -254,6 +271,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
254271
right=bound_right.relations[-1].root.input,
255272
expression=bound_expression.referred_expr[0].expression,
256273
type=type,
274+
advanced_extension=extension,
257275
)
258276
)
259277

@@ -268,6 +286,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
268286
def cross(
269287
left: PlanOrUnbound,
270288
right: PlanOrUnbound,
289+
extension: Optional[AdvancedExtension] = None,
271290
) -> UnboundPlan:
272291
def resolve(registry: ExtensionRegistry) -> stp.Plan:
273292
bound_left = left if isinstance(left, stp.Plan) else left(registry)
@@ -287,6 +306,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
287306
cross=stalg.CrossRel(
288307
left=bound_left.relations[-1].root.input,
289308
right=bound_right.relations[-1].root.input,
309+
advanced_extension=extension,
290310
)
291311
)
292312

@@ -303,6 +323,7 @@ def aggregate(
303323
input: PlanOrUnbound,
304324
grouping_expressions: Iterable[ExtendedExpressionOrUnbound],
305325
measures: Iterable[ExtendedExpressionOrUnbound],
326+
extension: Optional[AdvancedExtension] = None,
306327
) -> UnboundPlan:
307328
def resolve(registry: ExtensionRegistry) -> stp.Plan:
308329
bound_input = input if isinstance(input, stp.Plan) else input(registry)
@@ -332,6 +353,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
332353
stalg.AggregateRel.Measure(measure=m.referred_expr[0].measure)
333354
for m in bound_measures
334355
],
356+
advanced_extension=extension,
335357
)
336358
)
337359

tests/builders/plan/test_read.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from substrait.builders.type import boolean, i64
55
from substrait.builders.plan import read_named_table
66
import pytest
7+
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
8+
from google.protobuf import any
9+
from google.protobuf.wrappers_pb2 import StringValue
710

811
struct = stt.Type.Struct(
912
types=[i64(nullable=False), boolean()],
@@ -74,3 +77,33 @@ def test_read_rel_schema_nullable():
7477
Exception, match=r"NamedStruct must not contain a nullable struct"
7578
):
7679
read_named_table("example_table", named_struct)(None)
80+
81+
82+
def test_read_rel_ae():
83+
extension = AdvancedExtension(optimization=[any.pack(StringValue(value="Opt1"))])
84+
85+
actual = read_named_table(["example_db", "example_table"], named_struct, extension)(
86+
None
87+
)
88+
89+
expected = stp.Plan(
90+
relations=[
91+
stp.PlanRel(
92+
root=stalg.RelRoot(
93+
input=stalg.Rel(
94+
read=stalg.ReadRel(
95+
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
96+
base_schema=named_struct,
97+
named_table=stalg.ReadRel.NamedTable(
98+
names=["example_db", "example_table"]
99+
),
100+
advanced_extension=extension,
101+
)
102+
),
103+
names=["id", "is_applicable"],
104+
)
105+
)
106+
]
107+
)
108+
109+
assert actual == expected

0 commit comments

Comments
 (0)