55See `examples/builder_example.py` for usage.
66"""
77
8- from typing import Iterable , Union , Callable
8+ from typing import Iterable , Optional , Union , Callable
99
1010import substrait .gen .proto .algebra_pb2 as stalg
11+ from substrait .gen .proto .extensions .extensions_pb2 import AdvancedExtension
1112import substrait .gen .proto .plan_pb2 as stp
1213import substrait .gen .proto .type_pb2 as stt
1314import substrait .gen .proto .extended_expression_pb2 as stee
@@ -32,7 +33,9 @@ def _merge_extensions(*objs):
3233
3334
3435def read_named_table (
35- names : Union [str , Iterable [str ]], named_struct : stt .NamedStruct
36+ names : Union [str , Iterable [str ]],
37+ named_struct : stt .NamedStruct ,
38+ extension : Optional [AdvancedExtension ] = None ,
3639) -> UnboundPlan :
3740 if named_struct .struct .nullability is stt .Type .NULLABILITY_NULLABLE :
3841 raise Exception ("NamedStruct must not contain a nullable struct" )
@@ -47,6 +50,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
4750 common = stalg .RelCommon (direct = stalg .RelCommon .Direct ()),
4851 base_schema = named_struct ,
4952 named_table = stalg .ReadRel .NamedTable (names = _names ),
53+ advanced_extension = extension ,
5054 )
5155 )
5256
@@ -60,7 +64,9 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
6064
6165
6266def project (
63- plan : PlanOrUnbound , expressions : Iterable [ExtendedExpressionOrUnbound ]
67+ plan : PlanOrUnbound ,
68+ expressions : Iterable [ExtendedExpressionOrUnbound ],
69+ extension : Optional [AdvancedExtension ] = None ,
6470) -> UnboundPlan :
6571 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
6672 _plan = plan if isinstance (plan , stp .Plan ) else plan (registry )
@@ -86,6 +92,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
8692 expressions = [
8793 e .expression for ee in bound_expressions for e in ee .referred_expr
8894 ],
95+ advanced_extension = extension ,
8996 )
9097 )
9198
@@ -97,7 +104,11 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
97104 return resolve
98105
99106
100- def filter (plan : PlanOrUnbound , expression : ExtendedExpressionOrUnbound ) -> UnboundPlan :
107+ def filter (
108+ plan : PlanOrUnbound ,
109+ expression : ExtendedExpressionOrUnbound ,
110+ extension : Optional [AdvancedExtension ] = None ,
111+ ) -> UnboundPlan :
101112 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
102113 bound_plan = plan if isinstance (plan , stp .Plan ) else plan (registry )
103114 ns = infer_plan_schema (bound_plan )
@@ -109,6 +120,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
109120 filter = stalg .FilterRel (
110121 input = bound_plan .relations [- 1 ].root .input ,
111122 condition = bound_expression .referred_expr [0 ].expression ,
123+ advanced_extension = extension ,
112124 )
113125 )
114126
@@ -130,6 +142,7 @@ def sort(
130142 tuple [ExtendedExpressionOrUnbound , stalg .SortField .SortDirection .ValueType ],
131143 ]
132144 ],
145+ extension : Optional [AdvancedExtension ] = None ,
133146) -> UnboundPlan :
134147 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
135148 bound_plan = plan if isinstance (plan , stp .Plan ) else plan (registry )
@@ -155,7 +168,8 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
155168 )
156169 for e in bound_expressions
157170 ],
158- )
171+ advanced_extension = extension ,
172+ ),
159173 )
160174
161175 return stp .Plan (
@@ -193,6 +207,7 @@ def fetch(
193207 plan : PlanOrUnbound ,
194208 offset : ExtendedExpressionOrUnbound ,
195209 count : ExtendedExpressionOrUnbound ,
210+ extension : Optional [AdvancedExtension ] = None ,
196211) -> UnboundPlan :
197212 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
198213 bound_plan = plan if isinstance (plan , stp .Plan ) else plan (registry )
@@ -208,6 +223,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
208223 if bound_offset
209224 else None ,
210225 count_expr = bound_count .referred_expr [0 ].expression ,
226+ advanced_extension = extension ,
211227 )
212228 )
213229
@@ -230,6 +246,7 @@ def join(
230246 right : PlanOrUnbound ,
231247 expression : ExtendedExpressionOrUnbound ,
232248 type : stalg .JoinRel .JoinType ,
249+ extension : Optional [AdvancedExtension ] = None ,
233250) -> UnboundPlan :
234251 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
235252 bound_left = left if isinstance (left , stp .Plan ) else left (registry )
@@ -254,6 +271,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
254271 right = bound_right .relations [- 1 ].root .input ,
255272 expression = bound_expression .referred_expr [0 ].expression ,
256273 type = type ,
274+ advanced_extension = extension ,
257275 )
258276 )
259277
@@ -268,6 +286,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
268286def cross (
269287 left : PlanOrUnbound ,
270288 right : PlanOrUnbound ,
289+ extension : Optional [AdvancedExtension ] = None ,
271290) -> UnboundPlan :
272291 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
273292 bound_left = left if isinstance (left , stp .Plan ) else left (registry )
@@ -287,6 +306,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
287306 cross = stalg .CrossRel (
288307 left = bound_left .relations [- 1 ].root .input ,
289308 right = bound_right .relations [- 1 ].root .input ,
309+ advanced_extension = extension ,
290310 )
291311 )
292312
@@ -303,6 +323,7 @@ def aggregate(
303323 input : PlanOrUnbound ,
304324 grouping_expressions : Iterable [ExtendedExpressionOrUnbound ],
305325 measures : Iterable [ExtendedExpressionOrUnbound ],
326+ extension : Optional [AdvancedExtension ] = None ,
306327) -> UnboundPlan :
307328 def resolve (registry : ExtensionRegistry ) -> stp .Plan :
308329 bound_input = input if isinstance (input , stp .Plan ) else input (registry )
@@ -332,6 +353,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
332353 stalg .AggregateRel .Measure (measure = m .referred_expr [0 ].measure )
333354 for m in bound_measures
334355 ],
356+ advanced_extension = extension ,
335357 )
336358 )
337359
0 commit comments