Skip to content

Commit ba513ee

Browse files
authored
[tosa] : Add support for quantize_per_tensor. (#4390)
Adds support for quantize_per_tensor https://github.com/pytorch/pytorch/blob/a5436a5e8e4ee42d1debf52c2786c7ae0043a434/torch/ao/quantization/fx/_decomposed.py#L83 for tosa backend
1 parent 5b049a1 commit ba513ee

File tree

5 files changed

+331
-149
lines changed

5 files changed

+331
-149
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,31 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
111111
RankedTensorType output_type, Value input_value,
112112
ElementsAttr axes_elems, bool keep_dims);
113113

114+
// Creates IntegerAttrs for clamping, using provided min/max values or the
115+
// numeric limits of the element type if the values are not provided.
116+
LogicalResult getIntegerClampAttrs(ConversionPatternRewriter &rewriter,
117+
Operation *op, Type elemTy,
118+
std::optional<int64_t> minInt,
119+
std::optional<int64_t> maxInt,
120+
IntegerAttr &minAttr, IntegerAttr &maxAttr);
121+
122+
// Creates FloatAttrs for clamping, using provided min/max values or the numeric
123+
// limits of the element type if the values are not provided.
124+
LogicalResult getFloatClampAttrs(ConversionPatternRewriter &rewriter,
125+
Operation *op, Type elemTy,
126+
std::optional<double> minFloat,
127+
std::optional<double> maxFloat,
128+
FloatAttr &minAttr, FloatAttr &maxAttr);
129+
130+
// Implements "round half to even" logic for aten.round using TOSA ops.
131+
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
132+
// res = floor(input)
133+
// else:
134+
// res = ceil(input)
135+
std::optional<Value> createRoundHalfToEven(ConversionPatternRewriter &rewriter,
136+
Operation *op, Value input,
137+
RankedTensorType resultTy);
138+
114139
} // namespace tosa
115140
} // namespace mlir
116141

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 119 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -5405,69 +5405,45 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
54055405
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
54065406
auto outElemTy = outType.getElementType();
54075407

