diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 1cbb4aa99288..f3858034023c 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1740,6 +1740,12 @@ class ConvertAtenTransposeIntOp Value outVector = tensor::EmptyOp::create( rewriter, loc, getAsOpFoldResult(outputDims), elementType); + // Note: The empty tensor type may not match `outType` due to folding + // performed by `getAsOpFoldResult` of `tensor::DimOp`. + // Cast to `outType` if needed to ensure type consistency. + if (outVector.getType() != outType) + outVector = tensor::CastOp::create(rewriter, loc, outType, outVector); + SmallVector permutation(inputRank); std::iota(permutation.begin(), permutation.end(), 0); permutation[dim0] = dim1; diff --git a/test/Conversion/TorchToLinalg/datamovement.mlir b/test/Conversion/TorchToLinalg/datamovement.mlir index b24bd7bd2d4f..82ecd69dd996 100644 --- a/test/Conversion/TorchToLinalg/datamovement.mlir +++ b/test/Conversion/TorchToLinalg/datamovement.mlir @@ -71,3 +71,28 @@ func.func @torch.aten.reflection_pad2d(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> %1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,1,4,4],f32>, !torch.list -> !torch.vtensor<[1,1,8,9],f32> return %1 : !torch.vtensor<[1,1,8,9],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.transpose.int$dynamic_dims( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,56,56,96],f32> -> tensor<1x56x56x96xf32> +// CHECK: %[[VAL_9:.*]] = tensor.expand_shape %[[VAL_1]] {{\[\[}}0], [1, 2], [3, 4], [5]] output_shape [1, 8, 7, 8, 7, 96] : tensor<1x56x56x96xf32> into tensor<1x8x7x8x7x96xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x8x8x7x7x96xf32> +// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[VAL_9]] {{.*}} outs(%[[EMPTY]] {{.*}} permutation = [0, 1, 3, 2, 4, 5] +// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[TRANSPOSE]] : tensor<1x8x8x7x7x96xf32> to tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[RESULT_CAST]] : tensor -> !torch.vtensor<[?,?,?,?,?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?,?,?,?,?],f32> +// CHECK: } +func.func @torch.aten.transpose.int$dynamic_dims(%arg0: !torch.vtensor<[1,56,56,96],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> { + %int1 = torch.constant.int 1 + %int8 = torch.constant.int 8 + %int7 = torch.constant.int 7 + %int96 = torch.constant.int 96 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int1, %int8, %int7, %int8, %int7, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,56,56,96],f32>, !torch.list -> !torch.vtensor<[?,?,?,?,?,?],f32> + %2 = torch.aten.transpose.int %1, %int2, %int3 : !torch.vtensor<[?,?,?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?,?,?],f32> + return %2 : !torch.vtensor<[?,?,?,?,?,?],f32> +}