@@ -134,12 +134,13 @@ class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
134134} // namespace
135135
136136namespace {
137- class ConvertAtenDivIntOp : public OpConversionPattern <AtenDivIntOp> {
137+ template <typename AtenOp>
138+ class ConvertAtenDivOp : public OpConversionPattern <AtenOp> {
138139public:
139- using OpConversionPattern<AtenDivIntOp>::OpConversionPattern;
140+ using OpConversionPattern<AtenOp>::OpConversionPattern;
141+ using OpAdaptor = typename AtenOp::Adaptor;
140142 LogicalResult
141- matchAndRewrite (AtenDivIntOp op,
142- typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
143+ matchAndRewrite (AtenOp op, OpAdaptor adaptor,
143144 ConversionPatternRewriter &rewriter) const override {
144145 Location loc = op.getLoc ();
145146 Value a = convertScalarToDtype (rewriter, loc, adaptor.getA (),
@@ -306,11 +307,13 @@ class ConvertAtenScalarArithOp : public OpConversionPattern<AtenOp> {
306307} // namespace
307308
308309namespace {
309- class ConvertAtenAddOp : public OpConversionPattern <AtenAddOp> {
310+ template <typename AtenOp, typename ArithFOp, typename ArithIOp>
311+ class ConvertAtenBinaryScalarOp : public OpConversionPattern <AtenOp> {
310312public:
311- using OpConversionPattern::OpConversionPattern;
313+ using OpConversionPattern<AtenOp>::OpConversionPattern;
314+ using OpAdaptor = typename AtenOp::Adaptor;
312315 LogicalResult
313- matchAndRewrite (AtenAddOp op, OpAdaptor adaptor,
316+ matchAndRewrite (AtenOp op, OpAdaptor adaptor,
314317 ConversionPatternRewriter &rewriter) const override {
315318 Location loc = op.getLoc ();
316319 Type resultType =
@@ -320,9 +323,9 @@ class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
320323 Value operandB =
321324 convertScalarToDtype (rewriter, loc, adaptor.getB (), resultType);
322325 if (isa<mlir::FloatType>(resultType)) {
323- rewriter.replaceOpWithNewOp <arith::AddFOp >(op, operandA, operandB);
326+ rewriter.replaceOpWithNewOp <ArithFOp >(op, operandA, operandB);
324327 } else if (isa<mlir::IntegerType>(resultType)) {
325- rewriter.replaceOpWithNewOp <arith::AddIOp >(op, operandA, operandB);
328+ rewriter.replaceOpWithNewOp <ArithIOp >(op, operandA, operandB);
326329 } else {
327330 return rewriter.notifyMatchFailure (
328331 op, " unimplemented: only support integer or float result type" );
@@ -497,8 +500,17 @@ class ConvertTorchToArith
497500 patterns.add <ConvertAtenCastOp<AtenFloatScalarOp>>(typeConverter, context);
498501 patterns.add <ConvertAtenCastOp<AtenIntScalarOp>>(typeConverter, context);
499502
500- target.addIllegalOp <AtenAddOp>();
501- patterns.add <ConvertAtenAddOp>(typeConverter, context);
503+ target.addIllegalOp <AtenAddOp, AtenSubOp, AtenMulOp>();
504+ patterns.add <
505+ ConvertAtenBinaryScalarOp<AtenAddOp, airth::AddFOp, airth::AddIOp>>(
506+ typeConverter, context);
507+ patterns.add <
508+ ConvertAtenBinaryScalarOp<AtenSubOp, airth::SubFOp, airth::SubIOp>>(
509+ typeConverter, context);
510+ patterns.add <
511+ ConvertAtenBinaryScalarOp<AtenMulOp, airth::MulFOp, airth::MulIOp>>(
512+ typeConverter, context);
513+
502514 target.addIllegalOp <AtenNegIntOp>();
503515 patterns.add <ConvertAtenNegIntOp>(typeConverter, context);
504516 target.addIllegalOp <AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
@@ -523,11 +535,12 @@ class ConvertTorchToArith
523535 typeConverter, context);
524536 patterns.add <ConvertAtenBinaryOp<AtenMulFloatOp, arith::MulFOp>>(
525537 typeConverter, context);
526- target.addIllegalOp <AtenDivIntOp>();
527- patterns.add <ConvertAtenDivIntOp>(typeConverter, context);
528- target.addIllegalOp <AtenDivFloatOp>();
529- patterns.add <ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
530- typeConverter, context);
538+
539+ target.addIllegalOp <AtenDivOp, AtenDivIntOp, AtenDivFloatOp>();
540+ patterns.add <ConvertAtenDivOp<AtenDivOp>>(typeConverter, context);
541+ patterns.add <ConvertAtenDivOp<AtenDivIntOp>>(typeConverter, context);
542+ patterns.add <ConvertAtenDivOp<AtenDivFloatOp>>(typeConverter, context);
543+
531544 target.addIllegalOp <AtenFloordivIntOp>();
532545 patterns.add <ConvertAtenBinaryOp<AtenFloordivIntOp, arith::FloorDivSIOp>>(
533546 typeConverter, context);
0 commit comments