diff --git a/Makefile b/Makefile index 4e95eb6..670c6eb 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ codegen-extensions: --input third_party/substrait/text/simple_extensions_schema.yaml \ --output src/substrait/gen/json/simple_extensions.py \ --output-model-type dataclasses.dataclass \ + --target-python-version 3.10 \ --disable-timestamp lint: diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index abda416..88ec5fd 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -1,18 +1,19 @@ -from datetime import date import itertools +from datetime import date +from typing import Any, Callable, Iterable, Union + import substrait.gen.proto.algebra_pb2 as stalg -import substrait.gen.proto.type_pb2 as stp import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.extensions.extensions_pb2 as ste +import substrait.gen.proto.type_pb2 as stp from substrait.extension_registry import ExtensionRegistry +from substrait.type_inference import infer_extended_expression_schema from substrait.utils import ( - type_num_names, - merge_extension_urns, - merge_extension_uris, merge_extension_declarations, + merge_extension_uris, + merge_extension_urns, + type_num_names, ) -from substrait.type_inference import infer_extended_expression_schema -from typing import Callable, Any, Union, Iterable UnboundExtendedExpression = Callable[ [stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression @@ -21,7 +22,7 @@ def _alias_or_inferred( - alias: Union[Iterable[str], str], + alias: Union[Iterable[str], str, None], op: str, args: Iterable[str], ): @@ -44,7 +45,7 @@ def resolve_expression( def literal( - value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None + value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None ) -> UnboundExtendedExpression: """Builds a resolver for ExtendedExpression containing a literal expression""" @@ -154,7 +155,7 @@ def resolve( return resolve -def column(field: Union[str, int], alias: Union[Iterable[str], str] = None): +def column(field: Union[str, int], alias: Union[Iterable[str], str, None] = None): """Builds a resolver for ExtendedExpression containing a FieldReference expression Accepts either an index or a field name of a desired field. @@ -208,7 +209,7 @@ def scalar_function( urn: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" @@ -306,7 +307,7 @@ def aggregate_function( urn: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" @@ -402,7 +403,7 @@ def window_function( function: str, expressions: Iterable[ExtendedExpressionOrUnbound], partitions: Iterable[ExtendedExpressionOrUnbound] = [], - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing a WindowFunction expression""" @@ -512,7 +513,7 @@ def resolve( def if_then( ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], _else: ExtendedExpressionOrUnbound, - alias: Union[Iterable[str], str] = None, + alias: Union[Iterable[str], str, None] = None, ): """Builds a resolver for ExtendedExpression containing an IfThen expression""" @@ -767,7 +768,11 @@ def resolve( return resolve -def cast(input: ExtendedExpressionOrUnbound, type: stp.Type): +def cast( + input: ExtendedExpressionOrUnbound, + type: stp.Type, + alias: Union[Iterable[str], str, None] = None, +): """Builds a resolver for ExtendedExpression containing a cast expression""" def resolve( @@ -785,7 +790,9 @@ def resolve( failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=["cast"], # TODO construct name from inputs + output_names=_alias_or_inferred( + alias, "cast", [bound_input.referred_expr[0].output_names[0]] + ), ) ], base_schema=base_schema, diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index 2885bb4..765fbef 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, TypeAlias, Union class Functions(Enum): @@ -13,7 +13,7 @@ class Functions(Enum): SEPARATE = 'SEPARATE' -Type = Union[str, Dict[str, Any]] +Type: TypeAlias = Union[str, Dict[str, Any]] class Type1(Enum): @@ -24,7 +24,7 @@ class Type1(Enum): string = 'string' -EnumOptions = List[str] +EnumOptions: TypeAlias = List[str] @dataclass @@ -49,7 +49,7 @@ class TypeArg: description: Optional[str] = None -Arguments = List[Union[EnumerationArg, ValueArg, TypeArg]] +Arguments: TypeAlias = List[Union[EnumerationArg, ValueArg, TypeArg]] @dataclass @@ -58,7 +58,7 @@ class Options1: description: Optional[str] = None -Options = Dict[str, Options1] +Options: TypeAlias = Dict[str, Options1] class ParameterConsistency(Enum): @@ -73,10 +73,10 @@ class VariadicBehavior: parameterConsistency: Optional[ParameterConsistency] = None -Deterministic = bool +Deterministic: TypeAlias = bool -SessionDependent = bool +SessionDependent: TypeAlias = bool class NullabilityHandling(Enum): @@ -85,13 +85,13 @@ class NullabilityHandling(Enum): DISCRETE = 'DISCRETE' -ReturnValue = Type +ReturnValue: TypeAlias = Type -Implementation = Dict[str, str] +Implementation: TypeAlias = Dict[str, str] -Intermediate = Type +Intermediate: TypeAlias = Type class Decomposable(Enum): @@ -100,10 +100,10 @@ class Decomposable(Enum): MANY = 'MANY' -Maxset = float +Maxset: TypeAlias = float -Ordered = bool +Ordered: TypeAlias = bool @dataclass @@ -196,7 +196,7 @@ class TypeParamDef: optional: Optional[bool] = None -TypeParamDefs = List[TypeParamDef] +TypeParamDefs: TypeAlias = List[TypeParamDef] @dataclass diff --git a/tests/builders/extended_expression/test_cast.py b/tests/builders/extended_expression/test_cast.py index 704f80d..bdad8d1 100644 --- a/tests/builders/extended_expression/test_cast.py +++ b/tests/builders/extended_expression/test_cast.py @@ -1,6 +1,6 @@ import substrait.gen.proto.algebra_pb2 as stalg -import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.type_pb2 as stt from substrait.builders.extended_expression import cast, literal from substrait.builders.type import i8, i16 from substrait.extension_registry import ExtensionRegistry @@ -37,7 +37,7 @@ def test_cast(): failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=["cast"], + output_names=["cast(Literal(3))"], ) ], base_schema=named_struct, @@ -48,6 +48,7 @@ def test_cast(): def test_cast_with_extension(): import yaml + import substrait.gen.proto.extensions.extensions_pb2 as ste from substrait.builders.extended_expression import scalar_function @@ -134,7 +135,7 @@ def test_cast_with_extension(): failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=["cast"], + output_names=["cast(add(Literal(1),Literal(2)))"], ) ], base_schema=named_struct,