-
Notifications
You must be signed in to change notification settings - Fork 5
feat: HashJoinOp mlir implementation
#129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -260,7 +260,8 @@ def Substrait_YieldOp : Substrait_Op<"yield", [ | |
| "::mlir::substrait::AggregateOp", | ||
| "::mlir::substrait::FilterOp", | ||
| "::mlir::substrait::PlanRelOp", | ||
| "::mlir::substrait::ProjectOp" | ||
| "::mlir::substrait::ProjectOp", | ||
| "::mlir::substrait::HashJoinOp" | ||
| ]> | ||
| ]> { | ||
| let summary = "Yields the result of a `PlanRelOp`"; | ||
|
|
@@ -716,6 +717,109 @@ def Substrait_FilterOp : Substrait_RelOp<"filter", [ | |
| }]; | ||
| } | ||
|
|
||
| def Substrait_KeyComparisonOp : Substrait_ExpressionOp<"compare", [ | ||
| DeclareOpInterfaceMethods<InferTypeOpInterface> | ||
| ]> { | ||
| let summary = "Key comparison expression"; | ||
| let description = [{ | ||
| Represents a comparison between two field references or expressions | ||
| used in join conditions. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| %3 = field_reference %arg0[0] : tuple<si32, si32> | ||
| %4 = field_reference %arg1[0] : tuple<si32, si32, si32> | ||
| %5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1 | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins | ||
| Substrait_ExpressionType:$lhs, | ||
| Substrait_ExpressionType:$rhs, | ||
| OptionalAttr<SimpleComparisonType>:$comparison_type, | ||
| OptionalAttr<UI32Attr>:$custom_function_id | ||
| ); | ||
|
|
||
| let results = (outs SI1:$result); | ||
|
|
||
| let assemblyFormat = [{ | ||
| $comparison_type $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($result) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could be |
||
| }]; | ||
|
|
||
| let builders = [ | ||
| OpBuilder<(ins | ||
| "::mlir::Value":$lhs, | ||
| "::mlir::Value":$rhs, | ||
| "SimpleComparisonType":$comparison_type), [{ | ||
| // Convert SimpleComparisonType to SimpleComparisonTypeAttr | ||
| auto comparisonAttr = ::mlir::substrait::SimpleComparisonTypeAttr::get($_builder.getContext(), comparison_type); | ||
| auto si1Type = ::mlir::IntegerType::get($_builder.getContext(), 1, ::mlir::IntegerType::Signed); | ||
| build($_builder, $_state, si1Type, lhs, rhs, comparisonAttr, {}); | ||
| }]> | ||
| ]; | ||
|
|
||
| let hasVerifier = 1; | ||
| } | ||
|
|
||
| def Substrait_HashJoinOp : Substrait_RelOp<"hash_join", [ | ||
| DeclareOpInterfaceMethods<InferTypeOpInterface>, | ||
| DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>, | ||
| SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">, | ||
| DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface> | ||
| ]> { | ||
| let summary = "hash join operation"; | ||
| let description = [{ | ||
| Represents a `HashJoinRel` message together with the `RelCommon`, left and | ||
| right `Rel` messages and `JoinType` enumeration it contains. The join condition | ||
| is represented as a region where expressions can compare fields from both sides. | ||
| Join keys are stored as field references used for the hash join implementation. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| %0 = ... | ||
| %1 = ... | ||
| %2 = hash_join inner %0, %1 on { | ||
| ^bb0(%arg0: tuple<si32, si32>, %arg1: tuple<si32, si32, si32>): | ||
| %3 = field_reference %arg0[0] : tuple<si32, si32> | ||
| %4 = field_reference %arg1[0] : tuple<si32, si32, si32> | ||
| %5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1 | ||
| // or depending if custom function is provided. | ||
| //TODO: the custom function is not yet supported. | ||
| %5 = call @cmp(%3, %4) : (si32, si32) -> si1 | ||
| yield %5 : si1 | ||
| } | ||
| ``` | ||
| }]; | ||
| let arguments = (ins | ||
| Substrait_Relation:$left, | ||
| Substrait_Relation:$right, | ||
| JoinType:$join_type, | ||
| OptionalAttr<Substrait_AdvancedExtensionAttr>:$advanced_extension | ||
| ); | ||
| let regions = (region AnyRegion:$condition); | ||
| let hasRegionVerifier = 1; | ||
| let results = (outs Substrait_Relation:$result); | ||
| let assemblyFormat = [{ | ||
| $join_type $left `,` $right | ||
| (`advanced_extension` `` $advanced_extension^)? | ||
| attr-dict `:` type($left) `,` type($right) `->` type($result) $condition | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as above (functional-type). |
||
| }]; | ||
| let builders = [ | ||
| OpBuilder<(ins "::mlir::Value":$left, "::mlir::Value":$right, | ||
| "JoinType":$join_type), [{ | ||
| build($_builder, $_state, left, right, join_type, /*advanced_extension=*/{}); | ||
| }]>, | ||
| ]; | ||
| let extraClassDefinition = [{ | ||
| /// Implement OpAsmOpInterface. | ||
| ::llvm::StringRef $cppClass::getDefaultDialect() { | ||
| return SubstraitDialect::getDialectNamespace(); | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
| def Substrait_JoinOp : Substrait_RelOp<"join", [ | ||
| DeclareOpInterfaceMethods<InferTypeOpInterface>, | ||
| DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1035,6 +1035,107 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc, | |
| return success(); | ||
| } | ||
|
|
||
| LogicalResult | ||
| HashJoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc, | ||
| ValueRange operands, DictionaryAttr attributes, | ||
| OpaqueProperties properties, RegionRange regions, | ||
| llvm::SmallVectorImpl<Type> &inferredReturnTypes) { | ||
| Value leftInput = operands[0]; | ||
| Value rightInput = operands[1]; | ||
|
|
||
| TypeRange leftFieldTypes = cast<TupleType>(leftInput.getType()).getTypes(); | ||
| TypeRange rightFieldTypes = cast<TupleType>(rightInput.getType()).getTypes(); | ||
|
|
||
| // Get accessor to `join_type`. | ||
| Adaptor adaptor(operands, attributes, properties, regions); | ||
| JoinType join_type = adaptor.getJoinType(); | ||
|
|
||
| SmallVector<mlir::Type> fieldTypes; | ||
|
|
||
| switch (join_type) { | ||
| case JoinType::unspecified: | ||
| case JoinType::inner: | ||
| case JoinType::outer: | ||
| case JoinType::right: | ||
| case JoinType::left: | ||
| llvm::append_range(fieldTypes, leftFieldTypes); | ||
| llvm::append_range(fieldTypes, rightFieldTypes); | ||
| break; | ||
| case JoinType::semi: | ||
| case JoinType::anti: | ||
| llvm::append_range(fieldTypes, leftFieldTypes); | ||
| break; | ||
| case JoinType::single: | ||
| llvm::append_range(fieldTypes, rightFieldTypes); | ||
| break; | ||
| } | ||
|
|
||
| auto resultType = TupleType::get(context, fieldTypes); | ||
|
|
||
| inferredReturnTypes = SmallVector<Type>{resultType}; | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult KeyComparisonOp::verify() { | ||
| auto &res = | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this may be a bit unintuitive at first (i.e. one might expect that the comparison op should always have identical types for both operands), a comment might be warranted here. Do we have some documentation about substraits stance here? That is, this code is assuming that the comparison op is able to perform comparisons against types that are "cast-compatible" (int <=> decimal, string <=> varchar).
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Storing |
||
| llvm::TypeSwitch<mlir::Type, LogicalResult>(getLhs().getType()) | ||
| .Case<substrait::VarCharType, substrait::StringType>([&](auto) { | ||
| if (!mlir::isa<substrait::VarCharType, substrait::StringType>( | ||
| getRhs())) | ||
| return this->emitError("Invalid rhs type for string comparison, " | ||
| "expected string-like type but got") | ||
| << getRhs().getType(); | ||
| return success(); | ||
| }) | ||
| .Case<IntegerType>([&](auto) { | ||
| if (!mlir::isa<IntegerType, substrait::DecimalType>(getRhs())) | ||
| return this->emitError("Invalid rhs type for integer comparison, " | ||
| "expected integer-like type but got") | ||
| << getRhs().getType(); | ||
| return success(); | ||
| }); | ||
| if (failed(res)) | ||
| return res; | ||
|
|
||
| return success(); | ||
|
Comment on lines
+1098
to
+1101
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could just |
||
| } | ||
|
|
||
| LogicalResult HashJoinOp::verifyRegions() { | ||
| if (getCondition().empty()) | ||
| return emitOpError() << "hash join must have a condition region"; | ||
|
|
||
| // Verify that we have exactly two arguments in the region, matching the left | ||
| // and right relations | ||
| Block &block = getCondition().front(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you've added the |
||
| if (block.getNumArguments() != 2) | ||
| return emitOpError() << "condition region must have exactly two arguments"; | ||
|
|
||
| if (block.getArgument(0).getType() != getLeft().getType()) | ||
| return emitOpError() | ||
| << "first block argument type must match left relation type"; | ||
| if (block.getArgument(1).getType() != getRight().getType()) | ||
| return emitOpError() | ||
| << "second block argument type must match right relation type"; | ||
|
|
||
| auto yieldOp = cast<YieldOp>(block.getTerminator()); | ||
| if (yieldOp.getValue().size() != 1) | ||
| return emitOpError() << "condition region must yield exactly one value"; | ||
|
|
||
| Type yieldedType = yieldOp.getValue()[0].getType(); | ||
| MLIRContext *context = getContext(); | ||
| Type si1Type = IntegerType::get(context, 1, IntegerType::Signed); | ||
| if (yieldedType != si1Type) | ||
| return emitOpError() << "condition region must yield a boolean (si1) value"; | ||
|
|
||
| Value conditionValue = yieldOp.getValue()[0]; | ||
| if (!isa<KeyComparisonOp>(conditionValue.getDefiningOp())) { | ||
| return emitOpError() << "condition must be produced by a comparison"; | ||
| } | ||
|
Comment on lines
+1132
to
+1134
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: remove braces. |
||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult | ||
| JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc, | ||
| ValueRange operands, DictionaryAttr attributes, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,7 @@ class SubstraitExporter { | |
| DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression) | ||
| DECLARE_EXPORT_FUNC(FetchOp, Rel) | ||
| DECLARE_EXPORT_FUNC(FilterOp, Rel) | ||
| DECLARE_EXPORT_FUNC(HashJoinOp, Rel) | ||
| DECLARE_EXPORT_FUNC(JoinOp, Rel) | ||
| DECLARE_EXPORT_FUNC(LiteralOp, Expression) | ||
| DECLARE_EXPORT_FUNC(ModuleOp, pb::Message) | ||
|
|
@@ -757,6 +758,86 @@ FailureOr<std::unique_ptr<Rel>> SubstraitExporter::exportOperation(JoinOp op) { | |
| return rel; | ||
| } | ||
|
|
||
| FailureOr<std::unique_ptr<Rel>> | ||
| SubstraitExporter::exportOperation(HashJoinOp op) { | ||
| auto relCommon = std::make_unique<RelCommon>(); | ||
| auto direct = std::make_unique<RelCommon::Direct>(); | ||
| relCommon->set_allocated_direct(direct.release()); | ||
|
|
||
| auto leftOp = | ||
| llvm::dyn_cast_if_present<RelOpInterface>(op.getLeft().getDefiningOp()); | ||
| if (!leftOp) | ||
| return op->emitOpError( | ||
| "left input was not produced by Substrait relation op"); | ||
|
|
||
| FailureOr<std::unique_ptr<Rel>> leftRel = exportOperation(leftOp); | ||
| if (failed(leftRel)) | ||
| return failure(); | ||
|
|
||
| auto rightOp = | ||
| llvm::dyn_cast_if_present<RelOpInterface>(op.getRight().getDefiningOp()); | ||
| if (!rightOp) | ||
| return op->emitOpError( | ||
| "right input was not produced by Substrait relation op"); | ||
|
|
||
| FailureOr<std::unique_ptr<Rel>> rightRel = exportOperation(rightOp); | ||
| if (failed(rightRel)) | ||
| return failure(); | ||
|
|
||
| auto hashJoinRel = std::make_unique<HashJoinRel>(); | ||
| hashJoinRel->set_allocated_common(relCommon.release()); | ||
| hashJoinRel->set_allocated_left(leftRel->release()); | ||
| hashJoinRel->set_allocated_right(rightRel->release()); | ||
| hashJoinRel->set_type(static_cast<HashJoinRel::JoinType>(op.getJoinType())); | ||
|
|
||
| if (op.getCondition().empty()) { | ||
| return op->emitOpError("missing join condition"); | ||
| } | ||
|
Comment on lines
+793
to
+795
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove braces |
||
|
|
||
| Block &conditionBlock = op.getCondition().front(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above ( |
||
| Operation *terminator = conditionBlock.getTerminator(); | ||
| Value conditionValue = cast<YieldOp>(terminator).getValue()[0]; | ||
|
|
||
| auto compareOp = | ||
| dyn_cast_or_null<KeyComparisonOp>(conditionValue.getDefiningOp()); | ||
| if (!compareOp) { | ||
| return op->emitOpError("join condition must be a KeyComparisonOp"); | ||
| } | ||
|
Comment on lines
+803
to
+805
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be removed, since you're already checking for this in your verifier. |
||
|
|
||
| Value leftKey = compareOp.getLhs(); | ||
| Value rightKey = compareOp.getRhs(); | ||
| FailureOr<std::unique_ptr<Expression>> leftKeyExpr; | ||
| if (auto leftFieldRef = | ||
| dyn_cast_or_null<FieldReferenceOp>(leftKey.getDefiningOp())) { | ||
| leftKeyExpr = exportOperation(leftFieldRef); | ||
| } else { | ||
| return op->emitOpError() << "left key must be a field reference"; | ||
| } | ||
|
|
||
| FailureOr<std::unique_ptr<Expression>> rightKeyExpr; | ||
| if (auto rightFieldRef = | ||
| dyn_cast_or_null<FieldReferenceOp>(rightKey.getDefiningOp())) { | ||
| rightKeyExpr = exportOperation(rightFieldRef); | ||
| } else { | ||
| return op->emitOpError() << "right key must be a field reference"; | ||
| } | ||
|
Comment on lines
+810
to
+823
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic should be moved to a verifier of the |
||
|
|
||
| auto keyComparison = hashJoinRel->add_keys(); | ||
| keyComparison->set_allocated_left(leftKeyExpr->get()->release_selection()); | ||
| keyComparison->set_allocated_right(rightKeyExpr->get()->release_selection()); | ||
|
|
||
| // TODO(trion): support custom function comparison types. | ||
| keyComparison->mutable_comparison()->set_simple( | ||
| static_cast<ComparisonJoinKey_SimpleComparisonType>( | ||
| compareOp.getComparisonType().value())); | ||
|
|
||
| exportAdvancedExtension(op, *hashJoinRel); | ||
|
|
||
| auto rel = std::make_unique<Rel>(); | ||
| rel->set_allocated_hash_join(hashJoinRel.release()); | ||
| return rel; | ||
| } | ||
|
|
||
| FailureOr<std::unique_ptr<Expression>> | ||
| SubstraitExporter::exportOperation(ExpressionOpInterface op) { | ||
| return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<Expression>>>( | ||
|
|
@@ -1576,6 +1657,7 @@ SubstraitExporter::exportOperation(RelOpInterface op) { | |
| FieldReferenceOp, | ||
| FilterOp, | ||
| JoinOp, | ||
| HashJoinOp, | ||
| NamedTableOp, | ||
| ProjectOp, | ||
| SetOp | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
custom_function_idis deliberately not implemented, i'd say remove this argument and write a TODO in thedescriptionfield.