Skip to content

Commit f002bbc

Browse files
committed
[TorchToArith] Implement conversion patterns for AtenSubOp, AtenMulOp, AtenDivOp.
Implement conversion patterns for `AtenSubOp`, `AtenMulOp`, `AtenDivOp`: Use unified template patterns `ConvertAtenBinaryScalarOp` to handle `AtenAddOp`, `AtenSubOp`, `AtenMulOp`; Use unified template patterns `ConvertAtenDivOp` to handle `AtenDivOp`, `AtenDivIntOp`, `AtenDivFloatOp`.
1 parent 42c3c29 commit f002bbc

File tree

2 files changed

+86
-18
lines changed

2 files changed

+86
-18
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,13 @@ class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
134134
} // namespace
135135

136136
namespace {
137-
class ConvertAtenDivIntOp : public OpConversionPattern<AtenDivIntOp> {
137+
template <typename AtenOp>
138+
class ConvertAtenDivOp : public OpConversionPattern<AtenOp> {
138139
public:
139-
using OpConversionPattern<AtenDivIntOp>::OpConversionPattern;
140+
using OpConversionPattern<AtenOp>::OpConversionPattern;
141+
using OpAdaptor = typename AtenOp::Adaptor;
140142
LogicalResult
141-
matchAndRewrite(AtenDivIntOp op,
142-
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
143+
matchAndRewrite(AtenOp op, OpAdaptor adaptor,
143144
ConversionPatternRewriter &rewriter) const override {
144145
Location loc = op.getLoc();
145146
Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(),
@@ -306,11 +307,13 @@ class ConvertAtenScalarArithOp : public OpConversionPattern<AtenOp> {
306307
} // namespace
307308

308309
namespace {
309-
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
310+
template <typename AtenOp, typename ArithFOp, typename ArithIOp>
311+
class ConvertAtenBinaryScalarOp : public OpConversionPattern<AtenOp> {
310312
public:
311-
using OpConversionPattern::OpConversionPattern;
313+
using OpConversionPattern<AtenOp>::OpConversionPattern;
314+
using OpAdaptor = typename AtenOp::Adaptor;
312315
LogicalResult
313-
matchAndRewrite(AtenAddOp op, OpAdaptor adaptor,
316+
matchAndRewrite(AtenOp op, OpAdaptor adaptor,
314317
ConversionPatternRewriter &rewriter) const override {
315318
Location loc = op.getLoc();
316319
Type resultType =
@@ -320,9 +323,9 @@ class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
320323
Value operandB =
321324
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
322325
if (isa<mlir::FloatType>(resultType)) {
323-
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
326+
rewriter.replaceOpWithNewOp<ArithFOp>(op, operandA, operandB);
324327
} else if (isa<mlir::IntegerType>(resultType)) {
325-
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
328+
rewriter.replaceOpWithNewOp<ArithIOp>(op, operandA, operandB);
326329
} else {
327330
return rewriter.notifyMatchFailure(
328331
op, "unimplemented: only support integer or float result type");
@@ -497,8 +500,17 @@ class ConvertTorchToArith
497500
patterns.add<ConvertAtenCastOp<AtenFloatScalarOp>>(typeConverter, context);
498501
patterns.add<ConvertAtenCastOp<AtenIntScalarOp>>(typeConverter, context);
499502

500-
target.addIllegalOp<AtenAddOp>();
501-
patterns.add<ConvertAtenAddOp>(typeConverter, context);
503+
target.addIllegalOp<AtenAddOp, AtenSubOp, AtenMulOp>();
504+
patterns.add<
505+
ConvertAtenBinaryScalarOp<AtenAddOp, arith::AddFOp, arith::AddIOp>>(
506+
typeConverter, context);
507+
patterns.add<
508+
ConvertAtenBinaryScalarOp<AtenSubOp, arith::SubFOp, arith::SubIOp>>(
509+
typeConverter, context);
510+
patterns.add<
511+
ConvertAtenBinaryScalarOp<AtenMulOp, arith::MulFOp, arith::MulIOp>>(
512+
typeConverter, context);
513+
502514
target.addIllegalOp<AtenNegIntOp>();
503515
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
504516
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
@@ -523,11 +535,12 @@ class ConvertTorchToArith
523535
typeConverter, context);
524536
patterns.add<ConvertAtenBinaryOp<AtenMulFloatOp, arith::MulFOp>>(
525537
typeConverter, context);
526-
target.addIllegalOp<AtenDivIntOp>();
527-
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
528-
target.addIllegalOp<AtenDivFloatOp>();
529-
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
530-
typeConverter, context);
538+
539+
target.addIllegalOp<AtenDivOp, AtenDivIntOp, AtenDivFloatOp>();
540+
patterns.add<ConvertAtenDivOp<AtenDivOp>>(typeConverter, context);
541+
patterns.add<ConvertAtenDivOp<AtenDivIntOp>>(typeConverter, context);
542+
patterns.add<ConvertAtenDivOp<AtenDivFloatOp>>(typeConverter, context);
543+
531544
target.addIllegalOp<AtenFloordivIntOp>();
532545
patterns.add<ConvertAtenBinaryOp<AtenFloordivIntOp, arith::FloorDivSIOp>>(
533546
typeConverter, context);

test/Conversion/TorchToArith/basic.mlir

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !
269269
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
270270
// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
271271
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
272-
// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
273-
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
272+
// CHECK: %[[DIV:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
273+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[DIV:.*]]
274274
// CHECK: return %[[OUT:.*]] : !torch.float
275275
func.func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
276276
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
@@ -407,3 +407,58 @@ func.func @torch.aten.Int.Scalar(%arg0: !torch.float) -> !torch.int {
407407
%0 = torch.aten.Int.Scalar %arg0 : !torch.float -> !torch.int
408408
return %0 : !torch.int
409409
}
410+
411+
// CHECK-LABEL: func.func @torch.aten.add(
412+
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
413+
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.int {
414+
// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
415+
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
416+
// CHECK: %[[LHS_FPTOSI:.*]] = arith.fptosi %[[LHS_F64]] : f64 to i64
417+
// CHECK: %[[RHS_FPTOSI:.*]] = arith.fptosi %[[RHS_F64]] : f64 to i64
418+
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS_FPTOSI]], %[[RHS_FPTOSI]] : i64
419+
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[ADD]]
420+
// CHECK: return %[[OUT]] : !torch.int
421+
func.func @torch.aten.add(%arg0: !torch.float, %arg1: !torch.float) -> !torch.int {
422+
%0 = torch.aten.add %arg0, %arg1 : !torch.float, !torch.float -> !torch.int
423+
return %0 : !torch.int
424+
}
425+
426+
// CHECK-LABEL: func.func @torch.aten.sub(
427+
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
428+
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
429+
// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
430+
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
431+
// CHECK: %[[SUB:.*]] = arith.subf %[[LHS_F64]], %[[RHS_F64]] : f64
432+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB]]
433+
// CHECK: return %[[OUT]] : !torch.float
434+
func.func @torch.aten.sub(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
435+
%0 = torch.aten.sub %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
436+
return %0 : !torch.float
437+
}
438+
439+
// CHECK-LABEL: func.func @torch.aten.mul(
440+
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
441+
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
442+
// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
443+
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
444+
// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64
445+
// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64
446+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]]
447+
// CHECK: return %[[OUT]] : !torch.float
448+
func.func @torch.aten.mul(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float {
449+
%0 = torch.aten.mul %arg0, %arg1 : !torch.int, !torch.float -> !torch.float
450+
return %0 : !torch.float
451+
}
452+
453+
// CHECK-LABEL: func.func @torch.aten.div(
454+
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
455+
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
456+
// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
457+
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
458+
// CHECK: %[[DIV:.*]] = arith.divf %[[LHS_F64]], %[[RHS_F64]] : f64
459+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[DIV]]
460+
// CHECK: return %[[OUT]] : !torch.float
461+
func.func @torch.aten.div(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
462+
%0 = torch.aten.div %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
463+
return %0 : !torch.float
464+
}

0 commit comments

Comments
 (0)