diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index a4a2180..392960f 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -5,23 +5,23 @@ See `examples/builder_example.py` for usage. """ -from typing import Iterable, Optional, Union, Callable +from typing import Callable, Iterable, Optional, Union import substrait.gen.proto.algebra_pb2 as stalg -from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension +import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.plan_pb2 as stp import substrait.gen.proto.type_pb2 as stt -import substrait.gen.proto.extended_expression_pb2 as stee -from substrait.extension_registry import ExtensionRegistry from substrait.builders.extended_expression import ( ExtendedExpressionOrUnbound, resolve_expression, ) +from substrait.extension_registry import ExtensionRegistry +from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension from substrait.type_inference import infer_plan_schema from substrait.utils import ( merge_extension_declarations, - merge_extension_urns, merge_extension_uris, + merge_extension_urns, ) UnboundPlan = Callable[[ExtensionRegistry], stp.Plan] @@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ) return resolve + + +def write_named_table( + table_names: Union[str, Iterable[str]], + input: PlanOrUnbound, + create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_input = input if isinstance(input, stp.Plan) else input(registry) + ns = infer_plan_schema(bound_input) + _table_names = [table_names] if isinstance(table_names, str) else table_names + _create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS + + write_rel = stalg.Rel( + write=stalg.WriteRel( + input=bound_input.relations[-1].root.input, + table_schema=ns, + op=stalg.WriteRel.WRITE_OP_CTAS, + create_mode=_create_mode, + named_table=stalg.NamedObjectWrite(names=_table_names), + ) + ) + return stp.Plan( + relations=[ + stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names)) + ], + **_merge_extensions(bound_input), + ) + + return resolve diff --git a/tests/builders/plan/test_write.py b/tests/builders/plan/test_write.py new file mode 100644 index 0000000..b0e1029 --- /dev/null +++ b/tests/builders/plan/test_write.py @@ -0,0 +1,48 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.type_pb2 as stt +from substrait.builders.plan import read_named_table, write_named_table +from substrait.builders.type import boolean, i64 + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + + +def test_write_rel(): + actual = write_named_table( + "example_table_write_test", + read_named_table("example_table", named_struct), + )(None) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + write=stalg.WriteRel( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon( + direct=stalg.RelCommon.Direct() + ), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable( + names=["example_table"] + ), + ) + ), + op=stalg.WriteRel.WRITE_OP_CTAS, + table_schema=named_struct, + create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS, + named_table=stalg.NamedObjectWrite( + names=["example_table_write_test"] + ), + ) + ), + names=["id", "is_applicable"], + ) + ) + ] + ) + assert actual == expected