From 1097006e0a1f6ad9dd1cd94073eeb37084277490 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 11 Dec 2025 10:27:49 +0000 Subject: [PATCH 1/2] [TorchToLinalg] Handle output vector type inconsistency in AtenTransposeInt op lowering This commit adds the support to handle output tensor type consistency in case of dynamic dims for AtenTransposeInt op lowering by adding a cast operation to ensure the output tensor type matches the expected type. Also, update tests in datamovement.mlir to reflect changes in tensor handling for dynamic dimensions in transpose operations. Example IR: """ %0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list -> !torch.vtensor<[?,?,?,?,?,?],f32> %2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32> """ --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 6 +++++ .../TorchToLinalg/datamovement.mlir | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 1cbb4aa99288..45c47b9210da 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1740,6 +1740,12 @@ class ConvertAtenTransposeIntOp Value outVector = tensor::EmptyOp::create( rewriter, loc, getAsOpFoldResult(outputDims), elementType); + // Note: The empty tensor type may not match `outType` due to folding + // performed by `getAsOpFoldResult` of `tensor::DimOp`. + // Cast to `outType` if needed to ensure type consistency. + if (outVector.getType() != outType) + outVector = rewriter.create(loc, outType, outVector); + SmallVector permutation(inputRank); std::iota(permutation.begin(), permutation.end(), 0); permutation[dim0] = dim1; diff --git a/test/Conversion/TorchToLinalg/datamovement.mlir b/test/Conversion/TorchToLinalg/datamovement.mlir index b24bd7bd2d4f..82ecd69dd996 100644 --- a/test/Conversion/TorchToLinalg/datamovement.mlir +++ b/test/Conversion/TorchToLinalg/datamovement.mlir @@ -71,3 +71,28 @@ func.func @torch.aten.reflection_pad2d(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> %1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,1,4,4],f32>, !torch.list -> !torch.vtensor<[1,1,8,9],f32> return %1 : !torch.vtensor<[1,1,8,9],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.transpose.int$dynamic_dims( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,56,56,96],f32> -> tensor<1x56x56x96xf32> +// CHECK: %[[VAL_9:.*]] = tensor.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3, 4], [5]] output_shape [1, 8, 7, 8, 7, 96] : tensor<1x56x56x96xf32> into tensor<1x8x7x8x7x96xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x8x8x7x7x96xf32> +// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[VAL_9]] {{.*}} outs(%[[EMPTY]] {{.*}} permutation = [0, 1, 3, 2, 4, 5] +// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[TRANSPOSE]] : tensor<1x8x8x7x7x96xf32> to tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST]] : tensor -> !torch.vtensor<[?,?,?,?,?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?,?,?],f32> +// CHECK: } +func.func @torch.aten.transpose.int$dynamic_dims(%arg0: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> { + %int1 = torch.constant.int 1 + %int8 = torch.constant.int 8 + %int7 = torch.constant.int 7 + %int96 = torch.constant.int 96 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list -> !torch.vtensor<[?,?,?,?,?,?],f32> + %2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32> + return %2 : !torch.vtensor<[?,?,?,?,?,?],f32> +} From f7a2015e845aee309b90792d02f1b5348293be21 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 11 Dec 2025 11:20:57 +0000 Subject: [PATCH 2/2] Fix depreceation error --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 45c47b9210da..f3858034023c 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1744,7 +1744,7 @@ class ConvertAtenTransposeIntOp // performed by `getAsOpFoldResult` of `tensor::DimOp`. // Cast to `outType` if needed to ensure type consistency. if (outVector.getType() != outType) - outVector = rewriter.create(loc, outType, outVector); + outVector = tensor::CastOp::create(rewriter, loc, outType, outVector); SmallVector permutation(inputRank); std::iota(permutation.begin(), permutation.end(), 0);