Skip to content

Commit 579201a

Browse files
committed
[TorchToArith] Implement conversion patterns for AtenSubOp, AtenMulOp, AtenDivOp.
Implement conversion patterns for `AtenSubOp`, `AtenMulOp`, `AtenDivOp`: Use unified template patterns `ConvertAtenBinaryScalarOp` to handle `AtenAddOp`, `AtenSubOp`, `AtenMulOp`; Use unified template patterns `ConvertAtenDivOp` to handle `AtenDivOp`, `AtenDivIntOp`, `AtenDivFloatOp`.
1 parent 42c3c29 commit 579201a

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,13 @@ class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
134134
} // namespace
135135

136136
namespace {
137-
class ConvertAtenDivIntOp : public OpConversionPattern<AtenDivIntOp> {
137+
template <typename AtenOp>
138+
class ConvertAtenDivOp : public OpConversionPattern<AtenOp> {
138139
public:
139-
using OpConversionPattern<AtenDivIntOp>::OpConversionPattern;
140+
using OpConversionPattern<AtenOp>::OpConversionPattern;
141+
using OpAdaptor = typename AtenOp::Adaptor;
140142
LogicalResult
141-
matchAndRewrite(AtenDivIntOp op,
142-
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
143+
matchAndRewrite(AtenOp op, OpAdaptor adaptor,
143144
ConversionPatternRewriter &rewriter) const override {
144145
Location loc = op.getLoc();
145146
Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(),
@@ -306,11 +307,13 @@ class ConvertAtenScalarArithOp : public OpConversionPattern<AtenOp> {
306307
} // namespace
307308

308309
namespace {
309-
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
310+
template <typename AtenOp, typename ArithFOp, typename ArithIOp>
311+
class ConvertAtenBinaryScalarOp : public OpConversionPattern<AtenOp> {
310312
public:
311-
using OpConversionPattern::OpConversionPattern;
313+
using OpConversionPattern<AtenOp>::OpConversionPattern;
314+
using OpAdaptor = typename AtenOp::Adaptor;
312315
LogicalResult
313-
matchAndRewrite(AtenAddOp op, OpAdaptor adaptor,
316+
matchAndRewrite(AtenOp op, OpAdaptor adaptor,
314317
ConversionPatternRewriter &rewriter) const override {
315318
Location loc = op.getLoc();
316319
Type resultType =
@@ -320,9 +323,9 @@ class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
320323
Value operandB =
321324
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
322325
if (isa<mlir::FloatType>(resultType)) {
323-
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
326+
rewriter.replaceOpWithNewOp<ArithFOp>(op, operandA, operandB);
324327
} else if (isa<mlir::IntegerType>(resultType)) {
325-
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
328+
rewriter.replaceOpWithNewOp<ArithIOp>(op, operandA, operandB);
326329
} else {
327330
return rewriter.notifyMatchFailure(
328331
op, "unimplemented: only support integer or float result type");
@@ -497,8 +500,17 @@ class ConvertTorchToArith
497500
patterns.add<ConvertAtenCastOp<AtenFloatScalarOp>>(typeConverter, context);
498501
patterns.add<ConvertAtenCastOp<AtenIntScalarOp>>(typeConverter, context);
499502

500-
target.addIllegalOp<AtenAddOp>();
501-
patterns.add<ConvertAtenAddOp>(typeConverter, context);
503+
target.addIllegalOp<AtenAddOp, AtenSubOp, AtenMulOp>();
504+
patterns.add<
505+
ConvertAtenBinaryScalarOp<AtenAddOp, airth::AddFOp, airth::AddIOp>>(
506+
typeConverter, context);
507+
patterns.add<
508+
ConvertAtenBinaryScalarOp<AtenSubOp, airth::SubFOp, airth::SubIOp>>(
509+
typeConverter, context);
510+
patterns.add<
511+
ConvertAtenBinaryScalarOp<AtenMulOp, airth::MulFOp, airth::MulIOp>>(
512+
typeConverter, context);
513+
502514
target.addIllegalOp<AtenNegIntOp>();
503515
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
504516
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
@@ -523,11 +535,12 @@ class ConvertTorchToArith
523535
typeConverter, context);
524536
patterns.add<ConvertAtenBinaryOp<AtenMulFloatOp, arith::MulFOp>>(
525537
typeConverter, context);
526-
target.addIllegalOp<AtenDivIntOp>();
527-
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
528-
target.addIllegalOp<AtenDivFloatOp>();
529-
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
530-
typeConverter, context);
538+
539+
target.addIllegalOp<AtenDivOp, AtenDivIntOp, AtenDivFloatOp>();
540+
patterns.add<ConvertAtenDivOp<AtenDivOp>>(typeConverter, context);
541+
patterns.add<ConvertAtenDivOp<AtenDivIntOp>>(typeConverter, context);
542+
patterns.add<ConvertAtenDivOp<AtenDivFloatOp>>(typeConverter, context);
543+
531544
target.addIllegalOp<AtenFloordivIntOp>();
532545
patterns.add<ConvertAtenBinaryOp<AtenFloordivIntOp, arith::FloorDivSIOp>>(
533546
typeConverter, context);

0 commit comments

Comments
 (0)