Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions src/substrait/builders/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
See `examples/builder_example.py` for usage.
"""

from typing import Iterable, Optional, Union, Callable
from typing import Callable, Iterable, Optional, Union

import substrait.gen.proto.algebra_pb2 as stalg
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
import substrait.gen.proto.extended_expression_pb2 as stee
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.type_pb2 as stt
import substrait.gen.proto.extended_expression_pb2 as stee
from substrait.extension_registry import ExtensionRegistry
from substrait.builders.extended_expression import (
ExtendedExpressionOrUnbound,
resolve_expression,
)
from substrait.extension_registry import ExtensionRegistry
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
from substrait.type_inference import infer_plan_schema
from substrait.utils import (
merge_extension_declarations,
merge_extension_urns,
merge_extension_uris,
merge_extension_urns,
)

UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]
Expand Down Expand Up @@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return resolve


def write_named_table(
table_names: Union[str, Iterable[str]],
input: PlanOrUnbound,
create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None,
) -> UnboundPlan:
def resolve(registry: ExtensionRegistry) -> stp.Plan:
bound_input = input if isinstance(input, stp.Plan) else input(registry)
ns = infer_plan_schema(bound_input)
_table_names = [table_names] if isinstance(table_names, str) else table_names
_create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS

write_rel = stalg.Rel(
write=stalg.WriteRel(
input=bound_input.relations[-1].root.input,
table_schema=ns,
op=stalg.WriteRel.WRITE_OP_CTAS,
create_mode=_create_mode,
named_table=stalg.NamedObjectWrite(names=_table_names),
)
)
return stp.Plan(
relations=[
stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names))
],
**_merge_extensions(bound_input),
)

return resolve
48 changes: 48 additions & 0 deletions tests/builders/plan/test_write.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.type_pb2 as stt
from substrait.builders.plan import read_named_table, write_named_table
from substrait.builders.type import boolean, i64

struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)


def test_write_rel():
actual = write_named_table(
"example_table_write_test",
read_named_table("example_table", named_struct),
)(None)

expected = stp.Plan(
relations=[
stp.PlanRel(
root=stalg.RelRoot(
input=stalg.Rel(
write=stalg.WriteRel(
input=stalg.Rel(
read=stalg.ReadRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct()
),
base_schema=named_struct,
named_table=stalg.ReadRel.NamedTable(
names=["example_table"]
),
)
),
op=stalg.WriteRel.WRITE_OP_CTAS,
table_schema=named_struct,
create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS,
named_table=stalg.NamedObjectWrite(
names=["example_table_write_test"]
),
)
),
names=["id", "is_applicable"],
)
)
]
)
assert actual == expected