Skip to content

Commit 921cf80

Browse files
AGVgiospada
authored andcommitted
feat(write_table): added the write table builder and test
Signed-off-by: AGV <[email protected]>
1 parent fa12088 commit 921cf80

File tree

2 files changed

+81
-3
lines changed

2 files changed

+81
-3
lines changed

src/substrait/builders/plan.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
from typing import Iterable, Optional, Union, Callable
99

10-
import substrait.gen.proto.algebra_pb2 as stalg
1110
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
11+
import substrait.gen.proto.algebra_pb2 as stalg
12+
import substrait.gen.proto.extended_expression_pb2 as stee
1213
import substrait.gen.proto.plan_pb2 as stp
1314
import substrait.gen.proto.type_pb2 as stt
14-
import substrait.gen.proto.extended_expression_pb2 as stee
15-
from substrait.extension_registry import ExtensionRegistry
1615
from substrait.builders.extended_expression import (
1716
ExtendedExpressionOrUnbound,
1817
resolve_expression,
1918
)
19+
from substrait.extension_registry import ExtensionRegistry
2020
from substrait.type_inference import infer_plan_schema
2121
from substrait.utils import (
2222
merge_extension_declarations,
@@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
379379
)
380380

381381
return resolve
382+
383+
384+
def write_table(
385+
table_names: Union[str, Iterable[str]],
386+
input: PlanOrUnbound,
387+
create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None,
388+
) -> UnboundPlan:
389+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
390+
bound_input = input if isinstance(input, stp.Plan) else input(registry)
391+
ns = infer_plan_schema(bound_input)
392+
_table_names = [table_names] if isinstance(table_names, str) else table_names
393+
_create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS
394+
395+
write_rel = stalg.Rel(
396+
write=stalg.WriteRel(
397+
input=bound_input.relations[-1].root.input,
398+
table_schema=ns,
399+
op=stalg.WriteRel.WRITE_OP_CTAS,
400+
create_mode=_create_mode,
401+
named_table=stalg.NamedObjectWrite(names=_table_names),
402+
)
403+
)
404+
return stp.Plan(
405+
relations=[
406+
stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names))
407+
],
408+
**_merge_extensions(bound_input),
409+
)
410+
411+
return resolve

tests/builders/plan/test_write.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.plan_pb2 as stp
3+
import substrait.gen.proto.type_pb2 as stt
4+
from substrait.builders.plan import read_named_table, write_table
5+
from substrait.builders.type import boolean, i64
6+
7+
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
8+
9+
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
10+
11+
12+
def test_write_rel():
13+
actual = write_table(
14+
"example_table_write_test",
15+
read_named_table("example_table", named_struct),
16+
)(None)
17+
18+
expected = stp.Plan(
19+
relations=[
20+
stp.PlanRel(
21+
root=stalg.RelRoot(
22+
input=stalg.Rel(
23+
write=stalg.WriteRel(
24+
input=stalg.Rel(
25+
read=stalg.ReadRel(
26+
common=stalg.RelCommon(
27+
direct=stalg.RelCommon.Direct()
28+
),
29+
base_schema=named_struct,
30+
named_table=stalg.ReadRel.NamedTable(
31+
names=["example_table"]
32+
),
33+
)
34+
),
35+
op=stalg.WriteRel.WRITE_OP_CTAS,
36+
table_schema=named_struct,
37+
create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS,
38+
named_table=stalg.NamedObjectWrite(
39+
names=["example_table_write_test"]
40+
),
41+
)
42+
),
43+
names=["id", "is_applicable"],
44+
)
45+
)
46+
]
47+
)
48+
assert actual == expected

0 commit comments

Comments
 (0)