Skip to content

Commit e24c5a9

Browse files
committed
fix examples and add CI/CD job to ensure they work
1 parent cb1fa49 commit e24c5a9

File tree

5 files changed

+77
-35
lines changed

5 files changed

+77
-35
lines changed

.github/workflows/example.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Run examples
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [ main ]
7+
8+
permissions:
9+
contents: read
10+
11+
jobs:
12+
example:
13+
name: Run ${{ matrix.example }}
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
example:
18+
- builder_example.py
19+
- duckdb_example.py
20+
- adbc_example.py
21+
- pyarrow_example.py
22+
steps:
23+
- name: Checkout code
24+
uses: actions/checkout@v5
25+
with:
26+
submodules: recursive
27+
- name: Install uv with python
28+
uses: astral-sh/setup-uv@v7
29+
with:
30+
python-version: "3.10"
31+
- name: Install package dependencies
32+
run: |
33+
uv sync --frozen --extra extensions
34+
- name: Run ${{ matrix.example }}
35+
run: |
36+
uv run examples/${{ matrix.example }}

examples/adbc_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def read_adbc_named_table(name: str, conn):
4545
table = filter(
4646
table,
4747
expression=scalar_function(
48-
"functions_comparison.yaml",
48+
"extension:io.substrait:functions_comparison",
4949
"gte",
5050
expressions=[column("ints"), literal(3, i64())],
5151
),

examples/builder_example.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121
def basic_example():
2222
ns = named_struct(
2323
names=["id", "is_applicable"],
24-
struct=struct(types=[i64(nullable=False), boolean()]),
24+
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
2525
)
2626

2727
table = read_named_table("example_table", ns)
2828
table = filter(table, expression=column("is_applicable"))
2929
table = filter(
3030
table,
3131
expression=scalar_function(
32-
"functions_comparison.yaml",
32+
"extension:io.substrait:functions_comparison",
3333
"lt",
34-
expressions=[column("id"), literal(100, i64())],
34+
expressions=[column("id"), literal(100, i64(nullable=False))],
3535
),
3636
)
3737
table = project(table, expressions=[column("id")])
@@ -40,15 +40,16 @@ def basic_example():
4040
pretty_print_plan(table(registry), use_colors=True)
4141

4242
"""
43-
extension_urns {
44-
extension_urn_anchor: 13
45-
urn: "functions_comparison.yaml"
43+
extension_uris {
44+
extension_uri_anchor: 2
45+
uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml"
4646
}
4747
extensions {
4848
extension_function {
49-
extension_urn_reference: 13
50-
function_anchor: 495
51-
name: "lt"
49+
extension_uri_reference: 2
50+
function_anchor: 124
51+
name: "lt:any_any"
52+
extension_urn_reference: 2
5253
}
5354
}
5455
relations {
@@ -84,7 +85,7 @@ def basic_example():
8485
nullability: NULLABILITY_NULLABLE
8586
}
8687
}
87-
nullability: NULLABILITY_NULLABLE
88+
nullability: NULLABILITY_REQUIRED
8889
}
8990
}
9091
named_table {
@@ -107,10 +108,10 @@ def basic_example():
107108
}
108109
condition {
109110
scalar_function {
110-
function_reference: 495
111+
function_reference: 124
111112
output_type {
112113
bool {
113-
nullability: NULLABILITY_NULLABLE
114+
nullability: NULLABILITY_REQUIRED
114115
}
115116
}
116117
arguments {
@@ -129,7 +130,6 @@ def basic_example():
129130
value {
130131
literal {
131132
i64: 100
132-
nullable: true
133133
}
134134
}
135135
}
@@ -152,25 +152,29 @@ def basic_example():
152152
names: "id"
153153
}
154154
}
155-
"""
155+
extension_urns {
156+
extension_urn_anchor: 2
157+
urn: "extension:io.substrait:functions_comparison"
158+
}
159+
"""
156160

157161

158162
def advanced_example():
159163
print("=== Simple Example ===")
160164
# Simple example (original)
161165
ns = named_struct(
162166
names=["id", "is_applicable"],
163-
struct=struct(types=[i64(nullable=False), boolean()]),
167+
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
164168
)
165169

166170
table = read_named_table("example_table", ns)
167171
table = filter(table, expression=column("is_applicable"))
168172
table = filter(
169173
table,
170174
expression=scalar_function(
171-
"functions_comparison.yaml",
175+
"extension:io.substrait:functions_comparison",
172176
"lt",
173-
expressions=[column("id"), literal(100, i64())],
177+
expressions=[column("id"), literal(100, i64(nullable=False))],
174178
),
175179
)
176180
table = project(table, expressions=[column("id")])
@@ -190,7 +194,8 @@ def advanced_example():
190194
string(nullable=False), # name
191195
i64(nullable=False), # age
192196
fp64(nullable=False), # salary
193-
]
197+
],
198+
nullable=False,
194199
),
195200
)
196201

@@ -200,7 +205,7 @@ def advanced_example():
200205
adult_users = filter(
201206
users,
202207
expression=scalar_function(
203-
"functions_comparison.yaml",
208+
"extension:io.substrait:functions_comparison",
204209
"gt",
205210
expressions=[column("age"), literal(25, i64())],
206211
),
@@ -216,7 +221,7 @@ def advanced_example():
216221
column("salary"),
217222
# Add a calculated field (this would show function options if available)
218223
scalar_function(
219-
"functions_arithmetic.yaml",
224+
"extension:io.substrait:functions_arithmetic",
220225
"multiply",
221226
expressions=[column("salary"), literal(1.1, fp64())],
222227
alias="salary_with_bonus",
@@ -238,7 +243,8 @@ def advanced_example():
238243
i64(nullable=False), # order_id
239244
fp64(nullable=False), # amount
240245
string(nullable=False), # status
241-
]
246+
],
247+
nullable=False,
242248
),
243249
)
244250

@@ -248,7 +254,7 @@ def advanced_example():
248254
high_value_orders = filter(
249255
orders,
250256
expression=scalar_function(
251-
"functions_comparison.yaml",
257+
"extension:io.substrait:functions_comparison",
252258
"gt",
253259
expressions=[column("amount"), literal(50.0, fp64())],
254260
),
@@ -280,16 +286,16 @@ def expression_only_example():
280286
print("=== Expression-Only Example ===")
281287
# Show complex expression structure
282288
complex_expr = scalar_function(
283-
"functions_arithmetic.yaml",
289+
"extension:io.substrait:functions_arithmetic",
284290
"multiply",
285291
expressions=[
286292
scalar_function(
287-
"functions_arithmetic.yaml",
293+
"extension:io.substrait:functions_arithmetic",
288294
"add",
289295
expressions=[
290296
column("base_salary"),
291297
scalar_function(
292-
"functions_arithmetic.yaml",
298+
"extension:io.substrait:functions_arithmetic",
293299
"multiply",
294300
expressions=[
295301
column("base_salary"),
@@ -299,7 +305,7 @@ def expression_only_example():
299305
],
300306
),
301307
scalar_function(
302-
"functions_arithmetic.yaml",
308+
"extension:io.substrait:functions_arithmetic",
303309
"subtract",
304310
expressions=[
305311
literal(1.0, fp64()),
@@ -312,7 +318,7 @@ def expression_only_example():
312318
print("Complex salary calculation expression:")
313319
# Create a simple plan to wrap the expression
314320
dummy_schema = named_struct(
315-
names=["base_salary"], struct=struct(types=[fp64(nullable=False)])
321+
names=["base_salary"], struct=struct(types=[fp64(nullable=False)], nullable=False)
316322
)
317323
dummy_table = read_named_table("dummy", dummy_schema)
318324
dummy_plan = project(dummy_table, expressions=[complex_expr])

examples/duckdb_example.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from substrait.builders.extended_expression import column, scalar_function, literal
1414
from substrait.builders.type import i32
1515
from substrait.extension_registry import ExtensionRegistry
16-
from substrait.json import dump_json
1716
import pyarrow.substrait as pa_substrait
1817

1918
try:
@@ -42,14 +41,13 @@ def read_duckdb_named_table(name: str, conn):
4241
table = filter(
4342
table,
4443
expression=scalar_function(
45-
"functions_comparison.yaml",
44+
"extension:io.substrait:functions_comparison",
4645
"equal",
4746
expressions=[column("c_nationkey"), literal(3, i32())],
4847
),
4948
)
5049
table = project(
5150
table, expressions=[column("c_name"), column("c_address"), column("c_nationkey")]
5251
)
53-
54-
sql = f"CALL from_substrait_json('{dump_json(table(registry))}')"
55-
print(duckdb.sql(sql))
52+
sql = "CALL from_substrait(?)"
53+
print(duckdb.sql(sql, params=[table(registry).SerializeToString()]))

src/substrait/utils/display.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,9 @@ def _stream_literal_value(
721721
f"{indent}{self._color('fp64', Colors.BLUE)}: {self._color(literal.fp64, Colors.GREEN)}\n"
722722
)
723723
elif literal.HasField("string"):
724+
string_value = f'"{literal.string}"'
724725
stream.write(
725-
f"{indent}{self._color('string', Colors.BLUE)}: {self._color(f'"{literal.string}"', Colors.GREEN)}\n"
726+
f'{indent}{self._color("string", Colors.BLUE)}: {self._color(string_value, Colors.GREEN)}\n'
726727
)
727728
elif literal.HasField("date"):
728729
stream.write(
@@ -782,8 +783,9 @@ def _stream_struct_literal(
782783
f"{self._get_indent_with_arrow(depth + 2)}{self._color('i32', Colors.BLUE)}: {self._color(field.i32, Colors.GREEN)}\n"
783784
)
784785
elif field.HasField("string"):
786+
field_string_value = f'"{field.string}"'
785787
stream.write(
786-
f"{self._get_indent_with_arrow(depth + 2)}{self._color('string', Colors.BLUE)}: {self._color(f'"{field.string}"', Colors.GREEN)}\n"
788+
f"{self._get_indent_with_arrow(depth + 2)}{self._color('string', Colors.BLUE)}: {self._color(field_string_value, Colors.GREEN)}\n"
787789
)
788790
elif field.HasField("boolean"):
789791
stream.write(

0 commit comments

Comments
 (0)