Skip to content

Commit e568f7e

Browse files
authored
Move handling of integer signedness to the backend conversions (#2597)
The function `getTypeForScalarType` currently takes an argument to specify the signedness of integer types. This is leakage of backend specific requirements into the torch dialect world. Because `getTypeForScalarType` is a utility function for the torch dialect, it should only produce types that match the sign conventions used by PyTorch (regular integers are signed and unsigned integers are unsigned). This commit removes the signedness argument from `getTypeForScalarType`, and moves the backend specific handling of integer types to the backend code.
1 parent 44f6942 commit e568f7e

File tree

8 files changed

+63
-20
lines changed

8 files changed

+63
-20
lines changed

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
2626
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
2727
int64_t length);
2828
torch_upstream::ScalarType getScalarTypeForType(Type type);
29-
FailureOr<Type> getTypeForScalarType(
30-
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
31-
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
29+
FailureOr<Type> getTypeForScalarType(MLIRContext *context,
30+
torch_upstream::ScalarType dtypeInt);
3231

3332
Type getTypeForTorchType(
3433
MLIRContext *context, Type type,

lib/Conversion/TorchToLinalg/TensorConstructors.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
127127
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
128128
return rewriter.notifyMatchFailure(
129129
op, "unimplemented: dtype must be a constant integer or none");
130-
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
131-
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
132-
IntegerType::Signless);
130+
FailureOr<Type> maybeResultElementType =
131+
torch_to_linalg::getBackendTypeForScalarType(
132+
op->getContext(), (torch_upstream::ScalarType)dtypeInt);
133133
if (failed(maybeResultElementType)) {
134134
return rewriter.notifyMatchFailure(
135135
op, "unable to convert `dtypeInt` to builtin type");
@@ -233,9 +233,9 @@ class ConvertAtenEmptyMemoryFormatOp
233233
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
234234
return rewriter.notifyMatchFailure(
235235
op, "unimplemented: dtype must be a constant integer or none");
236-
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
237-
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
238-
IntegerType::Signless);
236+
FailureOr<Type> maybeResultElementType =
237+
torch_to_linalg::getBackendTypeForScalarType(
238+
op->getContext(), (torch_upstream::ScalarType)dtypeInt);
239239
if (failed(maybeResultElementType)) {
240240
return rewriter.notifyMatchFailure(
241241
op, "unable to convert `dtypeInt` to builtin type");

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,9 +1057,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10571057
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
10581058
return nullptr;
10591059
}
1060-
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
1061-
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt,
1062-
IntegerType::Signless);
1060+
FailureOr<Type> maybeResultElementType =
1061+
torch_to_linalg::getBackendTypeForScalarType(
1062+
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt);
10631063
if (failed(maybeResultElementType)) {
10641064
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
10651065
return nullptr;

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "torch-mlir/Conversion/Utils/Utils.h"
2121
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2222
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
23-
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
2423
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2524

2625
using namespace mlir;
@@ -546,3 +545,18 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
546545
return torch_to_linalg::createElementwiseLinalgGeneric(
547546
b, loc, {tensor}, elementType, dtypePromoteBody);
548547
}
548+
549+
FailureOr<Type> torch_to_linalg::getBackendTypeForScalarType(
550+
MLIRContext *context, torch_upstream::ScalarType dtypeInt) {
551+
FailureOr<Type> maybeType =
552+
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
553+
if (failed(maybeType)) {
554+
return failure();
555+
}
556+
Type type = *maybeType;
557+
// The linalg-on-tensors backend currently expects integers to be signless.
558+
if (auto intType = type.dyn_cast<IntegerType>()) {
559+
type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless);
560+
}
561+
return type;
562+
}

lib/Conversion/TorchToLinalg/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "mlir/Transforms/DialectConversion.h"
11+
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
1112