5408-
int64_t minInt, maxInt;
5409-
double minFloat, maxFloat;
5410-
bool isMinNotNone = false;
5411-
bool isMaxNotNone = false;
5412-
5413-
auto isMinInt = matchPattern(op.getMin(), m_TorchConstantInt(&minInt));
5414-
auto isMinFloat = matchPattern(op.getMin(), m_TorchConstantFloat(&minFloat));
5415-
if (isMinInt) {
5416-
minFloat = static_cast<float>(minInt);
5417-
isMinNotNone = true;
5418-
} else if (isMinFloat) {
5419-
minInt = static_cast<int64_t>(minFloat);
5420-
isMinNotNone = true;
5421-
} else {
5422-
if (succeeded(checkNotNone(rewriter, op, op.getMin())))
5408+
std::optional<int64_t> minInt;
5409+
std::optional<double> minFloat;
5410+
{
5411+
int64_t minIntVal;
5412+
double minFloatVal;
5413+
if (matchPattern(op.getMin(), m_TorchConstantInt(&minIntVal))) {
5414+
minInt = minIntVal;
5415+
minFloat = static_cast<double>(minIntVal);
5416+
} else if (matchPattern(op.getMin(), m_TorchConstantFloat(&minFloatVal))) {
5417+
minFloat = minFloatVal;
5418+
minInt = static_cast<int64_t>(minFloatVal);
5419+
} else if (succeeded(checkNotNone(rewriter, op, op.getMin()))) {
54235420
return rewriter.notifyMatchFailure(op,
54245421
"min attr should be a torch constant");
5422+
}
54255423
}
54265424

5427-
auto isMaxInt = matchPattern(op.getMax(), m_TorchConstantInt(&maxInt));
5428-
auto isMaxFloat = matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloat));
5429-
if (isMaxInt) {
5430-
maxFloat = static_cast<float>(maxInt);
5431-
isMaxNotNone = true;
5432-
} else if (isMaxFloat) {
5433-
maxInt = static_cast<int64_t>(maxFloat);
5434-
isMaxNotNone = true;
5435-
} else {
5436-
if (succeeded(checkNotNone(rewriter, op, op.getMax())))
5425+
std::optional<int64_t> maxInt;
5426+
std::optional<double> maxFloat;
5427+
{
5428+
int64_t maxIntVal;
5429+
double maxFloatVal;
5430+
if (matchPattern(op.getMax(), m_TorchConstantInt(&maxIntVal))) {
5431+
maxInt = maxIntVal;
5432+
maxFloat = static_cast<double>(maxIntVal);
5433+
} else if (matchPattern(op.getMax(), m_TorchConstantFloat(&maxFloatVal))) {
5434+
maxFloat = maxFloatVal;
5435+
maxInt = static_cast<int64_t>(maxFloatVal);
5436+
} else if (succeeded(checkNotNone(rewriter, op, op.getMax()))) {
54375437
return rewriter.notifyMatchFailure(op,
54385438
"max attr should be a torch constant");
5439+
}
54395440
}
54405441

54415442
if (!isa<mlir::FloatType>(outElemTy)) {
54425443
IntegerAttr minIntAttr, maxIntAttr;
5443-
if (outElemTy.isInteger(8)) {
5444-
minIntAttr = rewriter.getIntegerAttr(
5445-
outElemTy,
5446-
isMinNotNone ? minInt : std::numeric_limits<int8_t>::min());
5447-
maxIntAttr = rewriter.getIntegerAttr(
5448-
outElemTy,
5449-
isMaxNotNone ? maxInt : std::numeric_limits<int8_t>::max());
5450-
} else if (outElemTy.isInteger(16)) {
5451-
minIntAttr = rewriter.getIntegerAttr(
5452-
outElemTy,
5453-
isMinNotNone ? minInt : std::numeric_limits<int16_t>::min());
5454-
maxIntAttr = rewriter.getIntegerAttr(
5455-
outElemTy,
5456-
isMaxNotNone ? maxInt : std::numeric_limits<int16_t>::max());
5457-
} else if (outElemTy.isInteger(32)) {
5458-
minIntAttr = rewriter.getIntegerAttr(
5459-
outElemTy,
5460-
isMinNotNone ? minInt : std::numeric_limits<int32_t>::min());
5461-
maxIntAttr = rewriter.getIntegerAttr(
5462-
outElemTy,
5463-
isMaxNotNone ? maxInt : std::numeric_limits<int32_t>::max());
5464-
} else if (outElemTy.isInteger(64)) {
5465-
minIntAttr = rewriter.getI64IntegerAttr(
5466-
isMinNotNone ? minInt : std::numeric_limits<int64_t>::min());
5467-
maxIntAttr = rewriter.getI64IntegerAttr(
5468-
isMaxNotNone ? maxInt : std::numeric_limits<int64_t>::max());
5469-
} else {
5470-
return rewriter.notifyMatchFailure(op, "Unsupported integer type");
5444+
if (failed(tosa::getIntegerClampAttrs(rewriter, op, outElemTy, minInt,
5445+
maxInt, minIntAttr, maxIntAttr))) {
5446+
return failure();
54715447
}
54725448

54735449
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
@@ -5477,28 +5453,10 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
54775453
tosa::NanPropagationMode::PROPAGATE));
54785454
} else {
54795455
FloatAttr minFloatAttr, maxFloatAttr;
5480-
if (outElemTy.isF16()) {
5481-
minFloatAttr =
5482-
rewriter.getF16FloatAttr(isMinNotNone ? minFloat : Float16Lowest);
5483-
maxFloatAttr =
5484-
rewriter.getF16FloatAttr(isMaxNotNone ? maxFloat : Float16Max);
5485-
} else if (outElemTy.isBF16()) {
5486-
minFloatAttr = rewriter.getFloatAttr(
5487-
rewriter.getBF16Type(), isMinNotNone ? minFloat : BFloat16Lowest);
5488-
maxFloatAttr = rewriter.getFloatAttr(
5489-
rewriter.getBF16Type(), isMaxNotNone ? maxFloat : BFloat16Max);
5490-
} else if (outElemTy.isF32()) {
5491-
minFloatAttr = rewriter.getF32FloatAttr(
5492-
isMinNotNone ? minFloat : std::numeric_limits<float>::lowest());
5493-
maxFloatAttr = rewriter.getF32FloatAttr(
5494-
isMaxNotNone ? maxFloat : std::numeric_limits<float>::max());
5495-
} else if (outElemTy.isF64()) {
5496-
minFloatAttr = rewriter.getF64FloatAttr(
5497-
isMinNotNone ? minFloat : std::numeric_limits<double>::lowest());
5498-
maxFloatAttr = rewriter.getF64FloatAttr(
5499-
isMaxNotNone ? maxFloat : std::numeric_limits<double>::max());
5500-
} else {
5501-
return rewriter.notifyMatchFailure(op, "Unsupported floating-point type");
5456+
if (failed(tosa::getFloatClampAttrs(rewriter, op, outElemTy, minFloat,
5457+
maxFloat, minFloatAttr,
5458+
maxFloatAttr))) {
5459+
return failure();
55025460
}
55035461

55045462
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
@@ -7403,17 +7361,6 @@ template <>
74037361
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
74047362
AtenRoundOp op, OpAdaptor adaptor,
74057363
ConversionPatternRewriter &rewriter) const {
7406-
// To round to the nearest integer, we will consider the fractional part of
7407-
// the input element (= input element - integer part of element). If the
7408-
// fractional part is smaller than 0.5, round the number down. If the
7409-
// fractional part is 0.5, apply "round half to even" rule. If the fractional
7410-
// part is greater than 0.5, round up.
7411-
//
7412-
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
7413-
// res = floor(input)
7414-
// else:
7415-
// res = ceil(input)
7416-
74177364
auto self = adaptor.getSelf();
74187365

74197366
auto selfTy = dyn_cast<TensorType>(self.getType());
@@ -7423,67 +7370,13 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
74237370
auto resultTy =
74247371
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
74257372

7426-
auto boolTy =
7427-
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
7428-
7429-
auto resultElemTy = resultTy.getElementType();
7430-
7431-
auto oneHalf =
7432-
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
7433-
7434-
auto two =
7435-
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
7436-
7437-
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf)
7438-
.failed() ||
7439-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, two).failed())
7373+
auto result = tosa::createRoundHalfToEven(rewriter, op, self, resultTy);
7374+
if (!result) {
74407375
return rewriter.notifyMatchFailure(
7441-
op, "Failed to equalize ranks among operands and result");
7442-
7443-
auto floorInput =
7444-
tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, self);
7445-
7446-
// input - floor(input)
7447-
auto fractionalPart = tosa::SubOp::create(rewriter, op->getLoc(), resultTy,
7448-
self, floorInput.getResult());
7449-
7450-
auto ceilInput = tosa::CeilOp::create(rewriter, op->getLoc(), resultTy, self);
7451-
7452-
auto floorInputDivByTwo = tosa::createMulOpAndCast(
7453-
rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
7454-
7455-
auto floorDivResult = tosa::FloorOp::create(rewriter, op->getLoc(), resultTy,
7456-
floorInputDivByTwo.getResult());
7457-
7458-
// (floor(input) // 2) * 2
7459-
auto evenComparison = tosa::createMulOpAndCast(
7460-
rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0);
7461-
7462-
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
7463-
auto floorInputEven =
7464-
tosa::EqualOp::create(rewriter, op->getLoc(), boolTy,
7465-
floorInput.getResult(), evenComparison.getResult());
7466-
7467-
auto fracEqualOneHalf = tosa::EqualOp::create(
7468-
rewriter, op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
7469-
7470-
auto fracLtOneHalf = tosa::GreaterOp::create(
7471-
rewriter, op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
7472-
7473-
// (frac == 0.5) && (floor(input) % 2 == 0)
7474-
auto fracEqualOneHalfCond = tosa::LogicalAndOp::create(
7475-
rewriter, op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
7476-
floorInputEven.getResult());
7477-
7478-
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
7479-
auto floorResultCond = tosa::LogicalOrOp::create(
7480-
rewriter, op->getLoc(), boolTy, fracLtOneHalf.getResult(),
7481-
fracEqualOneHalfCond.getResult());
7482-
7483-
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
7484-
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
7485-
ceilInput.getResult());
7376+
op, "failed to implement round-half-to-even with TOSA ops");
7377+
}
74867378

