diff --git a/externals/llvm-project b/externals/llvm-project index 41f65666f637..7bfdaa51f155 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 41f65666f6378bba7266be7c662c70074f04ed75 +Subproject commit 7bfdaa51f155432346e507d8ce389802c92eb530 diff --git a/externals/stablehlo b/externals/stablehlo index 4c0d4841519a..1ef9e390b529 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 4c0d4841519aed22e3689c30b72a0e4228051249 +Subproject commit 1ef9e390b5295e676d2b864fe1924bc2f3f4cf0f diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index f102d11617b3..b3f5052027e8 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -134,11 +134,18 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value initTensor = tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy); - Type fillValElemTy = elemTy; - if (auto dtypeComplex = dyn_cast(elemTy)) - fillValElemTy = cast(dtypeComplex.getElementType()); - - Value c0 = arith::ConstantOp::create(b, loc, b.getZeroAttr(fillValElemTy)); + Value c0; + if (auto dtypeComplex = dyn_cast(elemTy)) { + // For complex types, create a complex zero (0.0 + 0.0j) + Type floatType = cast(dtypeComplex.getElementType()); + Value realZero = + arith::ConstantOp::create(b, loc, b.getZeroAttr(floatType)); + Value imagZero = + arith::ConstantOp::create(b, loc, b.getZeroAttr(floatType)); + c0 = complex::CreateOp::create(b, loc, elemTy, realZero, imagZero); + } else { + c0 = arith::ConstantOp::create(b, loc, b.getZeroAttr(elemTy)); + } return linalg::FillOp::create(b, loc, c0, initTensor).getResult(0); } @@ -147,11 +154,17 @@ Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value initTensor = tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy); - Type fillValElemTy = elemTy; - if (auto dtypeComplex = dyn_cast(elemTy)) - fillValElemTy = cast(dtypeComplex.getElementType()); - - Value c1 = arith::ConstantOp::create(b, loc, b.getOneAttr(fillValElemTy)); + Value c1; + if (auto dtypeComplex = dyn_cast(elemTy)) { + // For complex types, create a complex one (1.0 + 0.0j) + Type floatType = cast(dtypeComplex.getElementType()); + Value realOne = arith::ConstantOp::create(b, loc, b.getOneAttr(floatType)); + Value imagZero = + arith::ConstantOp::create(b, loc, b.getZeroAttr(floatType)); + c1 = complex::CreateOp::create(b, loc, elemTy, realOne, imagZero); + } else { + c1 = arith::ConstantOp::create(b, loc, b.getOneAttr(elemTy)); + } return linalg::FillOp::create(b, loc, c1, initTensor).getResult(0); } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 8625a55205d3..240c1ca1ada0 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/IR/BuiltinOps.h" @@ -60,14 +61,14 @@ void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter, populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); + return typeConverter.isSignatureLegal(op.getFunctionType()); }); populateCallOpTypeConversionPattern(patterns, typeConverter); target.addDynamicallyLegalOp( [&](func::CallOp op) { return typeConverter.isLegal(op); }); - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + cf::populateCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); populateReturnOpTypeConversionPattern(patterns, typeConverter); target.addLegalOp(); diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index d5042926b63c..dd5ff074cc26 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ set(LinkedLibs MLIRFuncTransforms + MLIRControlFlowTransforms MLIRIR MLIRLinalgTransforms MLIRMemRefTransforms diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir index abd45183bd84..1bb8235dd5f6 100644 --- a/test/Conversion/TorchToLinalg/spectral.mlir +++ b/test/Conversion/TorchToLinalg/spectral.mlir @@ -10,7 +10,8 @@ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex> // CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> // CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex> -// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex>) -> tensor<16x5xcomplex> +// CHECK-DAG: %[[CPLX:.*]] = complex.create %[[CST]], %[[CST]] : complex +// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CPLX]] : complex) outs(%[[VAR1]] : tensor<16x5xcomplex>) -> tensor<16x5xcomplex> // CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex>) outs(%[[VAR2]] : tensor<16x5xcomplex>) { // CHECK: ^bb0(%in: f32, %in_1: complex, %out: complex): // CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex @@ -41,7 +42,8 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> // CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32> // CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0] // CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex> -// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex>) -> tensor<23x19xcomplex> +// CHECK-DAG: %[[CPLX:.*]] = complex.create %[[CST]], %[[CST]] : complex +// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CPLX]] : complex) outs(%[[VAR2]] : tensor<23x19xcomplex>) -> tensor<23x19xcomplex> // CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex>) outs(%[[VAR3]] : tensor<23x19xcomplex>) { // CHECK: ^bb0(%in: f32, %in_2: complex, %out: complex): // CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex