@@ -5405,69 +5405,45 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
54055405 dyn_cast<TensorType>(getTypeConverter ()->convertType (op.getType ()));
54065406 auto outElemTy = outType.getElementType ();
54075407
5408- int64_t minInt, maxInt;
5409- double minFloat, maxFloat;
5410- bool isMinNotNone = false ;
5411- bool isMaxNotNone = false ;
5412-
5413- auto isMinInt = matchPattern (op.getMin (), m_TorchConstantInt (&minInt));
5414- auto isMinFloat = matchPattern (op.getMin (), m_TorchConstantFloat (&minFloat));
5415- if (isMinInt) {
5416- minFloat = static_cast <float >(minInt);
5417- isMinNotNone = true ;
5418- } else if (isMinFloat) {
5419- minInt = static_cast <int64_t >(minFloat);
5420- isMinNotNone = true ;
5421- } else {
5422- if (succeeded (checkNotNone (rewriter, op, op.getMin ())))
5408+ std::optional<int64_t > minInt;
5409+ std::optional<double > minFloat;
5410+ {
5411+ int64_t minIntVal;
5412+ double minFloatVal;
5413+ if (matchPattern (op.getMin (), m_TorchConstantInt (&minIntVal))) {
5414+ minInt = minIntVal;
5415+ minFloat = static_cast <double >(minIntVal);
5416+ } else if (matchPattern (op.getMin (), m_TorchConstantFloat (&minFloatVal))) {
5417+ minFloat = minFloatVal;
5418+ minInt = static_cast <int64_t >(minFloatVal);
5419+ } else if (succeeded (checkNotNone (rewriter, op, op.getMin ()))) {
54235420 return rewriter.notifyMatchFailure (op,
54245421 " min attr should be a torch constant" );
5422+ }
54255423 }
54265424
5427- auto isMaxInt = matchPattern (op.getMax (), m_TorchConstantInt (&maxInt));
5428- auto isMaxFloat = matchPattern (op.getMax (), m_TorchConstantFloat (&maxFloat));
5429- if (isMaxInt) {
5430- maxFloat = static_cast <float >(maxInt);
5431- isMaxNotNone = true ;
5432- } else if (isMaxFloat) {
5433- maxInt = static_cast <int64_t >(maxFloat);
5434- isMaxNotNone = true ;
5435- } else {
5436- if (succeeded (checkNotNone (rewriter, op, op.getMax ())))
5425+ std::optional<int64_t > maxInt;
5426+ std::optional<double > maxFloat;
5427+ {
5428+ int64_t maxIntVal;
5429+ double maxFloatVal;
5430+ if (matchPattern (op.getMax (), m_TorchConstantInt (&maxIntVal))) {
5431+ maxInt = maxIntVal;
5432+ maxFloat = static_cast <double >(maxIntVal);
5433+ } else if (matchPattern (op.getMax (), m_TorchConstantFloat (&maxFloatVal))) {
5434+ maxFloat = maxFloatVal;
5435+ maxInt = static_cast <int64_t >(maxFloatVal);
5436+ } else if (succeeded (checkNotNone (rewriter, op, op.getMax ()))) {
54375437 return rewriter.notifyMatchFailure (op,
54385438 " max attr should be a torch constant" );
5439+ }
54395440 }
54405441
54415442 if (!isa<mlir::FloatType>(outElemTy)) {
54425443 IntegerAttr minIntAttr, maxIntAttr;
5443- if (outElemTy.isInteger (8 )) {
5444- minIntAttr = rewriter.getIntegerAttr (
5445- outElemTy,
5446- isMinNotNone ? minInt : std::numeric_limits<int8_t >::min ());
5447- maxIntAttr = rewriter.getIntegerAttr (
5448- outElemTy,
5449- isMaxNotNone ? maxInt : std::numeric_limits<int8_t >::max ());
5450- } else if (outElemTy.isInteger (16 )) {
5451- minIntAttr = rewriter.getIntegerAttr (
5452- outElemTy,
5453- isMinNotNone ? minInt : std::numeric_limits<int16_t >::min ());
5454- maxIntAttr = rewriter.getIntegerAttr (
5455- outElemTy,
5456- isMaxNotNone ? maxInt : std::numeric_limits<int16_t >::max ());
5457- } else if (outElemTy.isInteger (32 )) {
5458- minIntAttr = rewriter.getIntegerAttr (
5459- outElemTy,
5460- isMinNotNone ? minInt : std::numeric_limits<int32_t >::min ());
5461- maxIntAttr = rewriter.getIntegerAttr (
5462- outElemTy,
5463- isMaxNotNone ? maxInt : std::numeric_limits<int32_t >::max ());
5464- } else if (outElemTy.isInteger (64 )) {
5465- minIntAttr = rewriter.getI64IntegerAttr (
5466- isMinNotNone ? minInt : std::numeric_limits<int64_t >::min ());
5467- maxIntAttr = rewriter.getI64IntegerAttr (
5468- isMaxNotNone ? maxInt : std::numeric_limits<int64_t >::max ());
5469- } else {
5470- return rewriter.notifyMatchFailure (op, " Unsupported integer type" );
5444+ if (failed (tosa::getIntegerClampAttrs (rewriter, op, outElemTy, minInt,
5445+ maxInt, minIntAttr, maxIntAttr))) {
5446+ return failure ();
54715447 }
54725448
54735449 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
@@ -5477,28 +5453,10 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
54775453 tosa::NanPropagationMode::PROPAGATE));
54785454 } else {
54795455 FloatAttr minFloatAttr, maxFloatAttr;
5480- if (outElemTy.isF16 ()) {
5481- minFloatAttr =
5482- rewriter.getF16FloatAttr (isMinNotNone ? minFloat : Float16Lowest);
5483- maxFloatAttr =
5484- rewriter.getF16FloatAttr (isMaxNotNone ? maxFloat : Float16Max);
5485- } else if (outElemTy.isBF16 ()) {
5486- minFloatAttr = rewriter.getFloatAttr (
5487- rewriter.getBF16Type (), isMinNotNone ? minFloat : BFloat16Lowest);
5488- maxFloatAttr = rewriter.getFloatAttr (
5489- rewriter.getBF16Type (), isMaxNotNone ? maxFloat : BFloat16Max);
5490- } else if (outElemTy.isF32 ()) {
5491- minFloatAttr = rewriter.getF32FloatAttr (
5492- isMinNotNone ? minFloat : std::numeric_limits<float >::lowest ());
5493- maxFloatAttr = rewriter.getF32FloatAttr (
5494- isMaxNotNone ? maxFloat : std::numeric_limits<float >::max ());
5495- } else if (outElemTy.isF64 ()) {
5496- minFloatAttr = rewriter.getF64FloatAttr (
5497- isMinNotNone ? minFloat : std::numeric_limits<double >::lowest ());
5498- maxFloatAttr = rewriter.getF64FloatAttr (
5499- isMaxNotNone ? maxFloat : std::numeric_limits<double >::max ());
5500- } else {
5501- return rewriter.notifyMatchFailure (op, " Unsupported floating-point type" );
5456+ if (failed (tosa::getFloatClampAttrs (rewriter, op, outElemTy, minFloat,
5457+ maxFloat, minFloatAttr,
5458+ maxFloatAttr))) {
5459+ return failure ();
55025460 }
55035461
55045462 rewriter.replaceOpWithNewOp <tosa::ClampOp>(
@@ -7403,17 +7361,6 @@ template <>
74037361LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
74047362 AtenRoundOp op, OpAdaptor adaptor,
74057363 ConversionPatternRewriter &rewriter) const {
7406- // To round to the nearest integer, we will consider the fractional part of
7407- // the input element (= input element - integer part of element). If the
7408- // fractional part is smaller than 0.5, round the number down. If the
7409- // fractional part is 0.5, apply "round half to even" rule. If the fractional
7410- // part is greater than 0.5, round up.
7411- //
7412- // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
7413- // res = floor(input)
7414- // else:
7415- // res = ceil(input)
7416-
74177364 auto self = adaptor.getSelf ();
74187365
74197366 auto selfTy = dyn_cast<TensorType>(self.getType ());
@@ -7423,67 +7370,13 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
74237370 auto resultTy =
74247371 cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
74257372
7426- auto boolTy =
7427- RankedTensorType::get (resultTy.getShape (), rewriter.getIntegerType (1 ));
7428-
7429- auto resultElemTy = resultTy.getElementType ();
7430-
7431- auto oneHalf =
7432- tosa::getConstTensor<float >(rewriter, op, 0.5 , {}, resultElemTy).value ();
7433-
7434- auto two =
7435- tosa::getConstTensor<float >(rewriter, op, 2 , {}, resultElemTy).value ();
7436-
7437- if (mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), self, oneHalf)
7438- .failed () ||
7439- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), self, two).failed ())
7373+ auto result = tosa::createRoundHalfToEven (rewriter, op, self, resultTy);
7374+ if (!result) {
74407375 return rewriter.notifyMatchFailure (
7441- op, " Failed to equalize ranks among operands and result" );
7442-
7443- auto floorInput =
7444- tosa::FloorOp::create (rewriter, op->getLoc (), resultTy, self);
7445-
7446- // input - floor(input)
7447- auto fractionalPart = tosa::SubOp::create (rewriter, op->getLoc (), resultTy,
7448- self, floorInput.getResult ());
7449-
7450- auto ceilInput = tosa::CeilOp::create (rewriter, op->getLoc (), resultTy, self);
7451-
7452- auto floorInputDivByTwo = tosa::createMulOpAndCast (
7453- rewriter, op, resultTy, floorInput.getResult (), oneHalf, /* shift=*/ 0 );
7454-
7455- auto floorDivResult = tosa::FloorOp::create (rewriter, op->getLoc (), resultTy,
7456- floorInputDivByTwo.getResult ());
7457-
7458- // (floor(input) // 2) * 2
7459- auto evenComparison = tosa::createMulOpAndCast (
7460- rewriter, op, resultTy, floorDivResult.getResult (), two, /* shift=*/ 0 );
7461-
7462- // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
7463- auto floorInputEven =
7464- tosa::EqualOp::create (rewriter, op->getLoc (), boolTy,
7465- floorInput.getResult (), evenComparison.getResult ());
7466-
7467- auto fracEqualOneHalf = tosa::EqualOp::create (
7468- rewriter, op->getLoc (), boolTy, fractionalPart.getResult (), oneHalf);
7469-
7470- auto fracLtOneHalf = tosa::GreaterOp::create (
7471- rewriter, op->getLoc (), boolTy, oneHalf, fractionalPart.getResult ());
7472-
7473- // (frac == 0.5) && (floor(input) % 2 == 0)
7474- auto fracEqualOneHalfCond = tosa::LogicalAndOp::create (
7475- rewriter, op->getLoc (), boolTy, fracEqualOneHalf.getResult (),
7476- floorInputEven.getResult ());
7477-
7478- // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
7479- auto floorResultCond = tosa::LogicalOrOp::create (
7480- rewriter, op->getLoc (), boolTy, fracLtOneHalf.getResult (),
7481- fracEqualOneHalfCond.getResult ());
7482-
7483- rewriter.replaceOpWithNewOp <tosa::SelectOp>(
7484- op, resultTy, floorResultCond.getResult (), floorInput.getResult (),
7485- ceilInput.getResult ());
7376+ op, " failed to implement round-half-to-even with TOSA ops" );
7377+ }
74867378
7379+ rewriter.replaceOp (op, *result);
74877380 return success ();
74887381}
74897382
@@ -9434,6 +9327,86 @@ LogicalResult ConvertAtenOp<AtenDequantizeTensorOp>::matchAndRewrite(
94349327 return success ();
94359328}
94369329
9330+ // Legalization for aten.quantize_per_tensor
9331+ // Implements
9332+ // Q = clamp(round(X / scale) + zero_point)
9333+ template <>
9334+ LogicalResult ConvertAtenOp<AtenQuantizePerTensorOp>::matchAndRewrite(
9335+ AtenQuantizePerTensorOp op, OpAdaptor adaptor,
9336+ ConversionPatternRewriter &rewriter) const {
9337+ Value input = adaptor.getSelf ();
9338+ auto loc = op->getLoc ();
9339+
9340+ // Get scale and zero_point as constants.
9341+ double scaleConst;
9342+ if (!matchPattern (op.getScale (), m_TorchConstantFloat (&scaleConst)))
9343+ return rewriter.notifyMatchFailure (op, " scale must be a Scalar constant" );
9344+
9345+ int64_t zpConst;
9346+ if (!matchPattern (op.getZeroPoint (), m_TorchConstantInt (&zpConst)))
9347+ return rewriter.notifyMatchFailure (op,
9348+ " zero point must be a Scalar constant" );
9349+
9350+ // Get input and result types.
9351+ auto inputTy = cast<RankedTensorType>(input.getType ());
9352+ auto inputElemTy = inputTy.getElementType ();
9353+ auto resultTy = cast<RankedTensorType>(
9354+ getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
9355+ auto resultElemTy = resultTy.getElementType ();
9356+
9357+ // Rescale the input: input * (1.0 / scale)
9358+ auto scaleReciprocal = 1.0 / scaleConst;
9359+ auto scaleConstTensor = tosa::getConstTensor<float >(
9360+ rewriter, op, scaleReciprocal, {}, inputElemTy)
9361+ .value ();
9362+ if (mlir::tosa::EqualizeRanks (rewriter, loc, input, scaleConstTensor)
9363+ .failed ())
9364+ return rewriter.notifyMatchFailure (
9365+ op, " Failed to equalize ranks among operands" );
9366+ Value rescaledInput = tosa::createMulOpAndCast (
9367+ rewriter, op, inputTy, input, scaleConstTensor, /* shift =*/ 0 );
9368+
9369+ // Round
9370+ auto rounded =
9371+ tosa::createRoundHalfToEven (rewriter, op, rescaledInput, inputTy);
9372+ if (!rounded) {
9373+ return rewriter.notifyMatchFailure (
9374+ op, " failed to implement round-half-to-even with TOSA ops" );
9375+ }
9376+
9377+ // Cast to the destination integer type.
9378+ auto intermediateIntTy = resultTy.clone (resultElemTy);
9379+ Value castToInt =
9380+ tosa::CastOp::create (rewriter, loc, intermediateIntTy, *rounded);
9381+
9382+ // Add the zero point.
9383+ Value zpTensor =
9384+ tosa::createZeroPointTensor (rewriter, loc, intermediateIntTy, zpConst)
9385+ .value ();
9386+ if (mlir::tosa::EqualizeRanks (rewriter, loc, castToInt, zpTensor).failed ())
9387+ return failure ();
9388+ Value withZp = tosa::AddOp::create (rewriter, loc, intermediateIntTy,
9389+ castToInt, zpTensor);
9390+
9391+ // Clamp the result to the valid range of the quantized type.
9392+ std::optional<int64_t > minInt,
9393+ maxInt; // no initialization needed as we want to clamp to the numeric
9394+ // limits of the type
9395+ IntegerAttr minIntAttr, maxIntAttr;
9396+ if (failed (tosa::getIntegerClampAttrs (rewriter, op, resultElemTy, minInt,
9397+ maxInt, minIntAttr, maxIntAttr))) {
9398+ return failure ();
9399+ }
9400+ Value clamped = tosa::ClampOp::create (
9401+ rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr,
9402+ /* nan_mode=*/
9403+ tosa::NanPropagationModeAttr::get (rewriter.getContext (),
9404+ tosa::NanPropagationMode::PROPAGATE));
9405+
9406+ rewriter.replaceOp (op, clamped);
9407+ return success ();
9408+ }
9409+
94379410} // namespace
94389411
94399412// -----------------------------------------------------------------------------
@@ -9808,6 +9781,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
98089781 INSERT_ATENOP_PATTERN (AtenTanOp);
98099782 INSERT_ATENOP_PATTERN (AtenUnfoldOp);
98109783 INSERT_ATENOP_PATTERN (AtenDequantizeTensorOp);
9784+ INSERT_ATENOP_PATTERN (AtenQuantizePerTensorOp);
98119785#undef INSERT_ATENOP_PATTERN
98129786
98139787#define INSERT_CLONE_ATENOP_PATTERN (AtenOp ) \
0 commit comments