Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,20 @@ def SetOpKind : I32EnumAttr<"SetOpKind", "", [
let cppNamespace = "::mlir::substrait";
}


/// Represents the `SimpleComparisonType` protobuf enum.
///
/// The enum values correspond exactly to those in the `SimpleComparisonType` enum,
/// i.e., conversion through integers is possible.
def SimpleComparisonType : I32EnumAttr<"SimpleComparisonType", "", [
// clang-format off
I32EnumAttrCase<"unspecified", 0>,
I32EnumAttrCase<"equal", 1>,
I32EnumAttrCase<"is_not_distinct_from", 2>,
I32EnumAttrCase<"might_equal", 3>,
// clang-format on
]> {
let cppNamespace = "::mlir::substrait";
}

#endif // SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITENUMS
106 changes: 105 additions & 1 deletion include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def Substrait_YieldOp : Substrait_Op<"yield", [
"::mlir::substrait::AggregateOp",
"::mlir::substrait::FilterOp",
"::mlir::substrait::PlanRelOp",
"::mlir::substrait::ProjectOp"
"::mlir::substrait::ProjectOp",
"::mlir::substrait::HashJoinOp"
]>
]> {
let summary = "Yields the result of a `PlanRelOp`";
Expand Down Expand Up @@ -716,6 +717,109 @@ def Substrait_FilterOp : Substrait_RelOp<"filter", [
}];
}

def Substrait_KeyComparisonOp : Substrait_ExpressionOp<"compare", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Key comparison expression";
let description = [{
Represents a comparison between two field references or expressions
used in join conditions.

Example:

```mlir
%3 = field_reference %arg0[0] : tuple<si32, si32>
%4 = field_reference %arg1[0] : tuple<si32, si32, si32>
%5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1
```
}];

let arguments = (ins
Substrait_ExpressionType:$lhs,
Substrait_ExpressionType:$rhs,
OptionalAttr<SimpleComparisonType>:$comparison_type,
OptionalAttr<UI32Attr>:$custom_function_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If custom_function_id is deliberately not implemented, i'd say remove this argument and write a TODO in the description field.

);

let results = (outs SI1:$result);

let assemblyFormat = [{
$comparison_type $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

`(` type($lhs) `,` type($rhs) `)` `->` type($result)

could be

functional-type(operands, $result)

}];

let builders = [
OpBuilder<(ins
"::mlir::Value":$lhs,
"::mlir::Value":$rhs,
"SimpleComparisonType":$comparison_type), [{
// Convert SimpleComparisonType to SimpleComparisonTypeAttr
auto comparisonAttr = ::mlir::substrait::SimpleComparisonTypeAttr::get($_builder.getContext(), comparison_type);
auto si1Type = ::mlir::IntegerType::get($_builder.getContext(), 1, ::mlir::IntegerType::Signed);
build($_builder, $_state, si1Type, lhs, rhs, comparisonAttr, {});
}]>
];

let hasVerifier = 1;
}

def Substrait_HashJoinOp : Substrait_RelOp<"hash_join", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>
]> {
let summary = "hash join operation";
let description = [{
Represents a `HashJoinRel` message together with the `RelCommon`, left and
right `Rel` messages and `JoinType` enumeration it contains. The join condition
is represented as a region where expressions can compare fields from both sides.
Join keys are stored as field references used for the hash join implementation.

Example:

```mlir
%0 = ...
%1 = ...
%2 = hash_join inner %0, %1 on {
^bb0(%arg0: tuple<si32, si32>, %arg1: tuple<si32, si32, si32>):
%3 = field_reference %arg0[0] : tuple<si32, si32>
%4 = field_reference %arg1[0] : tuple<si32, si32, si32>
%5 = compare not_distinct_from %3, %4 : (si32, si32) -> si1
// or depending if custom function is provided.
//TODO: the custom function is not yet supported.
%5 = call @cmp(%3, %4) : (si32, si32) -> si1
yield %5 : si1
}
```
}];
let arguments = (ins
Substrait_Relation:$left,
Substrait_Relation:$right,
JoinType:$join_type,
OptionalAttr<Substrait_AdvancedExtensionAttr>:$advanced_extension
);
let regions = (region AnyRegion:$condition);
let hasRegionVerifier = 1;
let results = (outs Substrait_Relation:$result);
let assemblyFormat = [{
$join_type $left `,` $right
(`advanced_extension` `` $advanced_extension^)?
attr-dict `:` type($left) `,` type($right) `->` type($result) $condition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above (functional-type).

}];
let builders = [
OpBuilder<(ins "::mlir::Value":$left, "::mlir::Value":$right,
"JoinType":$join_type), [{
build($_builder, $_state, left, right, join_type, /*advanced_extension=*/{});
}]>,
];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
::llvm::StringRef $cppClass::getDefaultDialect() {
return SubstraitDialect::getDialectNamespace();
}
}];
}

