From 2a071867a49509fce5709dc0ddf30f0e6d32395d Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 25 Feb 2026 17:36:17 -0500 Subject: [PATCH 1/6] docs: add ADR for migration from Pest to LALRPOP --- adrs/lalrpop.md | 253 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 adrs/lalrpop.md diff --git a/adrs/lalrpop.md b/adrs/lalrpop.md new file mode 100644 index 0000000..f3a677f --- /dev/null +++ b/adrs/lalrpop.md @@ -0,0 +1,253 @@ +# ADR: Replace Pest with LALRPOP for Line-Level Parsing + +**Status:** Proposed +**Date:** 2026-02-25 + +## Context + +The parser for substrait-explain's text format is currently built with [Pest](https://pest.rs/), a PEG parser generator. Pest produces a generic parse tree (`Pairs`) that must be walked manually to build Substrait protobuf structs. This works but has several pain points: + +- **Manual tree walking** — ~5,000 lines of `RuleIter`, `ParsePair`, and `ScopedParsePair` code to consume Pest pairs and build protobuf structs directly. +- **`RuleIter::done()` ceremony** — The `Drop`-based consumption enforcement requires calling `done()` before any early return, forcing an awkward "extract all data first, validate after" pattern across the codebase. +- **No intermediate AST** — Pest pairs are converted directly to protobuf, tightly coupling the parser to the Substrait proto schema and mixing syntax parsing with semantic validation (e.g., extension lookups happen during parsing). +- **Silent ambiguity** — PEG's ordered choice silently resolves grammar ambiguities. Ordering bugs produce wrong parses rather than errors. +- **Per-relation grammar rules** — Each relation type has its own Pest rule despite the text format being designed around a uniform `Name[args => columns]` structure. + +## Decision + +Replace the Pest expression parser with [LALRPOP](https://github.com/lalrpop/lalrpop), introducing an intermediate AST and a lowering pass from AST to Substrait protobuf. + +### Scope: Strategy A — Replace Only the Line Parser + +The structural parser (`structural.rs`) — which handles section headers (`=== Extensions`, `=== Plan`), indentation-based tree building, and line dispatching — remains unchanged. Only the per-line content parsing (expressions, types, relation arguments) moves to LALRPOP. + +This is the minimal migration. The structural parser is clean and well-suited to its job; the pain points are entirely in the Pest layer. + +### Lexer: LALRPOP Built-in + +Use LALRPOP's built-in regex-based lexer. The token set is small (~25 tokens: `$`, `#`, `@`, `&`, `=>`, delimiters, identifiers, numbers, strings, keywords). The built-in lexer handles this comfortably, and automatic whitespace skipping is appropriate since each line is parsed independently. + +If we later need context-sensitive lexing (e.g., for more complex string escapes), we can upgrade to [Logos](https://github.com/maciejhirsz/logos) without changing the grammar. + +### Grammar: Abstract/Generic Relations + +Instead of per-relation grammar rules, define a single generic relation rule: + +```lalrpop +Relation: ast::Relation = { + "[" "=>" > "]" + => ast::Relation { name, args, outputs: cols }, + "[" "]" + => ast::Relation { name, args, outputs: vec![] }, + "[" "]" + => ast::Relation { name, args: vec![], outputs: vec![] }, +}; + +// Positional arguments followed by optional named arguments. +// Named arguments always come last — they cannot be intermingled with positional args. +ArgList: ast::ArgList = { + > "," > + => ast::ArgList { positional: pos, named }, + > + => ast::ArgList { positional: pos, named: vec![] }, + > + => ast::ArgList { positional: vec![], named }, +}; + +// Positional argument — expressions, enums, named columns, tuples, wildcard. +// Tuples also contain only positional arguments (no named args inside tuples). +Arg: ast::Arg = { + => ast::Arg::Expr(e), + => ast::Arg::Enum(e), + ":" => ast::Arg::NamedColumn(n, t), + "(" > ")" => ast::Arg::Tuple(args), + "_" => ast::Arg::Wildcard, +}; + +NamedArg: ast::NamedArg = { + "=" => ast::NamedArg { name: n, value: a }, +}; +``` + +Relation-specific validation (e.g., "Filter expects an expression before `=>` and references after") moves to the lowering pass. This reflects the text format's stated design principle: *"All relations follow the same structure: `Name[arguments => columns]`"*. + +**Argument ordering:** The grammar enforces that named arguments (`name=value`) always follow positional arguments — they cannot be intermingled. Tuples can only contain positional arguments. This is a structural constraint that belongs in the grammar, not the lowering pass, since it applies uniformly to all relations. + +**Implementation note (2026-02-26):** The current implementation parses argument entries as a single ordered list and records whether named/positional ordering is invalid, then reports that as a lowering validation error. This was chosen to keep the grammar conflict-free while preserving the same user-visible constraint. + +**Implementation note (2026-02-26):** Extension declaration lines in the +`=== Extensions` section (`@...` URNs and `#...` declarations) are parsed via +the same LALRPOP -> AST pipeline and then validated in parser/lowering code. + +**Benefits:** +- Smaller grammar with fewer LR conflict risks. +- Adding new relation types (Window, Set, HashJoin) requires only lowering code, no grammar recompilation. +- Cleaner separation: grammar owns syntax, lowering owns semantics. + +**Cost:** +- Parse errors are less relation-specific. Mitigated by the lowering pass, which has full context (including line reference) to produce clear semantic errors like "Filter's output must be field references, not named columns." + +### Intermediate AST + +Introduce a lightweight AST as the grammar's output, rather than building protobuf directly: + +```rust +// src/ast.rs (sketch) + +struct Document { + extensions: Extensions, + relations: Vec, +} + +struct RelationTree { + relation: Relation, + children: Vec, + line: LineInfo, // for error reporting +} + +struct Relation { + name: RelationName, + args: ArgList, + outputs: Vec, // positional only — no named args after => +} + +enum RelationName { + Standard(String), // "Read", "Filter", ... + Extension(ExtensionType, String), // "ExtensionLeaf:ParquetScan" +} + +/// Input arguments: positional args followed by optional named args. +/// Named args always come last; they cannot be intermingled. +struct ArgList { + positional: Vec, + named: Vec, +} + +/// A positional argument. Used both in input args and output columns. +/// Tuples also contain only positional args. +enum Arg { + Expr(Expr), + Enum(String), // &Inner, &AscNullsFirst + NamedColumn(String, Type), // name:type + Tuple(Vec), // (ref, enum) + Wildcard, // _ +} + +/// A named argument (name=value). Only valid in input position. +struct NamedArg { + name: String, + value: Arg, +} + +enum Expr { + FieldRef(u32), // $0, $1 + Literal(LiteralValue, Option), // 42, 'hello':string + FunctionCall { + name: String, + anchor: Option, + urn_anchor: Option, + args: Vec, + output_type: Option, + }, + IfThen { + clauses: Vec<(Expr, Expr)>, + else_expr: Box, + }, +} + +enum LiteralValue { + Integer(i64), + Float(f64), + Boolean(bool), + String(String), +} +``` + +### Lowering Pass: AST to Protobuf + +A new `lower` module converts the AST to `substrait::proto::Plan`: + +1. **Extension resolution** — resolve `#anchor` and `@urn` references against `SimpleExtensions`. This moves out of the parser entirely. +2. **Relation dispatch** — match on relation name, validate argument shapes, build the appropriate `Rel` variant. +3. **Type construction** — convert AST types to protobuf type structs. +4. **Error reporting** — errors reference the source line via `LineInfo`, providing context like "line 5: Filter expects a boolean expression before `=>`, got named column `foo:i64`". + +This cleanly separates the three concerns currently mixed in the parser: syntax (grammar), name resolution (extensions), and semantic validation (relation-specific rules). + +### Parse Algorithm: LR(1) with LALR(1) Fallback + +Start with LR(1) (LALRPOP's default), which accepts a broader grammar class. If build times become a concern, switch to LALR(1) via `grammar["LALR(1)"];` in the grammar file, which generates smaller code at the cost of rejecting some valid LR(1) grammars. + +## Consequences + +### What We Gain + +- **~3,000 lines of pair-walking code removed** — `RuleIter`, `ParsePair`, `ScopedParsePair`, and most of `common.rs`, `expressions.rs`, `relations.rs`, `types.rs` are replaced by inline grammar actions and a lowering pass. +- **Compile-time grammar validation** — Ambiguous grammars produce shift/reduce conflict errors instead of silent misbehavior. +- **Clean architecture** — Three distinct phases (parse → AST → protobuf) with clear boundaries. +- **Extension resolution decoupled from parsing** — Errors are categorized as syntactic vs. semantic, and extension changes don't affect the grammar. +- **Easier to add relations** — New relation types are code-only changes in the lowering pass. +- **`get_input_field_count()` bug fixable** — The incomplete field-count helper (currently only handles Read/Filter/Project) can be properly implemented in the lowering pass where all relation types are handled uniformly. + +### What We Lose / Risk + +- **Build time increase** — LALRPOP uses `build.rs` code generation. For our grammar size (~50–100 rules), this should add ~10–20s to builds, not minutes. Worth measuring early. +- **Opaque generated code** — LALRPOP's generated parser is not meant to be read. Grammar debugging relies on conflict diagnostics rather than code inspection. +- **Migration effort** — Touches every file in `src/parser/`. Estimated at 2–3 focused sessions. The existing roundtrip tests provide strong regression coverage. +- **New dependency** — `lalrpop` (build-dep) + `lalrpop-util` (runtime dep). Both are well-maintained (~2.8M downloads/month, latest release Feb 2026). +- **LR grammar engineering** — Some PEG patterns (ordered choice for disambiguation) need rethinking as explicit precedence. The abstract relation grammar minimizes this risk. + +### What Stays the Same + +- **Structural parser** (`structural.rs`) — Line-by-line state machine, indentation tree building, section header handling. Unchanged. +- **Textifier** (`src/textify/`) — Completely unaffected; it reads protobuf, not the parser's internals. +- **Public API** — `parse()` still returns `Result`. +- **Roundtrip tests** — Continue to validate parse→format→parse. They don't depend on parser internals. +- **Extension system** — `SimpleExtensions` and `ExtensionRegistry` remain; they're just consumed by the lowering pass instead of during parsing. + +## Migration Plan + +### Phase 1: AST Types and Grammar + +1. Add `lalrpop` to `[build-dependencies]` and `lalrpop-util` to `[dependencies]`. +2. Create `build.rs` with `lalrpop::process_root().unwrap()`. +3. Define AST types in `src/ast.rs`. +4. Write the LALRPOP grammar (`src/parser/grammar.lalrpop`) covering expressions, types, arguments, and the generic relation rule. +5. Unit-test the grammar in isolation against known-good inputs. + +### Phase 2: Lowering Pass + +1. Implement `src/lower.rs` — AST → `substrait::proto` conversion with extension resolution. +2. Implement per-relation validation and argument destructuring. +3. Wire up error reporting with line references. + +### Phase 3: Integration and Cutover + +1. Modify `structural.rs` to call the LALRPOP parser instead of Pest for line content. +2. Route LALRPOP AST through the lowering pass to produce protobuf. +3. Run the full roundtrip test suite — this is the primary validation. +4. Remove Pest: delete `expression_grammar.pest`, `common.rs` (RuleIter, ParsePair), and the per-relation/per-expression pair-walking code. + +### Phase 4: Cleanup + +1. Remove the `pest` and `pest_derive` dependencies. +2. Simplify error types (parse errors vs. lowering errors are now distinct). +3. Update GRAMMAR.md examples if any syntax details changed. + +## Alternatives Considered + +### Keep Pest, add an AST layer + +We could introduce an intermediate AST while keeping Pest. This would fix the "no AST" and "extension resolution during parsing" problems but would *not* fix RuleIter ergonomics, silent ambiguity, or per-relation grammar coupling. The pair-walking code would still exist, just targeting AST types instead of protobuf. + +### Hand-written recursive descent parser + +Maximum control and debuggability. No build-time code generation. Used by Ruff (which migrated *from* LALRPOP). However, for our grammar size, the engineering effort is higher than LALRPOP and we'd lose compile-time ambiguity checking. Better suited for larger, more complex grammars where error recovery and performance are critical. + +### Parser combinators (nom, winnow, chumsky) + +No grammar file, pure Rust. Good composability. But for a grammar of this size with clear structure, a grammar file is more readable and maintainable than combinator chains. Chumsky has excellent error recovery but is less mature than LALRPOP. + +### Stay with Pest as-is + +The parser works. The pain points are real but manageable. This is a valid choice if the migration effort is better spent on feature work (new relation types, better error messages, etc.). The risk is that each new relation adds more pair-walking boilerplate, compounding the maintenance cost. From be3e2e60b37db048ee7c685cfb636b2be59fb454 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Wed, 25 Feb 2026 19:38:56 -0500 Subject: [PATCH 2/6] feat: migrate parser from Pest to LALRPOP with structured AST and lowering pipeline Replace the Pest PEG parser with a LALRPOP LR(1) grammar, introducing a two-phase parse-then-lower architecture. The grammar produces typed AST nodes that are lowered to Substrait protobuf with relation-specific semantic validation. Includes extension declaration parsing via LALRPOP, extraction of the extension lowering module, FromStr/From trait implementations, and in-code documentation for all major modules. --- CONTRIBUTING.md | 38 +- Cargo.lock | 241 ++++- Cargo.toml | 6 +- GRAMMAR.md | 6 +- build.rs | 3 + src/extensions/any.rs | 11 +- src/extensions/args.rs | 4 +- src/extensions/conversion.rs | 131 +++ src/extensions/simple.rs | 7 + src/fixtures.rs | 6 +- src/parser/ast.rs | 539 ++++++++++ src/parser/common.rs | 346 ------- src/parser/errors.rs | 200 +++- src/parser/expression_grammar.pest | 288 ------ src/parser/expressions.rs | 979 ------------------ src/parser/extensions.rs | 431 ++------ src/parser/lalrpop_line.rs | 215 ++++ src/parser/line_grammar.lalrpop | 266 +++++ src/parser/lower/expr.rs | 170 ++++ src/parser/lower/extensions.rs | 165 +++ src/parser/lower/literals.rs | 257 +++++ src/parser/lower/mod.rs | 98 ++ src/parser/lower/relations.rs | 449 +++++++++ src/parser/lower/types.rs | 194 ++++ src/parser/lower/validate.rs | 157 +++ src/parser/mod.rs | 25 +- src/parser/relations.rs | 1505 ---------------------------- src/parser/structural.rs | 1088 ++++++-------------- src/parser/types.rs | 428 -------- src/structural/mod.rs | 438 -------- src/textify/expressions.rs | 74 +- src/textify/foundation.rs | 8 + src/textify/mod.rs | 7 +- src/textify/plan.rs | 29 +- src/textify/rels.rs | 43 +- src/textify/types.rs | 9 + tests/literal_roundtrip.rs | 67 ++ tests/plan_roundtrip.rs | 6 +- tests/types.rs | 245 ++--- 39 files changed, 3728 insertions(+), 5451 deletions(-) create mode 100644 build.rs create mode 100644 src/extensions/conversion.rs create mode 100644 src/parser/ast.rs delete mode 100644 src/parser/common.rs create mode 100644 src/parser/lalrpop_line.rs create mode 100644 src/parser/line_grammar.lalrpop create mode 100644 src/parser/lower/expr.rs create mode 100644 src/parser/lower/extensions.rs create mode 100644 src/parser/lower/literals.rs create mode 100644 src/parser/lower/mod.rs create mode 100644 src/parser/lower/relations.rs create mode 100644 src/parser/lower/types.rs create mode 100644 src/parser/lower/validate.rs delete mode 100644 src/parser/types.rs delete mode 100644 src/structural/mod.rs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 21fae2f..f2f2528 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,22 +90,16 @@ This project implements Substrait protobuf plan parsing and formatting. See the This project uses different error handling patterns depending on the context: -##### Pest Parser Context +##### Parser Context -Unwraps are safe where the grammar enforces the invariants. Always assert that the right rule was used to ensure the same invariants: +Use structured parse and lowering errors with explicit context (line number + source line). Syntax errors should come from LALRPOP parse entrypoints, and semantic validation errors should come from lowering: ```rust -// Pest parse functions should assert the rule matches -fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule(), "Expected rule {:?}", Self::rule()); - // Now safe to unwrap based on grammar guarantees - let inner = pair.into_inner().next().unwrap(); - // ... rest of parsing logic -} +let relation = lalrpop_line::parse_relation(line).map_err(|e| { + ParseError::Plan(ctx, MessageParseError::invalid("relation", e)) +})?; ``` -_Note: Ideally, `pest_derive` would remove the need for unwraps by providing structural types that encode the invariants, ensuring at compile-time the grammar and code remain in sync._ - ##### Other Parsing Context (e.g., lookups) These can reasonably fail. For the parser, we error on bad input and attempt to provide good error messages: @@ -116,7 +110,7 @@ let value = input.parse::().map_err(|_| "Invalid integer")?; let first = chars.next().ok_or("Expected non-empty input")?; ``` -_Note: When parsing with pest, use `pest::Span` and `pest::Error` types to provide clearer error messages about where in the input parsing failed. This helps users understand exactly what went wrong and where._ +_Note: Include enough context (line number, offending line, and message category) so users can quickly locate and fix bad input._ ##### Formatting Context @@ -157,22 +151,14 @@ assert!(result.is_ok()); - Return `Result` types for operations that can legitimately fail - Collect and continue for formatting errors, fail-fast for parsing errors -#### RuleIter Consumption Patterns - -`RuleIter` enforces complete consumption via its `Drop` implementation. This can cause assertion failures if validation functions return early before calling `iter.done()`. +#### Parser Pipeline Pattern -**Current workaround:** Extract all data first, then validate: +Keep parser internals split into two phases: -```rust -// Extract data without validation -let (limit_pair, offset_pair) = /* extract pairs */; -iter.done(); // Call before any early returns - -// Then validate using functional patterns -let count_mode = limit_pair.map(|pair| CountMode::parse_pair(extensions, pair)).transpose()?; -``` +1. Parse single lines to AST with LALRPOP (`line_grammar.lalrpop`). +2. Lower AST nodes to protobuf with relation-specific semantic validation (`parser/lower/`). -_Note: The RuleIter interface could be improved to handle this pattern more naturally._ +This separation keeps grammar concerns and semantic checks independent and easier to test. #### Documentation Formatting @@ -192,7 +178,7 @@ Since markdown files are included in Rustdoc, use `##` for actual `#` characters When working with [`GRAMMAR.md`](GRAMMAR.md): -- Use PEG format with `rule_name := definition` syntax +- Keep grammar notation in `GRAMMAR.md` implementation-independent (PEG/EBNF style), while ensuring examples remain accepted by the implemented parser (`line_grammar.lalrpop`) - Run `cargo test --doc` to verify all code examples compile - Ensure all referenced identifiers are properly defined somewhere in the grammar diff --git a/Cargo.lock b/Cargo.lock index df1582b..98e8a0d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,6 +82,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "ascii-canvas" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1e3e699d84ab1b0911a1010c5c106aa34ae89aeac103be5ce0c3859db1e891" +dependencies = [ + "term", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -94,6 +103,21 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitflags" version = "2.10.0" @@ -252,6 +276,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "ena" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabffdaee24bd1bf95c5ef7cec31260444317e72ea56c4c91750e8b7ee58d5f1" +dependencies = [ + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -411,6 +444,47 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lalrpop" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba4ebbd48ce411c1d10fb35185f5a51a7bfa3d8b24b4e330d30c9e3a34129501" +dependencies = [ + "ascii-canvas", + "bit-set", + "ena", + "itertools", + "lalrpop-util", + "petgraph 0.7.1", + "pico-args", + "regex", + "regex-syntax", + "sha3", + "string_cache", + "term", + "unicode-xid", + "walkdir", +] + +[[package]] +name = "lalrpop-util" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5baa5e9ff84f1aefd264e6869907646538a52147a755d494517a8007fb48733" +dependencies = [ + "regex-automata", + "rustversion", +] + [[package]] name = "libc" version = "0.2.180" @@ -423,6 +497,15 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -441,6 +524,12 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "num-traits" version = "0.2.19" @@ -462,6 +551,29 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "pbjson" version = "0.8.0" @@ -500,58 +612,46 @@ dependencies = [ ] [[package]] -name = "pest" -version = "2.8.5" +name = "petgraph" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9eb05c21a464ea704b53158d358a31e6425db2f63a1a7312268b05fe2b75f7" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ - "memchr", - "ucd-trie", + "fixedbitset", + "indexmap", ] [[package]] -name = "pest_derive" -version = "2.8.5" +name = "petgraph" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68f9dbced329c441fa79d80472764b1a2c7e57123553b8519b36663a2fb234ed" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ - "pest", - "pest_generator", + "fixedbitset", + "hashbrown 0.15.5", + "indexmap", ] [[package]] -name = "pest_generator" -version = "2.8.5" +name = "phf_shared" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bb96d5051a78f44f43c8f712d8e810adb0ebf923fc9ed2655a7f66f63ba8ee5" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn", + "siphasher", ] [[package]] -name = "pest_meta" -version = "2.8.5" +name = "pico-args" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "602113b5b5e8621770cfd490cfd90b9f84ab29bd2b0e49ad83eb6d186cef2365" -dependencies = [ - "pest", - "sha2", -] +checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315" [[package]] -name = "petgraph" -version = "0.8.3" +name = "precomputed-hash" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" -dependencies = [ - "fixedbitset", - "hashbrown 0.15.5", - "indexmap", -] +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prettyplease" @@ -592,7 +692,7 @@ dependencies = [ "itertools", "log", "multimap", - "petgraph", + "petgraph 0.8.3", "prettyplease", "prost", "prost-types", @@ -647,6 +747,15 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.2" @@ -744,6 +853,12 @@ dependencies = [ "syn", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "semver" version = "1.0.27" @@ -834,14 +949,13 @@ dependencies = [ ] [[package]] -name = "sha2" -version = "0.10.9" +name = "sha3" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" dependencies = [ - "cfg-if", - "cpufeatures", "digest", + "keccak", ] [[package]] @@ -850,6 +964,30 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "string_cache" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf776ba3fa74f83bf4b63c3dcbbf82173db2632ed8452cb2d891d33f459de70f" +dependencies = [ + "new_debug_unreachable", + "parking_lot", + "phf_shared", + "precomputed-hash", +] + [[package]] name = "strsim" version = "0.11.1" @@ -890,9 +1028,9 @@ dependencies = [ "chrono", "clap", "indexmap", + "lalrpop", + "lalrpop-util", "pbjson-types", - "pest", - "pest_derive", "prost", "prost-types", "serde", @@ -926,6 +1064,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "term" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8c27177b12a6399ffc08b98f76f7c9a1f4fe9fc967c784c5a071fa8d93cf7e1" +dependencies = [ + "windows-sys", +] + [[package]] name = "thiserror" version = "2.0.18" @@ -999,18 +1146,18 @@ dependencies = [ "typify-impl", ] -[[package]] -name = "ucd-trie" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" - [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unsafe-libyaml" version = "0.2.11" diff --git a/Cargo.toml b/Cargo.toml index d234dba..3599a91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,9 +47,8 @@ anyhow = { version = "1.0", optional = true } chrono = "0.4.41" clap = { version = "4.5", features = ["derive"], optional = true } indexmap = "2" +lalrpop-util = { version = "0.22.2", features = ["lexer"] } pbjson-types = { version = "0.8.0", optional = true } -pest = "2.8.0" -pest_derive = "2.8.0" prost = "0.14" prost-types = "0.14" serde_json = { version = "1.0", optional = true } @@ -61,3 +60,6 @@ thiserror = "2.0.12" # Used in examples serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" + +[build-dependencies] +lalrpop = "0.22.2" diff --git a/GRAMMAR.md b/GRAMMAR.md index 6003f80..befadb6 100644 --- a/GRAMMAR.md +++ b/GRAMMAR.md @@ -57,9 +57,8 @@ This document uses **PEG (Parsing Expression Grammar)** notation: - **`element*`** - Zero or more repetitions - **`element+`** - One or more repetitions - **`element1 / element2`** - Choice (try element1 first) - - _Implementation Note: Pest uses `|` instead of `/`_ - **`element1 element2`** - Sequence - - _Implementation Note: Pest uses `~` for explicit concatenation_ +- **`// comment`** - Line comment (allowed at line start or after whitespace) ## Basic Example @@ -199,7 +198,7 @@ A literal can come in the form of an integer, float, boolean, or string, and can - String literals with type annotations for non-primitive types - Examples: `'2023-01-01':date`, `'2023-12-25T14:30:45.123':timestamp` -All literal types (`integer`, `float`, `boolean`, and `string`) are now supported in the current implementation. Support for typed literals (string literals with non-primitive type annotations like `'2023-01-01':date`) remains to be implemented. +All literal types (`integer`, `float`, `boolean`, and `string`) are supported in the current implementation. Typed string literals are also supported for currently implemented target types (for example `date`, `time`, and `timestamp`). ## Types @@ -467,6 +466,7 @@ argument := literal / expression / enum / tuple tuple := "(" argument ("," argument)* ")" arguments := argument ("," argument)* named_arguments := name "=" argument ("," name "=" argument)* +comment := "//" (!newline .)* ``` #### Examples diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..841fa41 --- /dev/null +++ b/build.rs @@ -0,0 +1,3 @@ +fn main() { + lalrpop::process_src().expect("failed to generate LALRPOP parsers"); +} diff --git a/src/extensions/any.rs b/src/extensions/any.rs index b68218d..cc4aab2 100644 --- a/src/extensions/any.rs +++ b/src/extensions/any.rs @@ -1,3 +1,10 @@ +//! Unified wrapper for the protobuf `Any` type. +//! +//! Substrait extension relations carry opaque payloads as protobuf `Any` +//! messages. The Rust ecosystem has two incompatible `Any` types +//! (`prost_types::Any` and `pbjson_types::Any`); this module provides [`Any`] +//! and [`AnyRef`] to convert between them transparently. + use prost::{Message, Name}; use crate::extensions::registry::ExtensionError; @@ -11,8 +18,8 @@ pub struct Any { } /// A reference to a protobuf `Any` type. Can be created from references to -/// [`prost_types::Any`](prost_types::Any), -/// [`pbjson_types::Any`](pbjson_types::Any), or our own [`Any`](Any) type. +/// [`prost_types::Any`], +/// [`pbjson_types::Any`], or our own [`Any`] type. #[derive(Debug, Copy, Clone, PartialEq)] pub struct AnyRef<'a> { pub type_url: &'a str, diff --git a/src/extensions/args.rs b/src/extensions/args.rs index a14f8fd..f4a7f47 100644 --- a/src/extensions/args.rs +++ b/src/extensions/args.rs @@ -339,8 +339,8 @@ impl ExtensionRelationType { } } -// Note: create_rel is implemented in parser/extensions.rs to avoid -// pulling in protobuf dependencies in the core args module +// Note: relation construction lives in parser/lower/extensions.rs so this +// core args module stays parser- and protobuf-agnostic. impl ExtensionArgs { /// Create a new empty ExtensionArgs diff --git a/src/extensions/conversion.rs b/src/extensions/conversion.rs new file mode 100644 index 0000000..3334b3d --- /dev/null +++ b/src/extensions/conversion.rs @@ -0,0 +1,131 @@ +//! Conversion utilities between parsed extension arguments and registry types. +//! +//! This module provides helpers to convert between structured extension values +//! and text format fragments used by the formatter. + +use crate::extensions::{ExtensionArgs, ExtensionColumn, ExtensionValue}; + +/// Convert ExtensionArgs back to text format for textification +pub fn format_extension_args_to_text(args: &ExtensionArgs, extension_name: &str) -> String { + let mut parts = Vec::new(); + + // Add extension name + parts.push(extension_name.to_string()); + parts.push("[".to_string()); + + let mut inner_parts = Vec::new(); + + // Add positional arguments + for value in &args.positional { + inner_parts.push(format_extension_value_to_text(value)); + } + + // Add named arguments + for (name, value) in &args.named { + inner_parts.push(format!( + "{}={}", + name, + format_extension_value_to_text(value) + )); + } + + // Join the arguments + let args_str = inner_parts.join(", "); + + // Add output columns if present + if !args.output_columns.is_empty() { + let columns_str = args + .output_columns + .iter() + .map(format_extension_column_to_text) + .collect::>() + .join(", "); + + if args_str.is_empty() { + parts.push(format!("=> {columns_str}")); + } else { + parts.push(format!("{args_str} => {columns_str}")); + } + } else { + parts.push(args_str); + } + + parts.push("]".to_string()); + + parts.join("") +} + +/// Format an ExtensionValue to text +fn format_extension_value_to_text(value: &ExtensionValue) -> String { + match value { + ExtensionValue::String(s) => format!("'{s}'"), // Add quotes for strings + ExtensionValue::Integer(i) => i.to_string(), + ExtensionValue::Float(f) => f.to_string(), + ExtensionValue::Boolean(b) => b.to_string(), + ExtensionValue::Reference(r) => format!("${r}"), + ExtensionValue::Expression(e) => e.to_string(), + } +} + +/// Format an ExtensionColumn to text +fn format_extension_column_to_text(column: &ExtensionColumn) -> String { + match column { + ExtensionColumn::Named { name, type_spec } => format!("{name}:{type_spec}"), + ExtensionColumn::Reference(r) => format!("${r}"), + ExtensionColumn::Expression(e) => e.to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::extensions::ExtensionRelationType; + use crate::parser::{ParseError, Parser}; + + #[test] + fn test_format_extension_args() { + let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf); + args.positional.push(ExtensionValue::Reference(42)); + args.named.insert( + "path".to_string(), + ExtensionValue::String("data/*.parquet".to_string()), + ); + args.named.insert("batch_size".to_string(), ExtensionValue::Integer(1024)); + args.output_columns.push(ExtensionColumn::Named { + name: "col1".to_string(), + type_spec: "i32".to_string(), + }); + + let text = format_extension_args_to_text(&args, "ExtensionLeaf:ParquetScan"); + assert!(text.contains("$42")); + assert!(text.contains("path='data/*.parquet'")); + assert!(text.contains("batch_size=1024")); + assert!(text.contains("=> col1:i32")); + } + + #[test] + fn test_integration_extension_missing_registry_errors() { + // Test a complete plan with extension that has unregistered extension + let plan_text = r#" +=== Extensions +URNs: + @ 1: https://example.com/test +Functions: + # 10 @ 1: test_extension + +=== Plan +Root[result] + ExtensionLeaf:TestExtension[$5, test_arg='hello' => col1:i32] +"#; + + let result = Parser::parse(plan_text); + + match result { + Err(ParseError::UnregisteredExtension { name, .. }) => { + assert_eq!(name, "TestExtension"); + } + Err(e) => panic!("Expected UnregisteredExtension error, got {e}"), + Ok(_) => panic!("Expected parsing to fail without a registry"), + } + } +} diff --git a/src/extensions/simple.rs b/src/extensions/simple.rs index 9b74086..9b4b59b 100644 --- a/src/extensions/simple.rs +++ b/src/extensions/simple.rs @@ -1,3 +1,10 @@ +//! Anchor-to-name lookup table for Substrait simple extensions. +//! +//! [`SimpleExtensions`] maps extension anchors (`#10`, `@1`) to URNs and +//! function/type names. It is the shared data structure used by both the +//! parser (to resolve references while lowering) and the textifier (to +//! display human-readable names instead of numeric anchors). + use std::collections::BTreeMap; use std::collections::btree_map::Entry; use std::fmt; diff --git a/src/fixtures.rs b/src/fixtures.rs index 062eba0..21c4f27 100644 --- a/src/fixtures.rs +++ b/src/fixtures.rs @@ -3,7 +3,7 @@ use crate::extensions::simple::ExtensionKind; use crate::extensions::{ExtensionRegistry, SimpleExtensions}; use crate::format; -use crate::parser::{MessageParseError, Parser, ScopedParse}; +use crate::parser::Parser; use crate::textify::foundation::{ErrorAccumulator, ErrorList}; use crate::textify::plan::PlanWriter; use crate::textify::{ErrorQueue, OutputOptions, Scope, ScopedContext, Textify}; @@ -90,10 +90,6 @@ impl TestContext { assert!(errs.is_empty(), "{} Errors: {}", errs.0.len(), errs.0[0]); s } - - pub fn parse(&self, input: &str) -> Result { - T::parse(&self.extensions, input) - } } /// Roundtrip a plan and verify that the output is the same as the input, after diff --git a/src/parser/ast.rs b/src/parser/ast.rs new file mode 100644 index 0000000..93f812d --- /dev/null +++ b/src/parser/ast.rs @@ -0,0 +1,539 @@ +//! AST for the text representation of `substrait-explain`, used by the +//! [`crate::parser`] module. +//! +//! Parsing happens in two steps: +//! +//! 1. Conversion from text to the AST in the [`crate::parser`] module via +//! LALRPOP, with syntactical validation. +//! 2. Conversion from the AST to a Protobuf [`substrait::proto::Plan`], via the +//! [`crate::parser::lower`] module, with semantic validation, e.g. function +//! references, anchors, column references, argument types/shapes match. +//! +//! This AST is intentionally close to the text syntax and does not encode all +//! semantic constraints; semantic validation happens in lowering. + +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExtensionUrnDeclaration { + pub anchor: u32, + pub urn: String, +} + +impl ExtensionUrnDeclaration { + pub fn new(anchor: u32, urn: impl Into) -> Self { + Self { + anchor, + urn: urn.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExtensionDeclaration { + pub anchor: u32, + pub urn_anchor: u32, + pub name: String, +} + +impl ExtensionDeclaration { + pub fn new(anchor: u32, urn_anchor: u32, name: impl Into) -> Self { + Self { + anchor, + urn_anchor, + name: name.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Relation { + pub name: RelationName, + pub args: ArgList, + pub outputs: Vec, +} + +impl Relation { + pub fn new(name: RelationName, args: ArgList, outputs: Vec) -> Self { + Self { + name, + args, + outputs, + } + } + + pub fn with_outputs(name: RelationName, outputs: Vec) -> Self { + Self::new(name, ArgList::default(), outputs) + } + + pub fn with_entries(name: RelationName, entries: Vec, outputs: Vec) -> Self { + Self::new(name, ArgList::from_entries(entries), outputs) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RelationName { + Standard(String), + Extension(ExtensionType, String), +} + +impl RelationName { + pub fn standard(name: impl Into) -> Self { + Self::Standard(name.into()) + } + + pub fn extension(kind: ExtensionType, name: impl Into) -> Self { + Self::Extension(kind, name.into()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExtensionType { + Leaf, + Single, + Multi, +} + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct ArgList { + /// Positional arguments in source order. + pub positional: Vec, + /// Named arguments in source order. + pub named: Vec, + /// True when positional arguments appear after named arguments. + /// + /// The parser records this and lowering turns it into a user-facing error. + pub invalid_order: bool, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Arg { + Expr(Expr), + Enum(String), + NamedColumn(String, TypeExpr), + Tuple(Vec), + Wildcard, +} + +impl Arg { + pub fn expr(value: Expr) -> Self { + Self::Expr(value) + } + + pub fn enum_variant(value: impl Into) -> Self { + Self::Enum(value.into()) + } + + pub fn named_column(name: impl Into, typ: TypeExpr) -> Self { + Self::NamedColumn(name.into(), typ) + } + + pub fn tuple(values: Vec) -> Self { + Self::Tuple(values) + } + + pub fn wildcard() -> Self { + Self::Wildcard + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct NamedArg { + pub name: String, + pub value: Arg, +} + +impl NamedArg { + pub fn new(name: impl Into, value: Arg) -> Self { + Self { + name: name.into(), + value, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Expr { + FieldRef(i32), + Literal(Literal), + FunctionCall(FunctionCall), + IfThen(IfThenExpr), + Identifier(String), + /// A dot-separated name like `schema.table` or a single quoted name like + /// `"sales.q1"`. Each element is one segment. + DottedName(Vec), +} + +impl Expr { + pub fn field_ref(index: i32) -> Self { + Self::FieldRef(index) + } + + pub fn literal(value: Literal) -> Self { + Self::Literal(value) + } + + pub fn function_call(value: FunctionCall) -> Self { + Self::FunctionCall(value) + } + + pub fn if_then(value: IfThenExpr) -> Self { + Self::IfThen(value) + } + + pub fn identifier(name: impl Into) -> Self { + Self::Identifier(name.into()) + } + + pub fn dotted_name(names: Vec) -> Self { + if names.len() == 1 { + Self::Identifier(names.into_iter().next().unwrap()) + } else { + Self::DottedName(names) + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FunctionCall { + pub name: String, + pub anchor: Option, + pub urn_anchor: Option, + pub args: Vec, + pub output_type: Option, +} + +impl FunctionCall { + pub fn new( + name: impl Into, + anchor: Option, + urn_anchor: Option, + args: Vec, + output_type: Option, + ) -> Self { + Self { + name: name.into(), + anchor, + urn_anchor, + args, + output_type, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct IfThenExpr { + pub clauses: Vec<(Expr, Expr)>, + pub else_expr: Box, +} + +impl IfThenExpr { + pub fn new(clauses: Vec<(Expr, Expr)>, else_expr: Expr) -> Self { + Self { + clauses, + else_expr: Box::new(else_expr), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Literal { + pub value: LiteralValue, + pub typ: Option, +} + +impl Literal { + pub fn new(value: LiteralValue, typ: Option) -> Self { + Self { value, typ } + } + + pub fn integer(value: i64, typ: Option) -> Self { + Self::new(LiteralValue::Integer(value), typ) + } + + pub fn float(value: f64, typ: Option) -> Self { + Self::new(LiteralValue::Float(value), typ) + } + + pub fn boolean(value: bool, typ: Option) -> Self { + Self::new(LiteralValue::Boolean(value), typ) + } + + pub fn string(value: impl Into, typ: Option) -> Self { + Self::new(LiteralValue::String(value.into()), typ) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum LiteralValue { + Integer(i64), + Float(f64), + Boolean(bool), + String(String), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TypeExpr { + Simple { + name: String, + nullability: Nullability, + }, + List { + nullability: Nullability, + inner: Box, + }, + UserDefined { + name: String, + anchor: Option, + urn_anchor: Option, + nullability: Nullability, + parameters: Vec, + }, +} + +impl TypeExpr { + pub fn simple(name: impl Into, nullability: Nullability) -> Self { + Self::Simple { + name: name.into(), + nullability, + } + } + + pub fn list(nullability: Nullability, inner: TypeExpr) -> Self { + Self::List { + nullability, + inner: Box::new(inner), + } + } + + pub fn user_defined( + name: impl Into, + anchor: Option, + urn_anchor: Option, + nullability: Nullability, + parameters: Vec, + ) -> Self { + Self::UserDefined { + name: name.into(), + anchor, + urn_anchor, + nullability, + parameters, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum Nullability { + #[default] + Required, + Nullable, + Unspecified, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ArgEntry { + Positional(Arg), + Named(NamedArg), +} + +impl ArgEntry { + pub fn positional(value: Arg) -> Self { + Self::Positional(value) + } + + pub fn named(name: impl Into, value: Arg) -> Self { + Self::Named(NamedArg::new(name, value)) + } +} + +impl ArgList { + /// Build an argument list while preserving order constraints. + pub fn from_entries(entries: Vec) -> Self { + let mut positional = Vec::new(); + let mut named = Vec::new(); + let mut seen_named = false; + let mut invalid_order = false; + + for entry in entries { + match entry { + ArgEntry::Positional(arg) => { + if seen_named { + invalid_order = true; + } + positional.push(arg); + } + ArgEntry::Named(arg) => { + seen_named = true; + named.push(arg); + } + } + } + + Self { + positional, + named, + invalid_order, + } + } +} + +/// Unescape a quoted identifier or string literal. +/// +/// The input is expected to include surrounding quotes. +pub(crate) fn unescape_quoted(input: &str) -> String { + if input.len() < 2 { + return input.to_string(); + } + + let mut out = String::with_capacity(input.len().saturating_sub(2)); + let mut chars = input[1..input.len() - 1].chars(); + while let Some(c) = chars.next() { + if c == '\\' { + if let Some(next) = chars.next() { + match next { + 'n' => out.push('\n'), + 'r' => out.push('\r'), + 't' => out.push('\t'), + _ => out.push(next), + } + } + } else { + out.push(c); + } + } + out +} + +impl fmt::Display for Nullability { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Nullability::Required => Ok(()), + Nullability::Nullable => write!(f, "?"), + // Internal sentinel for "present but unspecified". + Nullability::Unspecified => write!(f, "⁉"), + } + } +} + +impl fmt::Display for TypeExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TypeExpr::Simple { name, nullability } => write!(f, "{name}{nullability}"), + TypeExpr::List { nullability, inner } => write!(f, "list{nullability}<{inner}>"), + TypeExpr::UserDefined { + name, + anchor, + urn_anchor, + nullability, + parameters, + } => { + write!(f, "{name}")?; + if let Some(anchor) = anchor { + write!(f, "#{anchor}")?; + } + if let Some(urn_anchor) = urn_anchor { + write!(f, "@{urn_anchor}")?; + } + write!(f, "{nullability}")?; + if !parameters.is_empty() { + write!(f, "<")?; + for (idx, parameter) in parameters.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{parameter}")?; + } + write!(f, ">")?; + } + Ok(()) + } + } + } +} + +impl fmt::Display for Literal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.value { + LiteralValue::Integer(v) => write!(f, "{v}")?, + LiteralValue::Float(v) => write!(f, "{v}")?, + LiteralValue::Boolean(v) => write!(f, "{v}")?, + LiteralValue::String(v) => write!(f, "'{}'", v.replace('\'', "\\'"))?, + } + if let Some(typ) = &self.typ { + write!(f, ":{typ}")?; + } + Ok(()) + } +} + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Expr::FieldRef(idx) => write!(f, "${idx}"), + Expr::Literal(lit) => write!(f, "{lit}"), + Expr::Identifier(name) => write!(f, "{name}"), + Expr::DottedName(names) => { + for (i, name) in names.iter().enumerate() { + if i > 0 { + write!(f, ".")?; + } + write!(f, "{name}")?; + } + Ok(()) + } + Expr::FunctionCall(call) => { + write!(f, "{}", call.name)?; + if let Some(anchor) = call.anchor { + write!(f, "#{anchor}")?; + } + if let Some(urn_anchor) = call.urn_anchor { + write!(f, "@{urn_anchor}")?; + } + write!(f, "(")?; + for (idx, arg) in call.args.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{arg}")?; + } + write!(f, ")")?; + if let Some(typ) = &call.output_type { + write!(f, ":{typ}")?; + } + Ok(()) + } + Expr::IfThen(if_then) => { + write!(f, "if_then(")?; + for (idx, (cond, result)) in if_then.clauses.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{cond} -> {result}")?; + } + write!(f, ", _ -> {})", if_then.else_expr) + } + } + } +} + +impl fmt::Display for Arg { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Arg::Expr(expr) => write!(f, "{expr}"), + Arg::Enum(value) => write!(f, "&{value}"), + Arg::NamedColumn(name, typ) => write!(f, "{name}:{typ}"), + Arg::Tuple(args) => { + write!(f, "(")?; + for (idx, arg) in args.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{arg}")?; + } + write!(f, ")") + } + Arg::Wildcard => write!(f, "_"), + } + } +} diff --git a/src/parser/common.rs b/src/parser/common.rs deleted file mode 100644 index 4db45a7..0000000 --- a/src/parser/common.rs +++ /dev/null @@ -1,346 +0,0 @@ -use std::fmt; - -use pest::Parser as PestParser; -use pest_derive::Parser as PestDeriveParser; -use thiserror::Error; - -use crate::extensions::SimpleExtensions; -use crate::extensions::simple::MissingReference; - -#[derive(PestDeriveParser)] -#[grammar = "parser/expression_grammar.pest"] // Path relative to src -pub struct ExpressionParser; - -/// An error that occurs when parsing a message within a specific line. Contains -/// context pointing at that specific error. -#[derive(Error, Debug, Clone)] -#[error("{kind} Error parsing {message}:\n{error}")] -pub struct MessageParseError { - pub message: &'static str, - pub kind: ErrorKind, - #[source] - pub error: Box>, -} - -#[derive(Debug, Clone)] -pub enum ErrorKind { - Syntax, - InvalidValue, - Lookup(MissingReference), -} - -impl MessageParseError { - pub fn syntax(message: &'static str, span: pest::Span, description: impl ToString) -> Self { - let error = pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: description.to_string(), - }, - span, - ); - Self::new(message, ErrorKind::Syntax, Box::new(error)) - } - pub fn invalid(message: &'static str, span: pest::Span, description: impl ToString) -> Self { - let error = pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: description.to_string(), - }, - span, - ); - Self::new(message, ErrorKind::InvalidValue, Box::new(error)) - } - - pub fn lookup( - message: &'static str, - missing: MissingReference, - span: pest::Span, - description: impl ToString, - ) -> Self { - let error = pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: description.to_string(), - }, - span, - ); - Self::new(message, ErrorKind::Lookup(missing), Box::new(error)) - } -} - -impl MessageParseError { - pub fn new( - message: &'static str, - kind: ErrorKind, - error: Box>, - ) -> Self { - Self { - message, - kind, - error, - } - } -} - -impl fmt::Display for ErrorKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ErrorKind::Syntax => write!(f, "Syntax"), - ErrorKind::InvalidValue => write!(f, "Invalid value"), - ErrorKind::Lookup(e) => write!(f, "Invalid reference ({e})"), - } - } -} - -pub fn unwrap_single_pair(pair: pest::iterators::Pair) -> pest::iterators::Pair { - let mut pairs = pair.into_inner(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - pair -} - -/// Unescapes a quoted string literal, handling escape sequences. -/// -/// # Arguments -/// * `pair` - The pest pair containing the string to unescape (must be Rule::string_literal or Rule::quoted_name). -/// -/// # Returns -/// * `String` with the unescaped contents. -/// -/// # Panics -/// Panics if the rule is not `string_literal` or `quoted_name` (this should never happen -/// if the pest grammar is working correctly). -/// -pub fn unescape_string(pair: pest::iterators::Pair) -> String { - let s = pair.as_str(); - - // Determine opener/closer based on rule type - let (opener, closer) = match pair.as_rule() { - Rule::string_literal => ('\'', '\''), - Rule::quoted_name => ('"', '"'), - _ => panic!( - "unescape_string called with unexpected rule: {:?}", - pair.as_rule() - ), - }; - - let mut result = String::new(); - let mut chars = s.chars(); - let first = chars.next().expect("Empty string literal"); - - assert_eq!( - first, opener, - "Expected opening quote '{opener}', got '{first}'" - ); - - // Skip the opening quote - while let Some(c) = chars.next() { - match c { - c if c == closer => { - // Skip the closing quote, and assert that there are no more characters. - assert_eq!( - chars.next(), - None, - "Unexpected characters after closing quote" - ); - break; - } - '\\' => { - let next = chars - .next() - .expect("Incomplete escape sequence at end of string"); - match next { - 'n' => result.push('\n'), - 't' => result.push('\t'), - 'r' => result.push('\r'), - // For all other characters (especially `"`, `'`, and `\`), we just - // push the character. - _ => result.push(next), - } - } - _ => result.push(c), - } - } - result -} - -// A trait for converting a pest::iterators::Pair into a Rust type. This -// is used to convert from the uniformly structured nesting -// pest::iterators::Pair into more structured types. -pub trait ParsePair: Sized { - // The rule that this type is parsed from. - fn rule() -> Rule; - - // The name of the protobuf message type that this type corresponds to. - fn message() -> &'static str; - - // Parse a single instance of this type from a pest::iterators::Pair. - // The input must match the rule returned by `rule`; otherwise, a panic is - // expected. - fn parse_pair(pair: pest::iterators::Pair) -> Self; - - fn parse_str(s: &str) -> Result { - let mut pairs = >::parse(Self::rule(), s) - .map_err(|e| MessageParseError::new(Self::message(), ErrorKind::Syntax, Box::new(e)))?; - assert_eq!(pairs.as_str(), s); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - Ok(Self::parse_pair(pair)) - } -} - -/// A trait for types that can be directly parsed from a string input, -/// regardless of context. -pub trait Parse { - fn parse(input: &str) -> Result - where - Self: Sized; -} - -impl Parse for T { - fn parse(input: &str) -> Result { - T::parse_str(input) - } -} - -/// A trait for types that are parsed from a `pest::iterators::Pair` that -/// depends on the context - e.g. extension lookups or other contextual -/// information. This is used for types that are not directly parsed from the -/// grammar, but rather require additional context to parse correctly. -pub trait ScopedParsePair: Sized { - // The rule that this type is parsed from. - fn rule() -> Rule; - - // The name of the protobuf message type that this type corresponds to. - fn message() -> &'static str; - - // Parse a single instance of this type from a `pest::iterators::Pair`. - // The input must match the rule returned by `rule`; otherwise, a panic is - // expected. - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result; -} - -pub trait ScopedParse: Sized { - fn parse(extensions: &SimpleExtensions, input: &str) -> Result - where - Self: Sized; -} - -impl ScopedParse for T { - fn parse(extensions: &SimpleExtensions, input: &str) -> Result { - let mut pairs = ExpressionParser::parse(Self::rule(), input) - .map_err(|e| MessageParseError::new(Self::message(), ErrorKind::Syntax, Box::new(e)))?; - assert_eq!(pairs.as_str(), input); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - Self::parse_pair(extensions, pair) - } -} - -pub fn iter_pairs(pair: pest::iterators::Pairs<'_, Rule>) -> RuleIter<'_> { - RuleIter { - iter: pair, - done: false, - } -} - -pub struct RuleIter<'a> { - iter: pest::iterators::Pairs<'a, Rule>, - // Set to true when done is called, so destructor doesn't panic - done: bool, -} - -impl<'a> From> for RuleIter<'a> { - fn from(iter: pest::iterators::Pairs<'a, Rule>) -> Self { - RuleIter { iter, done: false } - } -} - -impl<'a> RuleIter<'a> { - pub fn peek(&self) -> Option> { - self.iter.peek() - } - - // Pop the next pair if it matches the rule. Returns None if not. - pub fn try_pop(&mut self, rule: Rule) -> Option> { - match self.peek() { - Some(pair) if pair.as_rule() == rule => { - self.iter.next(); - Some(pair) - } - _ => None, - } - } - - // Pop the next pair, asserting it matches the given rule. Panics if not. - pub fn pop(&mut self, rule: Rule) -> pest::iterators::Pair<'a, Rule> { - let pair = self.iter.next().expect("expected another pair"); - assert_eq!( - pair.as_rule(), - rule, - "expected rule {:?}, got {:?}", - rule, - pair.as_rule() - ); - pair - } - - // Parse the next pair if it matches the rule. Returns None if not. - pub fn parse_if_next(&mut self) -> Option { - match self.peek() { - Some(pair) if pair.as_rule() == T::rule() => { - self.iter.next(); - Some(T::parse_pair(pair)) - } - _ => None, - } - } - - // Parse the next pair if it matches the rule. Returns None if not. - pub fn parse_if_next_scoped( - &mut self, - extensions: &SimpleExtensions, - ) -> Option> { - match self.peek() { - Some(pair) if pair.as_rule() == T::rule() => { - self.iter.next(); - Some(T::parse_pair(extensions, pair)) - } - _ => None, - } - } - - // Parse the next pair, assuming it matches the rule. Panics if not. - pub fn parse_next(&mut self) -> T { - let pair = self.iter.next().unwrap(); - T::parse_pair(pair) - } - - // Parse the next pair, assuming it matches the rule. Panics if not. - pub fn parse_next_scoped( - &mut self, - extensions: &SimpleExtensions, - ) -> Result { - let pair = self.iter.next().unwrap(); - T::parse_pair(extensions, pair) - } - - pub fn done(mut self) { - self.done = true; - assert_eq!(self.iter.next(), None); - } -} - -/// Make sure that the iterator was completely consumed when the iterator is -/// dropped - that we didn't leave any partially-parsed tokens. -/// -/// This is not strictly necessary, but it's a good way to catch bugs. -impl Drop for RuleIter<'_> { - fn drop(&mut self) { - if self.done || std::thread::panicking() { - return; - } - // If the iterator is not done, something probably went wrong. - assert_eq!(self.iter.next(), None); - } -} diff --git a/src/parser/errors.rs b/src/parser/errors.rs index a2fb1fe..61499e5 100644 --- a/src/parser/errors.rs +++ b/src/parser/errors.rs @@ -1,8 +1,21 @@ +//! Error types for the parser. +//! +//! Three layers of errors, matching the parsing phases: +//! +//! - [`SyntaxErrorDetail`] — LALRPOP parse failures with span and expected-token +//! info, for rendering caret diagnostics on a single line. +//! - [`MessageParseError`] — semantic errors during lowering (invalid values, +//! missing references), categorised by [`ErrorKind`]. +//! - [`ParseError`] — top-level error that attaches [`ParseContext`] (line +//! number + source text) to any of the above. + +use std::fmt; + use thiserror::Error; -use super::MessageParseError; use super::extensions::ExtensionParseError; use crate::extensions::registry::ExtensionError; +use crate::extensions::simple::MissingReference; /// Context for parse errors and warnings #[derive(Debug, Clone)] @@ -23,9 +36,174 @@ impl std::fmt::Display for ParseContext { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TextSpan { + pub start: usize, + pub end: usize, +} + +impl TextSpan { + pub fn new(start: usize, end: usize) -> Self { + Self { start, end } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SyntaxErrorKind { + InvalidToken, + UnexpectedEof, + UnexpectedToken { found: String }, + ExtraToken { found: String }, + User { message: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SyntaxErrorDetail { + pub kind: SyntaxErrorKind, + pub span: Option, + pub expected: Vec, +} + +impl SyntaxErrorDetail { + pub fn new(kind: SyntaxErrorKind, span: Option, expected: Vec) -> Self { + Self { + kind, + span, + expected, + } + } + + pub fn describe_with_line(&self, line: &str) -> String { + let mut message = self.to_string(); + if let Some(column) = self.column(line) { + message.push_str(&format!(" at column {column}")); + if let Some(caret) = self.caret_line(line) { + message.push('\n'); + message.push_str(line); + message.push('\n'); + message.push_str(&caret); + } + } + message + } + + fn column(&self, line: &str) -> Option { + let span = self.span?; + let idx = clamp_to_char_boundary(line, span.start); + Some(line[..idx].chars().count() + 1) + } + + fn caret_line(&self, line: &str) -> Option { + let span = self.span?; + let idx = clamp_to_char_boundary(line, span.start); + let col0 = line[..idx].chars().count(); + Some(format!("{}^", " ".repeat(col0))) + } +} + +impl fmt::Display for SyntaxErrorDetail { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + SyntaxErrorKind::InvalidToken => write!(f, "invalid token")?, + SyntaxErrorKind::UnexpectedEof => write!(f, "unexpected end of input")?, + SyntaxErrorKind::UnexpectedToken { found } => write!(f, "unexpected token '{found}'")?, + SyntaxErrorKind::ExtraToken { found } => write!(f, "extra token '{found}'")?, + SyntaxErrorKind::User { message } => write!(f, "{message}")?, + } + + if !self.expected.is_empty() { + write!(f, "; expected {}", Self::format_expected(&self.expected))?; + } + + Ok(()) + } +} + +impl SyntaxErrorDetail { + fn format_expected(expected: &[String]) -> String { + const MAX_ITEMS: usize = 6; + if expected.len() <= MAX_ITEMS { + expected.join(", ") + } else { + let mut display = expected[..MAX_ITEMS].join(", "); + display.push_str(", ..."); + display + } + } +} + +fn clamp_to_char_boundary(line: &str, mut idx: usize) -> usize { + idx = idx.min(line.len()); + while idx > 0 && !line.is_char_boundary(idx) { + idx -= 1; + } + idx +} + +#[derive(Debug, Clone)] +pub enum ErrorKind { + Syntax, + InvalidValue, + Lookup(MissingReference), +} + +impl std::fmt::Display for ErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorKind::Syntax => write!(f, "Syntax"), + ErrorKind::InvalidValue => write!(f, "Invalid value"), + ErrorKind::Lookup(e) => write!(f, "Invalid reference ({e})"), + } + } +} + +#[derive(Error, Debug, Clone)] +#[error("{kind} Error parsing {message}: {description}")] +pub struct MessageParseError { + pub message: &'static str, + pub kind: ErrorKind, + pub description: String, +} + +impl MessageParseError { + pub fn syntax(message: &'static str, description: impl ToString) -> Self { + Self::new(message, ErrorKind::Syntax, description) + } + + pub fn invalid(message: &'static str, description: impl ToString) -> Self { + Self::new(message, ErrorKind::InvalidValue, description) + } + + pub fn lookup( + message: &'static str, + missing: MissingReference, + description: impl ToString, + ) -> Self { + Self::new(message, ErrorKind::Lookup(missing), description) + } + + pub fn new(message: &'static str, kind: ErrorKind, description: impl ToString) -> Self { + Self { + message, + kind, + description: description.to_string(), + } + } +} + /// Parse errors that prevent successful parsing #[derive(Debug, Clone, Error)] pub enum ParseError { + #[error( + "Syntax error parsing {target} on {context}: {message}", + message = .detail.describe_with_line(&.context.line) + )] + Syntax { + target: &'static str, + context: ParseContext, + detail: Box, + }, + #[error("Error parsing extension on {0}: {1}")] Extension(ParseContext, #[source] ExtensionParseError), @@ -35,32 +213,14 @@ pub enum ParseError { #[error("Error parsing plan on {0}: {1}")] Plan(ParseContext, #[source] MessageParseError), - #[error("Expected '{expected}' on {context}")] - Expected { - expected: &'static str, - context: ParseContext, - }, - - #[error("No relation found in plan")] - NoRelationFound, - #[error("Unregistered extension '{name}' on {context}")] UnregisteredExtension { name: String, context: ParseContext }, #[error("Failed to parse relation on {0}: {1}")] RelationParse(ParseContext, String), - #[error("Unknown extension relation type: {0}")] - UnknownExtensionRelationType(String), - - #[error("Validation error: {0}")] - ValidationError(String), - - #[error("Error parsing section header on line {0}: {1}")] + #[error("Error parsing section header on {0}: {1}")] Initial(ParseContext, #[source] MessageParseError), - - #[error("Error parsing relation: {0}")] - Relation(ParseContext, #[source] MessageParseError), } /// Result type for the public Parser API diff --git a/src/parser/expression_grammar.pest b/src/parser/expression_grammar.pest index 27e8c2a..e69de29 100644 --- a/src/parser/expression_grammar.pest +++ b/src/parser/expression_grammar.pest @@ -1,288 +0,0 @@ -// Explicit whitespace, used below. -// Pest allows implicit whitespace with WHITESPACE; we don't do that here. -whitespace = _{ " " } -// Shorthand for optional whitespace between tokens -sp = _{ whitespace* } - -// Basic Terminals -identifier = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } -integer = @{ "-"? ~ ASCII_DIGIT+ } -float = @{ "-"? ~ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ } -boolean = @{ "true" | "false" } - -// Basic string literals, does not handle escape sequences for simplicity here. -// String literals are for literal values -string_literal = @{ - "\'" ~ // opening quote - (("\\" ~ ANY) // Escape sequence - | (!"\'" ~ ANY) // Any character except the closing quote - )* // middle - ~ "\'" // closing quote -} -// Quoted names are for function names or other identifiers that need to be quoted -quoted_name = @{ - "\"" ~ // opening quote - (("\\" ~ ANY) // Escape sequence - | (!"\"" ~ ANY) // Any character except the closing quote - )* // middle - ~ "\"" // closing quote -} - -// A name of a function or table can either be an unquoted identifier or a quoted string -name = { identifier | quoted_name } - -// Expression Components -// Field reference -reference = { "$" ~ integer } -// Literal -literal = { (float | integer | boolean | string_literal) ~ (":" ~ sp ~ type)? } - -// -- Components for types and functions -anchor = { "#" ~ sp ~ integer } -urn_anchor = { "@" ~ sp ~ integer } - -parameter = { type | integer | name } - -// An arbitrary parameter list for use with type expressions or function calls. -// Example: -// -// Note that if a type can have no parameters, there should be no parameter list. -// A type that can take parameters but has an empty parameter list should be -// written with the empty parameter list. -// -// As examples: 'i64' and 'struct<>' are valid, but 'i64<>' and 'struct' are not. -parameters = { "<" ~ (parameter ~ ("," ~ sp ~ parameter)*)? ~ ">" } - -// Nullability - ? for nullable, nothing for non-nullable. And for unspecified, ⁉. -nullability = { ("?" | "⁉")? } - -// -- Type Expressions -- -// See . These must be lowercase. -simple_type_name = { - "boolean" - | "i8" - | "i16" - | "i32" - | "i64" - | "u8" - | "u16" - | "u32" - | "u64" - | "fp32" - | "fp64" - | "string" - | "binary" - | "timestamp_tz" - | "timestamp" - | "date" - | "time" - | "interval_year" - | "uuid" -} - -// A simple native type expression, composed of a name and optional nullability. -// Examples: i64, bool?, string -simple_type = { simple_type_name ~ nullability } - -// -- Compound Types -- -// -// Note that the names here are case-insensitive. In the docs, they are shown as -// upper-case, but I prefer them lowercase... - -// A list type expression, composed of a list keyword and a type. -// Example: list -list_type = { ^"list" ~ nullability ~ "<" ~ type ~ ">" } - -// A map type expression, composed of a map keyword and a key type and a value type. -// Example: map -map_type = { ^"map" ~ nullability ~ "<" ~ type ~ "," ~ sp ~ type ~ ">" } - -struct_type = { ^"struct" ~ nullability ~ parameters } - -// A compound type expression, composed of a name, optional parameters, and optional nullability. -// Example: LIST?, MAP -// TODO: precisiontime types, interval_day -compound_type = { list_type | map_type | struct_type } - -// A user-defined type expression. -// -// Example: point#8@2? -// -// - The `u!` prefix is optional. -// - The anchor is optional if the name is unique in the extensions; otherwise, -// required. -// - The URN anchor is optional. Note that this is different from type -// expressions in the YAML, where these do not exist. -// - If the type is nullable, the nullability is required. -// - The parameters are required if they exist for this type. -user_defined_type = { - "u!"? ~ name ~ anchor? ~ urn_anchor? ~ nullability ~ parameters? -} - -// A type expression is a simple type, a compound type, or a user-defined type. -type = { simple_type | compound_type | user_defined_type } - -// -- Function Calls -- - -argument_list = { "(" ~ (expression ~ (sp ~ "," ~ sp ~ expression)*)? ~ ")" } - -// Scalar function call: -// - Function name -// - Optional Anchor, e.g. #1 -// - Optional URN Anchor, e.g. @1 -// - Arguments `()` are required, e.g. (1, 2, 3) -// - Optional output type, e.g. :i64 -function_call = { - name ~ sp ~ anchor? ~ sp ~ urn_anchor? ~ sp ~ argument_list ~ (":" ~ sp ~ type)? -} - -if_clause = { - expression ~ sp ~ "->" ~ sp ~ expression -} - -// For the form: if_then(true || false -> true, _ -> false) -if_then = { - "if_then(" ~ sp ~ (if_clause ~ sp ~ "," ~ sp)+ ~ sp ~ "_" ~ sp ~ "->" ~ sp ~ expression ~ ")" -} - -// Top-level Expression Rule -// Order matters for PEGs: Since an identifer can be a function call, we put that first. -expression = { if_then | function_call | reference | literal } - -// == Extensions == -// These rules are for parsing extension declarations, by line. - -// -- URN Extension Declaration -- -// Format: @anchor: urn_value -// Example: @1: /my/urn1 -// `urn_anchor` is defined above as "@" ~ integer. -// `sp` is defined above as whitespace*. - -// A URN value can be a sequence of non-whitespace characters -urn = @{ (!whitespace ~ ANY)+ } - -extension_urn_declaration = { urn_anchor ~ ":" ~ sp ~ urn } - -// -- Simple Extension Declaration (Function, Type, TypeVariation) -- -// Format: #anchor@urn_ref: name -// Example: #10@1: my_func -// `anchor` is defined above as "#" ~ integer. -// `urn_anchor` is defined above as "@" ~ integer. -// `name` is defined above as identifier | quoted_name. -simple_extension = { anchor ~ sp ~ urn_anchor ~ sp ~ ":" ~ sp ~ name } - -// == Relations == -// These rules are for parsing relations - well, for parsing one relation on one line. -// A relation is in the form: -// NAME[parameters => columns] -// -// Parameters are optional, and are a comma-separated list of expressions. -// Columns are required, and are a comma-separated list of named columns. -// -// Example: -// -// name "[" (arguments "=>")? columns "]" -// where: -// - arguments are a comma-separated list of table names or expressions - -table_name = { (name ~ ("." ~ name)*) } - -named_column = { name ~ ":" ~ type } -named_column_list = { (named_column ~ ("," ~ sp ~ named_column)*)? } - -relation = { - read_relation - | filter_relation - | project_relation - | aggregate_relation - | sort_relation - | fetch_relation - | join_relation - | extension_relation -} - -// At the top level, we can have a RelRoot, or a regular relation; both are valid. -top_level_relation = { - root_relation - | relation -} - -root_relation = { "Root" ~ "[" ~ root_name_list ~ "]" } -root_name_list = { (name ~ (sp ~ "," ~ sp ~ name)*)? } - -read_relation = { "Read" ~ "[" ~ table_name ~ sp ~ "=>" ~ sp ~ named_column_list ~ "]" } - -filter_relation = { "Filter" ~ "[" ~ expression ~ sp ~ "=>" ~ sp ~ reference_list ~ "]" } -reference_list = { (reference ~ (sp ~ "," ~ sp ~ reference)*)? } - -project_relation = { "Project" ~ "[" ~ project_argument_list ~ "]" } -project_argument_list = { (project_argument ~ (sp ~ "," ~ sp ~ project_argument)*)? } -project_argument = { reference | expression } - -// Aggregate relation: groups by specified fields and applies aggregate functions -// Format: Aggregate[group_by_fields => output_items] -// Examples: -// Aggregate[$0 => $0, sum($1), count($1)] # common case, 1 group, grouping on field 0, output is measure expressions and expressions used in group by -// Aggregate[$0, $1 => $0, $1, sum($2), count($2)] # common case, 1 group, grouping on fields 0 and 1, output sum, count and group by expressions -// Aggregate[($0, $1), ($1) => $0, $1, sum($2)] # multiple grouping sets. group by fields 0,1, and field 1, output are fields 0 and 1 and sum on field 2 -// Aggregate[_ => sum($0), count($1)] # No grouping (global aggregates) -aggregate_relation = { "Aggregate" ~ "[" ~ aggregate_group_by ~ sp ~ "=>" ~ sp ~ aggregate_output ~ "]" } - -aggregate_group_by = { grouping_set_list | expression_list} -grouping_set_list = {grouping_set ~ sp ~ ("," ~ sp ~ grouping_set)* } -grouping_set = { ("(" ~ sp ~ expression_list ~ sp ~ ")" ) | empty } -expression_list = { expression ~ sp ~ ( "," ~ sp ~ expression)* } - -aggregate_output = { (aggregate_output_item ~ (sp ~ "," ~ sp ~ aggregate_output_item)*)? } -aggregate_output_item = { reference | aggregate_measure } -// TODO: Add support for filter, invocation, etc. -aggregate_measure = { function_call } - -// Empty grouping symbol -empty = { "_" } - -// SortRel: Sort[($0, &AscNullsFirst), ($2, &DescNullsLast) => ...] -sort_relation = { "Sort" ~ "[" ~ sort_field_list ~ sp ~ "=>" ~ sp ~ reference_list ~ "]" } -sort_field_list = { (sort_field ~ (sp ~ "," ~ sp ~ sort_field)*)? } -sort_field = { "(" ~ sp ~ reference ~ sp ~ "," ~ sp ~ sort_direction ~ sp ~ ")" } -sort_direction = { "&AscNullsFirst" | "&AscNullsLast" | "&DescNullsFirst" | "&DescNullsLast" } - -// FetchRel: Fetch[limit=..., offset=... => ...] (named arguments only, any order, or _ for empty) -fetch_arg_name = { "limit" | "offset" } -fetch_value = { integer | expression } -fetch_named_arg = { fetch_arg_name ~ sp ~ "=" ~ sp ~ fetch_value } -fetch_named_arg_list = { fetch_named_arg ~ (sp ~ "," ~ sp ~ fetch_named_arg)* } -fetch_args = _{ fetch_named_arg_list | empty } -fetch_relation = { "Fetch" ~ "[" ~ fetch_args ~ sp ~ "=>" ~ sp ~ reference_list ~ "]" } - -// JoinRel: Join[&Inner, eq($0, $2) => $0, $1, $2, $3] -// Note: Longer prefixes must come before shorter ones (e.g., "&LeftSemi" before "&Left") -// so that the parser doesn't match "&Left" when it should match "&LeftSemi" -join_type = { "&Inner" | "&LeftSemi" | "&LeftAnti" | "&LeftSingle" | "&LeftMark" | "&Left" | "&RightSemi" | "&RightAnti" | "&RightSingle" | "&RightMark" | "&Right" | "&Outer" } -join_relation = { "Join" ~ "[" ~ join_type ~ sp ~ "," ~ sp ~ expression ~ sp ~ "=>" ~ sp ~ reference_list ~ "]" } - -// Extension relations for custom operators not covered by standard Substrait -// Format: ExtensionType:ExtensionName[arguments, named_arguments => columns] -// Follows the general relation grammar pattern exactly -// Examples: -// - ExtensionLeaf:ParquetScan[path='data/*.parquet', schema='auto' => col1:i32, col2:string] -// - ExtensionSingle:VectorNormalize[method='l2', dimensions=128 => $0, $1] -// - ExtensionMulti:FuzzyJoin[algorithm='lsh', threshold=0.8 => $0, $1, $2, $3] -extension_relation = { extension_name ~ "[" ~ (empty | (extension_arguments ~ (sp ~ "," ~ sp ~ extension_named_arguments)?) | extension_named_arguments)? ~ (sp ~ "=>" ~ sp ~ extension_columns)? ~ "]" } - -// Extension name format: ExtensionType:CustomExtensionName -extension_name = { extension_type ~ ":" ~ name } -extension_type = { "ExtensionLeaf" | "ExtensionSingle" | "ExtensionMulti" } - -// Positional arguments (expressions, literals, references) -// TODO: May want to add tuples, enums, …? -extension_arguments = { extension_argument ~ (sp ~ "," ~ sp ~ extension_argument)* } -extension_argument = {reference | literal | expression } - -// Named arguments (name=value pairs) -extension_named_arguments = { extension_named_argument ~ (sp ~ "," ~ sp ~ extension_named_argument)* } -extension_named_argument = { name ~ sp ~ "=" ~ sp ~ extension_argument } - -// Output columns - can mix named columns, references, and expressions -extension_columns = { extension_column ~ (sp ~ "," ~ sp ~ extension_column)* } -extension_column = { named_column | reference | expression } diff --git a/src/parser/expressions.rs b/src/parser/expressions.rs index de26949..e69de29 100644 --- a/src/parser/expressions.rs +++ b/src/parser/expressions.rs @@ -1,979 +0,0 @@ -use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; -use substrait::proto::aggregate_rel::Measure; -use substrait::proto::expression::field_reference::ReferenceType; -use substrait::proto::expression::if_then::IfClause; -use substrait::proto::expression::literal::LiteralType; -use substrait::proto::expression::{ - FieldReference, IfThen, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment, -}; -use substrait::proto::function_argument::ArgType; -use substrait::proto::r#type::{Fp64, I64, Kind, Nullability}; -use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type}; - -use super::types::get_and_validate_anchor; -use super::{ - MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string, - unwrap_single_pair, -}; -use crate::extensions::SimpleExtensions; -use crate::extensions::simple::ExtensionKind; -use crate::parser::ErrorKind; - -/// A field index (e.g., parsed from "$0" -> 0). -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct FieldIndex(pub i32); - -impl FieldIndex { - /// Convert this field index to a FieldReference for use in expressions. - pub fn to_field_reference(self) -> FieldReference { - // XXX: Why is it so many layers to make a struct field reference? This is - // surprisingly complex - FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( - reference_segment::StructField { - field: self.0, - child: None, - }, - ))), - })), - root_type: None, - } - } -} - -impl ParsePair for FieldIndex { - fn rule() -> Rule { - Rule::reference - } - - fn message() -> &'static str { - "FieldIndex" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - let inner = unwrap_single_pair(pair); - let index: i32 = inner.as_str().parse().unwrap(); - FieldIndex(index) - } -} - -impl ParsePair for FieldReference { - fn rule() -> Rule { - Rule::reference - } - - fn message() -> &'static str { - "FieldReference" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - - // TODO: Other types of references. - FieldIndex::parse_pair(pair).to_field_reference() - } -} - -fn to_int_literal( - value: pest::iterators::Pair, - typ: Option, -) -> Result { - assert_eq!(value.as_rule(), Rule::integer); - let parsed_value: i64 = value.as_str().parse().unwrap(); - - const DEFAULT_KIND: Kind = Kind::I64(I64 { - type_variation_reference: 0, - nullability: Nullability::Required as i32, - }); - - // If no type is provided, we assume i64, Nullability::Required. - let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND); - - let (lit, nullability, tvar) = match &kind { - // If no type is provided, we assume i64, Nullability::Required. - Kind::I8(i) => ( - LiteralType::I8(parsed_value as i32), - i.nullability, - i.type_variation_reference, - ), - Kind::I16(i) => ( - LiteralType::I16(parsed_value as i32), - i.nullability, - i.type_variation_reference, - ), - Kind::I32(i) => ( - LiteralType::I32(parsed_value as i32), - i.nullability, - i.type_variation_reference, - ), - Kind::I64(i) => ( - LiteralType::I64(parsed_value), - i.nullability, - i.type_variation_reference, - ), - k => { - let pest_error = pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: format!("Invalid type for integer literal: {k:?}"), - }, - value.as_span(), - ); - let error = MessageParseError { - message: "int_literal_type", - kind: ErrorKind::InvalidValue, - error: Box::new(pest_error), - }; - return Err(error); - } - }; - - Ok(Literal { - literal_type: Some(lit), - nullable: nullability != Nullability::Required as i32, - type_variation_reference: tvar, - }) -} - -fn to_float_literal( - value: pest::iterators::Pair, - typ: Option, -) -> Result { - assert_eq!(value.as_rule(), Rule::float); - let parsed_value: f64 = value.as_str().parse().unwrap(); - - const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 { - type_variation_reference: 0, - nullability: Nullability::Required as i32, - }); - - // If no type is provided, we assume fp64, Nullability::Required. - let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND); - - let (lit, nullability, tvar) = match &kind { - Kind::Fp32(f) => ( - LiteralType::Fp32(parsed_value as f32), - f.nullability, - f.type_variation_reference, - ), - Kind::Fp64(f) => ( - LiteralType::Fp64(parsed_value), - f.nullability, - f.type_variation_reference, - ), - k => { - let pest_error = pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: format!("Invalid type for float literal: {k:?}"), - }, - value.as_span(), - ); - let error = MessageParseError { - message: "float_literal_type", - kind: ErrorKind::InvalidValue, - error: Box::new(pest_error), - }; - return Err(error); - } - }; - - Ok(Literal { - literal_type: Some(lit), - nullable: nullability != Nullability::Required as i32, - type_variation_reference: tvar, - }) -} - -fn to_boolean_literal(value: pest::iterators::Pair) -> Result { - assert_eq!(value.as_rule(), Rule::boolean); - let parsed_value: bool = value.as_str().parse().unwrap(); - - Ok(Literal { - literal_type: Some(LiteralType::Boolean(parsed_value)), - nullable: false, - type_variation_reference: 0, - }) -} - -fn to_string_literal( - value: pest::iterators::Pair, - typ: Option, -) -> Result { - assert_eq!(value.as_rule(), Rule::string_literal); - let string_value = unescape_string(value.clone()); - - // If no type is provided, default to string - let Some(typ) = typ else { - return Ok(Literal { - literal_type: Some(LiteralType::String(string_value)), - nullable: false, - type_variation_reference: 0, - }); - }; - - let Some(kind) = typ.kind else { - return Ok(Literal { - literal_type: Some(LiteralType::String(string_value)), - nullable: false, - type_variation_reference: 0, - }); - }; - - match &kind { - Kind::Date(d) => { - // Parse date in ISO 8601 format: YYYY-MM-DD - let date_days = parse_date_to_days(&string_value, value.as_span())?; - Ok(Literal { - literal_type: Some(LiteralType::Date(date_days)), - nullable: d.nullability != Nullability::Required as i32, - type_variation_reference: d.type_variation_reference, - }) - } - Kind::Time(t) => { - // Parse time in ISO 8601 format: HH:MM:SS[.fff] - let time_microseconds = parse_time_to_microseconds(&string_value, value.as_span())?; - Ok(Literal { - literal_type: Some(LiteralType::Time(time_microseconds)), - nullable: t.nullability != Nullability::Required as i32, - type_variation_reference: t.type_variation_reference, - }) - } - #[allow(deprecated)] - Kind::Timestamp(ts) => { - // Parse timestamp in ISO 8601 format: YYYY-MM-DDTHH:MM:SS[.fff] or YYYY-MM-DD HH:MM:SS[.fff] - let timestamp_microseconds = - parse_timestamp_to_microseconds(&string_value, value.as_span())?; - Ok(Literal { - literal_type: Some(LiteralType::Timestamp(timestamp_microseconds)), - nullable: ts.nullability != Nullability::Required as i32, - type_variation_reference: ts.type_variation_reference, - }) - } - _ => { - // For other types, treat as string - Ok(Literal { - literal_type: Some(LiteralType::String(string_value)), - nullable: false, - type_variation_reference: 0, - }) - } - } -} - -/// Parse a date string using chrono to days since Unix epoch -fn parse_date_to_days(date_str: &str, span: pest::Span) -> Result { - // Try multiple date formats for flexibility - let formats = ["%Y-%m-%d", "%Y/%m/%d"]; - - for format in &formats { - if let Ok(date) = NaiveDate::parse_from_str(date_str, format) { - // Calculate days since Unix epoch (1970-01-01) - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let days = date.signed_duration_since(epoch).num_days(); - return Ok(days as i32); - } - } - - Err(MessageParseError { - message: "date_parse_format", - kind: ErrorKind::InvalidValue, - error: Box::new(pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: format!( - "Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD" - ), - }, - span, - )), - }) -} - -/// Parse a time string using chrono to microseconds since midnight -fn parse_time_to_microseconds(time_str: &str, span: pest::Span) -> Result { - // Try multiple time formats for flexibility - let formats = ["%H:%M:%S%.f", "%H:%M:%S"]; - - for format in &formats { - if let Ok(time) = NaiveTime::parse_from_str(time_str, format) { - // Convert to microseconds since midnight - let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap(); - let duration = time.signed_duration_since(midnight); - return Ok(duration.num_microseconds().unwrap_or(0)); - } - } - - Err(MessageParseError { - message: "time_parse_format", - kind: ErrorKind::InvalidValue, - error: Box::new(pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: format!( - "Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff" - ), - }, - span, - )), - }) -} - -/// Parse a timestamp string using chrono to microseconds since Unix epoch -fn parse_timestamp_to_microseconds( - timestamp_str: &str, - span: pest::Span, -) -> Result { - // Try multiple timestamp formats for flexibility - let formats = [ - "%Y-%m-%dT%H:%M:%S%.f", // ISO 8601 with T and fractional seconds - "%Y-%m-%dT%H:%M:%S", // ISO 8601 with T - "%Y-%m-%d %H:%M:%S%.f", // Space separator with fractional seconds - "%Y-%m-%d %H:%M:%S", // Space separator - "%Y/%m/%dT%H:%M:%S%.f", // Alternative date format with T - "%Y/%m/%dT%H:%M:%S", // Alternative date format with T - "%Y/%m/%d %H:%M:%S%.f", // Alternative date format with space - "%Y/%m/%d %H:%M:%S", // Alternative date format with space - ]; - - for format in &formats { - if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) { - // Calculate microseconds since Unix epoch (1970-01-01 00:00:00) - let epoch = DateTime::from_timestamp(0, 0).unwrap().naive_utc(); - let duration = datetime.signed_duration_since(epoch); - return Ok(duration.num_microseconds().unwrap_or(0)); - } - } - - Err(MessageParseError { - message: "timestamp_parse_format", - kind: ErrorKind::InvalidValue, - error: Box::new(pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: format!( - "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS" - ), - }, - span, - )), - }) -} - -impl ScopedParsePair for Literal { - fn rule() -> Rule { - Rule::literal - } - - fn message() -> &'static str { - "Literal" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut pairs = pair.into_inner(); - let value = pairs.next().unwrap(); // First item is always the value - let typ = pairs.next(); // Second item is optional type - assert!(pairs.next().is_none()); - let typ = match typ { - Some(t) => Some(Type::parse_pair(extensions, t)?), - None => None, - }; - match value.as_rule() { - Rule::integer => to_int_literal(value, typ), - Rule::float => to_float_literal(value, typ), - Rule::boolean => to_boolean_literal(value), - Rule::string_literal => to_string_literal(value, typ), - _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()), - } - } -} - -impl ScopedParsePair for ScalarFunction { - fn rule() -> Rule { - Rule::function_call - } - - fn message() -> &'static str { - "ScalarFunction" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let span = pair.as_span(); - let mut iter = RuleIter::from(pair.into_inner()); - - // Parse function name (required) - let name = iter.parse_next::(); - - // Parse optional URN anchor (e.g., #1) - let anchor = iter - .try_pop(Rule::anchor) - .map(|n| unwrap_single_pair(n).as_str().parse::().unwrap()); - - // Parse optional URN anchor (e.g., @1) - let _urn_anchor = iter - .try_pop(Rule::urn_anchor) - .map(|n| unwrap_single_pair(n).as_str().parse::().unwrap()); - - // Parse argument list (required) - let argument_list = iter.pop(Rule::argument_list); - let mut arguments = Vec::new(); - for e in argument_list.into_inner() { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)), - }); - } - - // Parse optional output type (e.g., :i64) - let output_type = match iter.try_pop(Rule::r#type) { - Some(t) => Some(Type::parse_pair(extensions, t)?), - None => None, - }; - - iter.done(); - let anchor = - get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?; - Ok(ScalarFunction { - function_reference: anchor, - arguments, - options: vec![], // TODO: Function Options - output_type, - #[allow(deprecated)] - args: vec![], - }) - } -} - -impl ScopedParsePair for Expression { - fn rule() -> Rule { - Rule::expression - } - - fn message() -> &'static str { - "Expression" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let inner = unwrap_single_pair(pair); - match inner.as_rule() { - Rule::literal => Ok(Expression { - rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)), - }), - Rule::function_call => Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair( - extensions, inner, - )?)), - }), - Rule::reference => Ok(Expression { - rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair( - inner, - )))), - }), - Rule::if_then => Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen::parse_pair( - extensions, inner, - )?))), - }), - _ => unreachable!( - "Grammar guarantees expression can only be literal, function_call, reference, or if_then, got: {:?}", - inner.as_rule() - ), - } - } -} - -impl ScopedParsePair for IfClause { - fn rule() -> Rule { - Rule::if_clause - } - - fn message() -> &'static str { - "IfClause" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut pairs = pair.into_inner(); // should have 2 children, 2 expressions - - let condition = pairs.next().unwrap(); - let result = pairs.next().unwrap(); - assert!(pairs.next().is_none()); - - let ex1 = Some(Expression::parse_pair(extensions, condition)?); - let ex2 = Some(Expression::parse_pair(extensions, result)?); - - Ok(IfClause { - r#if: ex1, - then: ex2, - }) - } -} - -impl ScopedParsePair for IfThen { - fn rule() -> Rule { - Rule::if_then - } - fn message() -> &'static str { - "IfThen" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - - let mut iter = RuleIter::from(pair.into_inner()); // should have 2 or more children - - let mut ifs: Vec = Vec::new(); - - // gets all of the if clauses - while let Some(p) = iter.try_pop(Rule::if_clause) { - let if_clause = IfClause::parse_pair(extensions, p)?; - ifs.push(if_clause); - } - - let pair = iter.try_pop(Rule::expression).unwrap(); // should be else expression - iter.done(); - let else_clause = Some(Box::new(Expression::parse_pair(extensions, pair)?)); - - Ok(IfThen { - ifs, - r#else: else_clause, - }) - } -} -pub struct Name(pub String); - -impl ParsePair for Name { - fn rule() -> Rule { - Rule::name - } - - fn message() -> &'static str { - "Name" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - let inner = unwrap_single_pair(pair); - match inner.as_rule() { - Rule::identifier => Name(inner.as_str().to_string()), - Rule::quoted_name => Name(unescape_string(inner)), - _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()), - } - } -} - -impl ScopedParsePair for Measure { - fn rule() -> Rule { - Rule::aggregate_measure - } - - fn message() -> &'static str { - "Measure" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: pest::iterators::Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - - // Extract the inner function_call from aggregate_measure - let function_call_pair = unwrap_single_pair(pair); - assert_eq!(function_call_pair.as_rule(), Rule::function_call); - - // Parse as ScalarFunction, then convert to AggregateFunction - let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?; - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: scalar.function_reference, - arguments: scalar.arguments, - options: scalar.options, - output_type: scalar.output_type, - invocation: 0, // TODO: support invocation (ALL, DISTINCT, etc.) - phase: 0, // TODO: support phase (INITIAL_TO_RESULT, PARTIAL_TO_INTERMEDIATE, etc.) - sorts: vec![], // TODO: support sorts for ordered aggregates - #[allow(deprecated)] - args: scalar.args, - }), - filter: None, // TODO: support filter conditions on aggregate measures - }) - } -} - -#[cfg(test)] -mod tests { - use pest::Parser as PestParser; - - use super::*; - use crate::parser::ExpressionParser; - - fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> { - let mut pairs = ExpressionParser::parse(rule, input).unwrap(); - assert_eq!(pairs.as_str(), input); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - pair - } - - fn assert_parses_to(input: &str, expected: T) { - let pair = parse_exact(T::rule(), input); - let actual = T::parse_pair(pair); - assert_eq!(actual, expected); - } - - fn assert_parses_with( - ext: &SimpleExtensions, - input: &str, - expected: T, - ) { - let pair = parse_exact(T::rule(), input); - let actual = T::parse_pair(ext, pair).unwrap(); - assert_eq!(actual, expected); - } - - #[test] - fn test_parse_field_reference() { - assert_parses_to("$1", FieldIndex(1).to_field_reference()); - } - - #[test] - fn test_parse_integer_literal() { - let extensions = SimpleExtensions::default(); - let expected = Literal { - literal_type: Some(LiteralType::I64(1)), - nullable: false, - type_variation_reference: 0, - }; - assert_parses_with(&extensions, "1", expected); - } - - #[test] - fn test_parse_float_literal() { - // First test that the grammar can parse floats - let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap(); - let parsed_text = pairs.as_str(); - assert_eq!(parsed_text, "3.82"); - - let extensions = SimpleExtensions::default(); - let expected = Literal { - literal_type: Some(LiteralType::Fp64(3.82)), - nullable: false, - type_variation_reference: 0, - }; - assert_parses_with(&extensions, "3.82", expected); - } - - #[test] - fn test_parse_negative_float_literal() { - let extensions = SimpleExtensions::default(); - let expected = Literal { - literal_type: Some(LiteralType::Fp64(-2.5)), - nullable: false, - type_variation_reference: 0, - }; - assert_parses_with(&extensions, "-2.5", expected); - } - - #[test] - fn test_parse_boolean_true_literal() { - let extensions = SimpleExtensions::default(); - let expected = Literal { - literal_type: Some(LiteralType::Boolean(true)), - nullable: false, - type_variation_reference: 0, - }; - assert_parses_with(&extensions, "true", expected); - } - - #[test] - fn test_parse_boolean_false_literal() { - let extensions = SimpleExtensions::default(); - let expected = Literal { - literal_type: Some(LiteralType::Boolean(false)), - nullable: false, - type_variation_reference: 0, - }; - assert_parses_with(&extensions, "false", expected); - } - - #[test] - fn test_parse_float_literal_with_fp32_type() { - let extensions = SimpleExtensions::default(); - let pair = parse_exact(Rule::literal, "3.82:fp32"); - let result = Literal::parse_pair(&extensions, pair).unwrap(); - - match result.literal_type { - Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON), - _ => panic!("Expected Fp32 literal type"), - } - } - - #[test] - fn test_parse_date_literal() { - let extensions = SimpleExtensions::default(); - let pair = parse_exact(Rule::literal, "'2023-12-25':date"); - let result = Literal::parse_pair(&extensions, pair).unwrap(); - - match result.literal_type { - Some(LiteralType::Date(days)) => { - // 2023-12-25 should be a positive number of days since 1970-01-01 - assert!( - days > 0, - "Expected positive days since epoch, got: {}", - days - ); - } - _ => panic!("Expected Date literal type, got: {:?}", result.literal_type), - } - } - - #[test] - fn test_parse_time_literal() { - let extensions = SimpleExtensions::default(); - let pair = parse_exact(Rule::literal, "'14:30:45':time"); - let result = Literal::parse_pair(&extensions, pair).unwrap(); - - match result.literal_type { - Some(LiteralType::Time(microseconds)) => { - // 14:30:45 = (14*3600 + 30*60 + 45) * 1_000_000 microseconds - let expected = (14 * 3600 + 30 * 60 + 45) * 1_000_000; - assert_eq!(microseconds, expected); - } - _ => panic!("Expected Time literal type, got: {:?}", result.literal_type), - } - } - - #[test] - fn test_parse_timestamp_literal_with_t() { - let extensions = SimpleExtensions::default(); - let pair = parse_exact(Rule::literal, "'2023-01-01T12:00:00':timestamp"); - let result = Literal::parse_pair(&extensions, pair).unwrap(); - - match result.literal_type { - #[allow(deprecated)] - Some(LiteralType::Timestamp(microseconds)) => { - assert!( - microseconds > 0, - "Expected positive microseconds since epoch" - ); - } - _ => panic!( - "Expected Timestamp literal type, got: {:?}", - result.literal_type - ), - } - } - - #[test] - fn test_parse_timestamp_literal_with_space() { - let extensions = SimpleExtensions::default(); - let pair = parse_exact(Rule::literal, "'2023-01-01 12:00:00':timestamp"); - let result = Literal::parse_pair(&extensions, pair).unwrap(); - - match result.literal_type { - #[allow(deprecated)] - Some(LiteralType::Timestamp(microseconds)) => { - assert!( - microseconds > 0, - "Expected positive microseconds since epoch" - ); - } - _ => panic!( - "Expected Timestamp literal type, got: {:?}", - result.literal_type - ), - } - } - - /// Helper function to create a literal boolean expression - fn make_literal_bool(value: bool) -> Expression { - Expression { - rex_type: Some(RexType::Literal(Literal { - literal_type: Some(LiteralType::Boolean(value)), - nullable: false, - type_variation_reference: 0, - })), - } - } - - #[test] - fn test_parse_if_then_single_clause() { - let extensions = SimpleExtensions::default(); - let input = "if_then(true -> 42, _ -> 0)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 1); - assert!(result.r#else.is_some()); - } - - #[test] - fn test_parse_if_then_with_typed_literals() { - let extensions = SimpleExtensions::default(); - let input = "if_then(true -> 100:i32, _ -> -100:i32)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 1); - assert!(result.r#else.is_some()); - } - - #[test] - fn test_parse_if_then_with_date_literals() { - let extensions = SimpleExtensions::default(); - let input = "if_then(true -> '2023-12-25':date, _ -> '1970-01-01':date)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 1); - assert!(result.r#else.is_some()); - } - - #[test] - fn test_parse_if_then_with_time_literals() { - let extensions = SimpleExtensions::default(); - let input = "if_then(true -> '14:30:45':time, _ -> '00:00:00':time)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 1); - assert!(result.r#else.is_some()); - } - - #[test] - fn test_parse_if_then_with_timestamp_literals() { - let extensions = SimpleExtensions::default(); - let input = "if_then(true -> '2023-01-01T12:00:00':timestamp, _ -> '1970-01-01T00:00:00':timestamp)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 1); - assert!(result.r#else.is_some()); - } - - #[test] - fn test_parse_if_clause_with_whitespace_variations() { - let extensions = SimpleExtensions::default(); - - // Test with various whitespace patterns - let inputs = vec!["true->false", "true -> false", "true -> false"]; - - for input in inputs { - let pair = parse_exact(Rule::if_clause, input); - let result = IfClause::parse_pair(&extensions, pair).unwrap(); - assert!(result.r#if.is_some()); - assert!(result.then.is_some()); - } - } - - #[test] - fn test_if_clause_structure() { - let extensions = SimpleExtensions::default(); - let pair = parse_exact(Rule::if_clause, "42 -> 100"); - let result = IfClause::parse_pair(&extensions, pair).unwrap(); - - // Verify the if clause has both condition and result - let if_expr = result.r#if.as_ref().unwrap(); - let then_expr = result.then.as_ref().unwrap(); - - // Check that they are literal expressions - match (&if_expr.rex_type, &then_expr.rex_type) { - (Some(RexType::Literal(_)), Some(RexType::Literal(_))) => { - // Success - both are literals as expected - } - _ => panic!("Expected both if and then to be literals"), - } - } - - #[test] - fn test_if_then_structure() { - let extensions = SimpleExtensions::default(); - let input = "if_then(true -> 1, false -> 2, _ -> 0)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - // Verify structure - assert_eq!(result.ifs.len(), 2); - - // Check each if clause - for clause in &result.ifs { - assert!(clause.r#if.is_some(), "If clause condition should exist"); - assert!(clause.then.is_some(), "If clause result should exist"); - } - - // Check else clause - assert!(result.r#else.is_some(), "Else clause should exist"); - } - - #[test] - fn test_parse_if_then_mixed_types_in_conditions() { - let extensions = SimpleExtensions::default(); - // Different types in conditions (not results) - let input = "if_then(true -> 1, true -> 'yes', 'yes' -> true, 42 -> 2, $0 -> 3, _ -> 0)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 5); - assert!(result.r#else.is_some()); - } - - #[test] - fn test_if_then_preserves_clause_order() { - let extensions = SimpleExtensions::default(); - let input = "if_then(1 -> 10, 2 -> 20, 3 -> 30, _ -> 0)"; - let pair = parse_exact(Rule::if_then, input); - let result = IfThen::parse_pair(&extensions, pair).unwrap(); - - assert_eq!(result.ifs.len(), 3); - - // Verify the clauses are in order by checking the literal values - for (i, clause) in result.ifs.iter().enumerate() { - if let Some(Expression { - rex_type: Some(RexType::Literal(lit)), - }) = &clause.r#if - && let Some(LiteralType::I64(val)) = &lit.literal_type - { - assert_eq!(*val, (i as i64) + 1); - } - } - } - - #[test] - fn test_parse_if_then() { - let extensions = SimpleExtensions::default(); - - let c1 = IfClause { - r#if: Some(make_literal_bool(true)), - then: Some(make_literal_bool(true)), - }; - - let c2 = IfClause { - r#if: Some(make_literal_bool(false)), - then: Some(make_literal_bool(false)), - }; - - let if_clause = IfThen { - ifs: vec![c1, c2], - r#else: Some(Box::new(make_literal_bool(false))), - }; - assert_parses_with( - &extensions, - "if_then(true -> true , false -> false, _ -> false)", - if_clause, - ); - } -} diff --git a/src/parser/extensions.rs b/src/parser/extensions.rs index ebdb0d0..f1af6e0 100644 --- a/src/parser/extensions.rs +++ b/src/parser/extensions.rs @@ -1,14 +1,21 @@ +//! Incremental parser for the `=== Extensions` section. +//! +//! The structural parser feeds lines one at a time to [`ExtensionParser`], +//! which uses a state machine ([`ExtensionParserState`]) to track which +//! subsection it's in (URNs, Functions, Types, or Type Variations) and +//! dispatches each line to the appropriate LALRPOP entry point. + use std::fmt; use std::str::FromStr; use thiserror::Error; -use super::{ParsePair, Rule, RuleIter, unescape_string, unwrap_single_pair}; +use crate::extensions::any::Any; +use crate::extensions::registry::ExtensionError; use crate::extensions::simple::{self, ExtensionKind}; -use crate::extensions::{ - ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, InsertError, - RawExpression, SimpleExtensions, -}; +use crate::extensions::{ExtensionArgs, ExtensionRegistry, InsertError, SimpleExtensions}; +use crate::parser::ast; +use crate::parser::errors::{MessageParseError, ParseContext, ParseError, SyntaxErrorDetail}; use crate::parser::structural::IndentedLine; #[derive(Debug, Clone, Error)] @@ -18,21 +25,13 @@ pub enum ExtensionParseError { #[error("Error adding extension: {0}")] ExtensionError(#[from] InsertError), #[error("Error parsing message: {0}")] - Message(#[from] super::MessageParseError), + Message(#[from] MessageParseError), } -/// The state of the extension parser - tracking what section of extension -/// parsing we are in. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum ExtensionParserState { - // The extensions section, after parsing the 'Extensions:' header, before - // parsing any subsection headers. Extensions, - // The extension URNs section, after parsing the 'URNs:' subsection header, - // and any URNs so far. ExtensionUrns, - // In a subsection, after parsing the subsection header, and any - // declarations so far. ExtensionDeclarations(ExtensionKind), } @@ -48,12 +47,6 @@ impl fmt::Display for ExtensionParserState { } } -/// The parser for the extension section of the Substrait file format. -/// -/// This is responsible for parsing the extension section of the file, which -/// contains the extension URNs and declarations. Note that this parser does not -/// parse the header; otherwise, this is symmetric with the -/// SimpleExtensions::write method. #[derive(Debug)] pub struct ExtensionParser { state: ExtensionParserState, @@ -70,10 +63,8 @@ impl Default for ExtensionParser { } impl ExtensionParser { - pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> { + pub fn parse_line(&mut self, line: IndentedLine<'_>) -> Result<(), ExtensionParseError> { if line.1.is_empty() { - // Blank lines are allowed between subsections, so if we see - // one, we revert out of the subsection. self.state = ExtensionParserState::Extensions; return Ok(()); } @@ -87,7 +78,7 @@ impl ExtensionParser { } } - fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> { + fn parse_subsection(&mut self, line: IndentedLine<'_>) -> Result<(), ExtensionParseError> { match line { IndentedLine(0, simple::EXTENSION_URNS_HEADER) => { self.state = ExtensionParserState::ExtensionUrns; @@ -110,12 +101,11 @@ impl ExtensionParser { } } - fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> { + fn parse_extension_urns(&mut self, line: IndentedLine<'_>) -> Result<(), ExtensionParseError> { match line { - IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent + IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => { - let urn = - URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?; + let urn = parse_urn_declaration(s)?; self.extensions.add_extension_urn(urn.urn, urn.anchor)?; Ok(()) } @@ -125,13 +115,13 @@ impl ExtensionParser { fn parse_declarations( &mut self, - line: IndentedLine, + line: IndentedLine<'_>, extension_kind: ExtensionKind, ) -> Result<(), ExtensionParseError> { match line { - IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent + IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => { - let decl = SimpleExtensionDeclaration::from_str(s)?; + let decl = parse_simple_declaration(s)?; self.extensions.add_extension( extension_kind, decl.urn_anchor, @@ -153,328 +143,58 @@ impl ExtensionParser { } } -#[derive(Debug, Clone, PartialEq)] -pub struct URNExtensionDeclaration { - pub anchor: u32, - pub urn: String, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct SimpleExtensionDeclaration { - pub anchor: u32, - pub urn_anchor: u32, - pub name: String, -} - -impl ParsePair for URNExtensionDeclaration { - fn rule() -> Rule { - Rule::extension_urn_declaration - } - - fn message() -> &'static str { - "URNExtensionDeclaration" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - - let mut iter = RuleIter::from(pair.into_inner()); - let anchor_pair = iter.pop(Rule::urn_anchor); - let anchor = unwrap_single_pair(anchor_pair) - .as_str() - .parse::() - .unwrap(); - let urn = iter.pop(Rule::urn).as_str().to_string(); - iter.done(); - - URNExtensionDeclaration { anchor, urn } - } -} - -impl FromStr for URNExtensionDeclaration { - type Err = super::MessageParseError; - - fn from_str(s: &str) -> Result { - Self::parse_str(s) - } -} - -impl ParsePair for SimpleExtensionDeclaration { - fn rule() -> Rule { - Rule::simple_extension - } - - fn message() -> &'static str { - "SimpleExtensionDeclaration" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - let mut iter = RuleIter::from(pair.into_inner()); - let anchor_pair = iter.pop(Rule::anchor); - let anchor = unwrap_single_pair(anchor_pair) - .as_str() - .parse::() - .unwrap(); - let urn_anchor_pair = iter.pop(Rule::urn_anchor); - let urn_anchor = unwrap_single_pair(urn_anchor_pair) - .as_str() - .parse::() - .unwrap(); - let name_pair = iter.pop(Rule::name); - let name = unwrap_single_pair(name_pair).as_str().to_string(); - iter.done(); - - SimpleExtensionDeclaration { - anchor, - urn_anchor, - name, - } - } -} - -impl FromStr for SimpleExtensionDeclaration { - type Err = super::MessageParseError; - - fn from_str(s: &str) -> Result { - Self::parse_str(s) - } -} - -// Extension relation parsing implementations -// These were moved from extensions/registry.rs to maintain clean architecture - -use crate::extensions::any::Any; -use crate::parser::expressions::{FieldIndex, Name}; - -impl ParsePair for ExtensionValue { - fn rule() -> Rule { - Rule::extension_argument - } - - fn message() -> &'static str { - "ExtensionValue" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - - let inner = unwrap_single_pair(pair); // Extract the actual content - - match inner.as_rule() { - Rule::reference => { - // Reuse the existing FieldIndex parser, then extract the i32 - let field_index = FieldIndex::parse_pair(inner); - ExtensionValue::Reference(field_index.0) - } - Rule::literal => { - // Literal can contain integer, float, boolean, or string_literal - let mut literal_inner = inner.into_inner(); - let value_pair = literal_inner.next().unwrap(); - match value_pair.as_rule() { - Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)), - Rule::integer => { - let int_val = value_pair.as_str().parse::().unwrap(); - ExtensionValue::Integer(int_val) - } - Rule::float => { - let float_val = value_pair.as_str().parse::().unwrap(); - ExtensionValue::Float(float_val) - } - Rule::boolean => { - let bool_val = value_pair.as_str() == "true"; - ExtensionValue::Boolean(bool_val) - } - _ => panic!("Unexpected literal value type: {:?}", value_pair.as_rule()), - } - } - Rule::string_literal => ExtensionValue::String(unescape_string(inner)), - Rule::integer => { - // Direct integer (not wrapped in literal rule) - let int_val = inner.as_str().parse::().unwrap(); - ExtensionValue::Integer(int_val) - } - Rule::float => { - // Direct float (not wrapped in literal rule) - let float_val = inner.as_str().parse::().unwrap(); - ExtensionValue::Float(float_val) - } - Rule::boolean => { - // Direct boolean (not wrapped in literal rule) - let bool_val = inner.as_str() == "true"; - ExtensionValue::Boolean(bool_val) - } - Rule::expression => { - ExtensionValue::Expression(RawExpression::new(inner.as_str().to_string())) - } - _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()), - } +fn parse_urn_declaration(input: &str) -> Result { + let parsed: ast::ExtensionUrnDeclaration = + parse_declaration_line(input, "URNExtensionDeclaration")?; + if parsed.urn.trim().is_empty() { + return Err(MessageParseError::invalid( + "URNExtensionDeclaration", + "URN cannot be empty", + )); } + Ok(parsed) } -impl ParsePair for ExtensionColumn { - fn rule() -> Rule { - Rule::extension_column - } - - fn message() -> &'static str { - "ExtensionColumn" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - - let inner = unwrap_single_pair(pair); // Extract the actual content - - match inner.as_rule() { - Rule::named_column => { - let mut iter = inner.into_inner(); - let name_pair = iter.next().unwrap(); // Grammar guarantees name exists - let type_pair = iter.next().unwrap(); // Grammar guarantees type exists - - let name = Name::parse_pair(name_pair).0.to_string(); // Reuse existing Name parser - let type_spec = type_pair.as_str().to_string(); // Types are complex, store as string for now - - ExtensionColumn::Named { name, type_spec } - } - Rule::reference => { - // Reuse the existing FieldIndex parser, then extract the i32 - let field_index = FieldIndex::parse_pair(inner); - ExtensionColumn::Reference(field_index.0) - } - Rule::expression => { - ExtensionColumn::Expression(RawExpression::new(inner.as_str().to_string())) - } - _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()), - } +fn parse_simple_declaration(input: &str) -> Result { + let parsed: ast::ExtensionDeclaration = + parse_declaration_line(input, "SimpleExtensionDeclaration")?; + if parsed.name.trim().is_empty() { + return Err(MessageParseError::invalid( + "SimpleExtensionDeclaration", + "name cannot be empty", + )); } + Ok(parsed) } -/// Fully parsed extension invocation, including the user-supplied name and the -/// structured argument payload. -#[derive(Debug, Clone)] -pub struct ExtensionInvocation { - pub name: String, - pub args: ExtensionArgs, -} - -impl ParsePair for ExtensionInvocation { - fn rule() -> Rule { - Rule::extension_relation - } - - fn message() -> &'static str { - "ExtensionInvocation" - } - - fn parse_pair(pair: pest::iterators::Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - - let mut iter = pair.into_inner(); - - // Parse extension name to determine relation type and custom name - let extension_name_pair = iter.next().unwrap(); // Grammar guarantees extension_name exists - let full_extension_name = extension_name_pair.as_str(); - - // Extract the relation type and custom name from the extension name - // (e.g., "ExtensionLeaf:ParquetScan" -> "ExtensionLeaf" and "ParquetScan") - let (relation_type_str, custom_name) = if full_extension_name.contains(':') { - let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect(); - (parts[0], parts[1].to_string()) - } else { - (full_extension_name, "UnknownExtension".to_string()) - }; - - let relation_type = ExtensionRelationType::from_str(relation_type_str).unwrap(); - let mut args = ExtensionArgs::new(relation_type); - - // Parse optional arguments and columns - for inner_pair in iter { - match inner_pair.as_rule() { - Rule::extension_arguments => { - // Parse positional arguments - for arg_pair in inner_pair.into_inner() { - if arg_pair.as_rule() == Rule::extension_argument { - let value = ExtensionValue::parse_pair(arg_pair); - args.positional.push(value); - } - } - } - Rule::extension_named_arguments => { - // Parse named arguments - for arg_pair in inner_pair.into_inner() { - if arg_pair.as_rule() == Rule::extension_named_argument { - let mut arg_iter = arg_pair.into_inner(); - let name_pair = arg_iter.next().unwrap(); - let value_pair = arg_iter.next().unwrap(); - - let name = Name::parse_pair(name_pair).0.to_string(); - let value = ExtensionValue::parse_pair(value_pair); - args.named.insert(name, value); - } - } - } - Rule::extension_columns => { - // Parse output columns - for col_pair in inner_pair.into_inner() { - if col_pair.as_rule() == Rule::extension_column { - let column = ExtensionColumn::parse_pair(col_pair); - args.output_columns.push(column); - } - } - } - Rule::empty => {} // "_" — no arguments - r => panic!("Unexpected rule in ExtensionArgs: {:?}", r), - } - } - - ExtensionInvocation { - name: custom_name, - args, - } - } +fn parse_declaration_line(input: &str, target: &'static str) -> Result +where + T: FromStr, +{ + input + .parse::() + .map_err(|e| MessageParseError::syntax(target, e.describe_with_line(input))) } -impl ExtensionRelationType { - /// Create appropriate relation structure from extension detail and children. - /// This method handles the structural logic for creating different extension relation types. - pub fn create_rel( - self, - detail: Option, - children: Vec>, - ) -> Result { - use substrait::proto::rel::RelType; - use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel}; - - // Validate child count matches relation type - self.validate_child_count(children.len())?; - - let rel_type = match self { - ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel { - common: None, - detail: detail.map(Into::into), - }), - ExtensionRelationType::Single => { - let input = children.into_iter().next().map(|child| *child); - RelType::ExtensionSingle(Box::new(ExtensionSingleRel { - common: None, - detail: detail.map(Into::into), - input: input.map(Box::new), - })) - } - ExtensionRelationType::Multi => { - let inputs = children.into_iter().map(|child| *child).collect(); - RelType::ExtensionMulti(ExtensionMultiRel { - common: None, - detail: detail.map(Into::into), - inputs, - }) - } - }; - - Ok(substrait::proto::Rel { - rel_type: Some(rel_type), - }) +pub fn resolve_extension_detail( + registry: &ExtensionRegistry, + line_no: i64, + line: &str, + extension_name: &str, + extension_args: &ExtensionArgs, +) -> Result, ParseError> { + let detail = registry.parse_extension(extension_name, extension_args); + + match detail { + Ok(any) => Ok(Some(any)), + Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension { + name: extension_name.to_string(), + context: ParseContext::new(line_no, line.to_string()), + }), + Err(err) => Err(ParseError::ExtensionDetail( + ParseContext::new(line_no, line.to_string()), + err, + )), } } @@ -486,36 +206,26 @@ mod tests { #[test] fn test_parse_urn_extension_declaration() { let line = "@1: /my/urn1"; - let urn = URNExtensionDeclaration::parse_str(line).unwrap(); + let urn = parse_urn_declaration(line).unwrap(); assert_eq!(urn.anchor, 1); assert_eq!(urn.urn, "/my/urn1"); } #[test] fn test_parse_simple_extension_declaration() { - // Assumes a format like "@anchor: urn_anchor:name" let line = "#5@2: my_function_name"; - let decl = SimpleExtensionDeclaration::from_str(line).unwrap(); + let decl = parse_simple_declaration(line).unwrap(); assert_eq!(decl.anchor, 5); assert_eq!(decl.urn_anchor, 2); assert_eq!(decl.name, "my_function_name"); - // Test with a different name format, e.g. with underscores and numbers let line2 = "#10 @200: another_ext_123"; - let decl = SimpleExtensionDeclaration::from_str(line2).unwrap(); + let decl = parse_simple_declaration(line2).unwrap(); assert_eq!(decl.anchor, 10); assert_eq!(decl.urn_anchor, 200); assert_eq!(decl.name, "another_ext_123"); } - #[test] - fn test_parse_urn_extension_declaration_str() { - let line = "@1: /my/urn1"; - let urn = URNExtensionDeclaration::parse_str(line).unwrap(); - assert_eq!(urn.anchor, 1); - assert_eq!(urn.urn, "/my/urn1"); - } - #[test] fn test_extensions_round_trip_plan() { let input = r#" @@ -533,22 +243,17 @@ Type Variations: "# .trim_start(); - // Parse the input using the structural parser let plan = Parser::parse(input).unwrap(); - // Verify the plan has the expected extensions assert_eq!(plan.extension_urns.len(), 2); assert_eq!(plan.extensions.len(), 4); - // Convert the plan extensions back to SimpleExtensions let (extensions, errors) = SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions); assert!(errors.is_empty()); - // Convert back to string let output = extensions.to_string(" "); - // The output should match the input assert_eq!(output, input); } } diff --git a/src/parser/lalrpop_line.rs b/src/parser/lalrpop_line.rs new file mode 100644 index 0000000..8f7166d --- /dev/null +++ b/src/parser/lalrpop_line.rs @@ -0,0 +1,215 @@ +//! Thin wrappers around LALRPOP entry points for single-line parsing. + +use std::str::FromStr; + +use lalrpop_util::ParseError as LalrpopError; +use lalrpop_util::lexer::Token; + +use super::errors::{SyntaxErrorDetail, SyntaxErrorKind, TextSpan}; +use crate::parser::ast; + +lalrpop_util::lalrpop_mod!( + #[allow(clippy::all)] + #[allow(warnings)] + line_grammar, + "/parser/line_grammar.rs" +); + +pub fn parse_relation(input: &str) -> Result { + line_grammar::RelationParser::new() + .parse(input) + .map_err(|e| map_syntax_error(input, e)) +} + +/// Parse only root names (`Root[...]`) for top-level lines. +/// +/// Callers can fall back to [`parse_relation`] if this fails. +pub fn parse_root_names(input: &str) -> Result, SyntaxErrorDetail> { + line_grammar::RootNamesParser::new() + .parse(input) + .map_err(|e| map_syntax_error(input, e)) +} + +pub fn parse_extension_urn_declaration( + input: &str, +) -> Result { + line_grammar::ExtensionUrnDeclParser::new() + .parse(input) + .map_err(|e| map_syntax_error(input, e)) +} + +pub fn parse_extension_declaration( + input: &str, +) -> Result { + line_grammar::ExtensionDeclParser::new() + .parse(input) + .map_err(|e| map_syntax_error(input, e)) +} + +impl FromStr for ast::ExtensionUrnDeclaration { + type Err = SyntaxErrorDetail; + + fn from_str(s: &str) -> Result { + parse_extension_urn_declaration(s) + } +} + +impl FromStr for ast::ExtensionDeclaration { + type Err = SyntaxErrorDetail; + + fn from_str(s: &str) -> Result { + parse_extension_declaration(s) + } +} + +fn map_syntax_error( + input: &str, + error: LalrpopError, &'static str>, +) -> SyntaxErrorDetail { + match error { + LalrpopError::InvalidToken { location } => SyntaxErrorDetail::new( + SyntaxErrorKind::InvalidToken, + Some(single_char_span(input, location)), + Vec::new(), + ), + LalrpopError::UnrecognizedEof { location, expected } => SyntaxErrorDetail::new( + SyntaxErrorKind::UnexpectedEof, + Some(TextSpan::new(location, location)), + normalize_expected(expected), + ), + LalrpopError::UnrecognizedToken { + token: (start, token, end), + expected, + } => SyntaxErrorDetail::new( + SyntaxErrorKind::UnexpectedToken { + found: token.1.to_string(), + }, + Some(TextSpan::new(start, end)), + normalize_expected(expected), + ), + LalrpopError::ExtraToken { + token: (start, token, end), + } => SyntaxErrorDetail::new( + SyntaxErrorKind::ExtraToken { + found: token.1.to_string(), + }, + Some(TextSpan::new(start, end)), + Vec::new(), + ), + LalrpopError::User { error } => SyntaxErrorDetail::new( + SyntaxErrorKind::User { + message: error.to_string(), + }, + None, + Vec::new(), + ), + } +} + +fn normalize_expected(mut expected: Vec) -> Vec { + expected.sort(); + expected.dedup(); + expected +} + +fn single_char_span(input: &str, start: usize) -> TextSpan { + let end = input + .get(start..) + .and_then(|suffix| suffix.chars().next().map(|ch| start + ch.len_utf8())) + .unwrap_or(start); + TextSpan::new(start, end) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::ast::{Arg, Expr, Nullability, RelationName, TypeExpr}; + + #[test] + fn parses_generic_relation_with_named_args() { + let relation = parse_relation("Filter[$0, keep=true => out:i64]").unwrap(); + assert_eq!(relation.name, RelationName::Standard("Filter".to_string())); + assert_eq!(relation.args.positional.len(), 1); + assert_eq!(relation.args.named.len(), 1); + assert_eq!(relation.outputs.len(), 1); + } + + #[test] + fn parses_extension_relation() { + let relation = parse_relation("ExtensionLeaf:ParquetScan[path='x']").unwrap(); + assert!(matches!(relation.name, RelationName::Extension(_, _))); + } + + #[test] + fn parses_root_names() { + let names = parse_root_names("Root[a, b, c]").unwrap(); + assert_eq!(names, vec!["a", "b", "c"]); + } + + #[test] + fn named_args_must_follow_positional_args() { + let relation = parse_relation("Project[$0, keep=true]").unwrap(); + assert!(matches!( + relation.args.positional[0], + Arg::Expr(Expr::FieldRef(0)) + )); + } + + #[test] + fn parses_unspecified_nullability_suffix() { + let relation = parse_relation("Read[t => c:i32⁉]").unwrap(); + assert!(matches!( + &relation.outputs[0], + Arg::NamedColumn( + _, + TypeExpr::Simple { + nullability: Nullability::Unspecified, + .. + } + ) + )); + } + + #[test] + fn parses_extension_urn_declaration() { + let decl = parse_extension_urn_declaration("@ 1: https://example.com/functions.yaml") + .expect("parse urn declaration"); + assert_eq!(decl.anchor, 1); + assert_eq!(decl.urn, "https://example.com/functions.yaml"); + } + + #[test] + fn parses_extension_declaration() { + let decl = parse_extension_declaration("# 10 @ 1: my_function").expect("parse declaration"); + assert_eq!(decl.anchor, 10); + assert_eq!(decl.urn_anchor, 1); + assert_eq!(decl.name, "my_function"); + } + + #[test] + fn from_str_parses_extension_urn_declaration() { + let decl: ast::ExtensionUrnDeclaration = "@ 1: https://example.com/functions.yaml" + .parse() + .expect("parse urn declaration"); + assert_eq!(decl.anchor, 1); + assert_eq!(decl.urn, "https://example.com/functions.yaml"); + } + + #[test] + fn from_str_parses_extension_declaration() { + let decl: ast::ExtensionDeclaration = + "# 10 @ 1: my_function".parse().expect("parse declaration"); + assert_eq!(decl.anchor, 10); + assert_eq!(decl.urn_anchor, 1); + assert_eq!(decl.name, "my_function"); + } + + #[test] + fn from_str_reports_syntax_details() { + let err = "@ nope: urn" + .parse::() + .expect_err("should fail"); + assert!(matches!(err.kind, SyntaxErrorKind::UnexpectedToken { .. })); + assert!(!err.expected.is_empty()); + } +} diff --git a/src/parser/line_grammar.lalrpop b/src/parser/line_grammar.lalrpop new file mode 100644 index 0000000..e1893bd --- /dev/null +++ b/src/parser/line_grammar.lalrpop @@ -0,0 +1,266 @@ +grammar; + +use crate::parser::ast::{ + ExtensionDeclaration, ExtensionUrnDeclaration, unescape_quoted, Arg, ArgEntry, ArgList, Expr, + ExtensionType, FunctionCall, IfThenExpr, Literal, Nullability, Relation, RelationName, + TypeExpr, +}; + +// -------------------- +// Lexer +// -------------------- + +match { + r"\s*" => { }, + r"-?[0-9]+\.[0-9]+" => FLOAT, + r"-?[0-9]+" => INT, + r"([A-Za-z][A-Za-z0-9+.\-]*://[^\s]+)|(/[^\s]+)" => URN_VALUE, + r"\$[0-9]+" => FIELD, + r#""([^"\\]|\\.)*""# => QUOTED_NAME, + r"'([^'\\]|\\.)*'" => STRING, + r"[A-Za-z_][A-Za-z0-9_]*" => IDENT, + _ +} + +// -------------------- +// Root / Relation Lines +// -------------------- + +pub RootNames: Vec = { + "Root" "[" "]" => names, +}; + +pub ExtensionUrnDecl: ExtensionUrnDeclaration = { + ":" => ExtensionUrnDeclaration::new(anchor, urn), +}; + +pub ExtensionDecl: ExtensionDeclaration = { + ":" + => ExtensionDeclaration::new(anchor, urn_anchor, name), +}; + +NameListOpt: Vec = { + => vec![], + > => names, +}; + +UrnValue: String = { + => value.to_string(), + => value.to_string(), + => unescape_quoted(value), +}; + +pub Relation: Relation = { + "[" "]" => Relation::new(name, body.0, body.1), +}; + +RelationBody: (ArgList, Vec) = { + => (ArgList::default(), vec![]), + "=>" => (ArgList::default(), outputs), + > => (ArgList::from_entries(entries), vec![]), + > "=>" => { + (ArgList::from_entries(entries), outputs) + }, +}; + +ArgListOpt: Vec = { + => vec![], + > => args, +}; + +RelationName: RelationName = { + "ExtensionLeaf" ":" => RelationName::extension(ExtensionType::Leaf, name), + "ExtensionSingle" ":" => RelationName::extension(ExtensionType::Single, name), + "ExtensionMulti" ":" => RelationName::extension(ExtensionType::Multi, name), + => RelationName::standard(name), +}; + +// -------------------- +// Arguments +// -------------------- + +ArgEntry: ArgEntry = { + "=" => ArgEntry::named(name, value), + => ArgEntry::positional(value), +}; + +Arg: Arg = { + ":" => Arg::named_column(name, typ), + "&" => Arg::enum_variant(value), + "(" ")" => Arg::tuple(values), + "_" => Arg::wildcard(), + => Arg::expr(expr), +}; + +// -------------------- +// Expressions +// -------------------- + +// Keep `if_then` and function-call productions ahead of `Name` so bare +// identifiers do not greedily consume expression starts. +Expr: Expr = { + => Expr::if_then(if_then), + => Expr::function_call(call), + // TODO(parser-hardening): `FIELD` can lex values larger than i32; avoid + // panicking parse here and surface a structured syntax/validation error. + => Expr::field_ref(field[1..].parse::().expect("field index")), + => Expr::literal(literal), + => Expr::dotted_name(names), +}; + +FunctionCallExpr: FunctionCall = { + "(" ")" + => FunctionCall::new(header.0, header.1, header.2, args, output_type), +}; + +FunctionHeader: (String, Option, Option) = { + => (name, None, None), + => (name, Some(anchor), None), + => (name, None, Some(urn_anchor)), + => (name, Some(anchor), Some(urn_anchor)), +}; + +ExprListOpt: Vec = { + => vec![], + > => values, +}; + +OutputTypeOpt: Option = { + => None, + ":" => Some(typ), +}; + +IfThenExprRule: IfThenExpr = { + "if_then" "(" > "," "_" "->" ")" + => IfThenExpr::new(clauses, else_expr), +}; + +IfClause: (Expr, Expr) = { + "->" => (condition, result), +}; + +// -------------------- +// Literals +// -------------------- + +LiteralExpr: Literal = { + => Literal::float(value.parse::().expect("float"), typ), + // TODO(parser-hardening): `INT` accepts arbitrary length digits; overflow + // in parse::() currently panics instead of returning a parse error. + => Literal::integer(value.parse::().expect("integer"), typ), + => Literal::boolean(value, typ), + => Literal::string(unescape_quoted(value), typ), +}; + +BooleanLiteral: bool = { + "true" => true, + "false" => false, +}; + +TypeOpt: Option = { + => None, + ":" => Some(typ), +}; + +// -------------------- +// Types +// -------------------- + +Type: TypeExpr = { + => simple, + => list, + => user_defined, +}; + +SimpleType: TypeExpr = { + => TypeExpr::simple(name, nullability), +}; + +ListType: TypeExpr = { + "list" "<" ">" => TypeExpr::list(nullability, inner), +}; + +UserDefinedType: TypeExpr = { + + => TypeExpr::user_defined(base.0, base.1, base.2, nullability, params), +}; + +UserTypeBase: (String, Option, Option) = { + <_prefix:"u!"?> => (name, refs.0, refs.1), +}; + +AnchorRefs: (Option, Option) = { + => (None, None), + => (Some(anchor), None), + => (None, Some(urn_anchor)), + => (Some(anchor), Some(urn_anchor)), +}; + +NullabilitySuffix: Nullability = { + => Nullability::Required, + "?" => Nullability::Nullable, + "⁉" => Nullability::Unspecified, +}; + +TypeParamsOpt: Vec = { + => vec![], + "<" > ">" => params, +}; + +SimpleTypeName: &'input str = { + "boolean", + "i8", + "i16", + "i32", + "i64", + "fp32", + "fp64", + "string", + "binary", + "timestamp_tz", + "timestamp", + "date", + "time", + "interval_year", + "uuid", +}; + +Anchor: u32 = { + // TODO(parser-hardening): large anchors can panic parse::() on + // overflow; prefer returning a structured parse error. + "#" => value.parse::().expect("anchor"), +}; + +UrnAnchor: u32 = { + // TODO(parser-hardening): large URN anchors can panic parse::() on + // overflow; prefer returning a structured parse error. + "@" => value.parse::().expect("urn anchor"), +}; + +Name: String = { + => value.to_string(), + => unescape_quoted(value), +}; + +/// A dot-separated sequence of names, e.g. `schema.table` or `"sales.q1"`. +DottedName: Vec = { + )*> => { + let mut names = vec![first]; + names.extend(rest); + names + }, +}; + +// -------------------- +// Generic List Helpers +// -------------------- + +// A non-empty comma-separated list. +CommaSep1: Vec = { + > "," => { + let mut values = left; + values.push(right); + values + }, + => vec![value], +}; diff --git a/src/parser/lower/expr.rs b/src/parser/lower/expr.rs new file mode 100644 index 0000000..beaf92f --- /dev/null +++ b/src/parser/lower/expr.rs @@ -0,0 +1,170 @@ +//! Expression lowering helpers. + +use substrait::proto::expression::field_reference::ReferenceType; +use substrait::proto::expression::if_then::IfClause; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{ + FieldReference, IfThen, ReferenceSegment, RexType, ScalarFunction, reference_segment, +}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +use super::types::resolve_anchor; +use super::{Lower, LowerCtx}; +use crate::extensions::simple::ExtensionKind; +use crate::parser::ast; +use crate::parser::errors::ParseError; + +pub(crate) fn lower_arg_as_expression( + ctx: &LowerCtx<'_>, + arg: &ast::Arg, + message: &'static str, +) -> Result { + match arg { + ast::Arg::Expr(expr) => expr.lower(ctx, message), + _ => ctx.invalid( + message, + format!("expected expression argument, got '{arg}'"), + ), + } +} + +pub(crate) fn fetch_value_expression( + ctx: &LowerCtx<'_>, + value: &ast::Arg, + field: &'static str, +) -> Result { + match value { + ast::Arg::Expr(ast::Expr::Literal(ast::Literal { + value: ast::LiteralValue::Integer(number), + typ: None, + })) => { + // Keep integer limit/offset as direct I64 literals. + if *number < 0 { + return ctx.invalid( + "Fetch", + format!("Fetch {field} must be non-negative, got: {number}"), + ); + } + Ok(i64_literal_expr(*number)) + } + ast::Arg::Expr(expr) => expr.lower(ctx, "Fetch"), + _ => ctx.invalid( + "Fetch", + format!("Fetch {field} must be an expression, got '{value}'"), + ), + } +} + +pub(crate) fn i64_literal_expr(value: i64) -> Expression { + Expression { + rex_type: Some(RexType::Literal(substrait::proto::expression::Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::I64(value)), + })), + } +} + +pub(crate) fn field_reference(field: i32) -> FieldReference { + FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField(Box::new( + reference_segment::StructField { field, child: None }, + ))), + })), + root_type: None, + } +} + +pub(crate) fn field_ref_expression(field: i32) -> Expression { + Expression { + rex_type: Some(RexType::Selection(Box::new(field_reference(field)))), + } +} + +impl Lower for ast::Expr { + type Output = Expression; + + fn lower(&self, ctx: &LowerCtx<'_>, message: &'static str) -> Result { + let rex_type = match self { + ast::Expr::FieldRef(index) => RexType::Selection(Box::new(field_reference(*index))), + ast::Expr::Literal(literal) => RexType::Literal(literal.lower(ctx, message)?), + ast::Expr::FunctionCall(call) => RexType::ScalarFunction(call.lower(ctx, message)?), + ast::Expr::IfThen(if_then) => RexType::IfThen(Box::new(if_then.lower(ctx, message)?)), + ast::Expr::Identifier(name) => { + // Bare identifiers are valid for names in the AST, but not as + // scalar expressions in relational operators. + return ctx.invalid(message, format!("unexpected bare identifier '{name}'")); + } + ast::Expr::DottedName(names) => { + return ctx.invalid( + message, + format!("unexpected dotted name '{}'", names.join(".")), + ); + } + }; + + Ok(Expression { + rex_type: Some(rex_type), + }) + } +} + +impl Lower for ast::IfThenExpr { + type Output = IfThen; + + fn lower(&self, ctx: &LowerCtx<'_>, message: &'static str) -> Result { + if self.clauses.is_empty() { + return ctx.invalid(message, "if_then requires at least one clause"); + } + + let mut clauses = Vec::with_capacity(self.clauses.len()); + for (condition, result) in &self.clauses { + clauses.push(IfClause { + r#if: Some(condition.lower(ctx, message)?), + then: Some(result.lower(ctx, message)?), + }); + } + + Ok(IfThen { + ifs: clauses, + r#else: Some(Box::new(self.else_expr.lower(ctx, message)?)), + }) + } +} + +impl Lower for ast::FunctionCall { + type Output = ScalarFunction; + + fn lower(&self, ctx: &LowerCtx<'_>, message: &'static str) -> Result { + let function_reference = resolve_anchor( + ctx, + ExtensionKind::Function, + self.anchor, + &self.name, + message, + )?; + + let mut arguments = Vec::with_capacity(self.args.len()); + for arg in &self.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(arg.lower(ctx, message)?)), + }); + } + + let output_type = match &self.output_type { + Some(t) => Some(t.lower(ctx, message)?), + None => None, + }; + + Ok(ScalarFunction { + function_reference, + arguments, + options: vec![], + output_type, + #[allow(deprecated)] + args: vec![], + }) + } +} diff --git a/src/parser/lower/extensions.rs b/src/parser/lower/extensions.rs new file mode 100644 index 0000000..8604dd0 --- /dev/null +++ b/src/parser/lower/extensions.rs @@ -0,0 +1,165 @@ +//! Extension-relation lowering and AST -> `ExtensionArgs` conversion. + +use std::collections::HashSet; + +use substrait::proto::rel::RelType; +use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, Rel}; + +use super::LowerCtx; +use super::validate::ensure_unique_named_arg; +use crate::extensions::any::Any; +use crate::extensions::{ + ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, RawExpression, +}; +use crate::parser::ast; +use crate::parser::errors::{MessageParseError, ParseError}; +use crate::parser::extensions::resolve_extension_detail; + +impl From for ExtensionRelationType { + fn from(value: ast::ExtensionType) -> Self { + match value { + ast::ExtensionType::Leaf => ExtensionRelationType::Leaf, + ast::ExtensionType::Single => ExtensionRelationType::Single, + ast::ExtensionType::Multi => ExtensionRelationType::Multi, + } + } +} + +#[allow(clippy::vec_box)] +pub(crate) fn lower_extension( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + extension_type: ast::ExtensionType, + extension_name: &str, + child_relations: Vec>, +) -> Result { + let relation_type: ExtensionRelationType = extension_type.into(); + let args = build_extension_args_from_relation(ctx, relation, relation_type)?; + + relation_type + .validate_child_count(child_relations.len()) + .map_err(|e| plan_error(ctx, MessageParseError::invalid("extension_relation", e)))?; + + let detail = + resolve_extension_detail(ctx.registry, ctx.line_no, ctx.line, extension_name, &args)?; + + Ok(build_extension_relation( + relation_type, + detail, + child_relations, + )) +} + +pub(crate) fn build_extension_args_from_relation( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + relation_type: ExtensionRelationType, +) -> Result { + let mut args = ExtensionArgs::new(relation_type); + let mut seen_named = HashSet::new(); + + for arg in &relation.args.positional { + args.positional.push(extension_value_from_arg(arg).map_err(|e| plan_error(ctx, e))?); + } + + for named in &relation.args.named { + ensure_unique_named_arg(ctx, &mut seen_named, &named.name, "extension_relation")?; + args.named.insert( + named.name.clone(), + extension_value_from_arg(&named.value).map_err(|e| plan_error(ctx, e))?, + ); + } + + for arg in &relation.outputs { + args.output_columns.push(extension_column_from_arg(arg).map_err(|e| plan_error(ctx, e))?); + } + + Ok(args) +} + +pub(crate) fn extension_value_from_arg( + arg: &ast::Arg, +) -> Result { + match arg { + ast::Arg::Expr(ast::Expr::FieldRef(idx)) => Ok(ExtensionValue::Reference(*idx)), + ast::Arg::Expr(ast::Expr::Literal(lit)) => match (&lit.value, &lit.typ) { + (ast::LiteralValue::String(value), None) => Ok(ExtensionValue::String(value.clone())), + (ast::LiteralValue::Integer(value), None) => Ok(ExtensionValue::Integer(*value)), + (ast::LiteralValue::Float(value), None) => Ok(ExtensionValue::Float(*value)), + (ast::LiteralValue::Boolean(value), None) => Ok(ExtensionValue::Boolean(*value)), + _ => Ok(ExtensionValue::Expression(RawExpression::new( + lit.to_string(), + ))), + }, + ast::Arg::Expr(expr) => Ok(ExtensionValue::Expression(RawExpression::new( + expr.to_string(), + ))), + ast::Arg::NamedColumn(_, _) + | ast::Arg::Enum(_) + | ast::Arg::Tuple(_) + | ast::Arg::Wildcard => Ok(ExtensionValue::Expression(RawExpression::new( + arg.to_string(), + ))), + } +} + +pub(crate) fn extension_column_from_arg( + arg: &ast::Arg, +) -> Result { + match arg { + ast::Arg::NamedColumn(name, typ) => Ok(ExtensionColumn::Named { + name: name.clone(), + type_spec: typ.to_string(), + }), + ast::Arg::Expr(ast::Expr::FieldRef(idx)) => Ok(ExtensionColumn::Reference(*idx)), + ast::Arg::Expr(expr) => Ok(ExtensionColumn::Expression(RawExpression::new( + expr.to_string(), + ))), + ast::Arg::Enum(_) | ast::Arg::Tuple(_) | ast::Arg::Wildcard => Ok( + ExtensionColumn::Expression(RawExpression::new(arg.to_string())), + ), + } +} + +#[allow(clippy::vec_box)] +fn build_extension_relation( + relation_type: ExtensionRelationType, + detail: Option, + children: Vec>, +) -> Rel { + debug_assert!( + relation_type.validate_child_count(children.len()).is_ok(), + "child count should be validated before relation construction" + ); + + let rel_type = match relation_type { + ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel { + common: None, + detail: detail.map(Into::into), + }), + ExtensionRelationType::Single => { + let input = children.into_iter().next().map(|child| *child); + RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common: None, + detail: detail.map(Into::into), + input: input.map(Box::new), + })) + } + ExtensionRelationType::Multi => { + let inputs = children.into_iter().map(|child| *child).collect(); + RelType::ExtensionMulti(ExtensionMultiRel { + common: None, + detail: detail.map(Into::into), + inputs, + }) + } + }; + + Rel { + rel_type: Some(rel_type), + } +} + +fn plan_error(ctx: &LowerCtx<'_>, error: MessageParseError) -> ParseError { + ParseError::Plan(ctx.parse_context(), error) +} diff --git a/src/parser/lower/literals.rs b/src/parser/lower/literals.rs new file mode 100644 index 0000000..1232c22 --- /dev/null +++ b/src/parser/lower/literals.rs @@ -0,0 +1,257 @@ +//! Literal lowering and typed literal coercion. + +use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; +use substrait::proto::expression::Literal; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::r#type::{self as ptype, Kind, Nullability}; + +use super::types::lower_type_kind_no_box; +use super::{Lower, LowerCtx}; +use crate::parser::ast; +use crate::parser::errors::ParseError; + +impl Lower for ast::Literal { + type Output = Literal; + + fn lower(&self, ctx: &LowerCtx<'_>, message: &'static str) -> Result { + match (&self.value, &self.typ) { + (ast::LiteralValue::Integer(value), typ) => { + lower_int_literal(ctx, *value, typ.as_ref(), message) + } + (ast::LiteralValue::Float(value), typ) => { + lower_float_literal(ctx, *value, typ.as_ref(), message) + } + (ast::LiteralValue::Boolean(value), _typ) => Ok(Literal { + literal_type: Some(LiteralType::Boolean(*value)), + nullable: false, + type_variation_reference: 0, + }), + (ast::LiteralValue::String(value), typ) => { + lower_string_literal(ctx, value, typ.as_ref(), message) + } + } + } +} + +fn lower_int_literal( + ctx: &LowerCtx<'_>, + parsed_value: i64, + typ: Option<&ast::TypeExpr>, + message: &'static str, +) -> Result { + let kind = match typ { + None => Kind::I64(ptype::I64 { + type_variation_reference: 0, + nullability: Nullability::Required as i32, + }), + Some(typ) => lower_type_kind_no_box(ctx, typ, message)?, + }; + + let (lit, nullability, tvar) = match kind { + // TODO(parser-hardening): validate numeric bounds before narrowing. + // Values like `1000:i8` currently cast via `as` and silently truncate. + Kind::I8(i) => ( + LiteralType::I8(parsed_value as i32), + i.nullability, + i.type_variation_reference, + ), + Kind::I16(i) => ( + LiteralType::I16(parsed_value as i32), + i.nullability, + i.type_variation_reference, + ), + Kind::I32(i) => ( + LiteralType::I32(parsed_value as i32), + i.nullability, + i.type_variation_reference, + ), + Kind::I64(i) => ( + LiteralType::I64(parsed_value), + i.nullability, + i.type_variation_reference, + ), + other => { + return ctx.invalid( + message, + format!("Invalid type for integer literal: {other:?}"), + ); + } + }; + + Ok(Literal { + literal_type: Some(lit), + nullable: nullability != Nullability::Required as i32, + type_variation_reference: tvar, + }) +} + +fn lower_float_literal( + ctx: &LowerCtx<'_>, + parsed_value: f64, + typ: Option<&ast::TypeExpr>, + message: &'static str, +) -> Result { + let kind = match typ { + None => Kind::Fp64(ptype::Fp64 { + type_variation_reference: 0, + nullability: Nullability::Required as i32, + }), + Some(typ) => lower_type_kind_no_box(ctx, typ, message)?, + }; + + let (lit, nullability, tvar) = match kind { + Kind::Fp32(f) => ( + LiteralType::Fp32(parsed_value as f32), + f.nullability, + f.type_variation_reference, + ), + Kind::Fp64(f) => ( + LiteralType::Fp64(parsed_value), + f.nullability, + f.type_variation_reference, + ), + other => { + return ctx.invalid( + message, + format!("Invalid type for float literal: {other:?}"), + ); + } + }; + + Ok(Literal { + literal_type: Some(lit), + nullable: nullability != Nullability::Required as i32, + type_variation_reference: tvar, + }) +} + +fn lower_string_literal( + ctx: &LowerCtx<'_>, + string_value: &str, + typ: Option<&ast::TypeExpr>, + message: &'static str, +) -> Result { + let Some(typ) = typ else { + return Ok(Literal { + literal_type: Some(LiteralType::String(string_value.to_string())), + nullable: false, + type_variation_reference: 0, + }); + }; + + let kind = lower_type_kind_no_box(ctx, typ, message)?; + match kind { + Kind::String(s) => Ok(Literal { + literal_type: Some(LiteralType::String(string_value.to_string())), + nullable: s.nullability != Nullability::Required as i32, + type_variation_reference: s.type_variation_reference, + }), + Kind::Date(d) => { + let date_days = parse_date_to_days(ctx, string_value, message)?; + Ok(Literal { + literal_type: Some(LiteralType::Date(date_days)), + nullable: d.nullability != Nullability::Required as i32, + type_variation_reference: d.type_variation_reference, + }) + } + Kind::Time(t) => { + let micros = parse_time_to_microseconds(ctx, string_value, message)?; + Ok(Literal { + literal_type: Some(LiteralType::Time(micros)), + nullable: t.nullability != Nullability::Required as i32, + type_variation_reference: t.type_variation_reference, + }) + } + #[allow(deprecated)] + Kind::Timestamp(ts) => { + let micros = parse_timestamp_to_microseconds(ctx, string_value, message)?; + Ok(Literal { + literal_type: Some(LiteralType::Timestamp(micros)), + nullable: ts.nullability != Nullability::Required as i32, + type_variation_reference: ts.type_variation_reference, + }) + } + other => ctx.invalid( + message, + format!("Invalid type for string literal: {other:?}"), + ), + } +} + +fn parse_date_to_days( + ctx: &LowerCtx<'_>, + date_str: &str, + message: &'static str, +) -> Result { + let formats = ["%Y-%m-%d", "%Y/%m/%d"]; + + for format in &formats { + if let Ok(date) = NaiveDate::parse_from_str(date_str, format) { + // Constant date; construction cannot fail. + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch date"); + return Ok(date.signed_duration_since(epoch).num_days() as i32); + } + } + + ctx.invalid( + message, + format!("Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD"), + ) +} + +fn parse_time_to_microseconds( + ctx: &LowerCtx<'_>, + time_str: &str, + message: &'static str, +) -> Result { + let formats = ["%H:%M:%S%.f", "%H:%M:%S"]; + + for format in &formats { + if let Ok(time) = NaiveTime::parse_from_str(time_str, format) { + // Constant time; construction cannot fail. + let midnight = NaiveTime::from_hms_opt(0, 0, 0).expect("midnight"); + let duration = time.signed_duration_since(midnight); + return Ok(duration.num_microseconds().unwrap_or(0)); + } + } + + ctx.invalid( + message, + format!("Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff"), + ) +} + +fn parse_timestamp_to_microseconds( + ctx: &LowerCtx<'_>, + timestamp_str: &str, + message: &'static str, +) -> Result { + let formats = [ + "%Y-%m-%dT%H:%M:%S%.f", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d %H:%M:%S%.f", + "%Y-%m-%d %H:%M:%S", + "%Y/%m/%dT%H:%M:%S%.f", + "%Y/%m/%dT%H:%M:%S", + "%Y/%m/%d %H:%M:%S%.f", + "%Y/%m/%d %H:%M:%S", + ]; + + for format in &formats { + if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) { + // Constant timestamp; construction cannot fail. + let epoch = DateTime::from_timestamp(0, 0) + .expect("epoch timestamp") + .naive_utc(); + let duration = datetime.signed_duration_since(epoch); + return Ok(duration.num_microseconds().unwrap_or(0)); + } + } + + ctx.invalid( + message, + format!( + "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS" + ), + ) +} diff --git a/src/parser/lower/mod.rs b/src/parser/lower/mod.rs new file mode 100644 index 0000000..3ae76f1 --- /dev/null +++ b/src/parser/lower/mod.rs @@ -0,0 +1,98 @@ +//! AST -> protobuf lowering. +//! +//! The AST is the representation of the syntax / semantics in a Rustic form; +//! that is assembled during the `parse` stage. +//! +//! Relation lowering is explicit (per relation type), while leaf AST nodes +//! ([`ast::Expr`], [`ast::Literal`], [`ast::TypeExpr`]) use the internal +//! `Lower` trait for uniform conversion. + +mod expr; +mod extensions; +mod literals; +mod relations; +mod types; +mod validate; + +use substrait::proto::Rel; + +use crate::extensions::{ExtensionRegistry, SimpleExtensions}; +use crate::parser::ast; +use crate::parser::errors::{MessageParseError, ParseContext, ParseError}; + +trait Lower { + /// Protobuf type that this trait converts into + type Output; + + /// Convert Self into that protobuf type + fn lower(&self, ctx: &LowerCtx<'_>, message: &'static str) -> Result; +} + +pub(crate) struct LowerCtx<'a> { + pub(crate) extensions: &'a SimpleExtensions, + pub(crate) registry: &'a ExtensionRegistry, + pub(crate) line: &'a str, + pub(crate) line_no: i64, +} + +impl<'a> LowerCtx<'a> { + pub(crate) fn new( + extensions: &'a SimpleExtensions, + registry: &'a ExtensionRegistry, + line: &'a str, + line_no: i64, + ) -> Self { + Self { + extensions, + registry, + line, + line_no, + } + } + + pub(crate) fn parse_context(&self) -> ParseContext { + // ParseContext owns line text, so errors can be surfaced without + // borrowing constraints. + ParseContext::new(self.line_no, self.line.to_string()) + } + + pub(crate) fn invalid( + &self, + message: &'static str, + description: impl ToString, + ) -> Result { + Err(ParseError::Plan( + self.parse_context(), + MessageParseError::invalid(message, description), + )) + } +} + +#[allow(clippy::vec_box)] +pub fn lower_relation( + extensions: &SimpleExtensions, + registry: &ExtensionRegistry, + relation: &ast::Relation, + line: &str, + line_no: i64, + child_relations: Vec>, + input_field_count: usize, +) -> Result { + let ctx = LowerCtx::new(extensions, registry, line, line_no); + + if relation.args.invalid_order { + return ctx.invalid( + "relation_arguments", + "named arguments must follow positional arguments", + ); + } + + match &relation.name { + ast::RelationName::Standard(name) => { + relations::lower_standard(&ctx, relation, name, child_relations, input_field_count) + } + ast::RelationName::Extension(kind, name) => { + extensions::lower_extension(&ctx, relation, *kind, name, child_relations) + } + } +} diff --git a/src/parser/lower/relations.rs b/src/parser/lower/relations.rs new file mode 100644 index 0000000..730070f --- /dev/null +++ b/src/parser/lower/relations.rs @@ -0,0 +1,449 @@ +//! Relation-specific lowering logic. + +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::fetch_rel::{CountMode, OffsetMode}; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::{Emit, EmitKind}; +use substrait::proto::sort_field::SortKind; +use substrait::proto::r#type::{self as ptype, Nullability}; +use substrait::proto::{ + AggregateFunction, AggregateRel, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, + ReadRel, Rel, RelCommon, SortField, SortRel, read_rel, +}; + +use super::expr::{fetch_value_expression, field_ref_expression, lower_arg_as_expression}; +use super::validate::{ + ensure_exact_child_count, ensure_exact_positional_count, ensure_no_children, + ensure_no_named_args, expect_one_child, output_mapping_from_args, parse_join_type, + parse_sort_direction, +}; +use super::{Lower, LowerCtx}; +use crate::parser::ast; +use crate::parser::errors::ParseError; + +#[allow(clippy::vec_box)] +pub(crate) fn lower_standard( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + name: &str, + child_relations: Vec>, + input_field_count: usize, +) -> Result { + match name { + "Read" => lower_read(ctx, relation, child_relations), + "Filter" => lower_filter(ctx, relation, child_relations), + "Project" => lower_project(ctx, relation, child_relations, input_field_count), + "Aggregate" => lower_aggregate(ctx, relation, child_relations), + "Sort" => lower_sort(ctx, relation, child_relations), + "Fetch" => lower_fetch(ctx, relation, child_relations), + "Join" => lower_join(ctx, relation, child_relations), + _ => Err(ParseError::RelationParse( + ctx.parse_context(), + format!("unsupported relation '{name}'"), + )), + } +} + +#[allow(clippy::vec_box)] +fn lower_read( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + child_relations: Vec>, +) -> Result { + ensure_no_named_args(ctx, relation, "Read")?; + ensure_no_children(ctx, child_relations.len(), "Read")?; + + ensure_exact_positional_count( + ctx, + relation.args.positional.len(), + 1, + "Read", + "Read expects exactly one positional argument (table name)", + )?; + + let names = match &relation.args.positional[0] { + ast::Arg::Expr(ast::Expr::Identifier(name)) => vec![name.clone()], + ast::Arg::Expr(ast::Expr::DottedName(names)) => names.clone(), + arg => { + return ctx.invalid( + "Read", + format!("Read table name must be an identifier, got '{arg}'"), + ); + } + }; + + let mut column_names = Vec::with_capacity(relation.outputs.len()); + let mut column_types = Vec::with_capacity(relation.outputs.len()); + for col in &relation.outputs { + match col { + ast::Arg::NamedColumn(name, typ) => { + column_names.push(name.clone()); + column_types.push(typ.lower(ctx, "Read")?); + } + _ => { + return ctx.invalid("Read", "Read outputs must be named columns, e.g. col:i32"); + } + } + } + + let struct_ = ptype::Struct { + types: column_types, + type_variation_reference: 0, + nullability: Nullability::Required as i32, + }; + + let read = ReadRel { + base_schema: Some(NamedStruct { + names: column_names, + r#struct: Some(struct_), + }), + read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable { + names, + advanced_extension: None, + })), + ..Default::default() + }; + + Ok(Rel { + rel_type: Some(RelType::Read(Box::new(read))), + }) +} + +#[allow(clippy::vec_box)] +fn lower_filter( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + mut child_relations: Vec>, +) -> Result { + ensure_no_named_args(ctx, relation, "Filter")?; + let input = expect_one_child(ctx, &mut child_relations, "Filter")?; + + ensure_exact_positional_count( + ctx, + relation.args.positional.len(), + 1, + "Filter", + "Filter expects one positional condition argument", + )?; + + let condition = lower_arg_as_expression(ctx, &relation.args.positional[0], "Filter")?; + + let emit = EmitKind::Emit(Emit { + output_mapping: output_mapping_from_args(ctx, &relation.outputs, "Filter")?, + }); + + Ok(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + input: Some(input), + condition: Some(Box::new(condition)), + common: Some(RelCommon { + emit_kind: Some(emit), + ..Default::default() + }), + advanced_extension: None, + }))), + }) +} + +#[allow(clippy::vec_box)] +fn lower_project( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + mut child_relations: Vec>, + input_field_count: usize, +) -> Result { + ensure_no_named_args(ctx, relation, "Project")?; + let input = expect_one_child(ctx, &mut child_relations, "Project")?; + + if !relation.outputs.is_empty() { + return ctx.invalid( + "Project", + "Project does not accept output columns after '=>'; use positional arguments only", + ); + } + + let mut expressions = Vec::new(); + let mut output_mapping = Vec::new(); + + for arg in &relation.args.positional { + match arg { + ast::Arg::Expr(ast::Expr::FieldRef(index)) => output_mapping.push(*index), + _ => { + let expr = lower_arg_as_expression(ctx, arg, "Project")?; + expressions.push(expr); + // Computed expressions are appended after input fields in the + // projected schema. + output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1)); + } + } + } + + Ok(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + input: Some(input), + expressions, + common: Some(RelCommon { + emit_kind: Some(EmitKind::Emit(Emit { output_mapping })), + ..Default::default() + }), + advanced_extension: None, + }))), + }) +} + +#[allow(clippy::vec_box)] +fn lower_aggregate( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + mut child_relations: Vec>, +) -> Result { + ensure_no_named_args(ctx, relation, "Aggregate")?; + let input = expect_one_child(ctx, &mut child_relations, "Aggregate")?; + + let mut grouping_expressions = Vec::new(); + for arg in &relation.args.positional { + match arg { + ast::Arg::Wildcard => {} + ast::Arg::Expr(ast::Expr::FieldRef(index)) => { + grouping_expressions.push(field_ref_expression(*index)); + } + _ => { + return ctx.invalid( + "Aggregate", + format!( + "Aggregate grouping argument must be field reference or '_', got '{arg}'" + ), + ); + } + } + } + + let group_by_count = grouping_expressions.len(); + let mut measures = Vec::new(); + let mut output_mapping = Vec::new(); + + for arg in &relation.outputs { + match arg { + ast::Arg::Expr(ast::Expr::FieldRef(index)) => output_mapping.push(*index), + ast::Arg::Expr(ast::Expr::FunctionCall(call)) => { + let scalar = call.lower(ctx, "Aggregate")?; + measures.push(Measure { + measure: Some(AggregateFunction { + function_reference: scalar.function_reference, + arguments: scalar.arguments, + options: scalar.options, + output_type: scalar.output_type, + invocation: 0, + phase: 0, + sorts: vec![], + #[allow(deprecated)] + args: scalar.args, + }), + filter: None, + }); + // Aggregate outputs are grouped fields followed by measures. + output_mapping.push(group_by_count as i32 + (measures.len() as i32 - 1)); + } + _ => { + return ctx.invalid( + "Aggregate", + format!("Aggregate output must be field reference or aggregate function, got '{arg}'"), + ); + } + } + } + + Ok(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + input: Some(input), + grouping_expressions, + // We use `grouping_expressions` (newer API) and leave legacy + // `groupings` empty for compatibility with current text format. + groupings: vec![], + measures, + common: Some(RelCommon { + emit_kind: Some(EmitKind::Emit(Emit { output_mapping })), + ..Default::default() + }), + advanced_extension: None, + }))), + }) +} + +#[allow(clippy::vec_box)] +fn lower_sort( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + mut child_relations: Vec>, +) -> Result { + ensure_no_named_args(ctx, relation, "Sort")?; + let input = expect_one_child(ctx, &mut child_relations, "Sort")?; + + let mut sorts = Vec::with_capacity(relation.args.positional.len()); + for arg in &relation.args.positional { + let ast::Arg::Tuple(values) = arg else { + return ctx.invalid( + "Sort", + "Sort fields must be tuples like ($0, &AscNullsFirst)", + ); + }; + if values.len() != 2 { + return ctx.invalid("Sort", "Sort tuple must contain exactly 2 items"); + } + + let field = match &values[0] { + ast::Arg::Expr(ast::Expr::FieldRef(idx)) => *idx, + _ => { + return ctx.invalid("Sort", "Sort tuple first item must be a field reference"); + } + }; + + let direction = match &values[1] { + ast::Arg::Enum(v) => parse_sort_direction(ctx, v)?, + _ => { + return ctx.invalid( + "Sort", + "Sort tuple second item must be a sort direction enum", + ); + } + }; + + sorts.push(SortField { + expr: Some(field_ref_expression(field)), + sort_kind: Some(SortKind::Direction(direction as i32)), + }); + } + + let output_mapping = output_mapping_from_args(ctx, &relation.outputs, "Sort")?; + + Ok(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + input: Some(input), + sorts, + common: Some(RelCommon { + emit_kind: Some(EmitKind::Emit(Emit { output_mapping })), + ..Default::default() + }), + advanced_extension: None, + }))), + }) +} + +#[allow(clippy::vec_box)] +fn lower_fetch( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + mut child_relations: Vec>, +) -> Result { + let input = expect_one_child(ctx, &mut child_relations, "Fetch")?; + + let mut count_mode = None; + let mut offset_mode = None; + + if relation.args.positional.len() == 1 + && matches!(relation.args.positional[0], ast::Arg::Wildcard) + { + if !relation.args.named.is_empty() { + return ctx.invalid("Fetch", "Fetch cannot mix '_' with named arguments"); + } + } else { + if !relation.args.positional.is_empty() { + return ctx.invalid( + "Fetch", + "Fetch accepts only named arguments (limit/offset) or '_'", + ); + } + + for named in &relation.args.named { + match named.name.as_str() { + "limit" => { + if count_mode.is_some() { + return ctx.invalid("Fetch", "Duplicate limit argument"); + } + count_mode = Some(CountMode::CountExpr(Box::new(fetch_value_expression( + ctx, + &named.value, + "limit", + )?))); + } + "offset" => { + if offset_mode.is_some() { + return ctx.invalid("Fetch", "Duplicate offset argument"); + } + offset_mode = Some(OffsetMode::OffsetExpr(Box::new(fetch_value_expression( + ctx, + &named.value, + "offset", + )?))); + } + other => { + return ctx.invalid("Fetch", format!("Unknown named argument '{other}'")); + } + } + } + } + + let output_mapping = output_mapping_from_args(ctx, &relation.outputs, "Fetch")?; + + Ok(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + input: Some(input), + common: Some(RelCommon { + emit_kind: Some(EmitKind::Emit(Emit { output_mapping })), + ..Default::default() + }), + advanced_extension: None, + offset_mode, + count_mode, + }))), + }) +} + +#[allow(clippy::vec_box)] +fn lower_join( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + mut child_relations: Vec>, +) -> Result { + ensure_no_named_args(ctx, relation, "Join")?; + + ensure_exact_child_count(ctx, child_relations.len(), 2, "Join")?; + + ensure_exact_positional_count( + ctx, + relation.args.positional.len(), + 2, + "Join", + "Join expects two positional args: join type and condition", + )?; + + let join_type = match &relation.args.positional[0] { + ast::Arg::Enum(value) => parse_join_type(ctx, value)?, + _ => { + return ctx.invalid("Join", "Join first argument must be enum (e.g. &Inner)"); + } + }; + + let condition = lower_arg_as_expression(ctx, &relation.args.positional[1], "Join")?; + + // Structural parsing preserves child order: first child is left input, second is right. + let left = child_relations.remove(0); + let right = child_relations.remove(0); + + let output_mapping = output_mapping_from_args(ctx, &relation.outputs, "Join")?; + + Ok(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: Some(RelCommon { + emit_kind: Some(EmitKind::Emit(Emit { output_mapping })), + ..Default::default() + }), + left: Some(left), + right: Some(right), + expression: Some(Box::new(condition)), + post_join_filter: None, + r#type: join_type as i32, + advanced_extension: None, + }))), + }) +} diff --git a/src/parser/lower/types.rs b/src/parser/lower/types.rs new file mode 100644 index 0000000..02628a2 --- /dev/null +++ b/src/parser/lower/types.rs @@ -0,0 +1,194 @@ +//! Type lowering and extension-anchor resolution. + +use substrait::proto::Type; +use substrait::proto::r#type::{self as ptype, Kind, Nullability, Parameter}; + +use super::{Lower, LowerCtx}; +use crate::extensions::simple::ExtensionKind; +use crate::parser::ast; +use crate::parser::errors::{ErrorKind, MessageParseError, ParseError}; + +impl Lower for ast::TypeExpr { + type Output = Type; + + fn lower(&self, ctx: &LowerCtx<'_>, message: &'static str) -> Result { + Ok(Type { + kind: Some(lower_type_kind(ctx, self, message)?), + }) + } +} + +pub(crate) fn lower_type_kind( + ctx: &LowerCtx<'_>, + typ: &ast::TypeExpr, + message: &'static str, +) -> Result { + match typ { + ast::TypeExpr::Simple { name, nullability } => { + lower_simple_type(ctx, name, *nullability, message) + } + ast::TypeExpr::List { nullability, inner } => Ok(Kind::List(Box::new(ptype::List { + nullability: to_nullability(*nullability) as i32, + r#type: Some(Box::new(inner.lower(ctx, message)?)), + type_variation_reference: 0, + }))), + ast::TypeExpr::UserDefined { + name, + anchor, + urn_anchor: _, + nullability, + parameters, + } => { + let type_reference = resolve_anchor(ctx, ExtensionKind::Type, *anchor, name, message)?; + let mut type_parameters = Vec::with_capacity(parameters.len()); + for parameter in parameters { + type_parameters.push(Parameter { + parameter: Some(ptype::parameter::Parameter::DataType( + parameter.lower(ctx, message)?, + )), + }); + } + Ok(Kind::UserDefined(ptype::UserDefined { + type_reference, + nullability: to_nullability(*nullability) as i32, + type_parameters, + type_variation_reference: 0, + })) + } + } +} + +pub(crate) fn lower_type_kind_no_box( + ctx: &LowerCtx<'_>, + typ: &ast::TypeExpr, + message: &'static str, +) -> Result { + // Type annotations in literals currently only support simple types. + match typ { + ast::TypeExpr::Simple { name, nullability } => { + lower_simple_type(ctx, name, *nullability, message) + } + _ => ctx.invalid( + message, + format!("literal type annotation must be simple type, got '{typ}'"), + ), + } +} + +pub(crate) fn lower_simple_type( + ctx: &LowerCtx<'_>, + name: &str, + nullability: ast::Nullability, + message: &'static str, +) -> Result { + let nullability = to_nullability(nullability) as i32; + let kind = match name { + "boolean" => Kind::Bool(ptype::Boolean { + nullability, + type_variation_reference: 0, + }), + "i8" => Kind::I8(ptype::I8 { + nullability, + type_variation_reference: 0, + }), + "i16" => Kind::I16(ptype::I16 { + nullability, + type_variation_reference: 0, + }), + "i32" => Kind::I32(ptype::I32 { + nullability, + type_variation_reference: 0, + }), + "i64" => Kind::I64(ptype::I64 { + nullability, + type_variation_reference: 0, + }), + "fp32" => Kind::Fp32(ptype::Fp32 { + nullability, + type_variation_reference: 0, + }), + "fp64" => Kind::Fp64(ptype::Fp64 { + nullability, + type_variation_reference: 0, + }), + "string" => Kind::String(ptype::String { + nullability, + type_variation_reference: 0, + }), + "binary" => Kind::Binary(ptype::Binary { + nullability, + type_variation_reference: 0, + }), + #[allow(deprecated)] + "timestamp" => Kind::Timestamp(ptype::Timestamp { + nullability, + type_variation_reference: 0, + }), + #[allow(deprecated)] + "timestamp_tz" => Kind::TimestampTz(ptype::TimestampTz { + nullability, + type_variation_reference: 0, + }), + "date" => Kind::Date(ptype::Date { + nullability, + type_variation_reference: 0, + }), + "time" => Kind::Time(ptype::Time { + nullability, + type_variation_reference: 0, + }), + "interval_year" => Kind::IntervalYear(ptype::IntervalYear { + nullability, + type_variation_reference: 0, + }), + "uuid" => Kind::Uuid(ptype::Uuid { + nullability, + type_variation_reference: 0, + }), + // Unsupported parity-first simple types; keep explicit validation error. + "u8" | "u16" | "u32" | "u64" => { + return ctx.invalid(message, format!("unsupported type '{name}'")); + } + _ => { + return ctx.invalid(message, format!("unknown type '{name}'")); + } + }; + Ok(kind) +} + +pub(crate) fn resolve_anchor( + ctx: &LowerCtx<'_>, + kind: ExtensionKind, + anchor: Option, + name: &str, + message: &'static str, +) -> Result { + let resolved = match anchor { + // When an explicit anchor is present, validate that it matches the + // requested name and extension kind. + Some(anchor) => ctx + .extensions + .is_name_unique(kind, anchor, name) + .map(|_| anchor), + None => ctx.extensions.find_by_name(kind, name), + }; + + resolved.map_err(|missing| { + ParseError::Plan( + ctx.parse_context(), + MessageParseError::new( + message, + ErrorKind::Lookup(missing), + format!("unable to resolve {kind} '{name}'"), + ), + ) + }) +} + +pub(crate) fn to_nullability(n: ast::Nullability) -> Nullability { + match n { + ast::Nullability::Required => Nullability::Required, + ast::Nullability::Nullable => Nullability::Nullable, + ast::Nullability::Unspecified => Nullability::Unspecified, + } +} diff --git a/src/parser/lower/validate.rs b/src/parser/lower/validate.rs new file mode 100644 index 0000000..0a4282e --- /dev/null +++ b/src/parser/lower/validate.rs @@ -0,0 +1,157 @@ +//! Shared relation-level validation helpers used by lowering. + +use std::collections::HashSet; + +use substrait::proto::sort_field::SortDirection; +use substrait::proto::{Rel, join_rel}; + +use super::LowerCtx; +use crate::parser::ast; +use crate::parser::errors::ParseError; + +pub(crate) fn ensure_no_named_args( + ctx: &LowerCtx<'_>, + relation: &ast::Relation, + message: &'static str, +) -> Result<(), ParseError> { + if relation.args.named.is_empty() { + return Ok(()); + } + ctx.invalid(message, "named arguments are not allowed here") +} + +pub(crate) fn ensure_no_children( + ctx: &LowerCtx<'_>, + child_count: usize, + message: &'static str, +) -> Result<(), ParseError> { + if child_count == 0 { + return Ok(()); + } + ctx.invalid( + message, + format!("{message} should have no input children, found {child_count}"), + ) +} + +pub(crate) fn ensure_exact_child_count( + ctx: &LowerCtx<'_>, + child_count: usize, + expected: usize, + message: &'static str, +) -> Result<(), ParseError> { + if child_count == expected { + return Ok(()); + } + + let suffix = if expected == 1 { "" } else { "ren" }; + ctx.invalid( + message, + format!( + "{message} should have exactly {expected} input child{suffix}, found {child_count}" + ), + ) +} + +#[allow(clippy::vec_box)] +pub(crate) fn expect_one_child( + ctx: &LowerCtx<'_>, + child_relations: &mut Vec>, + message: &'static str, +) -> Result, ParseError> { + match child_relations.len() { + 1 => Ok(child_relations.remove(0)), + n => ctx.invalid( + message, + format!("{message} should have exactly one input child, found {n}"), + ), + } +} + +pub(crate) fn ensure_exact_positional_count( + ctx: &LowerCtx<'_>, + positional_count: usize, + expected: usize, + message: &'static str, + detail: &'static str, +) -> Result<(), ParseError> { + if positional_count == expected { + return Ok(()); + } + ctx.invalid(message, detail) +} + +pub(crate) fn ensure_unique_named_arg( + ctx: &LowerCtx<'_>, + seen: &mut HashSet, + name: &str, + message: &'static str, +) -> Result<(), ParseError> { + if seen.insert(name.to_string()) { + return Ok(()); + } + + ctx.invalid(message, format!("duplicate named argument '{name}'")) +} + +pub(crate) fn output_mapping_from_args( + ctx: &LowerCtx<'_>, + args: &[ast::Arg], + message: &'static str, +) -> Result, ParseError> { + // Output mappings are index-only in the text format and refer to field + // positions after relation-specific projection/aggregation semantics. + let mut mapping = Vec::with_capacity(args.len()); + for arg in args { + match arg { + ast::Arg::Expr(ast::Expr::FieldRef(index)) => mapping.push(*index), + _ => { + return ctx.invalid( + message, + format!("output mapping expects field references, got '{arg}'"), + ); + } + } + } + Ok(mapping) +} + +pub(crate) fn parse_sort_direction( + ctx: &LowerCtx<'_>, + value: &str, +) -> Result { + let direction = match value { + "AscNullsFirst" => SortDirection::AscNullsFirst, + "AscNullsLast" => SortDirection::AscNullsLast, + "DescNullsFirst" => SortDirection::DescNullsFirst, + "DescNullsLast" => SortDirection::DescNullsLast, + _ => { + return ctx.invalid("Sort", format!("unknown sort direction '&{value}'")); + } + }; + Ok(direction) +} + +pub(crate) fn parse_join_type( + ctx: &LowerCtx<'_>, + value: &str, +) -> Result { + let join_type = match value { + "Inner" => join_rel::JoinType::Inner, + "Left" => join_rel::JoinType::Left, + "Right" => join_rel::JoinType::Right, + "Outer" => join_rel::JoinType::Outer, + "LeftSemi" => join_rel::JoinType::LeftSemi, + "RightSemi" => join_rel::JoinType::RightSemi, + "LeftAnti" => join_rel::JoinType::LeftAnti, + "RightAnti" => join_rel::JoinType::RightAnti, + "LeftSingle" => join_rel::JoinType::LeftSingle, + "RightSingle" => join_rel::JoinType::RightSingle, + "LeftMark" => join_rel::JoinType::LeftMark, + "RightMark" => join_rel::JoinType::RightMark, + _ => { + return ctx.invalid("Join", format!("unknown join type '&{value}'")); + } + }; + Ok(join_type) +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 03fdfe5..f11b886 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,12 +1,23 @@ -pub mod common; +//! Text → Substrait protobuf parsing. +//! +//! Parsing proceeds in two phases: +//! +//! 1. **Syntax** — each line is parsed into an [`ast`] node by the LALRPOP +//! grammar (invoked through `lalrpop_line`). The [`structural`] module +//! drives this, handling indentation, section headers, and the extensions +//! preamble. +//! 2. **Lowering** — the AST is converted to a [`substrait::proto::Plan`] by +//! the [`lower`] module, which performs semantic validation (anchor +//! resolution, column counts, type checking). +//! +//! The main entry point is [`Parser::parse`]. + +pub mod ast; pub mod errors; -pub mod expressions; pub mod extensions; -pub mod relations; +pub(crate) mod lalrpop_line; +pub mod lower; pub mod structural; -pub mod types; -pub use common::*; -pub use errors::{ParseContext, ParseError, ParseResult}; -pub use relations::RelationParsePair; +pub use errors::{ErrorKind, MessageParseError, ParseContext, ParseError, ParseResult}; pub use structural::{PLAN_HEADER, Parser}; diff --git a/src/parser/relations.rs b/src/parser/relations.rs index 12c8a36..e69de29 100644 --- a/src/parser/relations.rs +++ b/src/parser/relations.rs @@ -1,1505 +0,0 @@ -use std::collections::HashMap; - -use pest::iterators::Pair; -use prost::Message; -use substrait::proto::aggregate_rel::Grouping; -use substrait::proto::expression::literal::LiteralType; -use substrait::proto::expression::{Literal, RexType}; -use substrait::proto::fetch_rel::{CountMode, OffsetMode}; -use substrait::proto::rel::RelType; -use substrait::proto::rel_common::{Emit, EmitKind}; -use substrait::proto::sort_field::{SortDirection, SortKind}; -use substrait::proto::{ - AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel, - RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type, -}; - -use super::{ - ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair, -}; -use crate::extensions::any::Any; -use crate::extensions::registry::ExtensionError; -use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions}; -use crate::parser::errors::{ParseContext, ParseError}; -use crate::parser::expressions::{FieldIndex, Name}; - -/// Parsing context for relations that includes extensions, registry, and optional warning collection -pub struct RelationParsingContext<'a> { - pub extensions: &'a SimpleExtensions, - pub registry: &'a ExtensionRegistry, - pub line_no: i64, - pub line: &'a str, -} - -impl<'a> RelationParsingContext<'a> { - /// Resolve extension detail using registry. Any failure is treated as a hard parse error. - pub fn resolve_extension_detail( - &self, - extension_name: &str, - extension_args: &ExtensionArgs, - ) -> Result, ParseError> { - let detail = self - .registry - .parse_extension(extension_name, extension_args); - - match detail { - Ok(any) => Ok(Some(any)), - Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension { - name: extension_name.to_string(), - context: ParseContext::new(self.line_no, self.line.to_string()), - }), - Err(err) => Err(ParseError::ExtensionDetail( - ParseContext::new(self.line_no, self.line.to_string()), - err, - )), - } - } -} - -/// A trait for parsing relations with full context needed for tree building. -/// This includes extensions, the parsed pair, input children, and output field count. -pub trait RelationParsePair: Sized { - fn rule() -> Rule; - - fn message() -> &'static str; - - /// Parse a relation with full context for tree building. - /// - /// Args: - /// - extensions: The extensions context - /// - pair: The parsed pest pair - /// - input_children: The input relations (for wiring) - /// - input_field_count: Number of output fields from input children (for output mapping) - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result; - - fn into_rel(self) -> Rel; -} - -pub struct TableName(Vec); - -impl ParsePair for TableName { - fn rule() -> Rule { - Rule::table_name - } - - fn message() -> &'static str { - "TableName" - } - - fn parse_pair(pair: Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - let pairs = pair.into_inner(); - let mut names = Vec::with_capacity(pairs.len()); - let mut iter = RuleIter::from(pairs); - while let Some(name) = iter.parse_if_next::() { - names.push(name.0); - } - iter.done(); - Self(names) - } -} - -#[derive(Debug, Clone)] -pub struct Column { - pub name: String, - pub typ: Type, -} - -impl ScopedParsePair for Column { - fn rule() -> Rule { - Rule::named_column - } - - fn message() -> &'static str { - "Column" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut iter = RuleIter::from(pair.into_inner()); - let name = iter.parse_next::().0; - let typ = iter.parse_next_scoped(extensions)?; - iter.done(); - Ok(Self { name, typ }) - } -} - -pub struct NamedColumnList(Vec); - -impl ScopedParsePair for NamedColumnList { - fn rule() -> Rule { - Rule::named_column_list - } - - fn message() -> &'static str { - "NamedColumnList" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut columns = Vec::new(); - for col in pair.into_inner() { - columns.push(Column::parse_pair(extensions, col)?); - } - Ok(Self(columns)) - } -} - -/// This is a utility function for extracting a single child from the list of -/// children, to be used in the RelationParsePair trait. The RelationParsePair -/// trait passes a Vec of children, because some relations have multiple -/// children - but most accept exactly one child. -#[allow(clippy::vec_box)] -pub(crate) fn expect_one_child( - message: &'static str, - pair: &Pair, - mut input_children: Vec>, -) -> Result, MessageParseError> { - match input_children.len() { - 0 => Err(MessageParseError::invalid( - message, - pair.as_span(), - format!("{message} missing child"), - )), - 1 => Ok(input_children.pop().unwrap()), - n => Err(MessageParseError::invalid( - message, - pair.as_span(), - format!("{message} should have 1 input child, got {n}"), - )), - } -} - -/// Parse a reference list Pair and return an EmitKind::Emit. -fn parse_reference_emit(pair: Pair) -> EmitKind { - assert_eq!(pair.as_rule(), Rule::reference_list); - let output_mapping = pair - .into_inner() - .map(|p| FieldIndex::parse_pair(p).0) - .collect::>(); - EmitKind::Emit(Emit { output_mapping }) -} - -/// Extracts named arguments from pest pairs with duplicate detection and completeness checking. -/// -/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()` -/// -/// The fluent API ensures all arguments are processed exactly once and none are forgotten. -pub struct ParsedNamedArgs<'a> { - map: HashMap<&'a str, Pair<'a, Rule>>, -} - -impl<'a> ParsedNamedArgs<'a> { - pub fn new( - pairs: pest::iterators::Pairs<'a, Rule>, - rule: Rule, - ) -> Result { - let mut map = HashMap::new(); - for pair in pairs { - assert_eq!(pair.as_rule(), rule); - let mut inner = pair.clone().into_inner(); - let name_pair = inner.next().unwrap(); - let value_pair = inner.next().unwrap(); - assert_eq!(inner.next(), None); - let name = name_pair.as_str(); - if map.contains_key(name) { - return Err(MessageParseError::invalid( - "NamedArg", - name_pair.as_span(), - format!("Duplicate argument: {name}"), - )); - } - map.insert(name, value_pair); - } - Ok(Self { map }) - } - - // Returns the pair if it exists and matches the rule, otherwise None. - // Asserts that the rule must match the rule of the pair (and therefore - // panics in non-release-mode if not) - pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option>) { - let pair = self.map.remove(name).inspect(|pair| { - assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}"); - }); - (self, pair) - } - - // Returns an error if there are any unused arguments. - pub fn done(self) -> Result<(), MessageParseError> { - if let Some((name, pair)) = self.map.iter().next() { - return Err(MessageParseError::invalid( - "NamedArgExtractor", - // No span available for all unused args; use default. - pair.as_span(), - format!("Unknown argument: {name}"), - )); - } - Ok(()) - } -} - -impl RelationParsePair for ReadRel { - fn rule() -> Rule { - Rule::read_relation - } - - fn message() -> &'static str { - "ReadRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Read(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - // ReadRel is a leaf node - it should have no input children and 0 input fields - if !input_children.is_empty() { - return Err(MessageParseError::invalid( - Self::message(), - pair.as_span(), - "ReadRel should have no input children", - )); - } - if _input_field_count != 0 { - let error = pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { - message: "ReadRel should have 0 input fields".to_string(), - }, - pair.as_span(), - ); - return Err(MessageParseError::new( - "ReadRel", - ErrorKind::InvalidValue, - Box::new(error), - )); - } - - let mut iter = RuleIter::from(pair.into_inner()); - let table = iter.parse_next::().0; - let columns = iter.parse_next_scoped::(extensions)?.0; - iter.done(); - - let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip(); - let struct_ = r#type::Struct { - types, - type_variation_reference: 0, - nullability: r#type::Nullability::Required as i32, - }; - let named_struct = NamedStruct { - names, - r#struct: Some(struct_), - }; - - let read_rel = ReadRel { - base_schema: Some(named_struct), - read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable { - names: table, - advanced_extension: None, - })), - ..Default::default() - }; - - Ok(read_rel) - } -} - -impl RelationParsePair for FilterRel { - fn rule() -> Rule { - Rule::filter_relation - } - - fn message() -> &'static str { - "FilterRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Filter(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - // Form: Filter[condition => references] - - assert_eq!(pair.as_rule(), Self::rule()); - let input = expect_one_child(Self::message(), &pair, input_children)?; - let mut iter = RuleIter::from(pair.into_inner()); - // condition - let condition = iter.parse_next_scoped::(extensions)?; - // references (which become the emit) - let references_pair = iter.pop(Rule::reference_list); - iter.done(); - - let emit = parse_reference_emit(references_pair); - let common = RelCommon { - emit_kind: Some(emit), - ..Default::default() - }; - - Ok(FilterRel { - input: Some(input), - condition: Some(Box::new(condition)), - common: Some(common), - advanced_extension: None, - }) - } -} - -impl RelationParsePair for ProjectRel { - fn rule() -> Rule { - Rule::project_relation - } - - fn message() -> &'static str { - "ProjectRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Project(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let input = expect_one_child(Self::message(), &pair, input_children)?; - - // Get the argument list (contains references and expressions) - let arguments_pair = unwrap_single_pair(pair); - - let mut expressions = Vec::new(); - let mut output_mapping = Vec::new(); - - // Process each argument (can be either a reference or expression) - for arg in arguments_pair.into_inner() { - let inner_arg = unwrap_single_pair(arg); - match inner_arg.as_rule() { - Rule::reference => { - // Parse reference like "$0" -> 0 - let field_index = FieldIndex::parse_pair(inner_arg); - output_mapping.push(field_index.0); - } - Rule::expression => { - // Parse as expression (e.g., 42, add($0, $1)) - let _expr = Expression::parse_pair(extensions, inner_arg)?; - expressions.push(_expr); - // Expression: index after all input fields - output_mapping.push(_input_field_count as i32 + (expressions.len() as i32 - 1)); - } - _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()), - } - } - - let emit = EmitKind::Emit(Emit { output_mapping }); - let common = RelCommon { - emit_kind: Some(emit), - ..Default::default() - }; - - Ok(ProjectRel { - input: Some(input), - expressions, - common: Some(common), - advanced_extension: None, - }) - } -} - -impl RelationParsePair for AggregateRel { - fn rule() -> Rule { - Rule::aggregate_relation - } - - fn message() -> &'static str { - "AggregateRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Aggregate(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let input = expect_one_child(Self::message(), &pair, input_children)?; - let mut iter = RuleIter::from(pair.into_inner()); - let group_by_pair = iter.pop(Rule::aggregate_group_by); - let output_pair = iter.pop(Rule::aggregate_output); - iter.done(); - - let mut expression_index_map: HashMap, u32> = HashMap::new(); - let grouping_sets: Vec; - let grouping_expressions: Vec; - - // The grammar matches either a single `expression_list` (no parens, single grouping set) - // or a `grouping_set_list` (one or more sets, each wrapped in parens, or a set `_` with no parens). - let inner = group_by_pair - .into_inner() - .next() - .expect("aggregate_group_by must have one inner item"); - - let group: Grouping; - - match inner.as_rule() { - Rule::expression_list => { - // Single non-empty grouping set written without parens: e.g. "$0, $1" - (group, grouping_expressions) = - get_single_group(extensions, inner, &mut expression_index_map); - grouping_sets = vec![group]; - } - Rule::grouping_set_list => { - (grouping_sets, grouping_expressions) = get_grouping_sets(extensions, inner); - } - _ => unreachable!( - "Unexpected rule in aggregate_group_by: {:?}", - inner.as_rule() - ), - } - - // Parse output items (can be references or aggregate measures) - let mut measures = Vec::new(); - let mut output_mapping = Vec::new(); - let mut measure_count = 0; - - for aggregate_output_item in output_pair.into_inner() { - let inner_item = unwrap_single_pair(aggregate_output_item); - match inner_item.as_rule() { - Rule::reference => { - let field_index = FieldIndex::parse_pair(inner_item); - output_mapping.push(field_index.0); - } - Rule::aggregate_measure => { - let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?; - measures.push(measure); - output_mapping.push(grouping_expressions.len() as i32 + measure_count); - measure_count += 1; - } - _ => panic!( - "Unexpected inner output item rule: {:?}", - inner_item.as_rule() - ), - } - } - - let emit = EmitKind::Emit(Emit { output_mapping }); - let common = RelCommon { - emit_kind: Some(emit), - ..Default::default() - }; - - Ok(AggregateRel { - input: Some(input), - grouping_expressions, - groupings: grouping_sets, - measures, - common: Some(common), - advanced_extension: None, - }) - } -} - -fn get_single_group( - extensions: &SimpleExtensions, - inner: Pair<'_, Rule>, - expression_index_map: &mut HashMap, u32>, -) -> (Grouping, Vec) { - let mut expression_references: Vec = vec![]; - let mut grouping_expressions: Vec = Vec::new(); - let mut i = 0; - - // for each expression in the expression list we map to an index in the order they are parsed - for expr_pair in inner.into_inner() { - match Expression::parse_pair(extensions, expr_pair) { - Ok(exp) => { - // TODO: use a better key here than encoding to bytes. - // Ideally, substrait-rs would support `PartialEq` and `Hash`, - // but as there isn't an easy way to do that now, we'll skip. - let key = exp.encode_to_vec(); - expression_index_map.entry(key.clone()).or_insert_with(|| { - grouping_expressions.push(exp); - let index = i; - i += 1; - index // is expression returned by this closure and inserted into map - }); - expression_references.push(expression_index_map[&key]); - } - Err(_) => unreachable!( - "By the grammar rule, only expressions should be parsed so erroring can be dropped silently" - ), - } - } - - let group = Grouping { - expression_references, - #[allow(deprecated)] - grouping_expressions: vec![], // this is a deprecated proto field replaced by expression_references - }; - - (group, grouping_expressions) -} - -fn get_grouping_sets( - extensions: &SimpleExtensions, - inner: Pair<'_, Rule>, -) -> (Vec, Vec) { - let mut expression_index_map: HashMap, u32> = HashMap::new(); - - let mut grouping_sets: Vec = Vec::new(); - let mut grouping_expressions: Vec = Vec::new(); - - // One or more grouping sets, each written as `(expr_list)` or `_` - for grouping_set_pair in inner.into_inner() { - assert_eq!(grouping_set_pair.as_rule(), Rule::grouping_set); - - for item in grouping_set_pair.into_inner() { - match item.as_rule() { - Rule::empty => { - // empty grouping set — expression_references stays empty - grouping_sets.push(Grouping { - expression_references: vec![], - #[allow(deprecated)] - grouping_expressions: vec![], // this is a deprecated proto field replaced by expression_references - }); - } - Rule::expression_list => { - let (group, mut new_grouping_expressions) = - get_single_group(extensions, item, &mut expression_index_map); - grouping_expressions.append(&mut new_grouping_expressions); - grouping_sets.push(group); - } - _ => unreachable!("Unexpected item in grouping_set: {:?}", item.as_rule()), - } - } - } - (grouping_sets, grouping_expressions) -} - -impl ScopedParsePair for SortField { - fn rule() -> Rule { - Rule::sort_field - } - - fn message() -> &'static str { - "SortField" - } - - fn parse_pair( - _extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut iter = RuleIter::from(pair.into_inner()); - let reference_pair = iter.pop(Rule::reference); - let field_index = FieldIndex::parse_pair(reference_pair); - let direction_pair = iter.pop(Rule::sort_direction); - // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" -> - // "AscNullsFirst") The grammar includes '&' to distinguish enums from - // identifiers, but the enum variant names don't include it - let direction = match direction_pair.as_str().trim_start_matches('&') { - "AscNullsFirst" => SortDirection::AscNullsFirst, - "AscNullsLast" => SortDirection::AscNullsLast, - "DescNullsFirst" => SortDirection::DescNullsFirst, - "DescNullsLast" => SortDirection::DescNullsLast, - other => { - return Err(MessageParseError::invalid( - "SortDirection", - direction_pair.as_span(), - format!("Unknown sort direction: {other}"), - )); - } - }; - iter.done(); - Ok(SortField { - expr: Some(Expression { - rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new( - field_index.to_field_reference(), - ))), - }), - // TODO: Add support for SortKind::ComparisonFunctionReference - sort_kind: Some(SortKind::Direction(direction as i32)), - }) - } -} - -impl RelationParsePair for SortRel { - fn rule() -> Rule { - Rule::sort_relation - } - - fn message() -> &'static str { - "SortRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Sort(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let input = expect_one_child(Self::message(), &pair, input_children)?; - let mut iter = RuleIter::from(pair.into_inner()); - let sort_field_list_pair = iter.pop(Rule::sort_field_list); - let reference_list_pair = iter.pop(Rule::reference_list); - let mut sorts = Vec::new(); - for sort_field_pair in sort_field_list_pair.into_inner() { - let sort_field = SortField::parse_pair(extensions, sort_field_pair)?; - sorts.push(sort_field); - } - let emit = parse_reference_emit(reference_list_pair); - let common = RelCommon { - emit_kind: Some(emit), - ..Default::default() - }; - iter.done(); - Ok(SortRel { - input: Some(input), - sorts, - common: Some(common), - advanced_extension: None, - }) - } -} - -impl ScopedParsePair for CountMode { - fn rule() -> Rule { - Rule::fetch_value - } - fn message() -> &'static str { - "CountMode" - } - fn parse_pair( - extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut arg_inner = RuleIter::from(pair.into_inner()); - let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) { - int_pair - } else { - arg_inner.pop(Rule::expression) - }; - match value_pair.as_rule() { - Rule::integer => { - let value = value_pair.as_str().parse::().map_err(|e| { - MessageParseError::invalid( - Self::message(), - value_pair.as_span(), - format!("Invalid integer: {e}"), - ) - })?; - if value < 0 { - return Err(MessageParseError::invalid( - Self::message(), - value_pair.as_span(), - format!("Fetch limit must be non-negative, got: {value}"), - )); - } - Ok(CountMode::CountExpr(i64_literal_expr(value))) - } - Rule::expression => { - let expr = Expression::parse_pair(extensions, value_pair)?; - Ok(CountMode::CountExpr(Box::new(expr))) - } - _ => Err(MessageParseError::invalid( - Self::message(), - value_pair.as_span(), - format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()), - )), - } - } -} - -fn i64_literal_expr(value: i64) -> Box { - Box::new(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: false, - type_variation_reference: 0, - literal_type: Some(LiteralType::I64(value)), - })), - }) -} - -impl ScopedParsePair for OffsetMode { - fn rule() -> Rule { - Rule::fetch_value - } - fn message() -> &'static str { - "OffsetMode" - } - fn parse_pair( - extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let mut arg_inner = RuleIter::from(pair.into_inner()); - let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) { - int_pair - } else { - arg_inner.pop(Rule::expression) - }; - match value_pair.as_rule() { - Rule::integer => { - let value = value_pair.as_str().parse::().map_err(|e| { - MessageParseError::invalid( - Self::message(), - value_pair.as_span(), - format!("Invalid integer: {e}"), - ) - })?; - if value < 0 { - return Err(MessageParseError::invalid( - Self::message(), - value_pair.as_span(), - format!("Fetch offset must be non-negative, got: {value}"), - )); - } - Ok(OffsetMode::OffsetExpr(i64_literal_expr(value))) - } - Rule::expression => { - let expr = Expression::parse_pair(extensions, value_pair)?; - Ok(OffsetMode::OffsetExpr(Box::new(expr))) - } - _ => Err(MessageParseError::invalid( - Self::message(), - value_pair.as_span(), - format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()), - )), - } - } -} - -impl RelationParsePair for FetchRel { - fn rule() -> Rule { - Rule::fetch_relation - } - - fn message() -> &'static str { - "FetchRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Fetch(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - let input = expect_one_child(Self::message(), &pair, input_children)?; - let mut iter = RuleIter::from(pair.into_inner()); - - // Extract all pairs first, then do validation - let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) { - None => { - // If there are no arguments, it should be empty - iter.pop(Rule::empty); - (None, None) - } - Some(fetch_args_pair) => { - let extractor = - ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?; - let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value); - let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value); - extractor.done()?; - (limit_pair, offset_pair) - } - }; - - let reference_list_pair = iter.pop(Rule::reference_list); - let emit = parse_reference_emit(reference_list_pair); - let common = RelCommon { - emit_kind: Some(emit), - ..Default::default() - }; - iter.done(); - - // Now do validation after iterator is fully consumed - let count_mode = limit_pair - .map(|pair| CountMode::parse_pair(extensions, pair)) - .transpose()?; - let offset_mode = offset_pair - .map(|pair| OffsetMode::parse_pair(extensions, pair)) - .transpose()?; - Ok(FetchRel { - input: Some(input), - common: Some(common), - advanced_extension: None, - offset_mode, - count_mode, - }) - } -} - -impl ParsePair for join_rel::JoinType { - fn rule() -> Rule { - Rule::join_type - } - - fn message() -> &'static str { - "JoinType" - } - - fn parse_pair(pair: Pair) -> Self { - assert_eq!(pair.as_rule(), Self::rule()); - let join_type_str = pair.as_str().trim_start_matches('&'); - match join_type_str { - "Inner" => join_rel::JoinType::Inner, - "Left" => join_rel::JoinType::Left, - "Right" => join_rel::JoinType::Right, - "Outer" => join_rel::JoinType::Outer, - "LeftSemi" => join_rel::JoinType::LeftSemi, - "RightSemi" => join_rel::JoinType::RightSemi, - "LeftAnti" => join_rel::JoinType::LeftAnti, - "RightAnti" => join_rel::JoinType::RightAnti, - "LeftSingle" => join_rel::JoinType::LeftSingle, - "RightSingle" => join_rel::JoinType::RightSingle, - "LeftMark" => join_rel::JoinType::LeftMark, - "RightMark" => join_rel::JoinType::RightMark, - _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"), - } - } -} - -impl RelationParsePair for JoinRel { - fn rule() -> Rule { - Rule::join_relation - } - - fn message() -> &'static str { - "JoinRel" - } - - fn into_rel(self) -> Rel { - Rel { - rel_type: Some(RelType::Join(Box::new(self))), - } - } - - fn parse_pair_with_context( - extensions: &SimpleExtensions, - pair: Pair, - input_children: Vec>, - _input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Self::rule()); - - // Join requires exactly 2 input children - if input_children.len() != 2 { - return Err(MessageParseError::invalid( - Self::message(), - pair.as_span(), - format!( - "JoinRel should have exactly 2 input children, got {}", - input_children.len() - ), - )); - } - - let mut children_iter = input_children.into_iter(); - let left = children_iter.next().unwrap(); - let right = children_iter.next().unwrap(); - - let mut iter = RuleIter::from(pair.into_inner()); - - // Parse join type - let join_type = iter.parse_next::(); - - // Parse join condition expression - let condition = iter.parse_next_scoped::(extensions)?; - - // Parse output references (which become the emit) - let reference_list_pair = iter.pop(Rule::reference_list); - iter.done(); - - let emit = parse_reference_emit(reference_list_pair); - let common = RelCommon { - emit_kind: Some(emit), - ..Default::default() - }; - - Ok(JoinRel { - common: Some(common), - left: Some(left), - right: Some(right), - expression: Some(Box::new(condition)), - post_join_filter: None, // Not supported in grammar yet - r#type: join_type as i32, - advanced_extension: None, - }) - } -} - -#[cfg(test)] -mod tests { - use pest::Parser; - - use super::*; - use crate::fixtures::TestContext; - use crate::parser::{ExpressionParser, Rule}; - - #[test] - fn test_parse_relation() { - // Removed: test_parse_relation for old Relation struct - } - - #[test] - fn test_parse_read_relation() { - let extensions = SimpleExtensions::default(); - let read = ReadRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"), - vec![], - 0, - ) - .unwrap(); - let names = match &read.read_type { - Some(read_rel::ReadType::NamedTable(table)) => &table.names, - _ => panic!("Expected NamedTable"), - }; - assert_eq!(names, &["ab", "cd", "ef"]); - let columns = &read - .base_schema - .as_ref() - .unwrap() - .r#struct - .as_ref() - .unwrap() - .types; - assert_eq!(columns.len(), 2); - } - - /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64 - fn example_read_relation() -> ReadRel { - let extensions = SimpleExtensions::default(); - ReadRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::read_relation, - "Read[ab.cd.ef => a:i32, b:string?, c:i64]", - ), - vec![], - 0, - ) - .unwrap() - } - - #[test] - fn test_parse_filter_relation() { - let extensions = SimpleExtensions::default(); - let filter = FilterRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - assert_eq!(emit, &[0, 1, 2]); - } - - #[test] - fn test_parse_project_relation() { - let extensions = SimpleExtensions::default(); - let project = ProjectRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::project_relation, "Project[$0, $1, 42]"), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - - // Should have 1 expression (42) and 2 references ($0, $1) - assert_eq!(project.expressions.len(), 1); - - let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3. - assert_eq!(emit, &[0, 1, 3]); - } - - #[test] - fn test_parse_project_relation_complex() { - let extensions = SimpleExtensions::default(); - let project = ProjectRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"), - vec![Box::new(example_read_relation().into_rel())], - 5, // Assume 5 input fields - ) - .unwrap(); - - // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1) - assert_eq!(project.expressions.len(), 2); - - let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions) - // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1) - assert_eq!(emit, &[5, 0, 6, 2, 1]); - } - - #[test] - fn test_parse_aggregate_relation() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml") - .with_function(1, 10, "sum") - .with_function(1, 11, "count") - .extensions; - - let aggregate = AggregateRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::aggregate_relation, - "Aggregate[($0, $1), _ => sum($2), $0, count($2)]", - ), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - - // Should have 2 group-by sets ($0, $1) and an empty group, and emit 2 measures (sum($2), count($2)) - assert_eq!(aggregate.grouping_expressions.len(), 2); - assert_eq!(aggregate.groupings[0].expression_references.len(), 2); - assert_eq!(aggregate.groupings.len(), 2); - assert_eq!(aggregate.measures.len(), 2); - - let emit_kind = &aggregate - .common - .as_ref() - .unwrap() - .emit_kind - .as_ref() - .unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [2, 0, 3] (measures and group-by fields in order) - // sum($2) -> 2, $0 -> 0, count($2) -> 3 - assert_eq!(emit, &[2, 0, 3]); - } - - #[test] - fn test_parse_aggregate_relation_maintain_column_order() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml") - .with_function(1, 10, "sum") - .with_function(1, 11, "count") - .extensions; - - let aggregate = AggregateRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::aggregate_relation, - "Aggregate[$0 => sum($1), $0, count($1)]", - ), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - - // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1)) - assert_eq!(aggregate.grouping_expressions.len(), 1); - assert_eq!(aggregate.groupings.len(), 1); - assert_eq!(aggregate.measures.len(), 2); - - let emit_kind = &aggregate - .common - .as_ref() - .unwrap() - .emit_kind - .as_ref() - .unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [1, 0, 2] (grouping fields + measures) - assert_eq!(emit, &[1, 0, 2]); - } - - #[test] - fn test_parse_aggregate_relation_simple() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml") - .with_function(1, 10, "sum") - .extensions; - - let aggregate = AggregateRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - - assert_eq!(aggregate.grouping_expressions.len(), 2); - assert_eq!(aggregate.groupings.len(), 1); - // expression_references must be positions [0, 1], not raw field indices [2, 0] - assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]); - } - - #[test] - fn test_parse_aggregate_relation_global_aggregate() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml") - .with_function(1, 10, "sum") - .with_function(1, 11, "count") - .extensions; - - let aggregate = AggregateRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::aggregate_relation, - "Aggregate[_ => sum($0), count($1)]", - ), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - - // Should have 0 group-by fields and 2 measures - assert_eq!(aggregate.grouping_expressions.len(), 0); - assert_eq!(aggregate.groupings.len(), 1); - assert_eq!(aggregate.groupings[0].expression_references.len(), 0); - assert_eq!(aggregate.measures.len(), 2); - - let emit_kind = &aggregate - .common - .as_ref() - .unwrap() - .emit_kind - .as_ref() - .unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [0, 1] (measures only, no group-by fields) - assert_eq!(emit, &[0, 1]); - } - - #[test] - fn test_parse_aggregate_relation_grouping_sets() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml") - .with_function(1, 11, "count") - .extensions; - - let read_rel = ReadRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::read_relation, - "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]", - ), - vec![], - 0, - ) - .unwrap(); - - let aggregate = AggregateRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::aggregate_relation, - "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]", - ), - vec![Box::new(read_rel.into_rel())], - 4, - ) - .unwrap(); - - // Should have 4 group-by fields and 1 measures - assert_eq!(aggregate.grouping_expressions.len(), 3); - assert_eq!(aggregate.groupings.len(), 4); - assert_eq!(aggregate.groupings[0].expression_references.len(), 3); - assert_eq!(aggregate.groupings[1].expression_references.len(), 2); - assert_eq!(aggregate.groupings[2].expression_references.len(), 1); - assert_eq!(aggregate.groupings[3].expression_references.len(), 0); - assert_eq!(aggregate.measures.len(), 1); - } - - #[test] - fn test_fetch_relation_positive_values() { - let extensions = SimpleExtensions::default(); - - // Test valid positive values should work - let fetch_rel = FetchRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"), - vec![Box::new(example_read_relation().into_rel())], - 3, - ) - .unwrap(); - - // Verify the limit and offset values are correct - assert_eq!( - fetch_rel.count_mode, - Some(CountMode::CountExpr(i64_literal_expr(10))) - ); - assert_eq!( - fetch_rel.offset_mode, - Some(OffsetMode::OffsetExpr(i64_literal_expr(5))) - ); - } - - #[test] - fn test_fetch_relation_negative_limit_rejected() { - let extensions = SimpleExtensions::default(); - - // Test that fetch relations with negative limits are properly rejected - let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]"); - if let Ok(mut pairs) = parsed_result { - let pair = pairs.next().unwrap(); - if pair.as_str() == "Fetch[limit=-5 => $0]" { - // Full parse succeeded, now test that validation catches the negative value - let result = FetchRel::parse_pair_with_context( - &extensions, - pair, - vec![Box::new(example_read_relation().into_rel())], - 3, - ); - assert!(result.is_err()); - let error_msg = result.unwrap_err().to_string(); - assert!(error_msg.contains("Fetch limit must be non-negative")); - } else { - // If grammar doesn't fully support negative values, that's also acceptable - // since it would prevent negative values at parse time - println!("Grammar prevents negative limit values at parse time"); - } - } else { - // Grammar doesn't support negative values in fetch context - println!("Grammar prevents negative limit values at parse time"); - } - } - - #[test] - fn test_fetch_relation_negative_offset_rejected() { - let extensions = SimpleExtensions::default(); - - // Test that fetch relations with negative offsets are properly rejected - let parsed_result = - ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]"); - if let Ok(mut pairs) = parsed_result { - let pair = pairs.next().unwrap(); - if pair.as_str() == "Fetch[offset=-10 => $0]" { - // Full parse succeeded, now test that validation catches the negative value - let result = FetchRel::parse_pair_with_context( - &extensions, - pair, - vec![Box::new(example_read_relation().into_rel())], - 3, - ); - assert!(result.is_err()); - let error_msg = result.unwrap_err().to_string(); - assert!(error_msg.contains("Fetch offset must be non-negative")); - } else { - // If grammar doesn't fully support negative values, that's also acceptable - println!("Grammar prevents negative offset values at parse time"); - } - } else { - // Grammar doesn't support negative values in fetch context - println!("Grammar prevents negative offset values at parse time"); - } - } - - #[test] - fn test_parse_join_relation() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml") - .with_function(1, 10, "eq") - .extensions; - - let left_rel = example_read_relation().into_rel(); - let right_rel = example_read_relation().into_rel(); - - let join = JoinRel::parse_pair_with_context( - &extensions, - parse_exact( - Rule::join_relation, - "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]", - ), - vec![Box::new(left_rel), Box::new(right_rel)], - 6, // left (3) + right (3) = 6 total input fields - ) - .unwrap(); - - // Should be an Inner join - assert_eq!(join.r#type, join_rel::JoinType::Inner as i32); - - // Should have left and right relations - assert!(join.left.is_some()); - assert!(join.right.is_some()); - - // Should have a join condition - assert!(join.expression.is_some()); - - let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [0, 1, 3, 4] (selected columns) - assert_eq!(emit, &[0, 1, 3, 4]); - } - - #[test] - fn test_parse_join_relation_left_outer() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml") - .with_function(1, 10, "eq") - .extensions; - - let left_rel = example_read_relation().into_rel(); - let right_rel = example_read_relation().into_rel(); - - let join = JoinRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"), - vec![Box::new(left_rel), Box::new(right_rel)], - 6, - ) - .unwrap(); - - // Should be a Left join - assert_eq!(join.r#type, join_rel::JoinType::Left as i32); - - let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [0, 1, 2] - assert_eq!(emit, &[0, 1, 2]); - } - - #[test] - fn test_parse_join_relation_left_semi() { - let extensions = TestContext::new() - .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml") - .with_function(1, 10, "eq") - .extensions; - - let left_rel = example_read_relation().into_rel(); - let right_rel = example_read_relation().into_rel(); - - let join = JoinRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"), - vec![Box::new(left_rel), Box::new(right_rel)], - 6, - ) - .unwrap(); - - // Should be a LeftSemi join - assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32); - - let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap(); - let emit = match emit_kind { - EmitKind::Emit(emit) => &emit.output_mapping, - _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"), - }; - // Output mapping should be [0, 1] (only left columns for semi join) - assert_eq!(emit, &[0, 1]); - } - - #[test] - fn test_parse_join_relation_requires_two_children() { - let extensions = SimpleExtensions::default(); - - // Test with 0 children - let result = JoinRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"), - vec![], - 0, - ); - assert!(result.is_err()); - - // Test with 1 child - let result = JoinRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"), - vec![Box::new(example_read_relation().into_rel())], - 3, - ); - assert!(result.is_err()); - - // Test with 3 children - let result = JoinRel::parse_pair_with_context( - &extensions, - parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"), - vec![ - Box::new(example_read_relation().into_rel()), - Box::new(example_read_relation().into_rel()), - Box::new(example_read_relation().into_rel()), - ], - 9, - ); - assert!(result.is_err()); - } - - fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> { - let mut pairs = ExpressionParser::parse(rule, input).unwrap(); - assert_eq!(pairs.as_str(), input); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - pair - } -} diff --git a/src/parser/structural.rs b/src/parser/structural.rs index b2d513c..6160a66 100644 --- a/src/parser/structural.rs +++ b/src/parser/structural.rs @@ -1,29 +1,63 @@ //! Parser for the structural part of the Substrait file format. -//! -//! This is the overall parser for parsing the text format. It is responsible -//! for tracking which section of the file we are currently parsing, and parsing -//! each line separately. +use std::borrow::Cow; use std::fmt; use substrait::proto::rel::RelType; -use substrait::proto::{ - AggregateRel, FetchRel, FilterRel, JoinRel, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, - SortRel, plan_rel, -}; +use substrait::proto::{Plan, PlanRel, Rel, RelRoot, plan_rel}; use crate::extensions::{ExtensionRegistry, SimpleExtensions, simple}; -use crate::parser::common::{MessageParseError, ParsePair}; -use crate::parser::errors::{ParseContext, ParseError, ParseResult}; -use crate::parser::expressions::Name; -use crate::parser::extensions::{ExtensionInvocation, ExtensionParseError, ExtensionParser}; -use crate::parser::relations::RelationParsingContext; -use crate::parser::{ErrorKind, ExpressionParser, RelationParsePair, Rule, unwrap_single_pair}; +use crate::parser::errors::{ + MessageParseError, ParseContext, ParseError, ParseResult, SyntaxErrorDetail, +}; +use crate::parser::extensions::{ExtensionParseError, ExtensionParser}; +use crate::parser::{ast, lalrpop_line, lower}; pub const PLAN_HEADER: &str = "=== Plan"; -/// Represents an input line, trimmed of leading two-space indents and final -/// whitespace. Contains the number of indents and the trimmed line. +/// Strip `//` comments while preserving comment markers inside quoted strings +/// and URL-like tokens such as `https://...`. +fn strip_inline_comment(line: &str) -> Cow<'_, str> { + let mut in_single = false; + let mut in_double = false; + let mut escaped = false; + + for (idx, ch) in line.char_indices() { + if escaped { + escaped = false; + continue; + } + + if (in_single || in_double) && ch == '\\' { + escaped = true; + continue; + } + + if ch == '\'' && !in_double { + in_single = !in_single; + continue; + } + + if ch == '"' && !in_single { + in_double = !in_double; + continue; + } + + if !in_single && !in_double && line[idx..].starts_with("//") { + // Treat `//` as comment start only at beginning of line or when + // preceded by whitespace. This keeps tokens like `https://...` + // intact in extension URNs. + let starts_comment = + idx == 0 || line[..idx].chars().last().is_some_and(char::is_whitespace); + if starts_comment { + return Cow::Owned(line[..idx].trim_end().to_string()); + } + } + } + + Cow::Borrowed(line) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct IndentedLine<'a>(pub usize, pub &'a str); @@ -40,123 +74,204 @@ impl<'a> From<&'a str> for IndentedLine<'a> { } let indents = spaces / 2; - let (_, trimmed) = line.split_at(indents * 2); IndentedLine(indents, trimmed) } } -/// Represents a line in the [`Plan`] tree structure before it's converted to a -/// relation. This allows us to build the tree structure first, then convert to -/// relations with proper parent-child relationships. #[derive(Debug, Clone)] -pub struct LineNode<'a> { - pub pair: pest::iterators::Pair<'a, Rule>, +pub enum LineData { + Relation(ast::Relation), + Root(Vec), +} + +#[derive(Debug, Clone)] +pub struct LineNode { + pub data: LineData, pub line_no: i64, - pub children: Vec>, + pub line: String, + pub children: Vec, } -impl<'a> LineNode<'a> { +impl LineNode { pub fn context(&self) -> ParseContext { ParseContext { line_no: self.line_no, - line: self.pair.as_str().to_string(), + line: self.line.clone(), } } - pub fn parse(line: &'a str, line_no: i64) -> Result { - // Parse the line immediately to catch syntax errors - let mut pairs: pest::iterators::Pairs<'a, Rule> = - >::parse(Rule::relation, line).map_err(|e| { - ParseError::Plan( - ParseContext { - line_no, - line: line.to_string(), - }, - MessageParseError::new("relation", ErrorKind::InvalidValue, Box::new(e)), - ) - })?; - - let pair = pairs.next().unwrap(); - assert!(pairs.next().is_none()); // Should be exactly one pair + pub fn parse(line: &str, line_no: i64) -> Result { + let relation = lalrpop_line::parse_relation(line) + .map_err(|e| parse_syntax_error("relation", line_no, line, e))?; Ok(Self { - pair, + data: LineData::Relation(relation), line_no, + line: line.to_string(), children: Vec::new(), }) } - /// Parse the root relation of a plan, at depth 0. - pub fn parse_root(line: &'a str, line_no: i64) -> Result { - // Parse the line as a top-level relation (either root_relation or regular relation) - let mut pairs: pest::iterators::Pairs<'a, Rule> = >::parse( - Rule::top_level_relation, line - ) - .map_err(|e| { - ParseError::Plan( - ParseContext::new(line_no, line.to_string()), - MessageParseError::new("top_level_relation", ErrorKind::Syntax, Box::new(e)), - ) - })?; - - let pair = pairs.next().unwrap(); - assert!(pairs.next().is_none()); + pub fn parse_root(line: &str, line_no: i64) -> Result { + match lalrpop_line::parse_root_names(line) { + Ok(names) => { + return Ok(Self { + data: LineData::Root(names), + line_no, + line: line.to_string(), + children: Vec::new(), + }); + } + Err(root_error) if starts_with_root_keyword(line) => { + return Err(parse_syntax_error("root_names", line_no, line, root_error)); + } + Err(_) => {} + } - // Get the inner pair, which is either a root relation or a regular relation - let inner_pair = unwrap_single_pair(pair); + let relation = lalrpop_line::parse_relation(line) + .map_err(|e| parse_syntax_error("top_level_relation", line_no, line, e))?; Ok(Self { - pair: inner_pair, + data: LineData::Relation(relation), line_no, + line: line.to_string(), children: Vec::new(), }) } } -/// Helper function to get the number of input fields from a relation. -/// This is needed for Project relations to calculate output mapping indices. -fn get_input_field_count(rel: &Rel) -> usize { +fn starts_with_root_keyword(line: &str) -> bool { + let trimmed = line.trim_start(); + let Some(rest) = trimmed.strip_prefix("Root") else { + return false; + }; + rest.starts_with('[') || rest.starts_with(char::is_whitespace) +} + +fn parse_syntax_error( + target: &'static str, + line_no: i64, + line: &str, + detail: SyntaxErrorDetail, +) -> ParseError { + ParseError::Syntax { + target, + context: ParseContext::new(line_no, line.to_string()), + detail: Box::new(detail), + } +} + +fn get_output_field_count(rel: &Rel) -> usize { match &rel.rel_type { - Some(RelType::Read(read_rel)) => { - // For Read relations, count the fields in the base schema - read_rel - .base_schema + Some(RelType::Read(read_rel)) => read_rel + .base_schema + .as_ref() + .and_then(|schema| schema.r#struct.as_ref()) + .map(|s| s.types.len()) + .unwrap_or_else(|| { + read_rel + .base_schema + .as_ref() + .map(|s| s.names.len()) + .unwrap_or(0) + }), + Some(RelType::Filter(filter_rel)) => rel_common_count( + filter_rel + .common .as_ref() - .and_then(|schema| schema.r#struct.as_ref()) - .map(|struct_| struct_.types.len()) - .unwrap_or(0) - } - Some(RelType::Filter(filter_rel)) => { - // For Filter relations, get the count from the input + .and_then(|c| c.emit_kind.as_ref()), filter_rel .input .as_ref() - .map(|input| get_input_field_count(input)) - .unwrap_or(0) - } - Some(RelType::Project(project_rel)) => { - // For Project relations, get the count from the input + .map(|input| get_output_field_count(input)) + .unwrap_or(0), + ), + Some(RelType::Project(project_rel)) => rel_common_count( + project_rel + .common + .as_ref() + .and_then(|c| c.emit_kind.as_ref()), project_rel .input .as_ref() - .map(|input| get_input_field_count(input)) + .map(|input| get_output_field_count(input)) .unwrap_or(0) - } + + project_rel.expressions.len(), + ), + Some(RelType::Aggregate(aggregate_rel)) => rel_common_count( + aggregate_rel + .common + .as_ref() + .and_then(|c| c.emit_kind.as_ref()), + aggregate_rel.grouping_expressions.len() + aggregate_rel.measures.len(), + ), + Some(RelType::Sort(sort_rel)) => rel_common_count( + sort_rel.common.as_ref().and_then(|c| c.emit_kind.as_ref()), + sort_rel + .input + .as_ref() + .map(|input| get_output_field_count(input)) + .unwrap_or(0), + ), + Some(RelType::Fetch(fetch_rel)) => rel_common_count( + fetch_rel.common.as_ref().and_then(|c| c.emit_kind.as_ref()), + fetch_rel + .input + .as_ref() + .map(|input| get_output_field_count(input)) + .unwrap_or(0), + ), + Some(RelType::Join(join_rel)) => rel_common_count( + join_rel.common.as_ref().and_then(|c| c.emit_kind.as_ref()), + join_rel + .left + .as_ref() + .map(|left| get_output_field_count(left)) + .unwrap_or(0) + + join_rel + .right + .as_ref() + .map(|right| get_output_field_count(right)) + .unwrap_or(0), + ), + Some(RelType::ExtensionLeaf(ext_rel)) => rel_common_count( + ext_rel.common.as_ref().and_then(|c| c.emit_kind.as_ref()), + 0, + ), + Some(RelType::ExtensionSingle(ext_rel)) => rel_common_count( + ext_rel.common.as_ref().and_then(|c| c.emit_kind.as_ref()), + ext_rel + .input + .as_ref() + .map(|input| get_output_field_count(input)) + .unwrap_or(0), + ), + Some(RelType::ExtensionMulti(ext_rel)) => rel_common_count( + ext_rel.common.as_ref().and_then(|c| c.emit_kind.as_ref()), + ext_rel.inputs.iter().map(get_output_field_count).sum(), + ), _ => 0, } } +fn rel_common_count( + emit_kind: Option<&substrait::proto::rel_common::EmitKind>, + default: usize, +) -> usize { + match emit_kind { + // When Emit is explicit, output mapping is authoritative. + Some(substrait::proto::rel_common::EmitKind::Emit(emit)) => emit.output_mapping.len(), + Some(substrait::proto::rel_common::EmitKind::Direct(_)) => default, + None => default, + } +} + #[derive(Copy, Clone, Debug)] pub enum State { - // The initial state, before we have parsed any lines. Initial, - // The extensions section, after parsing the header and any other Extension lines. Extensions, - // The plan section, after parsing the header and any other Plan lines. Plan, } @@ -166,19 +281,14 @@ impl fmt::Display for State { } } -// An in-progress tree builder, building the tree of relations. #[derive(Debug, Clone, Default)] -pub struct TreeBuilder<'a> { - // Current tree of nodes being built. These have been successfully parsed - // into Pest pairs, but have not yet been converted to substrait plans. - current: Option>, - // Completed trees that have been built. - completed: Vec>, +pub struct TreeBuilder { + current: Option, + completed: Vec, } -impl<'a> TreeBuilder<'a> { - /// Traverse down the tree, always taking the last child at each level, until reaching the specified depth. - pub fn get_at_depth(&mut self, depth: usize) -> Option<&mut LineNode<'a>> { +impl TreeBuilder { + pub fn get_at_depth(&mut self, depth: usize) -> Option<&mut LineNode> { let mut node = self.current.as_mut()?; for _ in 0..depth { node = node.children.last_mut()?; @@ -186,7 +296,7 @@ impl<'a> TreeBuilder<'a> { Some(node) } - pub fn add_line(&mut self, depth: usize, node: LineNode<'a>) -> Result<(), ParseError> { + pub fn add_line(&mut self, depth: usize, node: LineNode) -> Result<(), ParseError> { if depth == 0 { if let Some(prev) = self.current.take() { self.completed.push(prev) @@ -201,7 +311,6 @@ impl<'a> TreeBuilder<'a> { node.context(), MessageParseError::invalid( "relation", - node.pair.as_span(), format!("No parent found for depth {depth}"), ), )); @@ -209,15 +318,11 @@ impl<'a> TreeBuilder<'a> { Some(parent) => parent, }; - parent.children.push(node.clone()); + parent.children.push(node); Ok(()) } - /// End of input - move any remaining nodes from stack to completed and - /// return any trees in progress. Resets the builder to its initial state - /// (empty) - pub fn finish(&mut self) -> Vec> { - // Move any remaining nodes from stack to completed + pub fn finish(&mut self) -> Vec { if let Some(node) = self.current.take() { self.completed.push(node); } @@ -225,17 +330,15 @@ impl<'a> TreeBuilder<'a> { } } -// Relation parsing component - handles converting LineNodes to Relations #[derive(Debug, Clone, Default)] -pub struct RelationParser<'a> { - tree: TreeBuilder<'a>, +pub struct RelationParser { + tree: TreeBuilder, } -impl<'a> RelationParser<'a> { - pub fn parse_line(&mut self, line: IndentedLine<'a>, line_no: i64) -> Result<(), ParseError> { +impl RelationParser { + pub fn parse_line(&mut self, line: IndentedLine<'_>, line_no: i64) -> Result<(), ParseError> { let IndentedLine(depth, line) = line; - // Use parse_root for depth 0 (top-level relations), parse for other depths let node = if depth == 0 { LineNode::parse_root(line, line_no)? } else { @@ -245,202 +348,82 @@ impl<'a> RelationParser<'a> { self.tree.add_line(depth, node) } - /// Parse a relation from a Pest pair of rule 'relation' into a Substrait - /// Rel. - // - // Clippy says a Vec> is unnecessary, as the Vec is already on the - // heap, but this is what the protobuf requires so we allow it here - #[allow(clippy::vec_box)] - fn parse_relation( - &self, - extensions: &SimpleExtensions, - registry: &ExtensionRegistry, - line_no: i64, - pair: pest::iterators::Pair, - child_relations: Vec>, - input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), Rule::relation); - let p = unwrap_single_pair(pair); - - let (e, r, l, p_inner, cr, ic) = ( - extensions, - registry, - line_no, - p, - child_relations, - input_field_count, - ); - - match p_inner.as_rule() { - Rule::read_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::filter_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::project_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::aggregate_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::sort_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::fetch_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::join_relation => self.parse_rel::(e, l, p_inner, cr, ic), - Rule::extension_relation => self.parse_extension_relation(e, r, l, p_inner, cr), - _ => todo!(), - } - } - - /// Parse a specific relation type. - // Box is needed because Rel is a large enum and we need to pass ownership - // through the RelationParsePair trait, which requires Box. - #[allow(clippy::vec_box)] - fn parse_rel( - &self, - extensions: &SimpleExtensions, - line_no: i64, - pair: pest::iterators::Pair, - child_relations: Vec>, - input_field_count: usize, - ) -> Result { - assert_eq!(pair.as_rule(), T::rule()); - - let line = pair.as_str(); - - let rel_type = - T::parse_pair_with_context(extensions, pair, child_relations, input_field_count); - - match rel_type { - Ok(rel) => Ok(rel.into_rel()), - Err(e) => Err(ParseError::Plan( - ParseContext::new(line_no, line.to_string()), - e, - )), - } - } - - /// Parse extension relations. - /// Extension relations need the full parsing context for registry lookups and warning handling. - // Box is needed because Rel is a large enum and we need to pass ownership - // through the RelationParsePair trait, which requires Box. - #[allow(clippy::vec_box)] - fn parse_extension_relation( - &self, - extensions: &SimpleExtensions, - registry: &ExtensionRegistry, - line_no: i64, - pair: pest::iterators::Pair, - child_relations: Vec>, - ) -> Result { - assert_eq!(pair.as_rule(), Rule::extension_relation); - - let line = pair.as_str(); - let pair_span = pair.as_span(); - - // Parse extension invocation, which includes the user-provided name - let ExtensionInvocation { - name, - args: extension_args, - } = ExtensionInvocation::parse_pair(pair); - - // Validate child count matches relation type - let child_count = child_relations.len(); - extension_args - .relation_type - .validate_child_count(child_count) - .map_err(|e| { - ParseError::Plan( - ParseContext::new(line_no, line.to_string()), - MessageParseError::invalid("extension_relation", pair_span, e), - ) - })?; - - let context = RelationParsingContext { - extensions, - registry, - line_no, - line, - }; - - let detail = context.resolve_extension_detail(&name, &extension_args)?; - - extension_args - .relation_type - .create_rel(detail, child_relations) - .map_err(|e| { - ParseError::Plan( - ParseContext::new(line_no, line.to_string()), - MessageParseError::invalid("extension_relation", pair_span, e), - ) - }) - } - - /// Convert a LineNode into a Substrait Rel. fn build_rel( &self, extensions: &SimpleExtensions, registry: &ExtensionRegistry, node: LineNode, - ) -> Result { - // Parse children first to get their output schemas + ) -> Result { + let context = node.context(); + let line = node.line.clone(); + let line_no = node.line_no; + let child_relations = node .children .into_iter() .map(|c| self.build_rel(extensions, registry, c).map(Box::new)) .collect::>, ParseError>>()?; - // Get the input field count from all the children let input_field_count = child_relations .iter() - .map(|r| get_input_field_count(r.as_ref())) - .reduce(|a, b| a + b) - .unwrap_or(0); + .map(|r| get_output_field_count(r.as_ref())) + .sum(); + + let relation = match node.data { + LineData::Relation(relation) => relation, + LineData::Root(_) => { + return Err(ParseError::Plan( + context, + MessageParseError::invalid( + "relation", + "Root cannot be used as non-root relation", + ), + )); + } + }; - // Parse this node using the stored pair - self.parse_relation( + lower::lower_relation( extensions, registry, - node.line_no, - node.pair, + &relation, + &line, + line_no, child_relations, input_field_count, ) } - /// Build a tree of relations. fn build_plan_rel( &self, extensions: &SimpleExtensions, registry: &ExtensionRegistry, mut node: LineNode, ) -> Result { - // Plain relations are allowed as root relations, they just don't have names. - if node.pair.as_rule() == Rule::relation { + let context = node.context(); + let data = node.data.clone(); + + if let LineData::Relation(_) = data { let rel = self.build_rel(extensions, registry, node)?; return Ok(PlanRel { rel_type: Some(plan_rel::RelType::Rel(rel)), }); } - // Otherwise, it must be a root relation. - assert_eq!(node.pair.as_rule(), Rule::root_relation); - let context = node.context(); - let span = node.pair.as_span(); - - // Parse the column names - let column_names_pair = unwrap_single_pair(node.pair); - assert_eq!(column_names_pair.as_rule(), Rule::root_name_list); - - let names: Vec = column_names_pair - .into_inner() - .map(|name_pair| { - assert_eq!(name_pair.as_rule(), Rule::name); - Name::parse_pair(name_pair).0 - }) - .collect(); + let LineData::Root(names) = data else { + unreachable!() + }; let child = match node.children.len() { - 1 => self.build_rel(extensions, registry, node.children.pop().unwrap())?, + 1 => self.build_rel( + extensions, + registry, + node.children.pop().expect("one child"), + )?, n => { return Err(ParseError::Plan( context, MessageParseError::invalid( "root_relation", - span, format!("Root relation must have exactly one child, found {n}"), ), )); @@ -457,7 +440,6 @@ impl<'a> RelationParser<'a> { }) } - /// Build all the trees. fn build( mut self, extensions: &SimpleExtensions, @@ -471,163 +453,26 @@ impl<'a> RelationParser<'a> { } } -/// A parser for Substrait query plans in text format. -/// -/// The `Parser` converts human-readable Substrait text format into Substrait -/// protobuf plans. It handles both the extensions section (which defines -/// functions, types, etc.) and the plan section (which defines the actual query -/// structure). -/// -/// ## Usage -/// -/// The simplest entry point is the static `parse()` method: -/// -/// ```rust -/// use substrait_explain::parser::Parser; -/// -/// let plan_text = r#" -/// === Plan -/// Root[c, d] -/// Project[$1, 42] -/// Read[schema.table => a:i64, b:string?] -/// "#; -/// -/// let plan = Parser::parse(plan_text).unwrap(); -/// ``` -/// -/// ## Input Format -/// -/// The parser expects input in the following format: -/// -/// ```text -/// === Extensions -/// URNs: -/// @ 1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml -/// Functions: -/// # 10 @ 1: add -/// === Plan -/// Root[columns] -/// Relation[arguments => columns] -/// ChildRelation[arguments => columns] -/// ``` -/// -/// - **Extensions section** (optional): Defines URNs and function/type declarations -/// - **Plan section** (required): Defines the query structure with indented relations -/// -/// ## Error Handling -/// -/// The parser provides detailed error information including: -/// - Line number where the error occurred -/// - The actual line content that failed to parse -/// - Specific error type and description -/// -/// ```rust -/// use substrait_explain::parser::Parser; -/// -/// let invalid_plan = r#" -/// === Plan -/// InvalidRelation[invalid syntax] -/// "#; -/// -/// match Parser::parse(invalid_plan) { -/// Ok(plan) => println!("Successfully parsed"), -/// Err(e) => eprintln!("Parse error: {}", e), -/// } -/// ``` -/// -/// ## Supported Relations -/// -/// The parser supports all standard Substrait relations: -/// - `Read[table => columns]` - Read from a table -/// - `Project[expressions]` - Project columns/expressions -/// - `Filter[condition => columns]` - Filter rows -/// - `Root[columns]` - Root relation with output columns -/// - And more... -/// -/// ## Extensions Support -/// -/// The parser fully supports Substrait Simple Extensions, allowing you to: -/// - Define custom functions with URNs and anchors -/// - Reference functions by name in expressions -/// - Use custom types and type variations -/// -/// ```rust -/// use substrait_explain::parser::Parser; -/// -/// let plan_with_extensions = r#" -/// === Extensions -/// URNs: -/// @ 1: https://example.com/functions.yaml -/// Functions: -/// ## 10 @ 1: my_custom_function -/// === Plan -/// Root[result] -/// Project[my_custom_function($0, $1)] -/// Read[table => col1:i32, col2:i32] -/// "#; -/// -/// let plan = Parser::parse(plan_with_extensions).unwrap(); -/// ``` -/// -/// ## Performance -/// -/// The parser is designed for efficiency: -/// - Single-pass parsing with minimal allocations -/// - Early error detection and reporting -/// - Memory-efficient tree building -/// -/// ## Thread Safety -/// -/// `Parser` instances are not thread-safe and should not be shared between threads. -/// However, the static `parse()` method is safe to call from multiple threads. #[derive(Debug)] -pub struct Parser<'a> { +pub struct Parser { line_no: i64, state: State, extension_parser: ExtensionParser, extension_registry: ExtensionRegistry, - relation_parser: RelationParser<'a>, + relation_parser: RelationParser, } -impl<'a> Default for Parser<'a> { + +impl Default for Parser { fn default() -> Self { Self::new() } } -impl<'a> Parser<'a> { - /// Parse a Substrait plan from text format. - /// - /// This is the main entry point for parsing. - /// - /// The input should be in the Substrait text format, which consists of: - /// - An optional extensions section starting with "=== Extensions" - /// - A plan section starting with "=== Plan" - /// - Indented relation definitions - /// - /// # Examples - /// - /// Simple parsing: - /// ```rust - /// use substrait_explain::parser::Parser; - /// - /// let plan_text = r#" - /// === Plan - /// Root[result] - /// Read[table => col:i32] - /// "#; - /// - /// let plan = Parser::parse(plan_text).unwrap(); - /// assert_eq!(plan.relations.len(), 1); - /// ``` - /// - /// # Errors - /// - /// Returns a [`ParseError`] if the input cannot be parsed. +impl Parser { pub fn parse(input: &str) -> ParseResult { Self::new().parse_plan(input) } - /// Create a new parser with default configuration. pub fn new() -> Self { Self { line_no: 1, @@ -638,35 +483,32 @@ impl<'a> Parser<'a> { } } - /// Configure the parser to use the specified extension registry. pub fn with_extension_registry(mut self, registry: ExtensionRegistry) -> Self { self.extension_registry = registry; self } - /// Parse a Substrait plan with the current parser configuration. - pub fn parse_plan(mut self, input: &'a str) -> ParseResult { - for line in input.lines() { + pub fn parse_plan(mut self, input: &str) -> ParseResult { + for raw_line in input.lines() { + let line = strip_inline_comment(raw_line); if line.trim().is_empty() { self.line_no += 1; continue; } - self.parse_line(line)?; + self.parse_line(line.as_ref(), raw_line)?; self.line_no += 1; } - let plan = self.build_plan()?; - Ok(plan) + self.build_plan() } - /// Parse a single line of input. - fn parse_line(&mut self, line: &'a str) -> Result<(), ParseError> { + fn parse_line(&mut self, line: &str, original_line: &str) -> Result<(), ParseError> { let indented_line = IndentedLine::from(line); let line_no = self.line_no; let ctx = || ParseContext { line_no, - line: line.to_string(), + line: original_line.to_string(), }; match self.state { @@ -674,48 +516,28 @@ impl<'a> Parser<'a> { State::Extensions => self .parse_extensions(indented_line) .map_err(|e| ParseError::Extension(ctx(), e)), - State::Plan => { - let IndentedLine(depth, line_str) = indented_line; - - // Parse the line - let node = if depth == 0 { - LineNode::parse_root(line_str, line_no)? - } else { - LineNode::parse(line_str, line_no)? - }; - - self.relation_parser.tree.add_line(depth, node) - } + State::Plan => self.relation_parser.parse_line(indented_line, line_no), } } - /// Parse the initial line(s) of the input, which is either a blank line or - /// the extensions or plan header. - fn parse_initial(&mut self, line: IndentedLine) -> Result<(), ParseError> { + fn parse_initial(&mut self, line: IndentedLine<'_>) -> Result<(), ParseError> { match line { - IndentedLine(0, l) if l.trim().is_empty() => {} + IndentedLine(0, l) if l.trim().is_empty() => Ok(()), IndentedLine(0, simple::EXTENSIONS_HEADER) => { self.state = State::Extensions; + Ok(()) } IndentedLine(0, PLAN_HEADER) => { self.state = State::Plan; + Ok(()) } - IndentedLine(n, l) => { - return Err(ParseError::Initial( - ParseContext::new(n as i64, l.to_string()), - MessageParseError::invalid( - "initial", - pest::Span::new(l, 0, l.len()).expect("Invalid span?!"), - format!("Unknown initial line: {l:?}"), - ), - )); - } + IndentedLine(_, l) => Err(ParseError::Initial( + ParseContext::new(self.line_no, l.to_string()), + MessageParseError::invalid("initial", format!("Unknown initial line: {l:?}")), + )), } - Ok(()) } - /// Parse a single line from the extensions section of the input, updating - /// the parser state. fn parse_extensions(&mut self, line: IndentedLine<'_>) -> Result<(), ExtensionParseError> { if line == IndentedLine(0, PLAN_HEADER) { self.state = State::Plan; @@ -724,7 +546,6 @@ impl<'a> Parser<'a> { self.extension_parser.parse_line(line) } - /// Build the plan from the parser state with warning collection. fn build_plan(self) -> Result { let Parser { relation_parser, @@ -734,11 +555,8 @@ impl<'a> Parser<'a> { } = self; let extensions = extension_parser.extensions(); - - // Parse the tree into relations let root_relations = relation_parser.build(extensions, &extension_registry)?; - // Build the final plan Ok(Plan { extension_urns: extensions.to_extension_urns(), extensions: extensions.to_extension_declarations(), @@ -750,363 +568,71 @@ impl<'a> Parser<'a> { #[cfg(test)] mod tests { - use substrait::proto::extensions::simple_extension_declaration::MappingType; - use super::*; - use crate::extensions::simple::ExtensionKind; - use crate::parser::extensions::ExtensionParserState; - - #[test] - fn test_parse_basic_block() { - let mut expected_extensions = SimpleExtensions::new(); - expected_extensions - .add_extension_urn("/urn/common".to_string(), 1) - .unwrap(); - expected_extensions - .add_extension_urn("/urn/specific_funcs".to_string(), 2) - .unwrap(); - expected_extensions - .add_extension(ExtensionKind::Function, 1, 10, "func_a".to_string()) - .unwrap(); - expected_extensions - .add_extension(ExtensionKind::Function, 2, 11, "func_b_special".to_string()) - .unwrap(); - expected_extensions - .add_extension(ExtensionKind::Type, 1, 20, "SomeType".to_string()) - .unwrap(); - expected_extensions - .add_extension(ExtensionKind::TypeVariation, 2, 30, "VarX".to_string()) - .unwrap(); - - let mut parser = ExtensionParser::default(); - let input_block = r#" -URNs: - @ 1: /urn/common - @ 2: /urn/specific_funcs -Functions: - # 10 @ 1: func_a - # 11 @ 2: func_b_special -Types: - # 20 @ 1: SomeType -Type Variations: - # 30 @ 2: VarX -"#; - - for line_str in input_block.trim().lines() { - parser - .parse_line(IndentedLine::from(line_str)) - .unwrap_or_else(|e| panic!("Failed to parse line \'{line_str}\': {e:?}")); - } - - assert_eq!(*parser.extensions(), expected_extensions); - - let extensions_str = parser.extensions().to_string(" "); - // The writer adds the header; the ExtensionParser does not parse the - // header, so we add it here for comparison. - let expected_str = format!( - "{}\n{}", - simple::EXTENSIONS_HEADER, - input_block.trim_start() - ); - assert_eq!(extensions_str.trim(), expected_str.trim()); - // Check final state after all lines are processed. - // The last significant line in input_block is a TypeVariation declaration. - assert_eq!( - parser.state(), - ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation) - ); - - // Check that a subsequent blank line correctly resets state to Extensions. - parser.parse_line(IndentedLine(0, "")).unwrap(); - assert_eq!(parser.state(), ExtensionParserState::Extensions); - } - - /// Test that we can parse a larger extensions block and it matches the input. - #[test] - fn test_parse_complete_extension_block() { - let mut parser = ExtensionParser::default(); - let input_block = r#" -URNs: - @ 1: /urn/common - @ 2: /urn/specific_funcs - @ 3: /urn/types_lib - @ 4: /urn/variations_lib -Functions: - # 10 @ 1: func_a - # 11 @ 2: func_b_special - # 12 @ 1: func_c_common -Types: - # 20 @ 1: CommonType - # 21 @ 3: LibraryType - # 22 @ 1: AnotherCommonType -Type Variations: - # 30 @ 4: VarX - # 31 @ 4: VarY -"#; - - for line_str in input_block.trim().lines() { - parser - .parse_line(IndentedLine::from(line_str)) - .unwrap_or_else(|e| panic!("Failed to parse line \'{line_str}\': {e:?}")); - } - - let extensions_str = parser.extensions().to_string(" "); - // The writer adds the header; the ExtensionParser does not parse the - // header, so we add it here for comparison. - let expected_str = format!( - "{}\n{}", - simple::EXTENSIONS_HEADER, - input_block.trim_start() - ); - assert_eq!(extensions_str.trim(), expected_str.trim()); - } - - #[test] - fn test_parse_relation_tree() { - // Example plan with a Project, a Filter, and a Read, nested by indentation - let plan = r#"=== Plan -Project[$0, $1, 42, 84] - Filter[$2 => $0, $1] - Read[my.table => a:i32, b:string?, c:boolean] -"#; - let mut parser = Parser::default(); - for line in plan.lines() { - parser.parse_line(line).unwrap(); - } - - // Complete the current tree to convert it to relations - let plan = parser.build_plan().unwrap(); - - let root_rel = &plan.relations[0].rel_type; - let first_rel = match root_rel { - Some(plan_rel::RelType::Rel(rel)) => rel, - _ => panic!("Expected Rel type, got {root_rel:?}"), - }; - // Root should be Project - let project = match &first_rel.rel_type { - Some(RelType::Project(p)) => p, - other => panic!("Expected Project at root, got {other:?}"), - }; - - // Check that Project has Filter as input - assert!(project.input.is_some()); - let filter_input = project.input.as_ref().unwrap(); - - // Check that Filter has Read as input - match &filter_input.rel_type { - Some(RelType::Filter(_)) => { - match &filter_input.rel_type { - Some(RelType::Filter(filter)) => { - assert!(filter.input.is_some()); - let read_input = filter.input.as_ref().unwrap(); - - // Check that Read has no input (it's a leaf) - match &read_input.rel_type { - Some(RelType::Read(_)) => {} - other => panic!("Expected Read relation, got {other:?}"), - } - } - other => panic!("Expected Filter relation, got {other:?}"), - } - } - other => panic!("Expected Filter relation, got {other:?}"), - } - } #[test] - fn test_parse_root_relation() { - // Test a plan with a Root relation - let plan = r#"=== Plan -Root[result] - Project[$0, $1] - Read[my.table => a:i32, b:string?] + fn parse_simple_plan() { + let input = r#" +=== Plan +Root[a] + Read[t => a:i32] "#; - let mut parser = Parser::default(); - for line in plan.lines() { - parser.parse_line(line).unwrap(); - } - let plan = parser.build_plan().unwrap(); - - // Check that we have exactly one relation + let plan = Parser::parse(input).expect("parse plan"); assert_eq!(plan.relations.len(), 1); - - let root_rel = &plan.relations[0].rel_type; - let rel_root = match root_rel { - Some(plan_rel::RelType::Root(rel_root)) => rel_root, - other => panic!("Expected Root type, got {other:?}"), - }; - - // Check that the root has the correct name - assert_eq!(rel_root.names, vec!["result"]); - - // Check that the root has a Project as input - let project_input = match &rel_root.input { - Some(rel) => rel, - None => panic!("Root should have an input"), - }; - - let project = match &project_input.rel_type { - Some(RelType::Project(p)) => p, - other => panic!("Expected Project as root input, got {other:?}"), - }; - - // Check that Project has Read as input - let read_input = match &project.input { - Some(rel) => rel, - None => panic!("Project should have an input"), - }; - - match &read_input.rel_type { - Some(RelType::Read(_)) => {} - other => panic!("Expected Read relation, got {other:?}"), - } } #[test] - fn test_parse_root_relation_no_names() { - // Test a plan with a Root relation with no names - let plan = r#"=== Plan -Root[] - Project[$0, $1] - Read[my.table => a:i32, b:string?] + fn parse_plan_with_inline_comments() { + let input = r#" +=== Plan +Root[a] // top-level result + Read[t => a:i32] // leaf relation "#; - let mut parser = Parser::default(); - for line in plan.lines() { - parser.parse_line(line).unwrap(); - } - - let plan = parser.build_plan().unwrap(); - - let root_rel = &plan.relations[0].rel_type; - let rel_root = match root_rel { - Some(plan_rel::RelType::Root(rel_root)) => rel_root, - other => panic!("Expected Root type, got {other:?}"), - }; - // Check that the root has no names - assert_eq!(rel_root.names, Vec::::new()); + let plan = Parser::parse(input).expect("parse plan"); + assert_eq!(plan.relations.len(), 1); } #[test] - fn test_parse_full_plan() { - // Test a complete Substrait plan with extensions and relations + fn parse_extensions_with_url_no_comment_strip() { let input = r#" === Extensions URNs: - @ 1: /urn/common - @ 2: /urn/specific_funcs + @ 1: https://example.com/functions.yaml // keep trailing comment Functions: - # 10 @ 1: func_a - # 11 @ 2: func_b_special -Types: - # 20 @ 1: SomeType -Type Variations: - # 30 @ 2: VarX + # 10 @ 1: foo === Plan -Project[$0, $1, 42, 84] - Filter[$2 => $0, $1] - Read[my.table => a:i32, b:string?, c:boolean] +Root[a] + Project[foo($0)] + Read[t => a:i32] "#; - let plan = Parser::parse(input).unwrap(); - - // Verify the plan structure - assert_eq!(plan.extension_urns.len(), 2); - assert_eq!(plan.extensions.len(), 4); + let plan = Parser::parse(input).expect("parse plan"); + assert_eq!(plan.extension_urns.len(), 1); + assert_eq!(plan.extensions.len(), 1); assert_eq!(plan.relations.len(), 1); + } - // Verify extension URIs - let urn1 = &plan.extension_urns[0]; - assert_eq!(urn1.extension_urn_anchor, 1); - assert_eq!(urn1.urn, "/urn/common"); - - let urn2 = &plan.extension_urns[1]; - assert_eq!(urn2.extension_urn_anchor, 2); - assert_eq!(urn2.urn, "/urn/specific_funcs"); - - // Verify extensions - let func1 = &plan.extensions[0]; - match &func1.mapping_type { - Some(MappingType::ExtensionFunction(f)) => { - assert_eq!(f.function_anchor, 10); - assert_eq!(f.extension_urn_reference, 1); - assert_eq!(f.name, "func_a"); - } - other => panic!("Expected ExtensionFunction, got {other:?}"), - } - - let func2 = &plan.extensions[1]; - match &func2.mapping_type { - Some(MappingType::ExtensionFunction(f)) => { - assert_eq!(f.function_anchor, 11); - assert_eq!(f.extension_urn_reference, 2); - assert_eq!(f.name, "func_b_special"); - } - other => panic!("Expected ExtensionFunction, got {other:?}"), - } - - let type1 = &plan.extensions[2]; - match &type1.mapping_type { - Some(MappingType::ExtensionType(t)) => { - assert_eq!(t.type_anchor, 20); - assert_eq!(t.extension_urn_reference, 1); - assert_eq!(t.name, "SomeType"); - } - other => panic!("Expected ExtensionType, got {other:?}"), - } - - let var1 = &plan.extensions[3]; - match &var1.mapping_type { - Some(MappingType::ExtensionTypeVariation(v)) => { - assert_eq!(v.type_variation_anchor, 30); - assert_eq!(v.extension_urn_reference, 2); - assert_eq!(v.name, "VarX"); - } - other => panic!("Expected ExtensionTypeVariation, got {other:?}"), - } + #[test] + fn parse_root_reports_root_names_errors() { + let input = r#" +=== Plan +Root[42] + Read[t => a:i32] +"#; - // Verify the relation tree structure - let root_rel = &plan.relations[0]; - match &root_rel.rel_type { - Some(plan_rel::RelType::Rel(rel)) => { - match &rel.rel_type { - Some(RelType::Project(project)) => { - // Verify Project relation - assert_eq!(project.expressions.len(), 2); // 42 and 84 - assert!(project.input.is_some()); // Should have Filter as input - - // Check the Filter input - let filter_input = project.input.as_ref().unwrap(); - match &filter_input.rel_type { - Some(RelType::Filter(filter)) => { - assert!(filter.input.is_some()); // Should have Read as input - - // Check the Read input - let read_input = filter.input.as_ref().unwrap(); - match &read_input.rel_type { - Some(RelType::Read(read)) => { - // Verify Read relation - let schema = read.base_schema.as_ref().unwrap(); - assert_eq!(schema.names.len(), 3); - assert_eq!(schema.names[0], "a"); - assert_eq!(schema.names[1], "b"); - assert_eq!(schema.names[2], "c"); - - let struct_ = schema.r#struct.as_ref().unwrap(); - assert_eq!(struct_.types.len(), 3); - } - other => panic!("Expected Read relation, got {other:?}"), - } - } - other => panic!("Expected Filter relation, got {other:?}"), - } - } - other => panic!("Expected Project relation, got {other:?}"), - } + let err = Parser::parse(input).expect_err("parse should fail"); + match err { + ParseError::Syntax { + target, context, .. + } => { + assert_eq!(target, "root_names"); + assert_eq!(context.line_no, 3); + assert_eq!(context.line, "Root[42]"); } - other => panic!("Expected Rel type, got {other:?}"), + other => panic!("expected ParseError::Syntax, got {other:?}"), } } } diff --git a/src/parser/types.rs b/src/parser/types.rs deleted file mode 100644 index 13c1781..0000000 --- a/src/parser/types.rs +++ /dev/null @@ -1,428 +0,0 @@ -use pest::iterators::Pair; -use substrait::proto::r#type::{Kind, Nullability, Parameter}; -use substrait::proto::{self, Type}; - -use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair}; -use crate::extensions::SimpleExtensions; -use crate::extensions::simple::ExtensionKind; -use crate::parser::{ErrorKind, MessageParseError}; - -// Given a name and an optional anchor, get the anchor and validate it. Errors will be pushed to the Scope error accumulator, -// and the anchor will be returned if it is valid. -pub(crate) fn get_and_validate_anchor( - extensions: &SimpleExtensions, - kind: ExtensionKind, - anchor: Option, - name: &str, - span: pest::Span, -) -> Result { - match anchor { - Some(a) => match extensions.is_name_unique(kind, a, name) { - Ok(_) => Ok(a), - Err(e) => { - let message = "Error matching name to anchor".to_string(); - let error = MessageParseError { - message: kind.name(), - kind: ErrorKind::Lookup(e), - error: Box::new(pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { message }, - span, - )), - }; - Err(error) - } - }, - None => match extensions.find_by_name(kind, name) { - Ok(a) => Ok(a), - Err(e) => { - let message = "Error finding extension for name".to_string(); - let error = MessageParseError { - message: kind.name(), - kind: ErrorKind::Lookup(e), - error: Box::new(pest::error::Error::new_from_span( - pest::error::ErrorVariant::CustomError { message }, - span, - )), - }; - Err(error) - } - }, - } -} - -impl ParsePair for Nullability { - fn rule() -> Rule { - Rule::nullability - } - - fn message() -> &'static str { - "Nullability" - } - - fn parse_pair(pair: Pair) -> Self { - assert_eq!(pair.as_rule(), Rule::nullability); - match pair.as_str() { - "?" => Nullability::Nullable, - "" => Nullability::Required, - "⁉" => Nullability::Unspecified, - _ => panic!("Invalid nullability: {}", pair.as_str()), - } - } -} - -impl ScopedParsePair for Parameter { - fn rule() -> Rule { - Rule::parameter - } - - fn message() -> &'static str { - "Parameter" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Rule::parameter); - let inner = unwrap_single_pair(pair); - match inner.as_rule() { - Rule::r#type => Ok(Parameter { - parameter: Some(proto::r#type::parameter::Parameter::DataType( - Type::parse_pair(extensions, inner)?, - )), - }), - _ => unimplemented!("{:?}", inner.as_rule()), - } - } -} - -fn parse_simple_type(pair: Pair) -> Type { - assert_eq!(pair.as_rule(), Rule::simple_type); - let mut iter = iter_pairs(pair.into_inner()); - let name = iter.pop(Rule::simple_type_name).as_str(); - let nullability = iter.parse_next::(); - iter.done(); - - let kind = match name { - "boolean" => Kind::Bool(proto::r#type::Boolean { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "i64" => Kind::I64(proto::r#type::I64 { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "i32" => Kind::I32(proto::r#type::I32 { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "i16" => Kind::I16(proto::r#type::I16 { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "i8" => Kind::I8(proto::r#type::I8 { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "fp32" => Kind::Fp32(proto::r#type::Fp32 { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "fp64" => Kind::Fp64(proto::r#type::Fp64 { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "string" => Kind::String(proto::r#type::String { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "binary" => Kind::Binary(proto::r#type::Binary { - nullability: nullability.into(), - type_variation_reference: 0, - }), - #[allow(deprecated)] - "timestamp" => Kind::Timestamp(proto::r#type::Timestamp { - nullability: nullability.into(), - type_variation_reference: 0, - }), - #[allow(deprecated)] - "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "date" => Kind::Date(proto::r#type::Date { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "time" => Kind::Time(proto::r#type::Time { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear { - nullability: nullability.into(), - type_variation_reference: 0, - }), - "uuid" => Kind::Uuid(proto::r#type::Uuid { - nullability: nullability.into(), - type_variation_reference: 0, - }), - _ => unreachable!("Type {} exists in parser but not implemented in code", name), - }; - Type { kind: Some(kind) } -} - -fn parse_compound_type( - extensions: &SimpleExtensions, - pair: Pair, -) -> Result { - assert_eq!(pair.as_rule(), Rule::compound_type); - let inner = unwrap_single_pair(pair); - match inner.as_rule() { - Rule::list_type => parse_list_type(extensions, inner), - // Rule::map_type => parse_map_type(inner), - // Rule::struct_type => parse_struct_type(inner), - _ => unimplemented!("{:?}", inner.as_rule()), - } -} - -fn parse_list_type( - extensions: &SimpleExtensions, - pair: Pair, -) -> Result { - assert_eq!(pair.as_rule(), Rule::list_type); - let mut iter = iter_pairs(pair.into_inner()); - let nullability = iter.parse_next::(); - let inner = iter.parse_next_scoped::(extensions)?; - iter.done(); - - Ok(Type { - kind: Some(Kind::List(Box::new(proto::r#type::List { - nullability: nullability.into(), - r#type: Some(Box::new(inner)), - type_variation_reference: 0, - }))), - }) -} - -fn parse_parameters( - extensions: &SimpleExtensions, - pair: Pair, -) -> Result, MessageParseError> { - assert_eq!(pair.as_rule(), Rule::parameters); - let mut iter = iter_pairs(pair.into_inner()); - let mut params = Vec::new(); - while let Some(param) = iter.parse_if_next_scoped::(extensions) { - params.push(param?); - } - iter.done(); - Ok(params) -} - -fn parse_user_defined_type( - extensions: &SimpleExtensions, - pair: Pair, -) -> Result { - let span = pair.as_span(); - assert_eq!(pair.as_rule(), Rule::user_defined_type); - let mut iter = iter_pairs(pair.into_inner()); - let name = iter.pop(Rule::name).as_str(); - let anchor = iter - .try_pop(Rule::anchor) - .map(|n| unwrap_single_pair(n).as_str().parse::().unwrap()); - - // TODO: Handle urn_anchor; validate that it matches the anchor - let _urn_anchor = iter - .try_pop(Rule::urn_anchor) - .map(|n| unwrap_single_pair(n).as_str().parse::().unwrap()); - - let nullability = iter.parse_next::(); - let parameters = match iter.try_pop(Rule::parameters) { - Some(p) => parse_parameters(extensions, p)?, - None => Vec::new(), - }; - iter.done(); - - let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?; - - Ok(Type { - kind: Some(Kind::UserDefined(proto::r#type::UserDefined { - type_reference: anchor, - nullability: nullability.into(), - type_parameters: parameters, - type_variation_reference: 0, - })), - }) -} - -impl ScopedParsePair for Type { - fn rule() -> Rule { - Rule::r#type - } - - fn message() -> &'static str { - "Type" - } - - fn parse_pair( - extensions: &SimpleExtensions, - pair: Pair, - ) -> Result { - assert_eq!(pair.as_rule(), Rule::r#type); - let inner = unwrap_single_pair(pair); - match inner.as_rule() { - Rule::simple_type => Ok(parse_simple_type(inner)), - Rule::compound_type => parse_compound_type(extensions, inner), - Rule::user_defined_type => parse_user_defined_type(extensions, inner), - _ => unreachable!( - "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}", - inner.as_rule() - ), - } - } -} - -#[cfg(test)] -mod tests { - use pest::Parser; - use substrait::proto::r#type::{I64, Kind, Nullability}; - - use super::*; - use crate::parser::ExpressionParser; - - #[test] - fn test_parse_simple_type() { - let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - let t = parse_simple_type(pair); - assert_eq!( - t, - Type { - kind: Some(Kind::I64(I64 { - nullability: Nullability::Required as i32, - type_variation_reference: 0, - })), - } - ); - - let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - let t = parse_simple_type(pair); - assert_eq!( - t, - Type { - kind: Some(Kind::String(proto::r#type::String { - nullability: Nullability::Nullable as i32, - type_variation_reference: 0, - })), - } - ); - } - - #[test] - fn test_parse_type() { - let extensions = SimpleExtensions::default(); - let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - let t = Type::parse_pair(&extensions, pair).unwrap(); - assert_eq!( - t, - Type { - kind: Some(Kind::I64(I64 { - nullability: Nullability::Required as i32, - type_variation_reference: 0, - })) - } - ); - } - - #[test] - fn test_parse_list_type() { - let extensions = SimpleExtensions::default(); - let mut pairs = ExpressionParser::parse(Rule::list_type, "list").unwrap(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - let t = parse_list_type(&extensions, pair).unwrap(); - assert_eq!( - t, - Type { - kind: Some(Kind::List(Box::new(proto::r#type::List { - nullability: Nullability::Required as i32, - r#type: Some(Box::new(Type { - kind: Some(Kind::I64(I64 { - nullability: Nullability::Required as i32, - type_variation_reference: 0, - })) - })), - type_variation_reference: 0, - }))) - } - ); - } - - #[test] - fn test_parse_parameters() { - let extensions = SimpleExtensions::default(); - let mut pairs = ExpressionParser::parse(Rule::parameters, "").unwrap(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - let t = parse_parameters(&extensions, pair).unwrap(); - assert_eq!( - t, - vec![ - Parameter { - parameter: Some(proto::r#type::parameter::Parameter::DataType(Type { - kind: Some(Kind::I64(proto::r#type::I64 { - nullability: Nullability::Nullable as i32, - type_variation_reference: 0, - })), - })), - }, - Parameter { - parameter: Some(proto::r#type::parameter::Parameter::DataType(Type { - kind: Some(Kind::String(proto::r#type::String { - nullability: Nullability::Required as i32, - type_variation_reference: 0, - })), - })), - }, - ] - ); - } - - #[test] - fn test_udts() { - let mut extensions = SimpleExtensions::default(); - extensions - .add_extension_urn("some_source".to_string(), 4) - .unwrap(); - extensions - .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string()) - .unwrap(); - let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42").unwrap(); - let pair = pairs.next().unwrap(); - assert_eq!(pairs.next(), None); - - let t = parse_user_defined_type(&extensions, pair).unwrap(); - assert_eq!( - t, - Type { - kind: Some(Kind::UserDefined(proto::r#type::UserDefined { - type_reference: 42, - type_variation_reference: 0, - nullability: Nullability::Required as i32, - type_parameters: vec![Parameter { - parameter: Some(proto::r#type::parameter::Parameter::DataType(Type { - kind: Some(Kind::I64(proto::r#type::I64 { - nullability: Nullability::Nullable as i32, - type_variation_reference: 0, - })), - })), - }], - })) - } - ); - } -} diff --git a/src/structural/mod.rs b/src/structural/mod.rs deleted file mode 100644 index eb38c9b..0000000 --- a/src/structural/mod.rs +++ /dev/null @@ -1,438 +0,0 @@ -//! Structural types for unified value representation -//! -//! This module provides a unified type system that bridges the gap between -//! extension values and textify values, using proper lifetime management -//! with `Cow<'a, str>` for flexible string handling. - -use std::borrow::Cow; -use std::fmt; - -use crate::extensions::simple::{ExtensionFunction, SimpleExtensions}; - -/// A unified value type that can represent both extension arguments and textify values -/// with proper lifetime management -#[derive(Debug, Clone)] -pub enum StructuralValue<'a> { - /// String value with flexible ownership (borrowed or owned) - String(Cow<'a, str>), - /// Integer literal value - Integer(i64), - /// Float literal value - Float(f64), - /// Boolean literal value - Boolean(bool), - /// Field reference ($0, $1, etc.) - Reference(i32), - /// Function call or extension reference - Function(FunctionRef<'a>), - /// Table name reference - TableName(Vec>), - /// Tuple of multiple values - Tuple(Vec>), - /// List of values - List(Vec>), - /// Missing or invalid value with error description - Missing(Cow<'a, str>), - /// Enum value representation - Enum(Cow<'a, str>), -} - -/// Represents a function call or extension reference with resolution state -#[derive(Debug, Clone)] -pub enum FunctionRef<'a> { - /// Unresolved reference by name only - Unresolved(Cow<'a, str>), - /// Unresolved reference with anchor notation (e.g., "add#10") - UnresolvedWithAnchor { name: Cow<'a, str>, anchor: u32 }, - /// Unresolved reference with URI anchor (e.g., "add@1") - UnresolvedWithUri { name: Cow<'a, str>, uri_anchor: u32 }, - /// Unresolved reference with both anchors (e.g., "add#10@1") - UnresolvedWithBoth { - name: Cow<'a, str>, - anchor: u32, - uri_anchor: u32, - }, - /// Resolved to a simple extension function - Resolved { - name: Cow<'a, str>, - extension: &'a ExtensionFunction, - }, -} - -/// Named argument pair for function calls and relations -#[derive(Debug, Clone)] -pub struct NamedArg<'a> { - pub name: Cow<'a, str>, - pub value: StructuralValue<'a>, -} - -/// Arguments collection for relations and function calls -#[derive(Debug, Clone)] -pub struct Arguments<'a> { - /// Positional arguments - pub positional: Vec>, - /// Named arguments - pub named: Vec>, -} - -/// Column specification with type information -#[derive(Debug, Clone)] -pub enum ColumnSpec<'a> { - /// Named column with optional type (name or name:type) - Named { - name: Cow<'a, str>, - type_spec: Option>, - }, - /// Field reference ($0, $1, etc.) - Reference(i32), - /// Expression column - Expression(StructuralValue<'a>), -} - -impl<'a> StructuralValue<'a> { - /// Create a new string value from any string-like type - pub fn string(s: impl Into>) -> Self { - Self::String(s.into()) - } - - /// Create a new owned string value - pub fn owned_string(s: String) -> Self { - Self::String(Cow::Owned(s)) - } - - /// Create a new borrowed string value - pub fn borrowed_string(s: &'a str) -> Self { - Self::String(Cow::Borrowed(s)) - } - - /// Create a function reference from a name - pub fn function(name: impl Into>) -> Self { - Self::Function(FunctionRef::Unresolved(name.into())) - } - - /// Create a function reference with anchor - pub fn function_with_anchor(name: impl Into>, anchor: u32) -> Self { - Self::Function(FunctionRef::UnresolvedWithAnchor { - name: name.into(), - anchor, - }) - } - - /// Create a function reference with URI anchor - pub fn function_with_uri(name: impl Into>, uri_anchor: u32) -> Self { - Self::Function(FunctionRef::UnresolvedWithUri { - name: name.into(), - uri_anchor, - }) - } - - /// Create a function reference with both anchors - pub fn function_with_both(name: impl Into>, anchor: u32, uri_anchor: u32) -> Self { - Self::Function(FunctionRef::UnresolvedWithBoth { - name: name.into(), - anchor, - uri_anchor, - }) - } - - /// Try to resolve function references using simple extensions - pub fn resolve_functions(&mut self, extensions: &'a SimpleExtensions) -> Result<(), String> { - match self { - StructuralValue::Function(func_ref) => { - *func_ref = func_ref.try_resolve(extensions)?; - } - StructuralValue::Tuple(values) | StructuralValue::List(values) => { - for value in values { - value.resolve_functions(extensions)?; - } - } - _ => {} // Other types don't need resolution - } - Ok(()) - } - - /// Convert to an owned version (removes all borrowed references) - pub fn into_owned(self) -> StructuralValue<'static> { - match self { - StructuralValue::String(s) => StructuralValue::String(Cow::Owned(s.into_owned())), - StructuralValue::Integer(i) => StructuralValue::Integer(i), - StructuralValue::Float(f) => StructuralValue::Float(f), - StructuralValue::Boolean(b) => StructuralValue::Boolean(b), - StructuralValue::Reference(r) => StructuralValue::Reference(r), - StructuralValue::Function(f) => StructuralValue::Function(f.into_owned()), - StructuralValue::TableName(names) => StructuralValue::TableName( - names - .into_iter() - .map(|n| Cow::Owned(n.into_owned())) - .collect(), - ), - StructuralValue::Tuple(values) => { - StructuralValue::Tuple(values.into_iter().map(|v| v.into_owned()).collect()) - } - StructuralValue::List(values) => { - StructuralValue::List(values.into_iter().map(|v| v.into_owned()).collect()) - } - StructuralValue::Missing(msg) => StructuralValue::Missing(Cow::Owned(msg.into_owned())), - StructuralValue::Enum(e) => StructuralValue::Enum(Cow::Owned(e.into_owned())), - } - } -} - -impl<'a> FunctionRef<'a> { - /// Try to resolve this function reference using simple extensions - pub fn try_resolve(self, extensions: &'a SimpleExtensions) -> Result { - match self { - FunctionRef::Unresolved(name) => { - // Try to find by name in extensions - if let Some(ext_fn) = extensions.functions().find_by_name(&name) { - Ok(FunctionRef::Resolved { - name, - extension: ext_fn, - }) - } else { - // Keep unresolved - Ok(FunctionRef::Unresolved(name)) - } - } - FunctionRef::UnresolvedWithAnchor { name, anchor } => { - // Try to resolve by anchor - if let Some(ext_fn) = extensions.functions().get(anchor) { - Ok(FunctionRef::Resolved { - name, - extension: ext_fn, - }) - } else { - Ok(FunctionRef::UnresolvedWithAnchor { name, anchor }) - } - } - // TODO: Add resolution logic for URI anchors and combined anchors - other => Ok(other), - } - } - - /// Convert to owned version - pub fn into_owned(self) -> FunctionRef<'static> { - match self { - FunctionRef::Unresolved(name) => FunctionRef::Unresolved(Cow::Owned(name.into_owned())), - FunctionRef::UnresolvedWithAnchor { name, anchor } => { - FunctionRef::UnresolvedWithAnchor { - name: Cow::Owned(name.into_owned()), - anchor, - } - } - FunctionRef::UnresolvedWithUri { name, uri_anchor } => FunctionRef::UnresolvedWithUri { - name: Cow::Owned(name.into_owned()), - uri_anchor, - }, - FunctionRef::UnresolvedWithBoth { - name, - anchor, - uri_anchor, - } => FunctionRef::UnresolvedWithBoth { - name: Cow::Owned(name.into_owned()), - anchor, - uri_anchor, - }, - FunctionRef::Resolved { name, extension: _ } => { - // Can't preserve the reference in owned version, convert to unresolved - FunctionRef::Unresolved(Cow::Owned(name.into_owned())) - } - } - } - - /// Get the function name regardless of resolution state - pub fn name(&self) -> &str { - match self { - FunctionRef::Unresolved(name) => name, - FunctionRef::UnresolvedWithAnchor { name, .. } => name, - FunctionRef::UnresolvedWithUri { name, .. } => name, - FunctionRef::UnresolvedWithBoth { name, .. } => name, - FunctionRef::Resolved { name, .. } => name, - } - } -} - -impl<'a> Arguments<'a> { - /// Create new empty arguments - pub fn new() -> Self { - Self { - positional: Vec::new(), - named: Vec::new(), - } - } - - /// Add a positional argument - pub fn add_positional(&mut self, value: StructuralValue<'a>) { - self.positional.push(value); - } - - /// Add a named argument - pub fn add_named(&mut self, name: impl Into>, value: StructuralValue<'a>) { - self.named.push(NamedArg { - name: name.into(), - value, - }); - } - - /// Check if arguments are empty - pub fn is_empty(&self) -> bool { - self.positional.is_empty() && self.named.is_empty() - } - - /// Resolve all function references in arguments - pub fn resolve_functions(&mut self, extensions: &'a SimpleExtensions) -> Result<(), String> { - for value in &mut self.positional { - value.resolve_functions(extensions)?; - } - for arg in &mut self.named { - arg.value.resolve_functions(extensions)?; - } - Ok(()) - } -} - -impl<'a> Default for Arguments<'a> { - fn default() -> Self { - Self::new() - } -} - -impl<'a> ColumnSpec<'a> { - /// Create a named column - pub fn named(name: impl Into>) -> Self { - Self::Named { - name: name.into(), - type_spec: None, - } - } - - /// Create a named column with type - pub fn named_with_type( - name: impl Into>, - type_spec: impl Into>, - ) -> Self { - Self::Named { - name: name.into(), - type_spec: Some(type_spec.into()), - } - } - - /// Create a reference column - pub fn reference(index: i32) -> Self { - Self::Reference(index) - } - - /// Create an expression column - pub fn expression(expr: StructuralValue<'a>) -> Self { - Self::Expression(expr) - } -} - -// Display implementations for debugging and textification -impl<'a> fmt::Display for StructuralValue<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - StructuralValue::String(s) => write!(f, "'{}'", s), - StructuralValue::Integer(i) => write!(f, "{}", i), - StructuralValue::Float(fl) => write!(f, "{}", fl), - StructuralValue::Boolean(b) => write!(f, "{}", b), - StructuralValue::Reference(r) => write!(f, "${}", r), - StructuralValue::Function(func) => write!(f, "{}", func), - StructuralValue::TableName(names) => { - write!( - f, - "{}", - names - .iter() - .map(|n| n.as_ref()) - .collect::>() - .join(".") - ) - } - StructuralValue::Tuple(values) => { - write!( - f, - "({})", - values - .iter() - .map(|v| v.to_string()) - .collect::>() - .join(", ") - ) - } - StructuralValue::List(values) => { - write!( - f, - "[{}]", - values - .iter() - .map(|v| v.to_string()) - .collect::>() - .join(", ") - ) - } - StructuralValue::Missing(msg) => write!(f, "!{{{}}}", msg), - StructuralValue::Enum(e) => write!(f, "&{}", e), - } - } -} - -impl<'a> fmt::Display for FunctionRef<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - FunctionRef::Unresolved(name) => write!(f, "{}", name), - FunctionRef::UnresolvedWithAnchor { name, anchor } => write!(f, "{}#{}", name, anchor), - FunctionRef::UnresolvedWithUri { name, uri_anchor } => { - write!(f, "{}@{}", name, uri_anchor) - } - FunctionRef::UnresolvedWithBoth { - name, - anchor, - uri_anchor, - } => { - write!(f, "{}#{}@{}", name, anchor, uri_anchor) - } - FunctionRef::Resolved { name, .. } => write!(f, "{}", name), - } - } -} - -impl<'a> fmt::Display for NamedArg<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}={}", self.name, self.value) - } -} - -impl<'a> fmt::Display for Arguments<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut parts = Vec::new(); - - // Add positional arguments - for value in &self.positional { - parts.push(value.to_string()); - } - - // Add named arguments - for arg in &self.named { - parts.push(arg.to_string()); - } - - write!(f, "{}", parts.join(", ")) - } -} - -impl<'a> fmt::Display for ColumnSpec<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ColumnSpec::Named { name, type_spec } => { - if let Some(type_spec) = type_spec { - write!(f, "{}:{}", name, type_spec) - } else { - write!(f, "{}", name) - } - } - ColumnSpec::Reference(r) => write!(f, "${}", r), - ColumnSpec::Expression(expr) => write!(f, "{}", expr), - } - } -} diff --git a/src/textify/expressions.rs b/src/textify/expressions.rs index 7315ea0..3b2c9c9 100644 --- a/src/textify/expressions.rs +++ b/src/textify/expressions.rs @@ -1,3 +1,18 @@ +//! Expression and literal formatting — everything that appears inside the +//! `[...]` brackets of a relation line. +//! +//! Syntax conventions used in the text format: +//! +//! - `(…)` function call, e.g. `add($0, $1)` +//! - `[…]` relation brackets +//! - `<…>` type parameters, e.g. `list` +//! - `!{…}` missing/error placeholder +//! - `$…` field reference, e.g. `$0` +//! - `#…` extension anchor, e.g. `#10` +//! - `@…` URN anchor, e.g. `@1` +//! - `…:…` type annotation, e.g. `42:i64` +//! - `&…` enum value, e.g. `&AscNullsFirst` + use std::fmt::{self}; use chrono::{DateTime, NaiveDate}; @@ -18,18 +33,6 @@ use crate::extensions::SimpleExtensions; use crate::extensions::simple::ExtensionKind; use crate::textify::types::{Name, NamedAnchor, OutputType, escaped}; -// …(…) for function call -// […] for variant -// <…> for parameters -// !{…} for missing value - -// $… for field reference -// #… for anchor -// @… for URN anchor -// …::… for cast -// …:… for specifying type -// &… for enum - pub fn textify_binary(items: &[u8], ctx: &S, w: &mut W) -> fmt::Result { if ctx.options().show_literal_binaries { write!(w, "0x")?; @@ -516,6 +519,53 @@ impl Textify for expr::Literal { } fn textify(&self, ctx: &S, w: &mut W) -> fmt::Result { + if self.nullable + && let Some(literal) = self.literal_type.as_ref() + { + match literal { + LiteralType::String(s) => { + let kind = Kind::String(ptype::String { + type_variation_reference: self.type_variation_reference, + nullability: Nullability::Nullable.into(), + }); + return textify_literal_from_string(s, kind, ctx, w); + } + LiteralType::Date(days) => { + let kind = Kind::Date(ptype::Date { + type_variation_reference: self.type_variation_reference, + nullability: Nullability::Nullable.into(), + }); + return textify_literal_from_string(&days_to_date_string(*days), kind, ctx, w); + } + LiteralType::Time(microseconds) => { + let kind = Kind::Time(ptype::Time { + type_variation_reference: self.type_variation_reference, + nullability: Nullability::Nullable.into(), + }); + return textify_literal_from_string( + µseconds_to_time_string(*microseconds), + kind, + ctx, + w, + ); + } + #[allow(deprecated)] + LiteralType::Timestamp(microseconds) => { + let kind = Kind::Timestamp(ptype::Timestamp { + type_variation_reference: self.type_variation_reference, + nullability: Nullability::Nullable.into(), + }); + return textify_literal_from_string( + µseconds_to_timestamp_string(*microseconds), + kind, + ctx, + w, + ); + } + _ => {} + } + } + write!(w, "{}", ctx.expect(self.literal_type.as_ref())) } } diff --git a/src/textify/foundation.rs b/src/textify/foundation.rs index 01c5551..9ee46c6 100644 --- a/src/textify/foundation.rs +++ b/src/textify/foundation.rs @@ -91,10 +91,15 @@ impl OutputOptions { } } } +/// Collects non-fatal formatting errors so that textification can continue +/// and produce best-effort output. Errors are later returned alongside the +/// formatted text. pub trait ErrorAccumulator: Clone { fn push(&self, e: FormatError); } +/// Channel-backed [`ErrorAccumulator`]. Cloning shares the same underlying +/// channel, so errors from any clone are visible to the consumer. #[derive(Debug, Clone)] pub struct ErrorQueue { sender: mpsc::Sender, @@ -219,6 +224,9 @@ impl<'a> fmt::Display for IndentStack<'a> { } } +/// Concrete [`Scope`] implementation that carries formatting options, +/// extension lookups, error accumulation, and the current indentation level +/// through the textify call tree. #[derive(Debug, Copy, Clone)] pub struct ScopedContext<'a, Err: ErrorAccumulator> { errors: &'a Err, diff --git a/src/textify/mod.rs b/src/textify/mod.rs index a8d4b27..9b02b1b 100644 --- a/src/textify/mod.rs +++ b/src/textify/mod.rs @@ -1,4 +1,9 @@ -//! Output a plan in text format. +//! Substrait protobuf → human-readable text conversion. +//! +//! The core pattern: each protobuf type implements [`Textify`], writing itself +//! into a [`std::fmt::Write`] with a [`Scope`] context that provides formatting +//! options, extension lookups, and error accumulation. [`plan::PlanWriter`] +//! is the top-level entry point. pub mod expressions; pub mod extensions; diff --git a/src/textify/plan.rs b/src/textify/plan.rs index 01773e8..4e5481b 100644 --- a/src/textify/plan.rs +++ b/src/textify/plan.rs @@ -1,3 +1,10 @@ +//! Top-level plan formatting — orchestrates the conversion of a full +//! [`proto::Plan`] into its text representation. +//! +//! [`PlanWriter`] is the entry point: it extracts extensions from the plan, +//! creates a [`ScopedContext`] that carries options, error accumulation, and +//! extension lookups, then delegates to the relation and expression formatters. + use std::fmt; use substrait::proto; @@ -8,6 +15,10 @@ use crate::parser::PLAN_HEADER; use crate::textify::foundation::ErrorAccumulator; use crate::textify::{OutputOptions, ScopedContext}; +/// Formats a [`proto::Plan`] to its text representation. +/// +/// Use [`Display`](std::fmt::Display) or call [`write_extensions`](Self::write_extensions) and +/// [`write_relations`](Self::write_relations) separately for more control. #[derive(Debug, Clone)] pub struct PlanWriter<'a, E: ErrorAccumulator + Default> { options: &'a OutputOptions, @@ -18,6 +29,10 @@ pub struct PlanWriter<'a, E: ErrorAccumulator + Default> { } impl<'a, E: ErrorAccumulator + Default + Clone> PlanWriter<'a, E> { + /// Create a writer for the given plan. Returns the [`PlanWriter`] and an + /// output channel for accumulated errors, an [`ErrorAccumulator`] — output + /// is best-effort and will run to completion, despite incomplete plans, + /// unresolved extensions or references, or other issues. pub fn new( options: &'a OutputOptions, plan: &'a proto::Plan, @@ -99,8 +114,8 @@ mod tests { }; use super::*; - use crate::parser::expressions::FieldIndex; use crate::textify::ErrorQueue; + use crate::textify::expressions::Reference; /// Test a fairly basic plan with an extension, read, and project. /// @@ -161,18 +176,10 @@ mod tests { function_reference: 10, arguments: vec![ FunctionArgument { - arg_type: Some(ArgType::Value(Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(0).to_field_reference(), - ))), - })), + arg_type: Some(ArgType::Value(Reference(0).into())), }, FunctionArgument { - arg_type: Some(ArgType::Value(Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(1).to_field_reference(), - ))), - })), + arg_type: Some(ArgType::Value(Reference(1).into())), }, ], options: vec![], diff --git a/src/textify/rels.rs b/src/textify/rels.rs index 1ff16d7..20e9249 100644 --- a/src/textify/rels.rs +++ b/src/textify/rels.rs @@ -1,3 +1,11 @@ +//! Relation formatting — converts protobuf relation types to the +//! `Name[args => outputs]` text format. +//! +//! Each Substrait relation (Read, Filter, Project, …) is converted to a +//! [`Relation`] struct containing its name, [`Arguments`], and child +//! relations. Arguments and outputs are represented as [`Value`] variants, +//! which bridge between the typed protobuf world and the flat text format. + use std::borrow::Cow; use std::collections::HashMap; use std::convert::TryFrom; @@ -90,6 +98,9 @@ pub struct NamedArg<'a> { } #[derive(Debug, Clone)] +/// A single argument or output value in a relation's `[args => outputs]` +/// bracket. Covers field references, expressions, literals, enum values, +/// named columns, and error placeholders ([`Value::Missing`]). pub enum Value<'a> { Name(Name<'a>), TableName(Vec>), @@ -973,7 +984,7 @@ mod tests { use super::*; use crate::fixtures::TestContext; - use crate::parser::expressions::FieldIndex; + use crate::textify::expressions::Reference; #[test] fn test_read_rel() { @@ -1140,11 +1151,7 @@ Filter[gt($0, 10:i32) => $0, $1] let agg_fn1 = get_aggregate_func(10, 1); let agg_fn2 = get_aggregate_func(11, 1); - let grouping_expressions = vec![Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(0).to_field_reference(), - ))), - }]; + let grouping_expressions = vec![Reference(0).into()]; let measures = vec![ aggregate_rel::Measure { @@ -1256,16 +1263,8 @@ Filter[gt($0, 10:i32) => $0, $1] let agg_fn2 = get_aggregate_func(11, 2); let grouping_expressions = vec![ - Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(0).to_field_reference(), - ))), - }, - Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(1).to_field_reference(), - ))), - }, + Reference(0).into(), + Reference(1).into(), ]; let grouping_sets = vec![ @@ -1482,11 +1481,7 @@ Filter[gt($0, 10:i32) => $0, $1] AggregateFunction { function_reference: func_ref, arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(column_ind).to_field_reference(), - ))), - })), + arg_type: Some(ArgType::Value(Reference(column_ind).into())), }], options: vec![], output_type: None, @@ -1564,10 +1559,6 @@ Filter[gt($0, 10:i32) => $0, $1] } fn create_exp(column_ind: i32) -> Expression { - Expression { - rex_type: Some(RexType::Selection(Box::new( - FieldIndex(column_ind).to_field_reference(), - ))), - } + Reference(column_ind).into() } } diff --git a/src/textify/types.rs b/src/textify/types.rs index 9f831d1..6253e34 100644 --- a/src/textify/types.rs +++ b/src/textify/types.rs @@ -1,3 +1,7 @@ +//! Type formatting — nullability suffixes, parameterized types, user-defined +//! types, and the [`Name`]/[`Anchor`] helpers used by relation and expression +//! formatting. + use std::fmt; use std::ops::Deref; @@ -91,9 +95,14 @@ impl<'a> Textify for Name<'a> { } } +/// An extension anchor (e.g. `#10`) that may be shown or hidden based on +/// [`OutputOptions::show_simple_extension_anchors`](super::OutputOptions::show_simple_extension_anchors). #[derive(Debug, Copy, Clone)] pub struct Anchor { reference: u32, + /// When `true`, the anchor is shown even at `Visibility::Required` + /// (i.e. it's needed for disambiguation). When `false`, it's only + /// shown at `Visibility::Always`. required: bool, } diff --git a/tests/literal_roundtrip.rs b/tests/literal_roundtrip.rs index 7ebbddd..8c5fdf1 100644 --- a/tests/literal_roundtrip.rs +++ b/tests/literal_roundtrip.rs @@ -1,6 +1,11 @@ //! Test roundtrip functionality for literal parsing and formatting +use substrait::proto::expression::RexType; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::plan_rel::RelType as PlanRelType; +use substrait::proto::rel::RelType; use substrait_explain::fixtures::roundtrip_plan; +use substrait_explain::parser::Parser; #[test] fn test_float_literal_roundtrip() { @@ -90,3 +95,65 @@ Root[statusq] "#; roundtrip_plan(plan); } + +#[test] +fn test_nullable_string_typed_literal_sets_nullable() { + let plan = r#" +=== Plan +Root[result] + Project['hello':string?] + Read[data => a:i64] +"#; + + roundtrip_plan(plan); + + let plan = Parser::parse(plan).expect("parse should succeed"); + let plan_rel = plan.relations.first().expect("plan relation"); + let PlanRelType::Root(root) = plan_rel.rel_type.as_ref().expect("plan rel type") else { + panic!("expected root relation"); + }; + let RelType::Project(project) = root + .input + .as_ref() + .and_then(|r| r.rel_type.as_ref()) + .expect("project relation") + else { + panic!("expected project relation"); + }; + + let literal = project + .expressions + .first() + .and_then(|e| e.rex_type.as_ref()) + .and_then(|rex| match rex { + RexType::Literal(lit) => Some(lit), + _ => None, + }) + .expect("literal expression"); + + assert!( + literal.nullable, + "typed string nullability should be preserved" + ); + assert!(matches!( + literal.literal_type.as_ref(), + Some(LiteralType::String(s)) if s == "hello" + )); +} + +#[test] +fn test_unsupported_typed_string_literal_errors() { + let plan = r#" +=== Plan +Root[result] + Project['abc':uuid] + Read[data => a:i64] +"#; + + let err = Parser::parse(plan).expect_err("parse should fail"); + let message = err.to_string(); + assert!( + message.contains("Invalid type for string literal"), + "unexpected error: {message}" + ); +} diff --git a/tests/plan_roundtrip.rs b/tests/plan_roundtrip.rs index 2e992fb..d925778 100644 --- a/tests/plan_roundtrip.rs +++ b/tests/plan_roundtrip.rs @@ -82,7 +82,7 @@ Functions: === Plan Root[name, parent, sum, count] Project[$0, $3, $1, $2] - Read[schema.table => name:string?, parent:string?, sum:fp64?, count:fp64?]"#; + Read[schema."some.table" => name:string?, parent:string?, sum:fp64?, count:fp64?]"#; roundtrip_plan(plan); } @@ -98,7 +98,7 @@ Functions: === Plan Root[name, num] Project[$1, coalesce($1, $2)] - Read[schema.table => name:string?, num:fp64?, other_num:fp64?, id:i64]"#; + Read[schema."db.with.dot".table => name:string?, num:fp64?, other_num:fp64?, id:i64]"#; let verbose_plan = r#"=== Extensions URNs: @@ -109,7 +109,7 @@ Functions: === Plan Root[name, num] Project[$1, coalesce#10($1, $2)] - Read[schema.table => name:string?, num:fp64?, other_num:fp64?, id:i64]"#; + Read[schema."db.with.dot".table => name:string?, num:fp64?, other_num:fp64?, id:i64]"#; roundtrip_plan_with_verbose(simple_plan, verbose_plan); } diff --git a/tests/types.rs b/tests/types.rs index ac06a95..1088361 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -1,168 +1,101 @@ -use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; -use substrait::proto::{Expression, Type}; -use substrait_explain::extensions::SimpleExtensions; -use substrait_explain::extensions::simple::ExtensionKind; -use substrait_explain::fixtures::TestContext; -use substrait_explain::parser::{Parse, ScopedParse}; -use substrait_explain::textify::{ErrorQueue, OutputOptions, Textify}; +use substrait_explain::fixtures::{roundtrip_plan, roundtrip_plan_with_verbose}; -/// Helper function to parse and check for errors, panicking if either fails -fn must_parse(result: Result, input: &str) -> T { - let errors = ErrorQueue::default(); - let t = match result { - Ok(t) => t, - Err(e) => { - println!("Error parsing {input}:\n{e}"); - panic!("{}", e); - } - }; - errors.errs().expect("Failure while parsing"); - t -} - -fn roundtrip_parse(ctx: &TestContext, input: &str) { - let t = must_parse(T::parse(input), input); - - let actual = ctx.textify_no_errors(&t); - assert_eq!(actual, input); -} - -fn assert_roundtrip(ctx: &TestContext, input: &str) { - let t = must_parse(T::parse(&ctx.extensions, input), input); - - let actual = ctx.textify_no_errors(&t); - assert_eq!(actual, input); -} - -fn roundtrip_with_simple_output( - extensions: SimpleExtensions, - verbose: &str, - simple: &str, -) { - let ctx1 = TestContext { - options: OutputOptions::verbose(), - extensions, - extension_registry: Default::default(), - }; - - let t = must_parse(T::parse(&ctx1.extensions, verbose), verbose); - - let actual = ctx1.textify_no_errors(&t); - assert_eq!( - actual, verbose, - "Expected verbose output: {verbose}, Actual: {actual}" - ); - - let ctx2 = TestContext { - options: OutputOptions::default(), - extensions: ctx1.extensions, - extension_registry: Default::default(), - }; - - let actual = ctx2.textify_no_errors(&t); - assert_eq!( - actual, simple, - "Expected simple output: {simple}, Actual: {actual}" - ); +#[test] +fn test_types_roundtrip() { + let cases = [ + "i32", + "i32?", + "binary", + "binary?", + "timestamp", + "timestamp?", + "timestamp_tz", + "timestamp_tz?", + "date", + "date?", + "time", + "time?", + "interval_year", + "interval_year?", + "uuid", + "uuid?", + "list", + "list?", + ]; + + for (idx, typ) in cases.iter().enumerate() { + let plan = format!("=== Plan\nRoot[c{idx}]\n Read[t{idx} => c{idx}:{typ}]"); + roundtrip_plan(&plan); + } } #[test] -fn test_types() { - let options = OutputOptions::verbose(); - let mut extensions = SimpleExtensions::default(); - extensions - .add_extension_urn("some_source".to_string(), 4) - .unwrap(); - extensions - .add_extension(ExtensionKind::Type, 4, 12, "MyType".to_string()) - .unwrap(); - - let ctx = TestContext { - options, - extensions, - extension_registry: Default::default(), - }; - - assert_roundtrip::(&ctx, "i32"); - assert_roundtrip::(&ctx, "i32?"); - - // Test new simple types added in issue #20 - assert_roundtrip::(&ctx, "binary"); - assert_roundtrip::(&ctx, "binary?"); - assert_roundtrip::(&ctx, "timestamp"); - assert_roundtrip::(&ctx, "timestamp?"); - assert_roundtrip::(&ctx, "timestamp_tz"); - assert_roundtrip::(&ctx, "timestamp_tz?"); - assert_roundtrip::(&ctx, "date"); - assert_roundtrip::(&ctx, "date?"); - assert_roundtrip::(&ctx, "time"); - assert_roundtrip::(&ctx, "time?"); - assert_roundtrip::(&ctx, "interval_year"); - assert_roundtrip::(&ctx, "interval_year?"); - assert_roundtrip::(&ctx, "uuid"); - assert_roundtrip::(&ctx, "uuid?"); - - assert_roundtrip::(&ctx, "MyType#12"); - assert_roundtrip::(&ctx, "MyType#12?"); - assert_roundtrip::(&ctx, "MyType#12"); - assert_roundtrip::(&ctx, "MyType#12"); - - assert_roundtrip::(&ctx, "MyType#12"); - assert_roundtrip::(&ctx, "MyType#12?"); +fn test_user_defined_type_roundtrip() { + let simple = r#"=== Extensions +URNs: + @ 4: some_source +Types: + # 12 @ 4: MyType + +=== Plan +Root[result] + Read[t => col:MyType]"#; + + let verbose = r#"=== Extensions +URNs: + @ 4: some_source +Types: + # 12 @ 4: MyType + +=== Plan +Root[result] + Read[t => col:MyType#12]"#; + + roundtrip_plan_with_verbose(simple, verbose); } #[test] -fn test_expression() { - let options = OutputOptions::default(); - let mut extensions = SimpleExtensions::default(); - extensions - .add_extension_urn("some_source".to_string(), 4) - .unwrap(); - extensions - .add_extension(ExtensionKind::Function, 4, 12, "foo".to_string()) - .unwrap(); - extensions - .add_extension(ExtensionKind::Function, 4, 14, "bar".to_string()) - .unwrap(); - - let ctx = TestContext { - options, - extensions, - extension_registry: Default::default(), - }; - - assert_roundtrip::(&ctx, "12"); - assert_roundtrip::(&ctx, "12:i32"); - roundtrip_parse::(&ctx, "$1"); - assert_roundtrip::(&ctx, "foo()"); - assert_roundtrip::(&ctx, "bar(12)"); - assert_roundtrip::(&ctx, "foo(bar(12:i16, 18), -4:i16)"); - assert_roundtrip::(&ctx, "foo($2, bar($5, 18), -4:i16)"); +fn test_expression_roundtrip() { + let plan = r#"=== Extensions +URNs: + @ 4: some_source +Functions: + # 12 @ 4: foo + # 14 @ 4: bar + +=== Plan +Root[result] + Project[foo(bar(12:i16, 18), -4:i16), foo($2, bar($5, 18), -4:i16)] + Read[t => a:i64, b:i64, c:i64, d:i64, e:i64, f:i64]"#; + + roundtrip_plan(plan); } #[test] fn test_verbose_and_simple_output() { - let mut extensions = SimpleExtensions::default(); - extensions - .add_extension_urn("some_source".to_string(), 4) - .unwrap(); - extensions - .add_extension(ExtensionKind::Function, 4, 12, "foo".to_string()) - .unwrap(); - extensions - .add_extension(ExtensionKind::Function, 4, 14, "bar".to_string()) - .unwrap(); - - roundtrip_with_simple_output::(extensions.clone(), "foo#12():i64", "foo()"); - // Check type is included in verbose output - roundtrip_with_simple_output::( - extensions.clone(), - "foo#12(3:i32, 5:i64):i64", - "foo(3:i32, 5)", - ); - roundtrip_with_simple_output::( - extensions.clone(), - "foo#12(bar#14(5:i64, $2)):i64", - "foo(bar(5, $2))", - ); + let simple = r#"=== Extensions +URNs: + @ 4: some_source +Functions: + # 12 @ 4: foo + # 14 @ 4: bar + +=== Plan +Root[result] + Project[foo(), foo(3:i32, 5), foo(bar(5, $2))] + Read[t => a:i64, b:i64, c:i64]"#; + + let verbose = r#"=== Extensions +URNs: + @ 4: some_source +Functions: + # 12 @ 4: foo + # 14 @ 4: bar + +=== Plan +Root[result] + Project[foo#12(), foo#12(3:i32, 5:i64), foo#12(bar#14(5:i64, $2))] + Read[t => a:i64, b:i64, c:i64]"#; + + roundtrip_plan_with_verbose(simple, verbose); } From 3782b7360d3ddec0194d07d475f463a67781bc5a Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Thu, 26 Feb 2026 12:10:50 -0500 Subject: [PATCH 3/6] docs: codify project conventions in CONTRIBUTING.md Add sections for import organization, pattern matching depth, unused parameter handling, parser pipeline pattern, extension handling guidelines, and rustdoc markdown formatting. --- CONTRIBUTING.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f2f2528..5d38bb7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,6 +90,14 @@ This project implements Substrait protobuf plan parsing and formatting. See the This project uses different error handling patterns depending on the context: +##### Repository-wide error expectation + +Error messages should include enough context to act on quickly: + +- Include location/context where relevant (for parser paths: line number and source line) +- Include a clear category or target (what failed) +- Keep syntax and semantic errors distinct when both layers exist + ##### Parser Context Use structured parse and lowering errors with explicit context (line number + source line). Syntax errors should come from LALRPOP parse entrypoints, and semantic validation errors should come from lowering: @@ -160,6 +168,18 @@ Keep parser internals split into two phases: This separation keeps grammar concerns and semantic checks independent and easier to test. +For extension declaration lines, use parser-layer entrypoints in +`parser/lalrpop_line.rs`: + +- Prefer explicit helpers (`parse_extension_urn_declaration`, + `parse_extension_declaration`) at parser call sites. +- `FromStr` is also available on parser AST declaration types + (`ExtensionUrnDeclaration`, `ExtensionDeclaration`) for ergonomic tests and + small adapters. + +Keep extension handlers parser-independent: lowering converts parser AST +relation arguments into `ExtensionArgs` before invoking extension resolution. + #### Documentation Formatting ##### Rustdoc Markdown Formatting @@ -181,6 +201,12 @@ When working with [`GRAMMAR.md`](GRAMMAR.md): - Keep grammar notation in `GRAMMAR.md` implementation-independent (PEG/EBNF style), while ensuring examples remain accepted by the implemented parser (`line_grammar.lalrpop`) - Run `cargo test --doc` to verify all code examples compile - Ensure all referenced identifiers are properly defined somewhere in the grammar +- Keep implementation, grammar, and docs aligned: behavior changes that affect syntax/grammar must be an explicit goal of the change and must be documented in both grammar/docs and implementation notes + +#### Readability and Rustic Conventions + +- Prefer standard Rust conversion traits where appropriate (`From`, `TryFrom`, `FromStr`) instead of ad hoc converter functions +- Extract helper functions when there are 2+ call sites and the extraction meaningfully reduces duplication or clarifies intent #### Testing Guidelines From b68d02d67a2878a322a343b18da8ef39b68121ab Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 13 Mar 2026 11:03:16 -0400 Subject: [PATCH 4/6] ci: add cargo doc checks to CI and Justfile --- .github/workflows/rust.yml | 7 +++++++ Justfile | 1 + 2 files changed, 8 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9218956..862360e 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -44,6 +44,13 @@ jobs: steps: - uses: actions/checkout@v4 - uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # 2.8.1 + - name: Check documentation + run: cargo doc --no-deps --all-features + env: + RUSTDOCFLAGS: "-D warnings" + + - name: Build + run: cargo build --features protoc --verbose - name: Run tests run: cargo test --all-features --verbose diff --git a/Justfile b/Justfile index 3f1ae16..19c0ae6 100644 --- a/Justfile +++ b/Justfile @@ -25,6 +25,7 @@ test: # Run clippy to check for linting errors. check: cargo clippy --examples --tests --all-features + RUSTDOCFLAGS="-D warnings" cargo doc --no-deps --all-features examples: cargo run --example basic_usage From 5a6d7f020b92d3a7b25660e4c6dfb872b0e62f91 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 13 Mar 2026 14:21:28 -0400 Subject: [PATCH 5/6] refactor: derive Eq + Hash on all AST types Add ordered_float dependency and implement manual Eq/Hash for LiteralValue (delegating to OrderedFloat for f64), then derive Eq + Hash on the full AST type chain. --- Cargo.lock | 10 +++++++++ Cargo.toml | 1 + src/parser/ast.rs | 57 +++++++++++++++++++++++++++++++++++------------ 3 files changed, 54 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 98e8a0d..cb6ffe3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -551,6 +551,15 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "ordered-float" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +dependencies = [ + "num-traits", +] + [[package]] name = "parking_lot" version = "0.12.5" @@ -1030,6 +1039,7 @@ dependencies = [ "indexmap", "lalrpop", "lalrpop-util", + "ordered-float", "pbjson-types", "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 3599a91..c83b82f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ anyhow = { version = "1.0", optional = true } chrono = "0.4.41" clap = { version = "4.5", features = ["derive"], optional = true } indexmap = "2" +ordered-float = "5" lalrpop-util = { version = "0.22.2", features = ["lexer"] } pbjson-types = { version = "0.8.0", optional = true } prost = "0.14" diff --git a/src/parser/ast.rs b/src/parser/ast.rs index 93f812d..dcacb17 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -13,6 +13,9 @@ //! semantic constraints; semantic validation happens in lowering. use std::fmt; +use std::hash::{Hash, Hasher}; + +use ordered_float::OrderedFloat; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExtensionUrnDeclaration { @@ -46,7 +49,7 @@ impl ExtensionDeclaration { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Relation { pub name: RelationName, pub args: ArgList, @@ -71,7 +74,7 @@ impl Relation { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum RelationName { Standard(String), Extension(ExtensionType, String), @@ -87,14 +90,14 @@ impl RelationName { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ExtensionType { Leaf, Single, Multi, } -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub struct ArgList { /// Positional arguments in source order. pub positional: Vec, @@ -106,7 +109,7 @@ pub struct ArgList { pub invalid_order: bool, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Arg { Expr(Expr), Enum(String), @@ -137,7 +140,7 @@ impl Arg { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct NamedArg { pub name: String, pub value: Arg, @@ -152,7 +155,7 @@ impl NamedArg { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Expr { FieldRef(i32), Literal(Literal), @@ -194,7 +197,7 @@ impl Expr { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct FunctionCall { pub name: String, pub anchor: Option, @@ -221,7 +224,7 @@ impl FunctionCall { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct IfThenExpr { pub clauses: Vec<(Expr, Expr)>, pub else_expr: Box, @@ -236,7 +239,7 @@ impl IfThenExpr { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Literal { pub value: LiteralValue, pub typ: Option, @@ -264,7 +267,7 @@ impl Literal { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum LiteralValue { Integer(i64), Float(f64), @@ -272,7 +275,33 @@ pub enum LiteralValue { String(String), } -#[derive(Debug, Clone, PartialEq)] +impl PartialEq for LiteralValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Integer(a), Self::Integer(b)) => a == b, + (Self::Float(a), Self::Float(b)) => OrderedFloat(*a) == OrderedFloat(*b), + (Self::Boolean(a), Self::Boolean(b)) => a == b, + (Self::String(a), Self::String(b)) => a == b, + _ => false, + } + } +} + +impl Eq for LiteralValue {} + +impl Hash for LiteralValue { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Integer(v) => v.hash(state), + Self::Float(v) => OrderedFloat(*v).hash(state), + Self::Boolean(v) => v.hash(state), + Self::String(v) => v.hash(state), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeExpr { Simple { name: String, @@ -323,7 +352,7 @@ impl TypeExpr { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum Nullability { #[default] Required, @@ -331,7 +360,7 @@ pub enum Nullability { Unspecified, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ArgEntry { Positional(Arg), Named(NamedArg), From 3901c44d4dcd480212f1324ebc2f3d4a14345231 Mon Sep 17 00:00:00 2001 From: Wendell Smith Date: Fri, 13 Mar 2026 13:57:59 -0400 Subject: [PATCH 6/6] feat: add grouping-set parsing support for aggregate relations --- src/parser/lower/relations.rs | 119 +++++++++++++++++++++++++++------- tests/plan_roundtrip.rs | 17 +++++ 2 files changed, 112 insertions(+), 24 deletions(-) diff --git a/src/parser/lower/relations.rs b/src/parser/lower/relations.rs index 730070f..4dae3ee 100644 --- a/src/parser/lower/relations.rs +++ b/src/parser/lower/relations.rs @@ -1,14 +1,16 @@ //! Relation-specific lowering logic. -use substrait::proto::aggregate_rel::Measure; +use std::collections::HashMap; + +use substrait::proto::aggregate_rel::{Grouping, Measure}; use substrait::proto::fetch_rel::{CountMode, OffsetMode}; use substrait::proto::rel::RelType; use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::sort_field::SortKind; use substrait::proto::r#type::{self as ptype, Nullability}; use substrait::proto::{ - AggregateFunction, AggregateRel, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, - ReadRel, Rel, RelCommon, SortField, SortRel, read_rel, + AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, + ProjectRel, ReadRel, Rel, RelCommon, SortField, SortRel, read_rel, }; use super::expr::{fetch_value_expression, field_ref_expression, lower_arg_as_expression}; @@ -200,23 +202,8 @@ fn lower_aggregate( ensure_no_named_args(ctx, relation, "Aggregate")?; let input = expect_one_child(ctx, &mut child_relations, "Aggregate")?; - let mut grouping_expressions = Vec::new(); - for arg in &relation.args.positional { - match arg { - ast::Arg::Wildcard => {} - ast::Arg::Expr(ast::Expr::FieldRef(index)) => { - grouping_expressions.push(field_ref_expression(*index)); - } - _ => { - return ctx.invalid( - "Aggregate", - format!( - "Aggregate grouping argument must be field reference or '_', got '{arg}'" - ), - ); - } - } - } + let (grouping_expressions, groupings) = + lower_grouping_sets(ctx, &relation.args.positional)?; let group_by_count = grouping_expressions.len(); let mut measures = Vec::new(); @@ -241,7 +228,6 @@ fn lower_aggregate( }), filter: None, }); - // Aggregate outputs are grouped fields followed by measures. output_mapping.push(group_by_count as i32 + (measures.len() as i32 - 1)); } _ => { @@ -257,9 +243,7 @@ fn lower_aggregate( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { input: Some(input), grouping_expressions, - // We use `grouping_expressions` (newer API) and leave legacy - // `groupings` empty for compatibility with current text format. - groupings: vec![], + groupings, measures, common: Some(RelCommon { emit_kind: Some(EmitKind::Emit(Emit { output_mapping })), @@ -270,6 +254,93 @@ fn lower_aggregate( }) } +/// Lower grouping arguments into deduplicated `grouping_expressions` and `Grouping`s. +/// +/// Handles both flat and tuple-based grouping: +/// - `$0, $1 => ...` — single grouping set with refs `[0, 1]` +/// - `($0, $1), ($0), _ => ...` — multiple grouping sets with deduplication +/// - `_ => ...` — global aggregate (empty expressions and groupings) +/// +/// Accepts arbitrary expressions (not just field refs) within grouping positions. +fn lower_grouping_sets( + ctx: &LowerCtx<'_>, + positional: &[ast::Arg], +) -> Result<(Vec, Vec), ParseError> { + let mut dedup: HashMap = HashMap::new(); + let mut grouping_expressions: Vec = Vec::new(); + let mut groupings: Vec = Vec::new(); + + for arg in positional { + match arg { + ast::Arg::Tuple(items) => { + let refs = lower_grouping_items(ctx, &mut dedup, &mut grouping_expressions, items)?; + groupings.push(make_grouping(refs)); + } + ast::Arg::Wildcard => { + groupings.push(make_grouping(vec![])); + } + ast::Arg::Expr(_) => { + // Bare expression: accumulate into an implicit single grouping set. + lower_grouping_item(ctx, &mut dedup, &mut grouping_expressions, arg)?; + } + _ => { + return ctx.invalid( + "Aggregate", + format!("Aggregate grouping argument must be expression, tuple, or '_', got '{arg}'"), + ); + } + } + } + + // If no explicit groupings were created, build one from all accumulated expressions. + if groupings.is_empty() && !dedup.is_empty() { + let expression_references: Vec = (0..grouping_expressions.len() as u32).collect(); + groupings.push(make_grouping(expression_references)); + } + + Ok((grouping_expressions, groupings)) +} + +/// Lower a slice of grouping items, returning their `expression_references`. +fn lower_grouping_items( + ctx: &LowerCtx<'_>, + dedup: &mut HashMap, + grouping_expressions: &mut Vec, + items: &[ast::Arg], +) -> Result, ParseError> { + let mut refs = Vec::with_capacity(items.len()); + for item in items { + let pos = lower_grouping_item(ctx, dedup, grouping_expressions, item)?; + refs.push(pos); + } + Ok(refs) +} + +/// Lower a single grouping item, deduplicating against previously seen items. +fn lower_grouping_item( + ctx: &LowerCtx<'_>, + dedup: &mut HashMap, + grouping_expressions: &mut Vec, + item: &ast::Arg, +) -> Result { + if let Some(&pos) = dedup.get(item) { + return Ok(pos); + } + let expr = lower_arg_as_expression(ctx, item, "Aggregate")?; + let pos = grouping_expressions.len() as u32; + grouping_expressions.push(expr); + dedup.insert(item.clone(), pos); + Ok(pos) +} + +fn make_grouping(expression_references: Vec) -> Grouping { + Grouping { + #[allow(deprecated)] + grouping_expressions: vec![], + expression_references, + } +} + #[allow(clippy::vec_box)] fn lower_sort( ctx: &LowerCtx<'_>, diff --git a/tests/plan_roundtrip.rs b/tests/plan_roundtrip.rs index d925778..51707c7 100644 --- a/tests/plan_roundtrip.rs +++ b/tests/plan_roundtrip.rs @@ -182,6 +182,23 @@ Root[sum, category, count] roundtrip_plan(plan); } +#[test] +fn test_aggregate_complex_grouping_sets_roundtrip() { + // 4 grouping sets, 3 unique expressions deduplicated across sets, 1 measure + let plan = r#"=== Extensions +URNs: + @ 1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml +Functions: + # 10 @ 1: count + +=== Plan +Root[category, region, amount, count] + Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)] + Read[orders => category:string?, region:string?, amount:fp64?, quantity:i32?]"#; + + roundtrip_plan(plan); +} + #[test] fn test_global_aggregate_relation_roundtrip() { let plan = r#"=== Extensions