|
7 | 7 |
|
8 | 8 | from typing import Iterable, Optional, Union, Callable |
9 | 9 |
|
10 | | -import substrait.gen.proto.algebra_pb2 as stalg |
11 | 10 | from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension |
| 11 | +import substrait.gen.proto.algebra_pb2 as stalg |
| 12 | +import substrait.gen.proto.extended_expression_pb2 as stee |
12 | 13 | import substrait.gen.proto.plan_pb2 as stp |
13 | 14 | import substrait.gen.proto.type_pb2 as stt |
14 | | -import substrait.gen.proto.extended_expression_pb2 as stee |
15 | | -from substrait.extension_registry import ExtensionRegistry |
16 | 15 | from substrait.builders.extended_expression import ( |
17 | 16 | ExtendedExpressionOrUnbound, |
18 | 17 | resolve_expression, |
19 | 18 | ) |
| 19 | +from substrait.extension_registry import ExtensionRegistry |
20 | 20 | from substrait.type_inference import infer_plan_schema |
21 | 21 | from substrait.utils import ( |
22 | 22 | merge_extension_declarations, |
@@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: |
379 | 379 | ) |
380 | 380 |
|
381 | 381 | return resolve |
| 382 | + |
| 383 | + |
| 384 | +def write_table( |
| 385 | + table_names: Union[str, Iterable[str]], |
| 386 | + input: PlanOrUnbound, |
| 387 | + create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None, |
| 388 | +) -> UnboundPlan: |
| 389 | + def resolve(registry: ExtensionRegistry) -> stp.Plan: |
| 390 | + bound_input = input if isinstance(input, stp.Plan) else input(registry) |
| 391 | + ns = infer_plan_schema(bound_input) |
| 392 | + _table_names = [table_names] if isinstance(table_names, str) else table_names |
| 393 | + _create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS |
| 394 | + |
| 395 | + write_rel = stalg.Rel( |
| 396 | + write=stalg.WriteRel( |
| 397 | + input=bound_input.relations[-1].root.input, |
| 398 | + table_schema=ns, |
| 399 | + op=stalg.WriteRel.WRITE_OP_CTAS, |
| 400 | + create_mode=_create_mode, |
| 401 | + named_table=stalg.NamedObjectWrite(names=_table_names), |
| 402 | + ) |
| 403 | + ) |
| 404 | + return stp.Plan( |
| 405 | + relations=[ |
| 406 | + stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names)) |
| 407 | + ], |
| 408 | + **_merge_extensions(bound_input), |
| 409 | + ) |
| 410 | + |
| 411 | + return resolve |
0 commit comments