Skip to content

Commit fa12088

Browse files
authored
feat(cast): add alias for cast expression (#127)
Added alias for the cast expression builder. Includes test coverage for cast alias functionality.
1 parent 1354c46 commit fa12088

File tree

4 files changed

+41
-32
lines changed

4 files changed

+41
-32
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ codegen-extensions:
1010
--input third_party/substrait/text/simple_extensions_schema.yaml \
1111
--output src/substrait/gen/json/simple_extensions.py \
1212
--output-model-type dataclasses.dataclass \
13+
--target-python-version 3.10 \
1314
--disable-timestamp
1415

1516
lint:

src/substrait/builders/extended_expression.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from datetime import date
21
import itertools
2+
from datetime import date
3+
from typing import Any, Callable, Iterable, Union
4+
35
import substrait.gen.proto.algebra_pb2 as stalg
4-
import substrait.gen.proto.type_pb2 as stp
56
import substrait.gen.proto.extended_expression_pb2 as stee
67
import substrait.gen.proto.extensions.extensions_pb2 as ste
8+
import substrait.gen.proto.type_pb2 as stp
79
from substrait.extension_registry import ExtensionRegistry
10+
from substrait.type_inference import infer_extended_expression_schema
811
from substrait.utils import (
9-
type_num_names,
10-
merge_extension_urns,
11-
merge_extension_uris,
1212
merge_extension_declarations,
13+
merge_extension_uris,
14+
merge_extension_urns,
15+
type_num_names,
1316
)
14-
from substrait.type_inference import infer_extended_expression_schema
15-
from typing import Callable, Any, Union, Iterable
1617

1718
UnboundExtendedExpression = Callable[
1819
[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression
@@ -21,7 +22,7 @@
2122

2223

2324
def _alias_or_inferred(
24-
alias: Union[Iterable[str], str],
25+
alias: Union[Iterable[str], str, None],
2526
op: str,
2627
args: Iterable[str],
2728
):
@@ -44,7 +45,7 @@ def resolve_expression(
4445

4546

4647
def literal(
47-
value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None
48+
value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None
4849
) -> UnboundExtendedExpression:
4950
"""Builds a resolver for ExtendedExpression containing a literal expression"""
5051

@@ -154,7 +155,7 @@ def resolve(
154155
return resolve
155156

156157

157-
def column(field: Union[str, int], alias: Union[Iterable[str], str] = None):
158+
def column(field: Union[str, int], alias: Union[Iterable[str], str, None] = None):
158159
"""Builds a resolver for ExtendedExpression containing a FieldReference expression
159160
160161
Accepts either an index or a field name of a desired field.
@@ -208,7 +209,7 @@ def scalar_function(
208209
urn: str,
209210
function: str,
210211
expressions: Iterable[ExtendedExpressionOrUnbound],
211-
alias: Union[Iterable[str], str] = None,
212+
alias: Union[Iterable[str], str, None] = None,
212213
):
213214
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
214215

@@ -306,7 +307,7 @@ def aggregate_function(
306307
urn: str,
307308
function: str,
308309
expressions: Iterable[ExtendedExpressionOrUnbound],
309-
alias: Union[Iterable[str], str] = None,
310+
alias: Union[Iterable[str], str, None] = None,
310311
):
311312
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
312313

@@ -402,7 +403,7 @@ def window_function(
402403
function: str,
403404
expressions: Iterable[ExtendedExpressionOrUnbound],
404405
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
405-
alias: Union[Iterable[str], str] = None,
406+
alias: Union[Iterable[str], str, None] = None,
406407
):
407408
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
408409

@@ -512,7 +513,7 @@ def resolve(
512513
def if_then(
513514
ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]],
514515
_else: ExtendedExpressionOrUnbound,
515-
alias: Union[Iterable[str], str] = None,
516+
alias: Union[Iterable[str], str, None] = None,
516517
):
517518
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
518519

@@ -767,7 +768,11 @@ def resolve(
767768
return resolve
768769

769770

770-
def cast(input: ExtendedExpressionOrUnbound, type: stp.Type):
771+
def cast(
772+
input: ExtendedExpressionOrUnbound,
773+
type: stp.Type,
774+
alias: Union[Iterable[str], str, None] = None,
775+
):
771776
"""Builds a resolver for ExtendedExpression containing a cast expression"""
772777

773778
def resolve(
@@ -785,7 +790,9 @@ def resolve(
785790
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
786791
)
787792
),
788-
output_names=["cast"], # TODO construct name from inputs
793+
output_names=_alias_or_inferred(
794+
alias, "cast", [bound_input.referred_expr[0].output_names[0]]
795+
),
789796
)
790797
],
791798
base_schema=base_schema,

src/substrait/gen/json/simple_extensions.py

Lines changed: 13 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/builders/extended_expression/test_cast.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
2-
import substrait.gen.proto.type_pb2 as stt
32
import substrait.gen.proto.extended_expression_pb2 as stee
3+
import substrait.gen.proto.type_pb2 as stt
44
from substrait.builders.extended_expression import cast, literal
55
from substrait.builders.type import i8, i16
66
from substrait.extension_registry import ExtensionRegistry
@@ -37,7 +37,7 @@ def test_cast():
3737
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
3838
)
3939
),
40-
output_names=["cast"],
40+
output_names=["cast(Literal(3))"],
4141
)
4242
],
4343
base_schema=named_struct,
@@ -48,6 +48,7 @@ def test_cast():
4848

4949
def test_cast_with_extension():
5050
import yaml
51+
5152
import substrait.gen.proto.extensions.extensions_pb2 as ste
5253
from substrait.builders.extended_expression import scalar_function
5354

@@ -134,7 +135,7 @@ def test_cast_with_extension():
134135
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
135136
)
136137
),
137-
output_names=["cast"],
138+
output_names=["cast(add(Literal(1),Literal(2)))"],
138139
)
139140
],
140141
base_schema=named_struct,

0 commit comments

Comments
 (0)