Skip to content

Commit 057f3ad

Browse files
authored
Decompose CustomOp that is ONNXRuntime FusedMatMul (#2092)
* Decompose CustomOp that is ONNXRuntime FusedMatMul Signed-off-by: Tung D. Le <[email protected]>
1 parent 6e6ec97 commit 057f3ad

File tree

2 files changed

+297
-0
lines changed

2 files changed

+297
-0
lines changed

src/Transform/ONNX/Decompose.cpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
// implement shape inference for the decomposed operation. Hence, it is expected
1616
// that there is no knowledge about tensor shape at this point.
1717
//
18+
// TODO: This file is quite busy as the number of decomposing op is increasing.
19+
// It is better to move decomposition of each operation into a separate file.
20+
//
1821
//===----------------------------------------------------------------------===//
1922

2023
#include "mlir/IR/Matchers.h"
@@ -26,6 +29,7 @@
2629
#include "src/Dialect/ONNX/ONNXOps.hpp"
2730
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
2831
#include "src/Pass/Passes.hpp"
32+
#include "src/Support/TypeUtilities.hpp"
2933
#include "src/Transform/ONNX/DecomposeEinsum.hpp"
3034

3135
using namespace mlir;
@@ -577,6 +581,175 @@ struct ConcatFusePattern : public ConversionPattern {
577581
}
578582
};
579583

