1- from datetime import date
21import itertools
2+ from datetime import date
3+ from typing import Any , Callable , Iterable , Union
4+
35import substrait .gen .proto .algebra_pb2 as stalg
4- import substrait .gen .proto .type_pb2 as stp
56import substrait .gen .proto .extended_expression_pb2 as stee
67import substrait .gen .proto .extensions .extensions_pb2 as ste
8+ import substrait .gen .proto .type_pb2 as stp
79from substrait .extension_registry import ExtensionRegistry
10+ from substrait .type_inference import infer_extended_expression_schema
811from substrait .utils import (
9- type_num_names ,
10- merge_extension_urns ,
11- merge_extension_uris ,
1212 merge_extension_declarations ,
13+ merge_extension_uris ,
14+ merge_extension_urns ,
15+ type_num_names ,
1316)
14- from substrait .type_inference import infer_extended_expression_schema
15- from typing import Callable , Any , Union , Iterable
1617
1718UnboundExtendedExpression = Callable [
1819 [stp .NamedStruct , ExtensionRegistry ], stee .ExtendedExpression
2122
2223
2324def _alias_or_inferred (
24- alias : Union [Iterable [str ], str ],
25+ alias : Union [Iterable [str ], str , None ],
2526 op : str ,
2627 args : Iterable [str ],
2728):
@@ -44,7 +45,7 @@ def resolve_expression(
4445
4546
4647def literal (
47- value : Any , type : stp .Type , alias : Union [Iterable [str ], str ] = None
48+ value : Any , type : stp .Type , alias : Union [Iterable [str ], str , None ] = None
4849) -> UnboundExtendedExpression :
4950 """Builds a resolver for ExtendedExpression containing a literal expression"""
5051
@@ -154,7 +155,7 @@ def resolve(
154155 return resolve
155156
156157
157- def column (field : Union [str , int ], alias : Union [Iterable [str ], str ] = None ):
158+ def column (field : Union [str , int ], alias : Union [Iterable [str ], str , None ] = None ):
158159 """Builds a resolver for ExtendedExpression containing a FieldReference expression
159160
160161 Accepts either an index or a field name of a desired field.
@@ -208,7 +209,7 @@ def scalar_function(
208209 urn : str ,
209210 function : str ,
210211 expressions : Iterable [ExtendedExpressionOrUnbound ],
211- alias : Union [Iterable [str ], str ] = None ,
212+ alias : Union [Iterable [str ], str , None ] = None ,
212213):
213214 """Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
214215
@@ -306,7 +307,7 @@ def aggregate_function(
306307 urn : str ,
307308 function : str ,
308309 expressions : Iterable [ExtendedExpressionOrUnbound ],
309- alias : Union [Iterable [str ], str ] = None ,
310+ alias : Union [Iterable [str ], str , None ] = None ,
310311):
311312 """Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
312313
@@ -402,7 +403,7 @@ def window_function(
402403 function : str ,
403404 expressions : Iterable [ExtendedExpressionOrUnbound ],
404405 partitions : Iterable [ExtendedExpressionOrUnbound ] = [],
405- alias : Union [Iterable [str ], str ] = None ,
406+ alias : Union [Iterable [str ], str , None ] = None ,
406407):
407408 """Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
408409
@@ -512,7 +513,7 @@ def resolve(
512513def if_then (
513514 ifs : Iterable [tuple [ExtendedExpressionOrUnbound , ExtendedExpressionOrUnbound ]],
514515 _else : ExtendedExpressionOrUnbound ,
515- alias : Union [Iterable [str ], str ] = None ,
516+ alias : Union [Iterable [str ], str , None ] = None ,
516517):
517518 """Builds a resolver for ExtendedExpression containing an IfThen expression"""
518519
@@ -767,7 +768,11 @@ def resolve(
767768 return resolve
768769
769770
770- def cast (input : ExtendedExpressionOrUnbound , type : stp .Type ):
771+ def cast (
772+ input : ExtendedExpressionOrUnbound ,
773+ type : stp .Type ,
774+ alias : Union [Iterable [str ], str , None ] = None ,
775+ ):
771776 """Builds a resolver for ExtendedExpression containing a cast expression"""
772777
773778 def resolve (
@@ -785,7 +790,9 @@ def resolve(
785790 failure_behavior = stalg .Expression .Cast .FAILURE_BEHAVIOR_RETURN_NULL ,
786791 )
787792 ),
788- output_names = ["cast" ], # TODO construct name from inputs
793+ output_names = _alias_or_inferred (
794+ alias , "cast" , [bound_input .referred_expr [0 ].output_names [0 ]]
795+ ),
789796 )
790797 ],
791798 base_schema = base_schema ,
0 commit comments