Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/workflows/codegen-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ jobs:
uses: devcontainers/[email protected]
with:
runCmd: |
# fetch submodule tags since actions/checkout does not
git submodule foreach 'git fetch --unshallow || true'
# Ensure dependencies are installed
uv sync --extra test --extra gen_proto
# Run all code generation steps
make antlr
./gen_proto.sh
make codegen-extensions
make codegen

- name: Check for uncommitted changes
run: |
Expand All @@ -36,9 +36,7 @@ jobs:
git diff src/substrait/gen/
echo ""
echo "To fix this, run:"
echo " make antlr"
echo " ./gen_proto.sh"
echo " make codegen-extensions"
echo " make codegen"
echo "Then commit the changes."
exit 1
fi
12 changes: 12 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ git submodule update --init --recursive

# Code generation

You can run the full code generation using the following command or use the individual commands to selectively regenerate the generated code. This does not update the Substrait Git submodule.

```
make codegen
```

## Protobuf stubs

Run the upgrade script to upgrade the submodule and regenerate the protobuf stubs.
Expand All @@ -31,6 +37,12 @@ uv sync --extra gen_proto
uv run ./update_proto.sh <version>
```

Or run the proto codegen without updating the Substrait Git submodule:

```
make codegen-proto
```

## Antlr grammar

Substrait uses antlr grammar to derive output types of extension functions. Make sure java is installed and ANTLR_JAR environment variable is set. Take a look at .devcontainer/Dockerfile for example setup.
Expand Down
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
codegen: antlr codegen-proto codegen-extensions codegen-version


antlr:
cd third_party/substrait/grammar \
&& java -jar ${ANTLR_JAR} -o ../../../src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4 \
&& rm ../../../src/substrait/gen/antlr/*.tokens \
&& rm ../../../src/substrait/gen/antlr/*.interp

codegen-version:
echo -n 'substrait_version = "' > src/substrait/gen/version.py \
&& cd third_party/substrait && git describe --tags | tr -d 'v\n' >> ../../src/substrait/gen/version.py && cd ../.. \
&& echo '"' >> src/substrait/gen/version.py

codegen-proto:
./gen_proto.sh

codegen-extensions:
uv run --with datamodel-code-generator datamodel-codegen \
--input-file-type jsonschema \
Expand Down
27 changes: 26 additions & 1 deletion src/substrait/builders/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
See `examples/builder_example.py` for usage.
"""

import re
from typing import Callable, Iterable, Optional, Union

import substrait.gen.proto.algebra_pb2 as stalg
Expand All @@ -23,12 +24,27 @@
merge_extension_uris,
merge_extension_urns,
)
from substrait.gen.version import substrait_version

UnboundPlan = Callable[[ExtensionRegistry], stp.Plan]

PlanOrUnbound = Union[stp.Plan, UnboundPlan]


def _create_default_version():
p = re.compile(r"(\d+)\.(\d+)\.(\d+)")
m = p.match(substrait_version)
global default_version
default_version = stp.Version(
major_number=int(m.group(1)),
minor_number=int(m.group(2)),
patch_number=int(m.group(3)),
)


_create_default_version()


def _merge_extensions(*objs):
"""Merge extension URIs, URNs, and declarations from multiple plan/expression objects.

Expand Down Expand Up @@ -65,9 +81,10 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[
stp.PlanRel(root=stalg.RelRoot(input=rel, names=named_struct.names))
]
],
)

return resolve
Expand Down Expand Up @@ -107,6 +124,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
**_merge_extensions(_plan, *bound_expressions),
)
Expand Down Expand Up @@ -137,6 +155,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
names = ns.names

return stp.Plan(
version=default_version,
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
**_merge_extensions(bound_plan, bound_expression),
)
Expand Down Expand Up @@ -183,6 +202,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
**_merge_extensions(bound_plan, *[e[0] for e in bound_expressions]),
)
Expand All @@ -200,6 +220,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand Down Expand Up @@ -238,6 +259,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand Down Expand Up @@ -286,6 +308,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
**_merge_extensions(bound_left, bound_right, bound_expression),
)
Expand Down Expand Up @@ -321,6 +344,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
)

return stp.Plan(
version=default_version,
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))],
**_merge_extensions(bound_left, bound_right),
)
Expand Down Expand Up @@ -372,6 +396,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
] + [e.referred_expr[0].output_names[0] for e in bound_measures]

return stp.Plan(
version=default_version,
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
**_merge_extensions(
bound_input, *bound_grouping_expressions, *bound_measures
Expand Down
1 change: 1 addition & 0 deletions src/substrait/gen/__init__.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/substrait/gen/version.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion tests/builders/plan/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.extensions.extensions_pb2 as ste
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table, aggregate
from substrait.builders.plan import read_named_table, aggregate, default_version
from substrait.builders.extended_expression import column, aggregate_function
from substrait.extension_registry import ExtensionRegistry
from substrait.type_inference import infer_plan_schema
Expand Down Expand Up @@ -56,6 +56,7 @@ def test_aggregate():
ns = infer_plan_schema(table(None))

expected = stp.Plan(
version=default_version,
extension_urns=[
ste.SimpleExtensionURN(extension_urn_anchor=1, urn="extension:test:urn")
],
Expand Down
5 changes: 3 additions & 2 deletions tests/builders/plan/test_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64, string
from substrait.builders.plan import read_named_table, cross
from substrait.builders.plan import read_named_table, cross, default_version
from substrait.extension_registry import ExtensionRegistry

registry = ExtensionRegistry(load_default_extensions=False)
Expand All @@ -28,6 +28,7 @@ def test_cross_join():
actual = cross(table, table2)(registry)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -40,7 +41,7 @@ def test_cross_join():
names=["id", "is_applicable", "fk_id", "name"],
)
)
]
],
)

assert actual == expected
5 changes: 3 additions & 2 deletions tests/builders/plan/test_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table, fetch
from substrait.builders.plan import read_named_table, fetch, default_version
from substrait.builders.extended_expression import literal
from substrait.extension_registry import ExtensionRegistry

Expand All @@ -24,6 +24,7 @@ def test_fetch():
actual = fetch(table, offset=offset, count=count)(registry)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -37,7 +38,7 @@ def test_fetch():
names=["id", "is_applicable"],
)
)
]
],
)

assert actual == expected
5 changes: 3 additions & 2 deletions tests/builders/plan/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table, filter
from substrait.builders.plan import read_named_table, filter, default_version
from substrait.builders.extended_expression import literal
from substrait.extension_registry import ExtensionRegistry

Expand All @@ -21,6 +21,7 @@ def test_filter():
actual = filter(table, literal(True, boolean()))(registry)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -37,7 +38,7 @@ def test_filter():
names=["id", "is_applicable"],
)
)
]
],
)

assert actual == expected
5 changes: 3 additions & 2 deletions tests/builders/plan/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64, string
from substrait.builders.plan import read_named_table, join
from substrait.builders.plan import read_named_table, join, default_version
from substrait.builders.extended_expression import literal
from substrait.extension_registry import ExtensionRegistry

Expand Down Expand Up @@ -31,6 +31,7 @@ def test_join():
)(registry)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -47,7 +48,7 @@ def test_join():
names=["id", "is_applicable", "fk_id", "name"],
)
)
]
],
)

assert actual == expected
5 changes: 3 additions & 2 deletions tests/builders/plan/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table, project
from substrait.builders.plan import read_named_table, project, default_version
from substrait.builders.extended_expression import column
from substrait.extension_registry import ExtensionRegistry

Expand All @@ -21,6 +21,7 @@ def test_project():
actual = project(table, [column("id")])(registry)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -47,7 +48,7 @@ def test_project():
names=["id"],
)
)
]
],
)

assert actual == expected
11 changes: 7 additions & 4 deletions tests/builders/plan/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table
from substrait.builders.plan import read_named_table, default_version
import pytest
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
from google.protobuf import any_pb2
Expand All @@ -20,6 +20,7 @@ def test_read_rel():
actual = read_named_table("example_table", named_struct)(None)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -35,7 +36,7 @@ def test_read_rel():
names=["id", "is_applicable"],
)
)
]
],
)

assert actual == expected
Expand All @@ -45,6 +46,7 @@ def test_read_rel_db():
actual = read_named_table(["example_db", "example_table"], named_struct)(None)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -60,7 +62,7 @@ def test_read_rel_db():
names=["id", "is_applicable"],
)
)
]
],
)

assert actual == expected
Expand Down Expand Up @@ -90,6 +92,7 @@ def test_read_rel_ae():
)

expected = stp.Plan(
version=default_version,
relations=[
stp.PlanRel(
root=stalg.RelRoot(
Expand All @@ -106,7 +109,7 @@ def test_read_rel_ae():
names=["id", "is_applicable"],
)
)
]
],
)

assert actual == expected
Loading