Skip to content

Commit 38b16b0

Browse files
Use DimAnalysis in lowering MatMul (#2195)
Signed-off-by: Tung D. Le <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent dea431d commit 38b16b0

File tree

4 files changed

+53
-15
lines changed

4 files changed

+53
-15
lines changed

src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
194194
populateLoweringONNXReductionOpPattern(patterns, typeConverter, ctx);
195195
populateLoweringONNXSoftmaxOpPattern(patterns, typeConverter, ctx);
196196
populateLoweringONNXTopKOpPattern(patterns, typeConverter, ctx);
197-
populateLoweringONNXMatMulOpPattern(patterns, typeConverter, ctx, enableTiling);
197+
populateLoweringONNXMatMulOpPattern(patterns, typeConverter, ctx, dimAnalysis, enableTiling);
198198
populateLoweringONNXMatMulIntegerOpPattern(patterns, typeConverter, ctx);
199199
populateLoweringONNXRandomNormalOpPattern(patterns, typeConverter, ctx);
200200
populateLoweringONNXRandomNormalLikeOpPattern(patterns, typeConverter, ctx);

src/Conversion/ONNXToKrnl/Math/MatMul.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ using namespace mlir;
2929
namespace onnx_mlir {
3030

3131
struct ONNXMatMulOpLowering : public OpConversionPattern<ONNXMatMulOp> {
32-
ONNXMatMulOpLowering(
33-
TypeConverter &typeConverter, MLIRContext *ctx, bool enableTiling)
34-
: OpConversionPattern(typeConverter, ctx), enableTiling(enableTiling) {}
32+
ONNXMatMulOpLowering(TypeConverter &typeConverter, MLIRContext *ctx,
33+
DimAnalysis *dimAnalysis, bool enableTiling)
34+
: OpConversionPattern(typeConverter, ctx), dimAnalysis(dimAnalysis),
35+
enableTiling(enableTiling) {}
36+
DimAnalysis *dimAnalysis;
3537
bool enableTiling;
3638
// Handle the generic cases, including when there are broadcasts.
3739
void replaceGenericMatmul(ONNXMatMulOpAdaptor &operandAdaptor,
@@ -433,20 +435,20 @@ struct ONNXMatMulOpLowering : public OpConversionPattern<ONNXMatMulOp> {
433435
/*broadcasting B*/ false,
434436
/*same static broadcast*/ false, alloc, zero, rewriter, loc);
435437
} else {
436-
// Test if have A and B have identical static broadcast shapes.
437-
bool sameStaticBroadcast = (enableTiling && aRank > 2 && aRank == bRank);
438-
if (sameStaticBroadcast) {
439-
auto aShape = A.getType().cast<MemRefType>().getShape();
440-
auto bShape = B.getType().cast<MemRefType>().getShape();
438+
// Test if have A and B have identical batch size.
439+
bool sameBatchsize = (enableTiling && aRank > 2 && aRank == bRank);
440+
if (sameBatchsize) {
441441
for (int i = 0; i < aRank - 2; ++i)
442-
if (aShape[i] == ShapedType::kDynamic || aShape[i] != bShape[i]) {
443-
sameStaticBroadcast = false;
442+
// Note that using A and B from the operation instead of adaptor.
443+
// It's because DimAnalysis has been done on operations.
444+
if (!dimAnalysis->sameDim(matMulOp.getA(), i, matMulOp.getB(), i)) {
445+
sameBatchsize = false;
444446
break;
445447
}
446448
}
447449
// While there is technically no broadcasting there, we can use nearly the
448450
// same logic as in replace2x2Matmul2dBroadcasting. So reuse that code.
449-
if (sameStaticBroadcast) {
451+
if (sameBatchsize) {
450452
assert(cRank == aRank && "expected IxK * *xKxJ = *xIxJ result");
451453
replace2x2Matmul2dBroadcasting(adaptor, elementType, shapeHelper,
452454
/*broadcasting B*/ true,
@@ -463,8 +465,10 @@ struct ONNXMatMulOpLowering : public OpConversionPattern<ONNXMatMulOp> {
463465
}; // namespace onnx_mlir
464466

465467
void populateLoweringONNXMatMulOpPattern(RewritePatternSet &patterns,
466-
TypeConverter &typeConverter, MLIRContext *ctx, bool enableTiling) {
467-
patterns.insert<ONNXMatMulOpLowering>(typeConverter, ctx, enableTiling);
468+
TypeConverter &typeConverter, MLIRContext *ctx, DimAnalysis *dimAnalysis,
469+
bool enableTiling) {
470+
patterns.insert<ONNXMatMulOpLowering>(
471+
typeConverter, ctx, dimAnalysis, enableTiling);
468472
}
469473

470474
} // namespace onnx_mlir

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ void populateLoweringONNXHardmaxOpPattern(
308308
void populateLoweringONNXLRNOpPattern(
309309
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
310310
void populateLoweringONNXMatMulOpPattern(mlir::RewritePatternSet &,
311-
mlir::TypeConverter &, mlir::MLIRContext *, bool enableTiling);
311+
mlir::TypeConverter &, mlir::MLIRContext *, DimAnalysis *,
312+
bool enableTiling);
312313
void populateLoweringONNXMatMulIntegerOpPattern(
313314
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
314315
void populateLoweringONNXRandomNormalOpPattern(

test/mlir/onnx/onnx_lowering_with_canonicalize_O3.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,39 @@ func.func private @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
15911591

15921592
// -----
15931593

1594+
// N-D x N-D
1595+
func.func private @test_matmul8(%arg0 : tensor<?x10x10xf32>) -> tensor<*xf32> {
1596+
%0 ="onnx.MatMul"(%arg0, %arg0) : (tensor<?x10x10xf32>, tensor<?x10x10xf32>) -> tensor<*xf32>
1597+
"func.return"(%0) : (tensor<*xf32>) -> ()
1598+
1599+
// mlir2FileCheck.py -a'["A", "B"]' -n'{"1": "RES"}'
1600+
// CHECK-LABEL: func.func private @test_matmul8
1601+
// CHECK-SAME: ([[A_:%.+]]: memref<?x10x10xf32>) -> memref<?x10x10xf32> {
1602+
// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32
1603+
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index
1604+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
1605+
// CHECK: [[VAR_dim_:%.+]] = memref.dim [[A_]], [[CST_0_]] : memref<?x10x10xf32>
1606+
// CHECK: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref<?x10x10xf32>
1607+
// CHECK: krnl.memset [[RES_]], [[CST_0_dot_000000_]] : memref<?x10x10xf32>
1608+
// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1
1609+
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[B_:%.+]] = [[CST_0_]] to [[VAR_dim_]]){
1610+
// CHECK-DAG: [[RES_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
1611+
// CHECK-DAG: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3
1612+
// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]]#0 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
1613+
// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]]#1 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
1614+
// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_1_]]#2 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
1615+
// CHECK: krnl.permute([[BLOCK_TILE__0_]], [[BLOCK_IN__0_]], [[BLOCK_TILE__0_]]_0, [[BLOCK_IN__0_]]_1, [[BLOCK_TILE__0_]]_2, [[BLOCK_IN__0_]]_3) [0, 3, 1, 4, 2, 5] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
1616+
// CHECK: krnl.iterate([[BLOCK_TILE__0_]], [[BLOCK_TILE__0_]]_0, [[BLOCK_TILE__0_]]_2) with ([[LOOP_1_]]#0 -> [[I_0_:%.+]] = [[CST_0_]] to [[CST_10_]], [[LOOP_1_]]#1 -> [[I_1_:%.+]] = [[CST_0_]] to [[CST_10_]], [[LOOP_1_]]#2 -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_10_]]){
1617+
// CHECK: [[VAR_3_:%.+]]:3 = krnl.get_induction_var_value([[BLOCK_TILE__0_]], [[BLOCK_TILE__0_]]_0, [[BLOCK_TILE__0_]]_2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
1618+
// CHECK: krnl.matmul [[A_]]{{.}}[[RES_1_]], [[CST_0_]], [[CST_0_]]{{.}}, [[A_]]{{.}}[[RES_1_]], [[CST_0_]], [[CST_0_]]{{.}}, [[RES_]]{{.}}[[RES_1_]], [[CST_0_]], [[CST_0_]]{{.}}, ([[BLOCK_IN__0_]], [[BLOCK_IN__0_]]_1, [[BLOCK_IN__0_]]_3), ([[VAR_3_]]#0, [[VAR_3_]]#1, [[VAR_3_]]#2), ([[CST_10_]], [[CST_10_]], [[CST_10_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 8]} : memref<?x10x10xf32>, memref<?x10x10xf32>, memref<?x10x10xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
1619+
// CHECK: }
1620+
// CHECK: }
1621+
// CHECK: return [[RES_]] : memref<?x10x10xf32>
1622+
// CHECK: }
1623+
}
1624+
1625+
// -----
1626+
15941627
func.func private @test_pool_unknown_dimensions(%arg0 : tensor<1x3x?x32xf32>) -> tensor<*xf32> {
15951628
%0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x?x32xf32>) -> tensor<*xf32>
15961629
"func.return"(%0) : (tensor<*xf32>) -> ()

0 commit comments

Comments
 (0)