Skip to content

Commit c53f36c

Browse files
authored
Fix an error on krnl-to-affine conversion for onnx.CategoryMapper with rank-2+ inputs (#2086)
This PR includes fixes an error on krnl-to-affine conversion for onnx.CategoryMapper with rank 2 inputs and their lit tests. Signed-off-by: Yasushi Negishi <[email protected]>
1 parent 057f3ad commit c53f36c

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ struct ONNXCategoryMapperOpLowering
8787

8888
// Basic information.
8989
int64_t rank = memRefType.getShape().size();
90+
assert(((rank == 1) || (rank == 2)) && "Invalid rank of input");
9091
ShapedType inputType = X.getType().cast<ShapedType>();
9192
Type elementType = inputType.getElementType();
9293

@@ -259,8 +260,14 @@ struct ONNXCategoryMapperOpLowering
259260
MathBuilder createMath(createKrnl);
260261
Value zero = createMath.constant(
261262
createMath.getBuilder().getIntegerType(64), 0);
263+
ArrayRef<int64_t> shape =
264+
memref.getType().cast<ShapedType>().getShape();
265+
SmallVector<int64_t, 4> newShape;
266+
for (uint64_t i = 0; i < shape.size(); i++)
267+
newShape.emplace_back(
268+
(shape[i] == ShapedType::kDynamic) ? 1 : shape[i]);
262269
auto memRefType = MemRefType::get(
263-
{rank}, krnl::StringType::get(elementType.getContext()));
270+
newShape, krnl::StringType::get(elementType.getContext()));
264271
Value stringMemRef = createKrnl.getRef(memRefType, memref, zero);
265272
inputElem = createKrnl.load(stringMemRef, loopInd);
266273
})

src/Dialect/ONNX/ONNXOps/ML/CategoryMapper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ LogicalResult ONNXCategoryMapperOp::verify() {
4747
}
4848

4949
ShapedType inputType = X.getType().cast<ShapedType>();
50+
if ((inputType.getRank() != 1) && (inputType.getRank() != 2))
51+
return emitOpError("input rank must be one or two");
5052
Type elementType = inputType.getElementType();
5153
if (!elementType.isInteger(64) && !elementType.isa<ONNXStringType>())
5254
return emitOpError("input must be a tensor of int64 or string");

src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ namespace onnx_mlir {
3232
/// Handle shape inference for unary element-wise operators.
3333
LogicalResult inferShapeForUnaryOps(Operation *op) {
3434
Value input = op->getOperand(0);
35+
if (!hasShapeAndRank(input))
36+
return success();
3537
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();
3638
return inferShapeForUnaryOps(
3739
op, inputType.getElementType(), inputType.getEncoding());
@@ -41,6 +43,8 @@ LogicalResult inferShapeForUnaryOps(Operation *op) {
4143
/// type.
4244
LogicalResult inferShapeForUnaryOps(Operation *op, Type elementType) {
4345
Value input = op->getOperand(0);
46+
if (!hasShapeAndRank(input))
47+
return success();
4448
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();
4549
return inferShapeForUnaryOps(op, elementType, inputType.getEncoding());
4650
}

test/mlir/onnx/onnx_lowering_category_mapper.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ func.func private @test_category_mapper_string_to_int64(%arg0 : tensor<2x2x!onnx
2020
// CHECK-DAG: [[LOOP_0:%.+]]:2 = krnl.define_loops 2
2121
// CHECK: krnl.iterate([[LOOP_0]]#0, [[LOOP_0]]#1) with ([[LOOP_0]]#0 -> [[I_0:%.+]] = 0 to 2, [[LOOP_0]]#1 -> [[I_1:%.+]] = 0 to 2){
2222
// CHECK: [[IVS:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0]]#0, [[LOOP_0]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
23-
// CHECK: [[REF:%.+]] = "krnl.getref"(%arg0, [[ZERO_i64]]) : (memref<2x2x!krnl.string>, i64) -> memref<2x!krnl.string>
24-
// CHECK: [[LOAD1:%.+]] = krnl.load [[REF]]{{.}}[[IVS]]#0, [[IVS]]#1{{.}} : memref<2x!krnl.string>
23+
// CHECK: [[REF:%.+]] = "krnl.getref"(%arg0, [[ZERO_i64]]) : (memref<2x2x!krnl.string>, i64) -> memref<2x2x!krnl.string>
24+
// CHECK: [[LOAD1:%.+]] = krnl.load [[REF]]{{.}}[[IVS]]#0, [[IVS]]#1{{.}} : memref<2x2x!krnl.string>
2525
// CHECK: [[INDEX:%.+]] = "krnl.find_index"([[LOAD1]], [[G]], [[V]], [[LEN]]) : (!krnl.string, memref<3xi32>, memref<3xi32>, i32) -> index
2626
// CHECK: [[LOAD2:%.+]] = krnl.load [[CAT_STRINGS]]{{.}}[[INDEX]]{{.}} : memref<3x!krnl.string>
2727
// CHECK: [[STRLEN:%.+]] = "krnl.strlen"([[LOAD2]]) : (!krnl.string) -> i64

0 commit comments

Comments
 (0)