Skip to content

Commit ddede68

Browse files
giospadaAGV
andauthored
feat(write): add write table builder and tests (#129)
Added WriteTable builder for creating write operations in Substrait plans. Added test for write table operations. --------- Signed-off-by: AGV <[email protected]> Co-authored-by: AGV <[email protected]>
1 parent ee0cb3f commit ddede68

File tree

2 files changed

+83
-5
lines changed

2 files changed

+83
-5
lines changed

src/substrait/builders/plan.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55
See `examples/builder_example.py` for usage.
66
"""
77

8-
from typing import Iterable, Optional, Union, Callable
8+
from typing import Callable, Iterable, Optional, Union
99

1010
import substrait.gen.proto.algebra_pb2 as stalg
11-
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
11+
import substrait.gen.proto.extended_expression_pb2 as stee
1212
import substrait.gen.proto.plan_pb2 as stp
1313
import substrait.gen.proto.type_pb2 as stt
14-
import substrait.gen.proto.extended_expression_pb2 as stee
15-
from substrait.extension_registry import ExtensionRegistry
1614
from substrait.builders.extended_expression import (
1715
ExtendedExpressionOrUnbound,
1816
resolve_expression,
1917
)
18+
from substrait.extension_registry import ExtensionRegistry
19+
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
2020
from substrait.type_inference import infer_plan_schema
2121
from substrait.utils import (
2222
merge_extension_declarations,
23-
merge_extension_urns,
2423
merge_extension_uris,
24+
merge_extension_urns,
2525
)
2626

2727
UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
@@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
379379
)
380380

381381
return resolve
382+
383+
384+
def write_named_table(
385+
table_names: Union[str, Iterable[str]],
386+
input: PlanOrUnbound,
387+
create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None,
388+
) -> UnboundPlan:
389+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
390+
bound_input = input if isinstance(input, stp.Plan) else input(registry)
391+
ns = infer_plan_schema(bound_input)
392+
_table_names = [table_names] if isinstance(table_names, str) else table_names
393+
_create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS
394+
395+
write_rel = stalg.Rel(
396+
write=stalg.WriteRel(
397+
input=bound_input.relations[-1].root.input,
398+
table_schema=ns,
399+
op=stalg.WriteRel.WRITE_OP_CTAS,
400+
create_mode=_create_mode,
401+
named_table=stalg.NamedObjectWrite(names=_table_names),
402+
)
403+
)
404+
return stp.Plan(
405+
relations=[
406+
stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names))
407+
],
408+
**_merge_extensions(bound_input),
409+
)
410+
411+
return resolve

tests/builders/plan/test_write.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.plan_pb2 as stp
3+
import substrait.gen.proto.type_pb2 as stt
4+
from substrait.builders.plan import read_named_table, write_named_table
5+
from substrait.builders.type import boolean, i64
6+
7+
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
8+
9+
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
10+
11+
12+
def test_write_rel():
13+
actual = write_named_table(
14+
"example_table_write_test",
15+
read_named_table("example_table", named_struct),
16+
)(None)
17+
18+
expected = stp.Plan(
19+
relations=[
20+
stp.PlanRel(
21+
root=stalg.RelRoot(
22+
input=stalg.Rel(
23+
write=stalg.WriteRel(
24+
input=stalg.Rel(
25+
read=stalg.ReadRel(
26+
common=stalg.RelCommon(
27+
direct=stalg.RelCommon.Direct()
28+
),
29+
base_schema=named_struct,
30+
named_table=stalg.ReadRel.NamedTable(
31+
names=["example_table"]
32+
),
33+
)
34+
),
35+
op=stalg.WriteRel.WRITE_OP_CTAS,
36+
table_schema=named_struct,
37+
create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS,
38+
named_table=stalg.NamedObjectWrite(
39+
names=["example_table_write_test"]
40+
),
41+
)
42+
),
43+
names=["id", "is_applicable"],
44+
)
45+
)
46+
]
47+
)
48+
assert actual == expected

0 commit comments

Comments
 (0)