Skip to content

Commit ec91b39

Browse files
Shape Inference now succeeds for unimplemented ops (#2200)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent d2f4797 commit ec91b39

File tree

6 files changed

+53
-17
lines changed

6 files changed

+53
-17
lines changed

src/Dialect/ONNX/ONNXDimAnalysis.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ bool exploreSameInputDims(const onnx_mlir::DimAnalysis::DimT &dim,
9696
// Get its shape interface.
9797
onnx_mlir::ONNXOpShapeHelper *shapeHelper =
9898
shape_op.getShapeHelper(op, {}, nullptr, nullptr);
99-
if (!shapeHelper)
99+
// If no shape helper, or unimplemented, just abort.
100+
if (!shapeHelper || !shapeHelper->isImplemented())
100101
return false;
101102

102103
// Compute shape.

src/Dialect/ONNX/ONNXOps.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,12 @@
2424
// Unsupported Operations
2525
//===---------------------------------------------------------------------===//
2626

27-
// Operations for which shape inference has not been implemented yet
28-
// If you add the implementation for one op, move it out of this section
29-
// Also please add test case in test/mlir/onnx/onnx_shape_inference.mlir
30-
// Followed by the implementation of lowering to Krnl and
31-
// Enable the corresponding node test in check-onnx-backend
32-
27+
// Operations for which shape inference has not been implemented.
3328
#define UNSUPPORTED_OPS(OP_TYPE) \
3429
/* shape inference interface method */ \
3530
mlir::LogicalResult mlir::OP_TYPE::inferShapes( \
3631
std::function<void(mlir::Region &)> doShapeInference) { \
37-
return emitOpError( \
38-
"op is not supported at this time. Please open an issue on " \
39-
"https://github.com/onnx/onnx-mlir and/or consider contributing " \
40-
"code. " \
41-
"Error encountered in shape inference."); \
32+
return mlir::success(); \
4233
}
4334

4435
#include "src/Dialect/ONNX/ONNXUnsupportedOps.hpp"

src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ struct ONNXUnimplementedOpShapeHelper : public ONNXOpShapeHelper {
7575
: ONNXOpShapeHelper(op, operands, ieBuilder, scope) {}
7676
virtual ~ONNXUnimplementedOpShapeHelper() {}
7777

78-
mlir::LogicalResult computeShape() final { return mlir::failure(); }
78+
bool isImplemented() override { return false; }
79+
mlir::LogicalResult computeShape() final { return mlir::success(); }
7980
};
8081

8182
// Classes for unsupported ops, including shape inference and shape helpers.

src/Interface/ShapeHelperOpInterface.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,18 @@ struct ONNXOpShapeHelper {
8585
IndexExprScope *scope); /* Install local scope if null. */
8686
virtual ~ONNXOpShapeHelper();
8787

88+
// Return true if implemented.
89+
virtual bool isImplemented() { return true; }
90+
8891
// Every leaf class is expected to create a computeShape with the following
8992
// signature. This method is responsible to compute at a minimum the output
9093
// dims.
94+
// Unimplemented operations return success, as these operations may be
95+
// transformed later in a sequence of operations with implemented shape
96+
// inference. To ensure an implementation, check the `isImplemented` function.
97+
// This is used, for example, in dynamic analysis, where unimplemented shape
98+
// inferences are simply ignored (and conservatively assume no knowledge about
99+
// that operation's transfer function).
91100
virtual mlir::LogicalResult computeShape() = 0;
92101

93102
// Compute shape and assert on failure.
@@ -105,8 +114,17 @@ struct ONNXOpShapeHelper {
105114
mlir::ArrayRef<mlir::Attribute> encodingList = {});
106115

107116
// Get output dims for the N-th output dimension as Index Expressions.
108-
// Scalar may have a DimsExpr that is empty.
109-
DimsExpr &getOutputDims(int n = 0) { return privateOutputsDims[n]; }
117+
// Scalar may have a DimsExpr that is empty. Requires an implementation.
118+
DimsExpr &getOutputDims(int n = 0) {
119+
if (!isImplemented()) {
120+
llvm::errs() << "Implementation of shape helper for op " << op->getName()
121+
<< "is not currently available; please open an issue on "
122+
<< "\"https://github.com/onnx/onnx-mlir/\" and/or consider "
123+
<< "contributing code if this op is required.\n";
124+
llvm_unreachable("missing implementation for shape inference");
125+
}
126+
return privateOutputsDims[n];
127+
}
110128
// Set output dims, merging the dims associated with the current type with
111129
// inferred dims provided here, as appropriate.
112130
void setOutputDims(

src/Interface/ShapeHelperOpInterface.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ def ShapeHelperOpInterface : OpInterface<"ShapeHelperOpInterface"> {
3030

3131
For operations that do not support shape helpers at this stage, a
3232
`ONNXUnimplementedOpShapeHelper` object is returned. This object does not
33-
compute shapes, and simply return failure when `computeShape` is called
34-
on it.
33+
compute shapes, and simply return success when `computeShape` is called
34+
on it. Users may verify if an operation has an actual implementation by
35+
calling `isImplemented()` on the shape helper object. An implementation
36+
is required when attempting to read the outputs of a shape helper object
37+
via the `getOutputDims` method.
3538

3639
The new object is allocated on the heap and it is the responsability
3740
of the object user to free the memory after last use.

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3547,3 +3547,25 @@ module {
35473547
// CHECK: }
35483548
}
35493549
}
3550+
3551+
// -----
3552+
3553+
// Check that ClipV6 operation shape inference goes through shape inference smoothly.
3554+
// ClipV6 has no shape inference as it is supposed to be first updated to the latest ClipOp.
3555+
// Using the latest shape inference, the default is to let unimplemented ops go through shape
3556+
// inference without asserts/failures. Asserts only occurs when the results of the shape
3557+
// inference is used.
3558+
// The output shoudl be the same as the input, as no shape inference is expected to be performed.
3559+
3560+
func.func @test_clipv6(%arg0: tensor<*xf32>) {
3561+
%0 = "onnx.ClipV6"(%arg0) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
3562+
return
3563+
3564+
// mlir2FileCheck.py
3565+
// CHECK-LABEL: func.func @test_clipv6
3566+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) {
3567+
// CHECK: [[VAR_0_:%.+]] = "onnx.ClipV6"([[PARAM_0_]]) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
3568+
// CHECK: return
3569+
// CHECK: }
3570+
}
3571+

0 commit comments

Comments
 (0)