Skip to content

Commit a165544

Browse files
committed
fix: ruff and type hints
1 parent 648a31a commit a165544

File tree

4 files changed

+45
-29
lines changed

4 files changed

+45
-29
lines changed

src/substrait/builders/plan.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
See `examples/builder_example.py` for usage.
66
"""
77

8-
from typing import Iterable, Optional, Union, Callable
98

9+
from typing import Callable, Iterable, Optional, TypedDict, Union
1010
import substrait.gen.proto.algebra_pb2 as stalg
1111
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
1212
import substrait.gen.proto.plan_pb2 as stp
@@ -20,16 +20,23 @@
2020
from substrait.type_inference import infer_plan_schema
2121
from substrait.utils import (
2222
merge_extension_declarations,
23-
merge_extension_urns,
2423
merge_extension_uris,
24+
merge_extension_urns,
2525
)
2626

2727
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
2828

2929
PlanOrUnbound = Union[stp.Plan, UnboundPlan]
3030

31+
_ExtensionDict = TypedDict(
32+
"_ExtensionDict",
33+
{"extension_uris": list, "extension_urns": list, "extensions": list},
34+
)
35+
3136

32-
def _merge_extensions(*objs):
37+
def _merge_extensions(
38+
*objs,
39+
) -> _ExtensionDict:
3340
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.
3441
3542
During the URI -> URN migration period, we maintain both URI and URN references

src/substrait/derivation_expression.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional
2-
from antlr4 import InputStream, CommonTokenStream
2+
3+
from antlr4 import CommonTokenStream, InputStream
4+
35
from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
46
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
57
from substrait.gen.proto.type_pb2 import Type
@@ -9,8 +11,9 @@ def _evaluate(x, values: dict):
911
if isinstance(x, SubstraitTypeParser.BinaryExprContext):
1012
left = _evaluate(x.left, values)
1113
right = _evaluate(x.right, values)
12-
13-
if x.op.text == "+":
14+
if x.op is None:
15+
raise Exception("Undefined operator op")
16+
elif x.op.text == "+":
1417
return left + right
1518
elif x.op.text == "-":
1619
return left - right
@@ -121,15 +124,19 @@ def _evaluate(x, values: dict):
121124
nullability=nullability,
122125
)
123126
)
124-
elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampContext):
127+
elif isinstance(
128+
parametrized_type, SubstraitTypeParser.PrecisionTimestampContext
129+
):
125130
precision = _evaluate(parametrized_type.precision, values)
126131
return Type(
127132
precision_timestamp=Type.PrecisionTimestamp(
128133
precision=precision,
129134
nullability=nullability,
130135
)
131136
)
132-
elif isinstance(parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext):
137+
elif isinstance(
138+
parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext
139+
):
133140
precision = _evaluate(parametrized_type.precision, values)
134141
return Type(
135142
precision_timestamp_tz=Type.PrecisionTimestampTZ(
@@ -144,27 +151,29 @@ def _evaluate(x, values: dict):
144151
)
145152
)
146153
elif isinstance(parametrized_type, SubstraitTypeParser.StructContext):
147-
types = list(map(lambda x: _evaluate(x,values),parametrized_type.expr()))
154+
types = list(
155+
map(lambda x: _evaluate(x, values), parametrized_type.expr())
156+
)
148157
return Type(
149158
struct=Type.Struct(
150159
types=types,
151160
nullability=nullability,
152161
)
153162
)
154163
elif isinstance(parametrized_type, SubstraitTypeParser.ListContext):
155-
type = _evaluate(parametrized_type.expr(),values)
164+
child_type = _evaluate(parametrized_type.expr(), values)
156165
return Type(
157166
list=Type.List(
158-
type=type,
167+
type=child_type,
159168
nullability=nullability,
160169
)
161170
)
162171

163172
elif isinstance(parametrized_type, SubstraitTypeParser.MapContext):
164173
return Type(
165174
map=Type.Map(
166-
key=_evaluate(parametrized_type.key,values),
167-
value=_evaluate(parametrized_type.value,values),
175+
key=_evaluate(parametrized_type.key, values),
176+
value=_evaluate(parametrized_type.value, values),
168177
nullability=nullability,
169178
)
170179
)

src/substrait/extension_registry.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
import yaml
21
import itertools
32
import re
4-
from substrait.gen.proto.type_pb2 import Type
5-
from importlib.resources import files as importlib_files
63
from collections import defaultdict
4+
from importlib.resources import files as importlib_files
75
from pathlib import Path
86
from typing import Optional, Union
9-
from .derivation_expression import evaluate, _evaluate, _parse
7+
8+
import yaml
9+
1010
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
1111
from substrait.gen.json import simple_extensions as se
12+
from substrait.gen.proto.type_pb2 import Type
1213
from substrait.simple_extension_utils import build_simple_extensions
13-
from .bimap import UriUrnBiDiMap
1414

15+
from .bimap import UriUrnBiDiMap
16+
from .derivation_expression import _evaluate, _parse, evaluate
1517

1618
DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions"
1719

@@ -290,7 +292,7 @@ def __init__(
290292
def __repr__(self) -> str:
291293
return f"{self.name}:{'_'.join(self.normalized_inputs)}"
292294

293-
def satisfies_signature(self, signature: tuple) -> Optional[str]:
295+
def satisfies_signature(self, signature: tuple | list) -> Optional[str]:
294296
if self.impl.variadic:
295297
min_args_allowed = self.impl.variadic.min or 0
296298
if len(signature) < min_args_allowed:
@@ -322,14 +324,12 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
322324
output_type = evaluate(self.impl.return_, parameters)
323325

324326
if self.nullability == se.NullabilityHandling.MIRROR:
325-
sig_contains_nullable = any(
326-
[
327-
p.__getattribute__(p.WhichOneof("kind")).nullability
328-
== Type.NULLABILITY_NULLABLE
329-
for p in signature
330-
if isinstance(p, Type)
331-
]
332-
)
327+
sig_contains_nullable = any([
328+
p.__getattribute__(p.WhichOneof("kind")).nullability
329+
== Type.NULLABILITY_NULLABLE
330+
for p in signature
331+
if isinstance(p, Type)
332+
])
333333
output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = (
334334
Type.NULLABILITY_NULLABLE
335335
if sig_contains_nullable
@@ -417,7 +417,7 @@ def register_extension_dict(self, definitions: dict, uri: str) -> None:
417417

418418
# TODO add an optional return type check
419419
def lookup_function(
420-
self, urn: str, function_name: str, signature: tuple
420+
self, urn: str, function_name: str, signature: tuple[Type] | list[Type]
421421
) -> Optional[tuple[FunctionEntry, Type]]:
422422
if (
423423
urn not in self._function_mapping

tests/test_extension_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,4 @@ def test_register_requires_uri():
572572

573573
# During migration, URI is required - this should fail with TypeError
574574
with pytest.raises(TypeError):
575-
registry.register_extension_dict(yaml.safe_load(content))
575+
registry.register_extension_dict(yaml.safe_load(content))

0 commit comments

Comments
 (0)