584+
// Decompose the custom op FusedMatMul that is produced by ONNXRuntime.
585+
// According to FusedMatMul specification, it is the result of fusing MatMul and
586+
// Transpose:
587+
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
588+
//
589+
// To decompose FusedMatMul, we need to know ranks of inputs A and B, so that
590+
// we can emit Transpose operations. But, in general, we have no information
591+
// about the ranks of A and B.
592+
//
593+
// The rewriting here only applies to a situation in which the transposed input
594+
// comes from another Transpose that we have rank information via looking at
595+
// `perm` // attribute. For example, if `transA = 1`, A must be from a Transpose
596+
// to determine the rank of A.
597+
//
598+
// Example of onnx.Custom:
599+
// ```
600+
// "onnx.Custom"(%0, %1) {alpha = 1.250000e-01 : f32,
601+
// domain_name = "com.microsoft",
602+
// function_name = "FusedMatMul",
603+
// transA = 0 : si64, transB = 1 : si64} :
604+
// (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
605+
// ```
606+
struct CustomOpFuseMatMulPattern : public OpConversionPattern<ONNXCustomOp> {
607+
CustomOpFuseMatMulPattern(MLIRContext *context)
608+
: OpConversionPattern(context) {}
609+
LogicalResult matchAndRewrite(ONNXCustomOp customOp,
610+
ONNXCustomOp::Adaptor adaptor,
611+
ConversionPatternRewriter &rewriter) const final {
612+
using namespace onnx_mlir;
613+
Location loc = customOp.getLoc();
614+
615+
// Match
616+
FloatAttr alphaAttr;
617+
int64_t rankA, rankB;
618+
if (!isCustomOpFusedMatMulMatched(customOp, alphaAttr, rankA, rankB))
619+
return failure();
620+
621+
// Rewrite ONNXCustomOp {alpha} (A, B) into `Mul(alpha, MatMul(A, B)`
622+
Value A = customOp.getOperands()[0];
623+
Value B = customOp.getOperands()[1];
624+
625+
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
626+
Type resType = customOp.getResult(0).getType();
627+
Type elementType = onnx_mlir::getElementType(resType);
628+
UnrankedTensorType unrankedType = UnrankedTensorType::get(elementType);
629+
630+
Value matmulA = A;
631+
Value matmulB = B;
632+
// Transpose A if transA.
633+
if (rankA != -1) {
634+
// Prepare permutation attribute.
635+
SmallVector<int64_t, 4> indices;
636+
for (int64_t i = 0; i < rankA - 2; ++i)
637+
indices.emplace_back(i);
638+
// Permute the last two dimensions.
639+
indices.emplace_back(rankA - 1);
640+
indices.emplace_back(rankA - 2);
641+
ArrayAttr permAttr = rewriter.getI64ArrayAttr(llvm::ArrayRef(indices));
642+
matmulA = create.onnx.transpose(unrankedType, A, permAttr);
643+
}
644+
// Transpose B if transB.
645+
if (rankB != -1) {
646+
// Prepare permutation attribute.
647+
SmallVector<int64_t, 4> indices;
648+
for (int64_t i = 0; i < rankB - 2; ++i)
649+
indices.emplace_back(i);
650+
// Permute the last two dimensions.
651+
indices.emplace_back(rankB - 1);
652+
indices.emplace_back(rankB - 2);
653+
ArrayAttr permAttr = rewriter.getI64ArrayAttr(llvm::ArrayRef(indices));
654+
matmulB = create.onnx.transpose(unrankedType, B, permAttr);
655+
}
656+
// alpha
657+
DenseElementsAttr alphaDenseAttr =
658+
onnx_mlir::createDenseElementsAttrFromFloatAttr(
659+
rewriter, elementType, alphaAttr);
660+
Value alpha = create.onnx.constant(alphaDenseAttr);
661+
662+
Value res = create.onnx.matmul(resType, matmulA, matmulB);
663+
res = create.onnx.mul(alpha, res);
664+
665+
rewriter.replaceOp(customOp, res);
666+
return success();
667+
}
668+
669+
public:
670+
static bool isCustomOpFusedMatMulMatched(ONNXCustomOp customOp,
671+
FloatAttr &alphaAttr, int64_t &rankA, int64_t &rankB) {
672+
Operation *genericOp = customOp.getOperation();
673+
// CustomOp has two operands.
674+
if (customOp.getNumOperands() != 2)
675+
return false;
676+
Value A = genericOp->getOperands()[0];
677+
Value B = genericOp->getOperands()[1];
678+
679+
// function_name is FusedMatMul.
680+
StringRef funcName = customOp.getFunctionName();
681+
if (!funcName.equals_insensitive("FusedMatMul"))
682+
return false;
683+
684+
// domain_name exists and is "com.microsoft";
685+
StringAttr domAttr = genericOp->getAttrOfType<StringAttr>("domain_name");
686+
if (!domAttr)
687+
return false;
688+
if (!domAttr.getValue().equals_insensitive("com.microsoft"))
689+
return false;
690+
691+
// transA and transB exist.
692+
IntegerAttr transA = genericOp->getAttrOfType<IntegerAttr>("transA");
693+
IntegerAttr transB = genericOp->getAttrOfType<IntegerAttr>("transB");
694+
if (!transA || !transB)
695+
return false;
696+
bool isTransA = (transA.getValue().getSExtValue() == 1);
697+
bool isTransB = (transB.getValue().getSExtValue() == 1);
698+
699+
// If transA=true, we have to know A's rank to generate ONNXTransposeOp for
700+
// A. In a good condition, A is ranked then its rank is avilable.
701+
//
702+
// If A is unranked, we hope that A is a result of another ONNXTransposeOp
703+
// whose permutation is available and can be used to infer the rank of A.
704+
// For example,
705+
// %A = "onnx.Transpose"(%0) {perm = [0, 2, 1, 3]} :
706+
// (tensor<*xf32>) -> tensor<*xf32>
707+
// A must have rank 4 as perm has 4 indices.
708+
if (isTransA) {
709+
if (onnx_mlir::hasShapeAndRank(A)) {
710+
rankA = A.getType().cast<ShapedType>().getRank();
711+
} else {
712+
if (isa<BlockArgument>(A))
713+
return false;
714+
if (auto transOp = dyn_cast<ONNXTransposeOp>(A.getDefiningOp())) {
715+
if (transOp.getPermAttr())
716+
rankA = transOp.getPermAttr().size();
717+
else
718+
return false;
719+
} else
720+
// Cannot determine the rank of A.
721+
return false;
722+
}
723+
} else
724+
rankA = -1;
725+
if (isTransB) {
726+
if (onnx_mlir::hasShapeAndRank(B)) {
727+
rankB = B.getType().cast<ShapedType>().getRank();
728+
} else {
729+
if (isa<BlockArgument>(B))
730+
return false;
731+
if (auto transOp = dyn_cast<ONNXTransposeOp>(B.getDefiningOp())) {
732+
if (transOp.getPermAttr())
733+
rankB = transOp.getPermAttr().size();
734+
else
735+
return false;
736+
} else
737+
// Cannot determine the rank of B.
738+
return false;
739+
}
740+
} else
741+
rankB = -1;
742+
743+
// Get alpha.
744+
alphaAttr = genericOp->getAttrOfType<FloatAttr>("alpha");
745+
if (!alphaAttr)
746+
return false;
747+
748+
// CustomOp is in a good form to rewrite.
749+
return true;
750+
}
751+
};
752+
580753
struct DecomposeONNXToONNXPass
581754
: public PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
582755
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeONNXToONNXPass)
@@ -640,6 +813,14 @@ void DecomposeONNXToONNXPass::runOnOperation() {
640813
ONNXTransposeOp transposeOp = NULL;
641814
return !isConcatFuseMatched(op, shapeOp, transposeOp);
642815
});
816+
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
817+
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
818+
target.addDynamicallyLegalOp<ONNXCustomOp>([](ONNXCustomOp op) {
819+
int64_t rankA, rankB;
820+
FloatAttr alpha;
821+
return !CustomOpFuseMatMulPattern::isCustomOpFusedMatMulMatched(
822+
op, alpha, rankA, rankB);
823+
});
643824

