From 5ebb426298f0ff4d8f8d98aa38feb34bdcb6bea7 Mon Sep 17 00:00:00 2001 From: Trion129 Date: Thu, 17 Apr 2025 14:18:23 +0000 Subject: [PATCH 1/4] initial changes --- .../Dialect/Substrait/IR/SubstraitEnums.td | 16 +++ .../Dialect/Substrait/IR/SubstraitOps.td | 78 ++++++++++++++ lib/Dialect/Substrait/IR/Substrait.cpp | 101 ++++++++++++++++++ lib/Target/SubstraitPB/Export.cpp | 70 ++++++++++++ lib/Target/SubstraitPB/Import.cpp | 101 ++++++++++++++++++ 5 files changed, 366 insertions(+) diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td index 8a5f9c33..87330d50 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td @@ -94,4 +94,20 @@ def SetOpKind : I32EnumAttr<"SetOpKind", "", [ let cppNamespace = "::mlir::substrait"; } + +/// Represents the `SimpleComparisonType` protobuf enum. +/// +/// The enum values correspond exactly to those in the `SimpleComparisonType` enum, +/// i.e., conversion through integers is possible. +def SimpleComparisonType : I32EnumAttr<"SimpleComparisonType", "", [ + // clang-format off + I32EnumAttrCase<"unspecified", 0>, + I32EnumAttrCase<"equal", 1>, + I32EnumAttrCase<"is_not_distinct_from", 2>, + I32EnumAttrCase<"might_equal", 3>, + // clang-format on + ]> { + let cppNamespace = "::mlir::substrait"; +} + #endif // SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITENUMS diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index 6f8eb861..ae021156 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -716,6 +716,84 @@ def Substrait_FilterOp : Substrait_RelOp<"filter", [ }]; } +def Substrait_HashJoinOp : Substrait_RelOp<"hash_join", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">, + DeclareOpInterfaceMethods + ]> { + let summary = "hash join operation"; + let description = [{ + Represents a `HashJoinRel` message together with the `RelCommon`, left and + right `Rel` messages and `JoinType` enumeration it contains. The join condition + is represented as a region where expressions can compare fields from both sides. + Join keys are stored as field references used for the hash join implementation. + + Example: + + ```mlir + %0 = ... + %1 = ... + %2 = hash_join inner %0, %1 on { + ^bb0(%arg0: tuple, %arg1: tuple): + %3 = field_reference %arg0[0] : tuple + %4 = field_reference %arg1[0] : tuple + %5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1 + // or depending if custom function is provided. + //TODO: the custom function is not yet supported. + %5 = call @cmp(%3, %4) : (si32, si32) -> si1 + yield %5 : si1 + } + ``` + }]; + let arguments = (ins + Substrait_Relation:$left, + Substrait_Relation:$right, + JoinType:$join_type, + OptionalAttr:$simple_comparison_type, + OptionalAttr:$custom_function_id, + Substrait_ExpressionType:$left_keys, + Substrait_ExpressionType:$right_keys, + OptionalAttr:$advanced_extension + ); + let regions = (region AnyRegion:$condition); + let hasRegionVerifier = 1; + let results = (outs Substrait_Relation:$result); + let assemblyFormat = [{ + $join_type $left `,` $right `on` + `left_keys` `` $left_keys `:` type($left_keys) `,` `right_keys` `` $right_keys `:` type($right_keys) + (`advanced_extension` `` $advanced_extension^)? + attr-dict `:` type($left) `,` type($right) `->` type($result) $condition + }]; + let builders = [ + OpBuilder<(ins "::mlir::Value":$left, "::mlir::Value":$right, + "JoinType":$join_type, + "::mlir::substrait::SimpleComparisonTypeAttr":$simple_comparison_type, + "::mlir::Value":$left_keys, + "::mlir::Value":$right_keys, + "::mlir::substrait::AdvancedExtensionAttr":$advanced_extension), [{ + build($_builder, $_state, left, right, join_type, simple_comparison_type, + /*custom_function_id=*/{}, left_keys, right_keys, advanced_extension); + }]>, + // Without advanced extension. + OpBuilder<(ins "::mlir::Value":$left, "::mlir::Value":$right, + "JoinType":$join_type, + "::mlir::substrait::SimpleComparisonTypeAttr":$simple_comparison_type, + "::mlir::Value":$left_keys, + "::mlir::Value":$right_keys), [{ + build($_builder, $_state, left, right, join_type, simple_comparison_type, + /*custom_function_id=*/{}, left_keys, right_keys, /*advanced_extension=*/{}); + }]>, + // TODO: Add support for custom function id. + ]; + let extraClassDefinition = [{ + /// Implement OpAsmOpInterface. + ::llvm::StringRef $cppClass::getDefaultDialect() { + return SubstraitDialect::getDialectNamespace(); + } + }]; +} + def Substrait_JoinOp : Substrait_RelOp<"join", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index 222f9976..1101b654 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -1035,6 +1035,107 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional loc, return success(); } +LogicalResult +HashJoinOp::inferReturnTypes(MLIRContext *context, std::optional loc, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + Value leftInput = operands[0]; + Value rightInput = operands[1]; + + TypeRange leftFieldTypes = cast(leftInput.getType()).getTypes(); + TypeRange rightFieldTypes = cast(rightInput.getType()).getTypes(); + + // Get accessor to `join_type`. + Adaptor adaptor(operands, attributes, properties, regions); + JoinType join_type = adaptor.getJoinType(); + + SmallVector fieldTypes; + + switch (join_type) { + case JoinType::unspecified: + case JoinType::inner: + case JoinType::outer: + case JoinType::right: + case JoinType::left: + llvm::append_range(fieldTypes, leftFieldTypes); + llvm::append_range(fieldTypes, rightFieldTypes); + break; + case JoinType::semi: + case JoinType::anti: + llvm::append_range(fieldTypes, leftFieldTypes); + break; + case JoinType::single: + llvm::append_range(fieldTypes, rightFieldTypes); + break; + } + + auto resultType = TupleType::get(context, fieldTypes); + + inferredReturnTypes = SmallVector{resultType}; + + return success(); +} + +LogicalResult HashJoinOp::verifyRegions() { + if (getCondition().empty()) + return success(); + + // Verify that we have exactly two arguments in the region, matching the left + // and right relations + Block &block = getCondition().front(); + if (block.getNumArguments() != 2) + return emitOpError() << "condition region must have exactly two arguments"; + + // Verify block argument types match input relation types + if (block.getArgument(0).getType() != getLeft().getType()) + return emitOpError() + << "first block argument type must match left relation type"; + if (block.getArgument(1).getType() != getRight().getType()) + return emitOpError() + << "second block argument type must match right relation type"; + + // Verify the condition region yields a boolean (si1) value + auto yieldOp = cast(block.getTerminator()); + if (yieldOp.getValue().size() != 1) + return emitOpError() << "condition region must yield exactly one value"; + + Type yieldedType = yieldOp.getValue()[0].getType(); + MLIRContext *context = getContext(); + Type si1Type = IntegerType::get(context, 1, IntegerType::Signed); + if (yieldedType != si1Type) + return emitOpError() << "condition region must yield a boolean (si1) value"; + + // Verify join keys + if (!getLeftKeys() || !getRightKeys()) + return emitOpError() << "must specify both left_keys and right_keys"; + + // Verify that left_keys references fields from left relation + Value leftKeys = getLeftKeys(); + Value leftContainer = leftKeys.getDefiningOp()->getOperand(0); + if (leftContainer != block.getArgument(0)) + return emitOpError() + << "left_keys must reference fields from the left relation"; + + // Verify that right_keys references fields from right relation + Value rightKeys = getRightKeys(); + Value rightContainer = rightKeys.getDefiningOp()->getOperand(0); + if (rightContainer != block.getArgument(1)) + return emitOpError() + << "right_keys must reference fields from the right relation"; + + // Verify the keys are compatible (if a simple comparison type is provided) + if (getSimpleComparisonType()) { + Type leftKeyType = leftKeys.getType(); + Type rightKeyType = rightKeys.getType(); + if (leftKeyType != rightKeyType) + return emitOpError() << "join keys must have the same type when using " + "simple comparison"; + } + + return success(); +} + LogicalResult JoinOp::inferReturnTypes(MLIRContext *context, std::optional loc, ValueRange operands, DictionaryAttr attributes, diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 510fcc83..31922151 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -63,6 +63,7 @@ class SubstraitExporter { DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression) DECLARE_EXPORT_FUNC(FetchOp, Rel) DECLARE_EXPORT_FUNC(FilterOp, Rel) + DECLARE_EXPORT_FUNC(HashJoinOp, Rel) DECLARE_EXPORT_FUNC(JoinOp, Rel) DECLARE_EXPORT_FUNC(LiteralOp, Expression) DECLARE_EXPORT_FUNC(ModuleOp, pb::Message) @@ -757,6 +758,74 @@ FailureOr> SubstraitExporter::exportOperation(JoinOp op) { return rel; } +FailureOr> +SubstraitExporter::exportOperation(HashJoinOp op) { + auto relCommon = std::make_unique(); + auto direct = std::make_unique(); + relCommon->set_allocated_direct(direct.release()); + + auto leftOp = + llvm::dyn_cast_if_present(op.getLeft().getDefiningOp()); + if (!leftOp) + return op->emitOpError( + "left input was not produced by Substrait relation op"); + + FailureOr> leftRel = exportOperation(leftOp); + if (failed(leftRel)) + return failure(); + + auto rightOp = + llvm::dyn_cast_if_present(op.getRight().getDefiningOp()); + if (!rightOp) + return op->emitOpError( + "right input was not produced by Substrait relation op"); + + FailureOr> rightRel = exportOperation(rightOp); + if (failed(rightRel)) + return failure(); + + auto hashJoinRel = std::make_unique(); + hashJoinRel->set_allocated_common(relCommon.release()); + hashJoinRel->set_allocated_left(leftRel->release()); + hashJoinRel->set_allocated_right(rightRel->release()); + hashJoinRel->set_type(static_cast(op.getJoinType())); + + Value leftKeys = op.getLeftKeys(); + FailureOr> leftKeyExpr = + exportOperation(cast(leftKeys.getDefiningOp())); + if (failed(leftKeyExpr)) + return failure(); + if (!leftKeyExpr->get()->has_selection()) { + return op->emitOpError("left key must be a field reference"); + } + + Value rightKeys = op.getRightKeys(); + FailureOr> rightKeyExpr = + exportOperation(cast(rightKeys.getDefiningOp())); + if (failed(rightKeyExpr)) + return failure(); + if (!rightKeyExpr->get()->has_selection()) { + return op->emitOpError("left key must be a field reference"); + } + + // Create ComparisonJoinKey for key_comparisons + auto keyComparison = hashJoinRel->add_keys(); + keyComparison->set_allocated_left(leftKeyExpr->get()->release_selection()); + keyComparison->set_allocated_right(rightKeyExpr->get()->release_selection()); + + // TODO(trion): support custom function comparison types. + keyComparison->mutable_comparison()->set_simple( + static_cast( + op.getSimpleComparisonTypeAttr().getValue())); + + // Attach the `AdvancedExtension` message if the attribute exists. + exportAdvancedExtension(op, *hashJoinRel); + + auto rel = std::make_unique(); + rel->set_allocated_hash_join(hashJoinRel.release()); + return rel; +} + FailureOr> SubstraitExporter::exportOperation(ExpressionOpInterface op) { return llvm::TypeSwitch>>( @@ -1576,6 +1645,7 @@ SubstraitExporter::exportOperation(RelOpInterface op) { FieldReferenceOp, FilterOp, JoinOp, + HashJoinOp, NamedTableOp, ProjectOp, SetOp diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index b0bb4048..fbc033a7 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -91,6 +91,7 @@ DECLARE_IMPORT_FUNC(ExtensionTable, Rel, ExtensionTableOp) DECLARE_IMPORT_FUNC(FieldReference, Expression::FieldReference, FieldReferenceOp) DECLARE_IMPORT_FUNC(JoinRel, Rel, JoinOp) +DECLARE_IMPORT_FUNC(HashJoinRel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Literal, Expression::Literal, LiteralOp) DECLARE_IMPORT_FUNC(NamedStruct, NamedStruct, ImportedNamedStruct) DECLARE_IMPORT_FUNC(NamedTable, Rel, NamedTableOp) @@ -102,6 +103,11 @@ DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp) DECLARE_IMPORT_FUNC(TopLevel, Plan, PlanOp) DECLARE_IMPORT_FUNC(TopLevel, PlanVersion, PlanVersionOp) +// If post join filter is present, wrap the given `hashJoin` op in a filter op +static mlir::FailureOr +wrapFilterOnJoin(ImplicitLocOpBuilder builder, RelOpInterface hashJoin, + const HashJoinRel &rel); + /// If present, imports the `advanced_extension` or `advanced_extensions` field /// from the given message and sets the obtained attribute on the given op. template @@ -633,6 +639,98 @@ static mlir::FailureOr importJoinRel(ImplicitLocOpBuilder builder, return joinOp; } +static mlir::FailureOr +wrapFilterOnJoin(ImplicitLocOpBuilder builder, RelOpInterface hashJoin, + const HashJoinRel &rel) { + if (!rel.has_post_join_filter()) { + return hashJoin; + } + // Create filter op. + auto filterOp = builder.create(hashJoin->getResult(0)); + filterOp.getCondition().push_back(new Block); + Block &conditionBlock = filterOp.getCondition().front(); + conditionBlock.addArgument(filterOp.getResult().getType(), + filterOp->getLoc()); + + // Create condition region. + const Expression &expression = rel.post_join_filter(); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(&conditionBlock); + + FailureOr conditionOp = + importExpression(builder, expression); + if (failed(conditionOp)) + return failure(); + + builder.create(conditionOp.value()->getResult(0)); + } + + return {filterOp}; +} + +static mlir::FailureOr +importHashJoinRel(ImplicitLocOpBuilder builder, const Rel &message) { + const HashJoinRel &hashJoinRel = message.hash_join(); + + // Import left and right inputs. + const Rel &leftRel = hashJoinRel.left(); + const Rel &rightRel = hashJoinRel.right(); + + mlir::FailureOr leftOp = importRel(builder, leftRel); + mlir::FailureOr rightOp = importRel(builder, rightRel); + + if (failed(leftOp) || failed(rightOp)) + return failure(); + + // Build `HashJoinOp`. + Value leftVal = leftOp.value()->getResult(0); + Value rightVal = rightOp.value()->getResult(0); + + std::optional join_type = static_cast(hashJoinRel.type()); + + FailureOr leftKeysExpr = + importFieldReference(builder, hashJoinRel.keys()[0].left()); + if (failed(leftKeysExpr)) + return failure(); + + FailureOr rightKeysExpr = + importFieldReference(builder, hashJoinRel.keys()[0].right()); + if (failed(rightKeysExpr)) + return failure(); + + // Check for unsupported set operations. + if (!join_type) + return mlir::emitError(builder.getLoc(), "unexpected 'operation' found"); + + auto simpleComparisonType = static_cast( + hashJoinRel.keys()[0].comparison().simple()); + + auto simpleComparisonTypeAttr = + SimpleComparisonTypeAttr::get(builder.getContext(), simpleComparisonType); + + mlir::FailureOr hashJoinOp = builder.create( + leftVal, rightVal, *join_type, simpleComparisonTypeAttr, + leftKeysExpr.value()->getResult(0), rightKeysExpr.value()->getResult(0)); + + if (failed(hashJoinOp)) { + return failure(); + } + + // Import advanced extension if it is present. + if (auto extensibleOp = + dyn_cast(hashJoinOp->getOperation())) { + importAdvancedExtension(builder, extensibleOp, hashJoinRel); + } + + mlir::FailureOr wrappedOp = wrapFilterOnJoin( + builder, cast(hashJoinOp->getOperation()), hashJoinRel); + if (failed(wrappedOp)) { + return failure(); + } + return wrappedOp; +} + static mlir::FailureOr importLiteral(ImplicitLocOpBuilder builder, const Expression::Literal &message) { @@ -1136,6 +1234,9 @@ static mlir::FailureOr importRel(ImplicitLocOpBuilder builder, case Rel::RelTypeCase::kJoin: maybeOp = importJoinRel(builder, message); break; + case Rel::RelTypeCase::kHashJoin: + maybeOp = importHashJoinRel(builder, message); + break; case Rel::RelTypeCase::kProject: maybeOp = importProjectRel(builder, message); break; From af9aded61f6ae9b8b49e0d280f1a1ab5d325389b Mon Sep 17 00:00:00 2001 From: Trion129 Date: Fri, 18 Apr 2025 11:25:32 +0000 Subject: [PATCH 2/4] Remove left and right keys to avoid redundancy --- .../Dialect/Substrait/IR/SubstraitOps.td | 70 +++++++++++++------ lib/Dialect/Substrait/IR/Substrait.cpp | 44 +++++------- lib/Target/SubstraitPB/Export.cpp | 46 +++++++----- lib/Target/SubstraitPB/Import.cpp | 57 ++++++++++----- 4 files changed, 131 insertions(+), 86 deletions(-) diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index ae021156..26769b8c 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -716,6 +716,50 @@ def Substrait_FilterOp : Substrait_RelOp<"filter", [ }]; } +def Substrait_KeyComparisonOp : Substrait_ExpressionOp<"compare", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Key comparison expression"; + let description = [{ + Represents a comparison between two field references or expressions + used in join conditions. + + Example: + + ```mlir + %3 = field_reference %arg0[0] : tuple + %4 = field_reference %arg1[0] : tuple + %5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1 + ``` + }]; + + let arguments = (ins + Substrait_ExpressionType:$lhs, + Substrait_ExpressionType:$rhs, + OptionalAttr:$comparison_type, + OptionalAttr:$custom_function_id + ); + + let results = (outs I1:$result); + + let assemblyFormat = [{ + $comparison_type $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($result) + }]; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$lhs, + "::mlir::Value":$rhs, + "SimpleComparisonType":$comparison_type), [{ + // Convert SimpleComparisonType to SimpleComparisonTypeAttr + auto comparisonAttr = ::mlir::substrait::SimpleComparisonTypeAttr::get($_builder.getContext(), comparison_type); + build($_builder, $_state, $_builder.getI1Type(), lhs, rhs, comparisonAttr, {}); + }]> + ]; + + let hasVerifier = 1; +} + def Substrait_HashJoinOp : Substrait_RelOp<"hash_join", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -750,41 +794,21 @@ def Substrait_HashJoinOp : Substrait_RelOp<"hash_join", [ Substrait_Relation:$left, Substrait_Relation:$right, JoinType:$join_type, - OptionalAttr:$simple_comparison_type, - OptionalAttr:$custom_function_id, - Substrait_ExpressionType:$left_keys, - Substrait_ExpressionType:$right_keys, OptionalAttr:$advanced_extension ); let regions = (region AnyRegion:$condition); let hasRegionVerifier = 1; let results = (outs Substrait_Relation:$result); let assemblyFormat = [{ - $join_type $left `,` $right `on` - `left_keys` `` $left_keys `:` type($left_keys) `,` `right_keys` `` $right_keys `:` type($right_keys) + $join_type $left `,` $right (`advanced_extension` `` $advanced_extension^)? attr-dict `:` type($left) `,` type($right) `->` type($result) $condition }]; let builders = [ OpBuilder<(ins "::mlir::Value":$left, "::mlir::Value":$right, - "JoinType":$join_type, - "::mlir::substrait::SimpleComparisonTypeAttr":$simple_comparison_type, - "::mlir::Value":$left_keys, - "::mlir::Value":$right_keys, - "::mlir::substrait::AdvancedExtensionAttr":$advanced_extension), [{ - build($_builder, $_state, left, right, join_type, simple_comparison_type, - /*custom_function_id=*/{}, left_keys, right_keys, advanced_extension); - }]>, - // Without advanced extension. - OpBuilder<(ins "::mlir::Value":$left, "::mlir::Value":$right, - "JoinType":$join_type, - "::mlir::substrait::SimpleComparisonTypeAttr":$simple_comparison_type, - "::mlir::Value":$left_keys, - "::mlir::Value":$right_keys), [{ - build($_builder, $_state, left, right, join_type, simple_comparison_type, - /*custom_function_id=*/{}, left_keys, right_keys, /*advanced_extension=*/{}); + "JoinType":$join_type), [{ + build($_builder, $_state, left, right, join_type, /*advanced_extension=*/{}); }]>, - // TODO: Add support for custom function id. ]; let extraClassDefinition = [{ /// Implement OpAsmOpInterface. diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index 1101b654..be450778 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -1077,9 +1077,21 @@ HashJoinOp::inferReturnTypes(MLIRContext *context, std::optional loc, return success(); } +LogicalResult KeyComparisonOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + + if (lhsType != rhsType) { + return emitOpError() << "operands must have the same type but got " + << lhsType << " and " << rhsType; + } + + return success(); +} + LogicalResult HashJoinOp::verifyRegions() { if (getCondition().empty()) - return success(); + return emitOpError() << "hash join must have a condition region"; // Verify that we have exactly two arguments in the region, matching the left // and right relations @@ -1087,7 +1099,6 @@ LogicalResult HashJoinOp::verifyRegions() { if (block.getNumArguments() != 2) return emitOpError() << "condition region must have exactly two arguments"; - // Verify block argument types match input relation types if (block.getArgument(0).getType() != getLeft().getType()) return emitOpError() << "first block argument type must match left relation type"; @@ -1095,7 +1106,6 @@ LogicalResult HashJoinOp::verifyRegions() { return emitOpError() << "second block argument type must match right relation type"; - // Verify the condition region yields a boolean (si1) value auto yieldOp = cast(block.getTerminator()); if (yieldOp.getValue().size() != 1) return emitOpError() << "condition region must yield exactly one value"; @@ -1106,31 +1116,9 @@ LogicalResult HashJoinOp::verifyRegions() { if (yieldedType != si1Type) return emitOpError() << "condition region must yield a boolean (si1) value"; - // Verify join keys - if (!getLeftKeys() || !getRightKeys()) - return emitOpError() << "must specify both left_keys and right_keys"; - - // Verify that left_keys references fields from left relation - Value leftKeys = getLeftKeys(); - Value leftContainer = leftKeys.getDefiningOp()->getOperand(0); - if (leftContainer != block.getArgument(0)) - return emitOpError() - << "left_keys must reference fields from the left relation"; - - // Verify that right_keys references fields from right relation - Value rightKeys = getRightKeys(); - Value rightContainer = rightKeys.getDefiningOp()->getOperand(0); - if (rightContainer != block.getArgument(1)) - return emitOpError() - << "right_keys must reference fields from the right relation"; - - // Verify the keys are compatible (if a simple comparison type is provided) - if (getSimpleComparisonType()) { - Type leftKeyType = leftKeys.getType(); - Type rightKeyType = rightKeys.getType(); - if (leftKeyType != rightKeyType) - return emitOpError() << "join keys must have the same type when using " - "simple comparison"; + Value conditionValue = yieldOp.getValue()[0]; + if (!isa(conditionValue.getDefiningOp())) { + return emitOpError() << "condition must be produced by a comparison"; } return success(); diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 31922151..28269677 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -790,25 +790,38 @@ SubstraitExporter::exportOperation(HashJoinOp op) { hashJoinRel->set_allocated_right(rightRel->release()); hashJoinRel->set_type(static_cast(op.getJoinType())); - Value leftKeys = op.getLeftKeys(); - FailureOr> leftKeyExpr = - exportOperation(cast(leftKeys.getDefiningOp())); - if (failed(leftKeyExpr)) - return failure(); - if (!leftKeyExpr->get()->has_selection()) { - return op->emitOpError("left key must be a field reference"); + if (op.getCondition().empty()) { + return op->emitOpError("missing join condition"); } - Value rightKeys = op.getRightKeys(); - FailureOr> rightKeyExpr = - exportOperation(cast(rightKeys.getDefiningOp())); - if (failed(rightKeyExpr)) - return failure(); - if (!rightKeyExpr->get()->has_selection()) { - return op->emitOpError("left key must be a field reference"); + Block &conditionBlock = op.getCondition().front(); + Operation *terminator = conditionBlock.getTerminator(); + Value conditionValue = cast(terminator).getValue()[0]; + + auto compareOp = + dyn_cast_or_null(conditionValue.getDefiningOp()); + if (!compareOp) { + return op->emitOpError("join condition must be a KeyComparisonOp"); + } + + Value leftKey = compareOp.getLhs(); + Value rightKey = compareOp.getRhs(); + FailureOr> leftKeyExpr; + if (auto leftFieldRef = + dyn_cast_or_null(leftKey.getDefiningOp())) { + leftKeyExpr = exportOperation(leftFieldRef); + } else { + return op->emitOpError() << "left key must be a field reference"; + } + + FailureOr> rightKeyExpr; + if (auto rightFieldRef = + dyn_cast_or_null(rightKey.getDefiningOp())) { + rightKeyExpr = exportOperation(rightFieldRef); + } else { + return op->emitOpError() << "right key must be a field reference"; } - // Create ComparisonJoinKey for key_comparisons auto keyComparison = hashJoinRel->add_keys(); keyComparison->set_allocated_left(leftKeyExpr->get()->release_selection()); keyComparison->set_allocated_right(rightKeyExpr->get()->release_selection()); @@ -816,9 +829,8 @@ SubstraitExporter::exportOperation(HashJoinOp op) { // TODO(trion): support custom function comparison types. keyComparison->mutable_comparison()->set_simple( static_cast( - op.getSimpleComparisonTypeAttr().getValue())); + compareOp.getComparisonType().value())); - // Attach the `AdvancedExtension` message if the attribute exists. exportAdvancedExtension(op, *hashJoinRel); auto rel = std::make_unique(); diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index fbc033a7..1d24e325 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -689,34 +689,55 @@ importHashJoinRel(ImplicitLocOpBuilder builder, const Rel &message) { std::optional join_type = static_cast(hashJoinRel.type()); - FailureOr leftKeysExpr = - importFieldReference(builder, hashJoinRel.keys()[0].left()); - if (failed(leftKeysExpr)) - return failure(); - - FailureOr rightKeysExpr = - importFieldReference(builder, hashJoinRel.keys()[0].right()); - if (failed(rightKeysExpr)) - return failure(); - // Check for unsupported set operations. if (!join_type) return mlir::emitError(builder.getLoc(), "unexpected 'operation' found"); - auto simpleComparisonType = static_cast( - hashJoinRel.keys()[0].comparison().simple()); + mlir::FailureOr hashJoinOp = + builder.create(leftVal, rightVal, *join_type); - auto simpleComparisonTypeAttr = - SimpleComparisonTypeAttr::get(builder.getContext(), simpleComparisonType); + if (failed(hashJoinOp)) { + return failure(); + } - mlir::FailureOr hashJoinOp = builder.create( - leftVal, rightVal, *join_type, simpleComparisonTypeAttr, - leftKeysExpr.value()->getResult(0), rightKeysExpr.value()->getResult(0)); + hashJoinOp->getCondition().push_back(new Block); + Block &conditionBlock = hashJoinOp->getCondition().front(); + conditionBlock.addArgument(leftVal.getType(), hashJoinOp->getLoc()); + conditionBlock.addArgument(rightVal.getType(), hashJoinOp->getLoc()); - if (failed(hashJoinOp)) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(&conditionBlock); + + FailureOr leftKeyOp = + importFieldReference(builder, hashJoinRel.keys()[0].left()); + if (failed(leftKeyOp)) + return failure(); + + FailureOr rightKeyOp = + importFieldReference(builder, hashJoinRel.keys()[0].right()); + if (failed(rightKeyOp)) return failure(); + + // Create the comparison operation + Value condition; + if (hashJoinRel.keys()[0].comparison().has_simple()) { + auto simpleComparisonType = static_cast( + hashJoinRel.keys()[0].comparison().simple()); + + Value leftKey = leftKeyOp.value(); + Value rightKey = rightKeyOp.value(); + + condition = builder.create(leftKey, rightKey, + simpleComparisonType); + } else { + // TODO(trion): Handle custom function if present + return mlir::emitError(builder.getLoc(), + "custom comparison functions not yet supported"); } + // Yield the condition from the region + builder.create(condition); + // Import advanced extension if it is present. if (auto extensibleOp = dyn_cast(hashJoinOp->getOperation())) { From 934f8cd441c14ab1fcb317078437464aa2ed3ebf Mon Sep 17 00:00:00 2001 From: Trion129 Date: Wed, 23 Apr 2025 14:47:54 +0000 Subject: [PATCH 3/4] Fix Yield MLIR generation issues --- .../substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td | 8 +++++--- lib/Target/SubstraitPB/Import.cpp | 5 ++++- lib/Target/SubstraitPB/ProtobufUtils.cpp | 4 ++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index 26769b8c..0f553ede 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -260,7 +260,8 @@ def Substrait_YieldOp : Substrait_Op<"yield", [ "::mlir::substrait::AggregateOp", "::mlir::substrait::FilterOp", "::mlir::substrait::PlanRelOp", - "::mlir::substrait::ProjectOp" + "::mlir::substrait::ProjectOp", + "::mlir::substrait::HashJoinOp" ]> ]> { let summary = "Yields the result of a `PlanRelOp`"; @@ -740,7 +741,7 @@ def Substrait_KeyComparisonOp : Substrait_ExpressionOp<"compare", [ OptionalAttr:$custom_function_id ); - let results = (outs I1:$result); + let results = (outs SI1:$result); let assemblyFormat = [{ $comparison_type $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($result) @@ -753,7 +754,8 @@ def Substrait_KeyComparisonOp : Substrait_ExpressionOp<"compare", [ "SimpleComparisonType":$comparison_type), [{ // Convert SimpleComparisonType to SimpleComparisonTypeAttr auto comparisonAttr = ::mlir::substrait::SimpleComparisonTypeAttr::get($_builder.getContext(), comparison_type); - build($_builder, $_state, $_builder.getI1Type(), lhs, rhs, comparisonAttr, {}); + auto si1Type = ::mlir::IntegerType::get($_builder.getContext(), 1, ::mlir::IntegerType::Signed); + build($_builder, $_state, si1Type, lhs, rhs, comparisonAttr, {}); }]> ]; diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 1d24e325..771bcc05 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -586,7 +586,10 @@ importFieldReference(ImplicitLocOpBuilder builder, // For the `root_reference` case, that's the current block argument. mlir::Block::BlockArgListType blockArgs = builder.getInsertionBlock()->getArguments(); - assert(blockArgs.size() == 1 && "expected a single block argument"); + if (blockArgs.empty()) { + return emitError(loc) + << "root reference requires at least one block argument"; + } container = blockArgs.front(); } else if (message.has_expression()) { // For the `expression` case, recursively import the expression. diff --git a/lib/Target/SubstraitPB/ProtobufUtils.cpp b/lib/Target/SubstraitPB/ProtobufUtils.cpp index 4db898b0..ce4375a2 100644 --- a/lib/Target/SubstraitPB/ProtobufUtils.cpp +++ b/lib/Target/SubstraitPB/ProtobufUtils.cpp @@ -42,6 +42,8 @@ FailureOr getCommon(const Rel &rel, Location loc) { return getCommon(rel.filter()); case Rel::RelTypeCase::kJoin: return getCommon(rel.join()); + case Rel::RelTypeCase::kHashJoin: + return getCommon(rel.hash_join()); case Rel::RelTypeCase::kProject: return getCommon(rel.project()); case Rel::RelTypeCase::kRead: @@ -73,6 +75,8 @@ FailureOr getMutableCommon(Rel *rel, Location loc) { return getMutableCommon(rel->mutable_filter()); case Rel::RelTypeCase::kJoin: return getMutableCommon(rel->mutable_join()); + case Rel::RelTypeCase::kHashJoin: + return getMutableCommon(rel->mutable_hash_join()); case Rel::RelTypeCase::kProject: return getMutableCommon(rel->mutable_project()); case Rel::RelTypeCase::kRead: From ab304539d5b8e741f6fe652488230e89755e2e9b Mon Sep 17 00:00:00 2001 From: Trion129 Date: Mon, 28 Apr 2025 14:01:36 +0000 Subject: [PATCH 4/4] Fix for comparison of similar key types --- lib/Dialect/Substrait/IR/Substrait.cpp | 26 +++++++++++++++++++------- lib/Target/SubstraitPB/Import.cpp | 4 ---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index be450778..7d147ed3 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -1078,13 +1078,25 @@ HashJoinOp::inferReturnTypes(MLIRContext *context, std::optional loc, } LogicalResult KeyComparisonOp::verify() { - Type lhsType = getLhs().getType(); - Type rhsType = getRhs().getType(); - - if (lhsType != rhsType) { - return emitOpError() << "operands must have the same type but got " - << lhsType << " and " << rhsType; - } + auto &res = + llvm::TypeSwitch(getLhs().getType()) + .Case([&](auto) { + if (!mlir::isa( + getRhs())) + return this->emitError("Invalid rhs type for string comparison, " + "expected string-like type but got") + << getRhs().getType(); + return success(); + }) + .Case([&](auto) { + if (!mlir::isa(getRhs())) + return this->emitError("Invalid rhs type for integer comparison, " + "expected integer-like type but got") + << getRhs().getType(); + return success(); + }); + if (failed(res)) + return res; return success(); } diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 771bcc05..b4e49274 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -675,8 +675,6 @@ wrapFilterOnJoin(ImplicitLocOpBuilder builder, RelOpInterface hashJoin, static mlir::FailureOr importHashJoinRel(ImplicitLocOpBuilder builder, const Rel &message) { const HashJoinRel &hashJoinRel = message.hash_join(); - - // Import left and right inputs. const Rel &leftRel = hashJoinRel.left(); const Rel &rightRel = hashJoinRel.right(); @@ -686,13 +684,11 @@ importHashJoinRel(ImplicitLocOpBuilder builder, const Rel &message) { if (failed(leftOp) || failed(rightOp)) return failure(); - // Build `HashJoinOp`. Value leftVal = leftOp.value()->getResult(0); Value rightVal = rightOp.value()->getResult(0); std::optional join_type = static_cast(hashJoinRel.type()); - // Check for unsupported set operations. if (!join_type) return mlir::emitError(builder.getLoc(), "unexpected 'operation' found");