Skip to content

Commit 77d0c3e

Browse files
Integrate LLVM at 7bfdaa51f155432346e507d8ce389802c92eb530 (#4399)
Update LLVM to llvm/llvm-project@7bfdaa5 Update StableHLO to openxla/stablehlo@1ef9e39 This commit make changes to the `BackendTypeConversionPass` as per this commit: llvm/llvm-project@d4c41b7. This commit also adds the support for handling complex type in `create[Zero|One]InitTensor` utility. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent cb0f5dc commit 77d0c3e

File tree

6 files changed

+34
-17
lines changed

6 files changed

+34
-17
lines changed

externals/llvm-project

Submodule llvm-project updated 14481 files

externals/stablehlo

Submodule stablehlo updated 75 files

lib/Conversion/Utils/Utils.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,18 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
134134
Value initTensor =
135135
tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy);
136136

137-
Type fillValElemTy = elemTy;
138-
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
139-
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
140-
141-
Value c0 = arith::ConstantOp::create(b, loc, b.getZeroAttr(fillValElemTy));
137+
Value c0;
138+
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy)) {
139+
// For complex types, create a complex zero (0.0 + 0.0j)
140+
Type floatType = cast<mlir::FloatType>(dtypeComplex.getElementType());
141+
Value realZero =
142+
arith::ConstantOp::create(b, loc, b.getZeroAttr(floatType));
143+
Value imagZero =
144+
arith::ConstantOp::create(b, loc, b.getZeroAttr(floatType));
145+
c0 = complex::CreateOp::create(b, loc, elemTy, realZero, imagZero);
146+
} else {
147+
c0 = arith::ConstantOp::create(b, loc, b.getZeroAttr(elemTy));
148+
}
142149
return linalg::FillOp::create(b, loc, c0, initTensor).getResult(0);
143150
}
144151

@@ -147,11 +154,17 @@ Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
147154
Value initTensor =
148155
tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy);
149156

150-
Type fillValElemTy = elemTy;
151-
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
152-
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
153-
154-
Value c1 = arith::ConstantOp::create(b, loc, b.getOneAttr(fillValElemTy));
157+
Value c1;
158+
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy)) {
159+
// For complex types, create a complex one (1.0 + 0.0j)
160+
Type floatType = cast<mlir::FloatType>(dtypeComplex.getElementType());
161+
Value realOne = arith::ConstantOp::create(b, loc, b.getOneAttr(floatType));
162+
Value imagZero =
163+
arith::ConstantOp::create(b, loc, b.getZeroAttr(floatType));
164+
c1 = complex::CreateOp::create(b, loc, elemTy, realOne, imagZero);
165+
} else {
166+
c1 = arith::ConstantOp::create(b, loc, b.getOneAttr(elemTy));
167+
}
155168
return linalg::FillOp::create(b, loc, c1, initTensor).getResult(0);
156169
}
157170

lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1314
#include "mlir/IR/BuiltinOps.h"
@@ -60,14 +61,14 @@ void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
6061
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
6162
typeConverter);
6263
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
63-
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
64-
typeConverter.isLegal(&op.getBody());
64+
return typeConverter.isSignatureLegal(op.getFunctionType());
6565
});
6666
populateCallOpTypeConversionPattern(patterns, typeConverter);
6767
target.addDynamicallyLegalOp<func::CallOp>(
6868
[&](func::CallOp op) { return typeConverter.isLegal(op); });
6969

70-
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
70+
cf::populateCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
71+
target);
7172
populateReturnOpTypeConversionPattern(patterns, typeConverter);
7273
target.addLegalOp<ModuleOp>();
7374

lib/Dialect/TorchConversion/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
set(LinkedLibs
22
MLIRFuncTransforms
3+
MLIRControlFlowTransforms
34
MLIRIR
45
MLIRLinalgTransforms
56
MLIRMemRefTransforms

test/Conversion/TorchToLinalg/spectral.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex<f32>>
1111
// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32>
1212
// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex<f32>>
13-
// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex<f32>>) -> tensor<16x5xcomplex<f32>>
13+
// CHECK-DAG: %[[CPLX:.*]] = complex.create %[[CST]], %[[CST]] : complex<f32>
14+
// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CPLX]] : complex<f32>) outs(%[[VAR1]] : tensor<16x5xcomplex<f32>>) -> tensor<16x5xcomplex<f32>>
1415
// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex<f32>>) outs(%[[VAR2]] : tensor<16x5xcomplex<f32>>) {
1516
// CHECK: ^bb0(%in: f32, %in_1: complex<f32>, %out: complex<f32>):
1617
// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex<f32>
@@ -41,7 +42,8 @@ func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) ->
4142
// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32>
4243
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0]
4344
// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex<f32>>
44-
// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex<f32>>) -> tensor<23x19xcomplex<f32>>
45+
// CHECK-DAG: %[[CPLX:.*]] = complex.create %[[CST]], %[[CST]] : complex<f32>
46+
// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CPLX]] : complex<f32>) outs(%[[VAR2]] : tensor<23x19xcomplex<f32>>) -> tensor<23x19xcomplex<f32>>
4547
// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex<f32>>) outs(%[[VAR3]] : tensor<23x19xcomplex<f32>>) {
4648
// CHECK: ^bb0(%in: f32, %in_2: complex<f32>, %out: complex<f32>):
4749
// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex<f32>

0 commit comments

Comments
 (0)