1213
namespace mlir {
1314
namespace torch {
@@ -88,6 +89,12 @@ Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
8889
Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor,
8990
Type elementType);
9091

92+
// Convert a scalar type to the corresponding builtin type in the
93+
// linalg-on-tensors backend.
94+
FailureOr<Type>
95+
getBackendTypeForScalarType(MLIRContext *context,
96+
torch_upstream::ScalarType dtypeInt);
97+
9198
} // namespace torch_to_linalg
9299
} // namespace torch
93100
} // namespace mlir

lib/Conversion/TorchToStablehlo/Basic.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,13 +1672,18 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
16721672
return rewriter.notifyMatchFailure(
16731673
op, "unimplemented: dtype must be a constant integer or none");
16741674
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
1675-
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
1676-
IntegerType::Signless);
1675+
op->getContext(), (torch_upstream::ScalarType)dtypeInt);
16771676
if (failed(maybeResultElementType)) {
16781677
return rewriter.notifyMatchFailure(
16791678
op, "unable to convert `dtypeInt` to builtin type");
16801679
}
16811680
resultElementType = *maybeResultElementType;
1681+
// The stablehlo backend expects signed integers to be signless.
1682+
if (resultElementType.isSignedInteger()) {
1683+
resultElementType = IntegerType::get(
1684+
op->getContext(), resultElementType.getIntOrFloatBitWidth(),
1685+
IntegerType::Signless);
1686+
}
16821687
}
16831688

16841689
// Create an uninitialized tensor of `resultSize` shape.

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,16 @@ Type Torch::getTypeForTorchType(
8585

8686
FailureOr<Type>
8787
Torch::getTypeForScalarType(MLIRContext *context,
88-
torch_upstream::ScalarType dtypeInt,
89-
mlir::IntegerType::SignednessSemantics signedness) {
88+
torch_upstream::ScalarType dtypeInt) {
9089
switch (dtypeInt) {
9190
case torch_upstream::ScalarType::Float:
9291
return Float32Type::get(context);
9392
case torch_upstream::ScalarType::Double:
9493
return Float64Type::get(context);
9594
case torch_upstream::ScalarType::Long:
96-
return IntegerType::get(context, 64, signedness);
95+
return IntegerType::get(context, 64, mlir::IntegerType::Signed);
9796
case torch_upstream::ScalarType::Int:
98-
return IntegerType::get(context, 32, signedness);
97+
return IntegerType::get(context, 32, mlir::IntegerType::Signed);
9998
case torch_upstream::ScalarType::Bool:
10099
return IntegerType::get(context, 1);
101100
case torch_upstream::ScalarType::BFloat16:
@@ -105,7 +104,7 @@ Torch::getTypeForScalarType(MLIRContext *context,
105104
case torch_upstream::ScalarType::Byte:
106105
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned);
107106
case torch_upstream::ScalarType::Char:
108-
return mlir::IntegerType::get(context, 8, signedness);
107+
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed);
109108
case torch_upstream::ScalarType::ComplexHalf:
110109
return mlir::ComplexType::get(Float16Type::get(context));
111110
case torch_upstream::ScalarType::ComplexFloat:

projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,25 @@ def EmptyModule_int(module, tu: TestUtils):
451451
module.forward()
452452

453453

454+
class EmptyUInt8Module(torch.nn.Module):
455+
456+
def __init__(self):
457+
super().__init__()
458+
459+
@export
460+
@annotate_args([
461+
None,
462+
])
463+
def forward(self):
464+
empty = torch.ops.aten.empty([1], dtype=torch.uint8)
465+
return torch.ops.aten.zeros_like(empty).to(torch.int8)
466+
467+
468+
@register_test_case(module_factory=lambda: EmptyUInt8Module())
469+
def EmptyModule_uint8(module, tu: TestUtils):
470+
module.forward()
471+
472+
454473
class EmptyFloatModule(torch.nn.Module):
455474

456475
def __init__(self):

0 commit comments

Comments
 (0)