|
15 | 15 | // implement shape inference for the decomposed operation. Hence, it is expected |
16 | 16 | // that there is no knowledge about tensor shape at this point. |
17 | 17 | // |
| 18 | +// TODO: This file is quite busy as the number of decomposing op is increasing. |
| 19 | +// It is better to move decomposition of each operation into a separate file. |
| 20 | +// |
18 | 21 | //===----------------------------------------------------------------------===// |
19 | 22 |
|
20 | 23 | #include "mlir/IR/Matchers.h" |
|
26 | 29 | #include "src/Dialect/ONNX/ONNXOps.hpp" |
27 | 30 | #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" |
28 | 31 | #include "src/Pass/Passes.hpp" |
| 32 | +#include "src/Support/TypeUtilities.hpp" |
29 | 33 | #include "src/Transform/ONNX/DecomposeEinsum.hpp" |
30 | 34 |
|
31 | 35 | using namespace mlir; |
@@ -577,6 +581,175 @@ struct ConcatFusePattern : public ConversionPattern { |
577 | 581 | } |
578 | 582 | }; |
579 | 583 |
|
| 584 | +// Decompose the custom op FusedMatMul that is produced by ONNXRuntime. |
| 585 | +// According to FusedMatMul specification, it is the result of fusing MatMul and |
| 586 | +// Transpose: |
| 587 | +// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul |
| 588 | +// |
| 589 | +// To decompose FusedMatMul, we need to know ranks of inputs A and B, so that |
| 590 | +// we can emit Transpose operations. But, in general, we have no information |
| 591 | +// about the ranks of A and B. |
| 592 | +// |
| 593 | +// The rewriting here only applies to a situation in which the transposed input |
| 594 | +// comes from another Transpose that we have rank information via looking at |
| 595 | +// `perm` // attribute. For example, if `transA = 1`, A must be from a Transpose |
| 596 | +// to determine the rank of A. |
| 597 | +// |
| 598 | +// Example of onnx.Custom: |
| 599 | +// ``` |
| 600 | +// "onnx.Custom"(%0, %1) {alpha = 1.250000e-01 : f32, |
| 601 | +// domain_name = "com.microsoft", |
| 602 | +// function_name = "FusedMatMul", |
| 603 | +// transA = 0 : si64, transB = 1 : si64} : |
| 604 | +// (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> |
| 605 | +// ``` |
| 606 | +struct CustomOpFuseMatMulPattern : public OpConversionPattern<ONNXCustomOp> { |
| 607 | + CustomOpFuseMatMulPattern(MLIRContext *context) |
| 608 | + : OpConversionPattern(context) {} |
| 609 | + LogicalResult matchAndRewrite(ONNXCustomOp customOp, |
| 610 | + ONNXCustomOp::Adaptor adaptor, |
| 611 | + ConversionPatternRewriter &rewriter) const final { |
| 612 | + using namespace onnx_mlir; |
| 613 | + Location loc = customOp.getLoc(); |
| 614 | + |
| 615 | + // Match |
| 616 | + FloatAttr alphaAttr; |
| 617 | + int64_t rankA, rankB; |
| 618 | + if (!isCustomOpFusedMatMulMatched(customOp, alphaAttr, rankA, rankB)) |
| 619 | + return failure(); |
| 620 | + |
| 621 | + // Rewrite ONNXCustomOp {alpha} (A, B) into `Mul(alpha, MatMul(A, B)` |
| 622 | + Value A = customOp.getOperands()[0]; |
| 623 | + Value B = customOp.getOperands()[1]; |
| 624 | + |
| 625 | + MultiDialectBuilder<OnnxBuilder> create(rewriter, loc); |
| 626 | + Type resType = customOp.getResult(0).getType(); |
| 627 | + Type elementType = onnx_mlir::getElementType(resType); |
| 628 | + UnrankedTensorType unrankedType = UnrankedTensorType::get(elementType); |
| 629 | + |
| 630 | + Value matmulA = A; |
| 631 | + Value matmulB = B; |
| 632 | + // Transpose A if transA. |
| 633 | + if (rankA != -1) { |
| 634 | + // Prepare permutation attribute. |
| 635 | + SmallVector<int64_t, 4> indices; |
| 636 | + for (int64_t i = 0; i < rankA - 2; ++i) |
| 637 | + indices.emplace_back(i); |
| 638 | + // Permute the last two dimensions. |
| 639 | + indices.emplace_back(rankA - 1); |
| 640 | + indices.emplace_back(rankA - 2); |
| 641 | + ArrayAttr permAttr = rewriter.getI64ArrayAttr(llvm::ArrayRef(indices)); |
| 642 | + matmulA = create.onnx.transpose(unrankedType, A, permAttr); |
| 643 | + } |
| 644 | + // Transpose B if transB. |
| 645 | + if (rankB != -1) { |
| 646 | + // Prepare permutation attribute. |
| 647 | + SmallVector<int64_t, 4> indices; |
| 648 | + for (int64_t i = 0; i < rankB - 2; ++i) |
| 649 | + indices.emplace_back(i); |
| 650 | + // Permute the last two dimensions. |
| 651 | + indices.emplace_back(rankB - 1); |
| 652 | + indices.emplace_back(rankB - 2); |
| 653 | + ArrayAttr permAttr = rewriter.getI64ArrayAttr(llvm::ArrayRef(indices)); |
| 654 | + matmulB = create.onnx.transpose(unrankedType, B, permAttr); |
| 655 | + } |
| 656 | + // alpha |
| 657 | + DenseElementsAttr alphaDenseAttr = |
| 658 | + onnx_mlir::createDenseElementsAttrFromFloatAttr( |
| 659 | + rewriter, elementType, alphaAttr); |
| 660 | + Value alpha = create.onnx.constant(alphaDenseAttr); |
| 661 | + |
| 662 | + Value res = create.onnx.matmul(resType, matmulA, matmulB); |
| 663 | + res = create.onnx.mul(alpha, res); |
| 664 | + |
| 665 | + rewriter.replaceOp(customOp, res); |
| 666 | + return success(); |
| 667 | + } |
| 668 | + |
| 669 | +public: |
| 670 | + static bool isCustomOpFusedMatMulMatched(ONNXCustomOp customOp, |
| 671 | + FloatAttr &alphaAttr, int64_t &rankA, int64_t &rankB) { |
| 672 | + Operation *genericOp = customOp.getOperation(); |
| 673 | + // CustomOp has two operands. |
| 674 | + if (customOp.getNumOperands() != 2) |
| 675 | + return false; |
| 676 | + Value A = genericOp->getOperands()[0]; |
| 677 | + Value B = genericOp->getOperands()[1]; |
| 678 | + |
| 679 | + // function_name is FusedMatMul. |
| 680 | + StringRef funcName = customOp.getFunctionName(); |
| 681 | + if (!funcName.equals_insensitive("FusedMatMul")) |
| 682 | + return false; |
| 683 | + |
| 684 | + // domain_name exists and is "com.microsoft"; |
| 685 | + StringAttr domAttr = genericOp->getAttrOfType<StringAttr>("domain_name"); |
| 686 | + if (!domAttr) |
| 687 | + return false; |
| 688 | + if (!domAttr.getValue().equals_insensitive("com.microsoft")) |
| 689 | + return false; |
| 690 | + |
| 691 | + // transA and transB exist. |
| 692 | + IntegerAttr transA = genericOp->getAttrOfType<IntegerAttr>("transA"); |
| 693 | + IntegerAttr transB = genericOp->getAttrOfType<IntegerAttr>("transB"); |
| 694 | + if (!transA || !transB) |
| 695 | + return false; |
| 696 | + bool isTransA = (transA.getValue().getSExtValue() == 1); |
| 697 | + bool isTransB = (transB.getValue().getSExtValue() == 1); |
| 698 | + |
| 699 | + // If transA=true, we have to know A's rank to generate ONNXTransposeOp for |
| 700 | + // A. In a good condition, A is ranked then its rank is avilable. |
| 701 | + // |
| 702 | + // If A is unranked, we hope that A is a result of another ONNXTransposeOp |
| 703 | + // whose permutation is available and can be used to infer the rank of A. |
| 704 | + // For example, |
| 705 | + // %A = "onnx.Transpose"(%0) {perm = [0, 2, 1, 3]} : |
| 706 | + // (tensor<*xf32>) -> tensor<*xf32> |
| 707 | + // A must have rank 4 as perm has 4 indices. |
| 708 | + if (isTransA) { |
| 709 | + if (onnx_mlir::hasShapeAndRank(A)) { |
| 710 | + rankA = A.getType().cast<ShapedType>().getRank(); |
| 711 | + } else { |
| 712 | + if (isa<BlockArgument>(A)) |
| 713 | + return false; |
| 714 | + if (auto transOp = dyn_cast<ONNXTransposeOp>(A.getDefiningOp())) { |
| 715 | + if (transOp.getPermAttr()) |
| 716 | + rankA = transOp.getPermAttr().size(); |
| 717 | + else |
| 718 | + return false; |
| 719 | + } else |
| 720 | + // Cannot determine the rank of A. |
| 721 | + return false; |
| 722 | + } |
| 723 | + } else |
| 724 | + rankA = -1; |
| 725 | + if (isTransB) { |
| 726 | + if (onnx_mlir::hasShapeAndRank(B)) { |
| 727 | + rankB = B.getType().cast<ShapedType>().getRank(); |
| 728 | + } else { |
| 729 | + if (isa<BlockArgument>(B)) |
| 730 | + return false; |
| 731 | + if (auto transOp = dyn_cast<ONNXTransposeOp>(B.getDefiningOp())) { |
| 732 | + if (transOp.getPermAttr()) |
| 733 | + rankB = transOp.getPermAttr().size(); |
| 734 | + else |
| 735 | + return false; |
| 736 | + } else |
| 737 | + // Cannot determine the rank of B. |
| 738 | + return false; |
| 739 | + } |
| 740 | + } else |
| 741 | + rankB = -1; |
| 742 | + |
| 743 | + // Get alpha. |
| 744 | + alphaAttr = genericOp->getAttrOfType<FloatAttr>("alpha"); |
| 745 | + if (!alphaAttr) |
| 746 | + return false; |
| 747 | + |
| 748 | + // CustomOp is in a good form to rewrite. |
| 749 | + return true; |
| 750 | + } |
| 751 | +}; |
| 752 | + |
580 | 753 | struct DecomposeONNXToONNXPass |
581 | 754 | : public PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>> { |
582 | 755 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeONNXToONNXPass) |
@@ -640,6 +813,14 @@ void DecomposeONNXToONNXPass::runOnOperation() { |
640 | 813 | ONNXTransposeOp transposeOp = NULL; |
641 | 814 | return !isConcatFuseMatched(op, shapeOp, transposeOp); |
642 | 815 | }); |
| 816 | + // Decompose CustomOp FusedMatMul introduced by onnxruntime: |
| 817 | + // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul |
| 818 | + target.addDynamicallyLegalOp<ONNXCustomOp>([](ONNXCustomOp op) { |
| 819 | + int64_t rankA, rankB; |
| 820 | + FloatAttr alpha; |
| 821 | + return !CustomOpFuseMatMulPattern::isCustomOpFusedMatMulMatched( |
| 822 | + op, alpha, rankA, rankB); |
| 823 | + }); |
643 | 824 |
|
644 | 825 | #ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE |
645 | 826 | #ifdef ONNX_MLIR_ENABLE_MHLO |
@@ -669,6 +850,9 @@ void DecomposeONNXToONNXPass::runOnOperation() { |
669 | 850 | populateWithGenerated(patterns); |
670 | 851 | patterns.insert<onnx_mlir::DecomposeEinsumPattern>(&getContext()); |
671 | 852 | patterns.insert<ConcatFusePattern>(&getContext()); |
| 853 | + // Decompose CustomOp FusedMatMul introduced by onnxruntime: |
| 854 | + // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul |
| 855 | + patterns.insert<CustomOpFuseMatMulPattern>(&getContext()); |
672 | 856 |
|
673 | 857 | #ifdef ONNX_MLIR_ENABLE_MHLO |
674 | 858 | if (this->target == "mhlo") { |
|
0 commit comments