7379+
rewriter.replaceOp(op, *result);
74877380
return success();
74887381
}
74897382

@@ -9434,6 +9327,86 @@ LogicalResult ConvertAtenOp<AtenDequantizeTensorOp>::matchAndRewrite(
94349327
return success();
94359328
}
94369329

9330+
// Legalization for aten.quantize_per_tensor
9331+
// Implements
9332+
// Q = clamp(round(X / scale) + zero_point)
9333+
template <>
9334+
LogicalResult ConvertAtenOp<AtenQuantizePerTensorOp>::matchAndRewrite(
9335+
AtenQuantizePerTensorOp op, OpAdaptor adaptor,
9336+
ConversionPatternRewriter &rewriter) const {
9337+
Value input = adaptor.getSelf();
9338+
auto loc = op->getLoc();
9339+
9340+
// Get scale and zero_point as constants.
9341+
double scaleConst;
9342+
if (!matchPattern(op.getScale(), m_TorchConstantFloat(&scaleConst)))
9343+
return rewriter.notifyMatchFailure(op, "scale must be a Scalar constant");
9344+
9345+
int64_t zpConst;
9346+
if (!matchPattern(op.getZeroPoint(), m_TorchConstantInt(&zpConst)))
9347+
return rewriter.notifyMatchFailure(op,
9348+
"zero point must be a Scalar constant");
9349+
9350+
// Get input and result types.
9351+
auto inputTy = cast<RankedTensorType>(input.getType());
9352+
auto inputElemTy = inputTy.getElementType();
9353+
auto resultTy = cast<RankedTensorType>(
9354+
getTypeConverter()->convertType(op->getResult(0).getType()));
9355+
auto resultElemTy = resultTy.getElementType();
9356+
9357+
// Rescale the input: input * (1.0 / scale)
9358+
auto scaleReciprocal = 1.0 / scaleConst;
9359+
auto scaleConstTensor = tosa::getConstTensor<float>(
9360+
rewriter, op, scaleReciprocal, {}, inputElemTy)
9361+
.value();
9362+
if (mlir::tosa::EqualizeRanks(rewriter, loc, input, scaleConstTensor)
9363+
.failed())
9364+
return rewriter.notifyMatchFailure(
9365+
op, "Failed to equalize ranks among operands");
9366+
Value rescaledInput = tosa::createMulOpAndCast(
9367+
rewriter, op, inputTy, input, scaleConstTensor, /*shift =*/0);
9368+
9369+
// Round
9370+
auto rounded =
9371+
tosa::createRoundHalfToEven(rewriter, op, rescaledInput, inputTy);
9372+
if (!rounded) {
9373+
return rewriter.notifyMatchFailure(
9374+
op, "failed to implement round-half-to-even with TOSA ops");
9375+
}
9376+
9377+
// Cast to the destination integer type.
9378+
auto intermediateIntTy = resultTy.clone(resultElemTy);
9379+
Value castToInt =
9380+
tosa::CastOp::create(rewriter, loc, intermediateIntTy, *rounded);
9381+
9382+
// Add the zero point.
9383+
Value zpTensor =
9384+
tosa::createZeroPointTensor(rewriter, loc, intermediateIntTy, zpConst)
9385+
.value();
9386+
if (mlir::tosa::EqualizeRanks(rewriter, loc, castToInt, zpTensor).failed())
9387+
return failure();
9388+
Value withZp = tosa::AddOp::create(rewriter, loc, intermediateIntTy,
9389+
castToInt, zpTensor);
9390+
9391+
// Clamp the result to the valid range of the quantized type.
9392+
std::optional<int64_t> minInt,
9393+
maxInt; // no initialization needed as we want to clamp to the numeric
9394+
// limits of the type
9395+
IntegerAttr minIntAttr, maxIntAttr;
9396+
if (failed(tosa::getIntegerClampAttrs(rewriter, op, resultElemTy, minInt,
9397+
maxInt, minIntAttr, maxIntAttr))) {
9398+
return failure();
9399+
}
9400+
Value clamped = tosa::ClampOp::create(
9401+
rewriter, loc, resultTy, withZp, minIntAttr, maxIntAttr,
9402+
/*nan_mode=*/
9403+
tosa::NanPropagationModeAttr::get(rewriter.getContext(),
9404+
tosa::NanPropagationMode::PROPAGATE));
9405+
9406+
rewriter.replaceOp(op, clamped);
9407+
return success();
9408+
}
9409+
94379410
} // namespace
94389411

94399412
// -----------------------------------------------------------------------------
@@ -9808,6 +9781,7 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
98089781
INSERT_ATENOP_PATTERN(AtenTanOp);
98099782
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
98109783
INSERT_ATENOP_PATTERN(AtenDequantizeTensorOp);
9784+
INSERT_ATENOP_PATTERN(AtenQuantizePerTensorOp);
98119785
#undef INSERT_ATENOP_PATTERN
98129786

98139787
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

0 commit comments

Comments
 (0)