644825
#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE
645826
#ifdef ONNX_MLIR_ENABLE_MHLO
@@ -669,6 +850,9 @@ void DecomposeONNXToONNXPass::runOnOperation() {
669850
populateWithGenerated(patterns);
670851
patterns.insert<onnx_mlir::DecomposeEinsumPattern>(&getContext());
671852
patterns.insert<ConcatFusePattern>(&getContext());
853+
// Decompose CustomOp FusedMatMul introduced by onnxruntime:
854+
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
855+
patterns.insert<CustomOpFuseMatMulPattern>(&getContext());
672856

673857
#ifdef ONNX_MLIR_ENABLE_MHLO
674858
if (this->target == "mhlo") {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// RUN: onnx-mlir-opt --decompose-onnx %s -split-input-file | FileCheck %s
2+
3+
// COM: Decompose CustomOp introduced by onnxruntime.
4+
5+
func.func @customop_fusedmatmul_onnxruntime(%arg0: tensor<3x5x7x9xf32>, %arg1:tensor<3x5x7x9xf32>) -> tensor<3x5x9x9xf32> {
6+
%0 = "onnx.Custom"(%arg0, %arg1) {alpha = 1.250000e-01 : f32, domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 1 : si64, transB = 0 : si64} : (tensor<3x5x7x9xf32>, tensor<3x5x7x9xf32>) -> tensor<3x5x9x9xf32>
7+
return %0: tensor<3x5x9x9xf32>
8+
9+
// CHECK-LABEL: func.func @customop_fusedmatmul_onnxruntime
10+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x5x7x9xf32>, [[PARAM_1_:%.+]]: tensor<3x5x7x9xf32>) -> tensor<3x5x9x9xf32> {
11+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 1, 3, 2]} : (tensor<3x5x7x9xf32>) -> tensor<3x5x9x7xf32>
12+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32>
13+
// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[VAR_0_]], [[PARAM_1_]]) : (tensor<3x5x9x7xf32>, tensor<3x5x7x9xf32>) -> tensor<3x5x9x9xf32>
14+
// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_2_]]) : (tensor<1xf32>, tensor<3x5x9x9xf32>) -> tensor<3x5x9x9xf32>
15+
// CHECK: return [[VAR_3_]] : tensor<3x5x9x9xf32>
16+
// CHECK: }
17+
}
18+
19+
// -----
20+
21+
func.func @customop_fusedmatmul_onnxruntime_no_transpose(%arg0: tensor<*xf32>, %arg1:tensor<*xf32>) -> tensor<*xf32> {
22+
%0 = "onnx.Custom"(%arg0, %arg1) {alpha = 1.250000e-01 : f32, domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 0 : si64, transB = 0 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
23+
return %0: tensor<*xf32>
24+
25+
// CHECK-LABEL: func.func @customop_fusedmatmul_onnxruntime_no_transpose
26+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
27+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32>
28+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
29+
// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_1_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
30+
// CHECK: return [[VAR_2_]] : tensor<*xf32>
31+
// CHECK: }
32+
}
33+
34+
// -----
35+
36+
func.func @customop_fusedmatmul_onnxruntime_transA(%arg0: tensor<*xf32>, %arg1:tensor<*xf32>) -> tensor<*xf32> {
37+
%0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32>
38+
%1 = "onnx.Custom"(%0, %arg1) {alpha = 1.250000e-01 : f32, domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 1 : si64, transB = 0 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
39+
return %1: tensor<*xf32>
40+
41+
// CHECK-LABEL: func.func @customop_fusedmatmul_onnxruntime_transA
42+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
43+
// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32>
44+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[VAR_0_]]) {perm = [0, 1, 3, 2]} : (tensor<*xf32>) -> tensor<*xf32>
45+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32>
46+
// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_1_]], [[PARAM_1_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
47+
// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
48+
// CHECK: return [[VAR_4_]] : tensor<*xf32>
49+
// CHECK: }
50+
}
51+
52+
// -----
53+
54+
func.func @customop_fusedmatmul_onnxruntime_transB(%arg0: tensor<*xf32>, %arg1:tensor<*xf32>) -> tensor<*xf32> {
55+
%0 = "onnx.Transpose"(%arg1) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32>
56+
%1 = "onnx.Custom"(%arg0, %0) {alpha = 1.250000e-01 : f32, domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 0 : si64, transB = 1 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
57+
return %1: tensor<*xf32>
58+
59+
// CHECK-LABEL: func.func @customop_fusedmatmul_onnxruntime_transB
60+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
61+
// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32>
62+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[VAR_0_]]) {perm = [0, 1, 3, 2]} : (tensor<*xf32>) -> tensor<*xf32>
63+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32>
64+
// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
65+
// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
66+
// CHECK: return [[VAR_4_]] : tensor<*xf32>
67+
// CHECK: }
68+
}
69+
70+
// -----
71+
72+
// COM: Do not rewrite because the domain_name is not "com.microsoft"
73+
func.func @customop_fusedmatmul_not_rewrite_domain(%arg0: tensor<*xf32>, %arg1:tensor<*xf32>) -> tensor<*xf32> {
74+
%0 = "onnx.Transpose"(%arg1) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32>
75+
%1 = "onnx.Custom"(%arg0, %0) {alpha = 1.250000e-01 : f32, domain_name = "abc.xyz", function_name = "FusedMatMul", transA = 0 : si64, transB = 1 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
76+
return %1: tensor<*xf32>
77+
78+
// CHECK-LABEL: func.func @customop_fusedmatmul_not_rewrite_domain
79+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
80+
// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32>
81+
// CHECK: [[VAR_1_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[VAR_0_]]) {alpha = 1.250000e-01 : f32, domain_name = "abc.xyz", function_name = "FusedMatMul", transA = 0 : si64, transB = 1 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
82+
// CHECK: return [[VAR_1_]] : tensor<*xf32>
83+
// CHECK: }
84+
}
85+
86+
// -----
87+
88+
// COM: Do not rewrite because A is transposed but its rank is unknown.
89+
// COM: So, there is no information to generate a transpose op.
90+
func.func @customop_fusedmatmul_not_rewrite_unranked_transpose(%arg0: tensor<*xf32>, %arg1:tensor<*xf32>) -> tensor<*xf32> {
91+
%1 = "onnx.Custom"(%arg0, %arg1) {alpha = 1.250000e-01 : f32, domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 1 : si64, transB = 0 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
92+
return %1: tensor<*xf32>
93+
94+
// CHECK-LABEL: func.func @customop_fusedmatmul_not_rewrite_unranked_transpose
95+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
96+
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {alpha = 1.250000e-01 : f32, domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 1 : si64, transB = 0 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
97+
// CHECK: return [[VAR_0_]] : tensor<*xf32>
98+
// CHECK: }
99+
}
100+
101+
// -----
102+
103+
// COM: Do not rewrite because alpha is not given.
104+
func.func @customop_fusedmatmul_not_rewrite_no_alpha(%arg0: tensor<*xf32>, %arg1:tensor<*xf32>) -> tensor<*xf32> {
105+
%1 = "onnx.Custom"(%arg0, %arg1) {domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 0 : si64, transB = 0 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
106+
return %1: tensor<*xf32>
107+
108+
// CHECK-LABEL: func.func @customop_fusedmatmul_not_rewrite_no_alpha
109+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
110+
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {domain_name = "com.microsoft", function_name = "FusedMatMul", transA = 0 : si64, transB = 0 : si64} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
111+
// CHECK: return [[VAR_0_]] : tensor<*xf32>
112+
// CHECK: }
113+
}

0 commit comments

Comments
 (0)