def Substrait_JoinOp : Substrait_RelOp<"join", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
Expand Down
101 changes: 101 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,107 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
return success();
}

LogicalResult
HashJoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
Value leftInput = operands[0];
Value rightInput = operands[1];

TypeRange leftFieldTypes = cast<TupleType>(leftInput.getType()).getTypes();
TypeRange rightFieldTypes = cast<TupleType>(rightInput.getType()).getTypes();

// Get accessor to `join_type`.
Adaptor adaptor(operands, attributes, properties, regions);
JoinType join_type = adaptor.getJoinType();

SmallVector<mlir::Type> fieldTypes;

switch (join_type) {
case JoinType::unspecified:
case JoinType::inner:
case JoinType::outer:
case JoinType::right:
case JoinType::left:
llvm::append_range(fieldTypes, leftFieldTypes);
llvm::append_range(fieldTypes, rightFieldTypes);
break;
case JoinType::semi:
case JoinType::anti:
llvm::append_range(fieldTypes, leftFieldTypes);
break;
case JoinType::single:
llvm::append_range(fieldTypes, rightFieldTypes);
break;
}

auto resultType = TupleType::get(context, fieldTypes);

inferredReturnTypes = SmallVector<Type>{resultType};

return success();
}

LogicalResult KeyComparisonOp::verify() {
auto &res =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this may be a bit unintuitive at first (i.e. one might expect that the comparison op should always have identical types for both operands), a comment might be warranted here.

Do we have some documentation about substraits stance here? That is, this code is assuming that the comparison op is able to perform comparisons against types that are "cast-compatible" (int <=> decimal, string <=> varchar).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Storing LogicalResult by reference is atypical; i'd just do auto res here.

llvm::TypeSwitch<mlir::Type, LogicalResult>(getLhs().getType())
.Case<substrait::VarCharType, substrait::StringType>([&](auto) {
if (!mlir::isa<substrait::VarCharType, substrait::StringType>(
getRhs()))
return this->emitError("Invalid rhs type for string comparison, "
"expected string-like type but got")
<< getRhs().getType();
return success();
})
.Case<IntegerType>([&](auto) {
if (!mlir::isa<IntegerType, substrait::DecimalType>(getRhs()))
return this->emitError("Invalid rhs type for integer comparison, "
"expected integer-like type but got")
<< getRhs().getType();
return success();
});
if (failed(res))
return res;

return success();
Comment on lines +1098 to +1101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just return res; here.

}

LogicalResult HashJoinOp::verifyRegions() {
if (getCondition().empty())
return emitOpError() << "hash join must have a condition region";

// Verify that we have exactly two arguments in the region, matching the left
// and right relations
Block &block = getCondition().front();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you've added the SingleBlockImplicitTerminator trait, you should be able to get the body block by just saying Block* body = getBody().

if (block.getNumArguments() != 2)
return emitOpError() << "condition region must have exactly two arguments";

if (block.getArgument(0).getType() != getLeft().getType())
return emitOpError()
<< "first block argument type must match left relation type";
if (block.getArgument(1).getType() != getRight().getType())
return emitOpError()
<< "second block argument type must match right relation type";

auto yieldOp = cast<YieldOp>(block.getTerminator());
if (yieldOp.getValue().size() != 1)
return emitOpError() << "condition region must yield exactly one value";

Type yieldedType = yieldOp.getValue()[0].getType();
MLIRContext *context = getContext();
Type si1Type = IntegerType::get(context, 1, IntegerType::Signed);
if (yieldedType != si1Type)
return emitOpError() << "condition region must yield a boolean (si1) value";

Value conditionValue = yieldOp.getValue()[0];
if (!isa<KeyComparisonOp>(conditionValue.getDefiningOp())) {
return emitOpError() << "condition must be produced by a comparison";
}
Comment on lines +1132 to +1134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: remove braces.


return success();
}

LogicalResult
JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attributes,
Expand Down
82 changes: 82 additions & 0 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression)
DECLARE_EXPORT_FUNC(FetchOp, Rel)
DECLARE_EXPORT_FUNC(FilterOp, Rel)
DECLARE_EXPORT_FUNC(HashJoinOp, Rel)
DECLARE_EXPORT_FUNC(JoinOp, Rel)
DECLARE_EXPORT_FUNC(LiteralOp, Expression)
DECLARE_EXPORT_FUNC(ModuleOp, pb::Message)
Expand Down Expand Up @@ -757,6 +758,86 @@ FailureOr<std::unique_ptr<Rel>> SubstraitExporter::exportOperation(JoinOp op) {
return rel;
}

FailureOr<std::unique_ptr<Rel>>
SubstraitExporter::exportOperation(HashJoinOp op) {
auto relCommon = std::make_unique<RelCommon>();
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());

auto leftOp =
llvm::dyn_cast_if_present<RelOpInterface>(op.getLeft().getDefiningOp());
if (!leftOp)
return op->emitOpError(
"left input was not produced by Substrait relation op");

FailureOr<std::unique_ptr<Rel>> leftRel = exportOperation(leftOp);
if (failed(leftRel))
return failure();

auto rightOp =
llvm::dyn_cast_if_present<RelOpInterface>(op.getRight().getDefiningOp());
if (!rightOp)
return op->emitOpError(
"right input was not produced by Substrait relation op");

FailureOr<std::unique_ptr<Rel>> rightRel = exportOperation(rightOp);
if (failed(rightRel))
return failure();

auto hashJoinRel = std::make_unique<HashJoinRel>();
hashJoinRel->set_allocated_common(relCommon.release());
hashJoinRel->set_allocated_left(leftRel->release());
hashJoinRel->set_allocated_right(rightRel->release());
hashJoinRel->set_type(static_cast<HashJoinRel::JoinType>(op.getJoinType()));

if (op.getCondition().empty()) {
return op->emitOpError("missing join condition");
}
Comment on lines +793 to +795
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove braces


Block &conditionBlock = op.getCondition().front();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above (op.getBody()).

Operation *terminator = conditionBlock.getTerminator();
Value conditionValue = cast<YieldOp>(terminator).getValue()[0];

auto compareOp =
dyn_cast_or_null<KeyComparisonOp>(conditionValue.getDefiningOp());
if (!compareOp) {
return op->emitOpError("join condition must be a KeyComparisonOp");
}
Comment on lines +803 to +805
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be removed, since you're already checking for this in your verifier.
nit: remove braces.


Value leftKey = compareOp.getLhs();
Value rightKey = compareOp.getRhs();
FailureOr<std::unique_ptr<Expression>> leftKeyExpr;
if (auto leftFieldRef =
dyn_cast_or_null<FieldReferenceOp>(leftKey.getDefiningOp())) {
leftKeyExpr = exportOperation(leftFieldRef);
} else {
return op->emitOpError() << "left key must be a field reference";
}

FailureOr<std::unique_ptr<Expression>> rightKeyExpr;
if (auto rightFieldRef =
dyn_cast_or_null<FieldReferenceOp>(rightKey.getDefiningOp())) {
rightKeyExpr = exportOperation(rightFieldRef);
} else {
return op->emitOpError() << "right key must be a field reference";
}
Comment on lines +810 to +823
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should be moved to a verifier of the KeyComparisonOp (i.e. set hasVerifier = 1 for the op).


auto keyComparison = hashJoinRel->add_keys();
keyComparison->set_allocated_left(leftKeyExpr->get()->release_selection());
keyComparison->set_allocated_right(rightKeyExpr->get()->release_selection());

// TODO(trion): support custom function comparison types.
keyComparison->mutable_comparison()->set_simple(
static_cast<ComparisonJoinKey_SimpleComparisonType>(
compareOp.getComparisonType().value()));

exportAdvancedExtension(op, *hashJoinRel);

auto rel = std::make_unique<Rel>();
rel->set_allocated_hash_join(hashJoinRel.release());
return rel;
}

FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(ExpressionOpInterface op) {
return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<Expression>>>(
Expand Down Expand Up @@ -1576,6 +1657,7 @@ SubstraitExporter::exportOperation(RelOpInterface op) {
FieldReferenceOp,
FilterOp,
JoinOp,
HashJoinOp,
NamedTableOp,
ProjectOp,
SetOp
Expand Down
Loading
Loading