Skip to content

Commit a70c43a

Browse files
authored
Add support for split (#2201)
* Add support for split Signed-off-by: philass <[email protected]> * Add lit tests Signed-off-by: philass <[email protected]> --------- Signed-off-by: philass <[email protected]>
1 parent ec91b39 commit a70c43a

File tree

8 files changed

+127
-10
lines changed

8 files changed

+127
-10
lines changed

src/Builder/OpBuildTable.inc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ op_dialect_version_map_["SoftmaxCrossEntropyLoss"] = {13};
188188
op_dialect_version_map_["Softplus"] = {1};
189189
op_dialect_version_map_["Softsign"] = {1};
190190
op_dialect_version_map_["SpaceToDepth"] = {13};
191-
op_dialect_version_map_["Split"] = {13, 11};
191+
op_dialect_version_map_["Split"] = {18, 13, 11};
192192
op_dialect_version_map_["SplitToSequence"] = {11};
193193
op_dialect_version_map_["Sqrt"] = {13};
194194
op_dialect_version_map_["Squeeze"] = {13, 11};
@@ -586,6 +586,8 @@ import_handler_map_["SpaceToDepth"] =
586586
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSpaceToDepthOp>;
587587
import_handler_map_["Split"] =
588588
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitOp>;
589+
import_handler_map_["SplitV13"] =
590+
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitV13Op>;
589591
import_handler_map_["SplitV11"] =
590592
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitV11Op>;
591593
import_handler_map_["SplitToSequence"] =

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ ValueRange OnnxBuilder::split(
255255
IntegerAttr axisAttr =
256256
IntegerAttr::get(b().getIntegerType(64, /*isSigned=*/true),
257257
APInt(64, axis, /*isSigned=*/true));
258-
return createOpAndInferShapes<ONNXSplitOp>(
259-
toTensors(outputTypes), toTensor(input), toTensor(split), axisAttr)
258+
return createOpAndInferShapes<ONNXSplitOp>(toTensors(outputTypes),
259+
toTensor(input), toTensor(split), axisAttr, IntegerAttr())
260260
.getResults();
261261
}
262262

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8556,6 +8556,53 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
85568556
}
85578557

85588558
def ONNXSplitOp:ONNX_Op<"Split",
8559+
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
8560+
let summary = "ONNX Split operation";
8561+
let description = [{
8562+
Split a tensor into a list of tensors, along the specified 'axis'.
8563+
Either input 'split' or the attribute 'num_outputs' should be specified, but not both.
8564+
If the attribute 'num_outputs' is specified, then the tensor is split into equal sized parts.
8565+
If the tensor is not evenly splittable into `num_outputs`, the last chunk will be smaller.
8566+
If the input 'split' is specified, it indicates the sizes of each output in the split.
8567+
}];
8568+
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$input,
8569+
AnyTypeOf<[TensorOf<[I64]>, NoneType]>:$split,
8570+
DefaultValuedAttr<SI64Attr, "0">:$axis,
8571+
OptionalAttr<SI64Attr>:$num_outputs);
8572+
let results = (outs Variadic<AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>>:$outputs);
8573+
let builders = [
8574+
OpBuilder<(ins "Value":$input, "Value":$split, "IntegerAttr":$axis, "IntegerAttr":$num_outputs), [{
8575+
auto resultType = UnrankedTensorType::get(input.getType().cast<ShapedType>().getElementType());
8576+
build($_builder, $_state, resultType, input, split, axis, num_outputs);
8577+
}]>,
8578+
OpBuilder<(ins "ValueRange":$operands, "ArrayRef<NamedAttribute>":$attributes), [{
8579+
auto resultType = UnrankedTensorType::get(operands[0].getType().cast<ShapedType>().getElementType());
8580+
build($_builder, $_state, {resultType}, operands, attributes);
8581+
}]>
8582+
];
8583+
let extraClassDeclaration = [{
8584+
static int getNumberOfOperands() {
8585+
return 2;
8586+
}
8587+
static int getNumberOfResults() {
8588+
return -1;
8589+
}
8590+
static std::vector<int> getTypeMap() {
8591+
return {20};
8592+
}
8593+
}];
8594+
let extraClassDefinition = [{
8595+
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
8596+
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
8597+
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXSplitOpShapeHelper(op, oper, ieb, scope);
8598+
assert(sh && "failed to allocate shape helper");
8599+
return sh;
8600+
}
8601+
}];
8602+
let hasVerifier = 1;
8603+
}
8604+
8605+
def ONNXSplitV13Op:ONNX_Op<"SplitV13",
85598606
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
85608607
let summary = "ONNX Split operation";
85618608
let description = [{
@@ -8591,12 +8638,11 @@ def ONNXSplitOp:ONNX_Op<"Split",
85918638
let extraClassDefinition = [{
85928639
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
85938640
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
8594-
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXSplitOpShapeHelper(op, oper, ieb, scope);
8641+
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXSplitV13OpShapeHelper(op, oper, ieb, scope);
85958642
assert(sh && "failed to allocate shape helper");
85968643
return sh;
85978644
}
85988645
}];
8599-
let hasVerifier = 1;
86008646
}
86018647

86028648
def ONNXSplitV11Op:ONNX_Op<"SplitV11",

src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ struct ONNXCommonSplitOpShapeHelper : public ONNXOpShapeHelper {
598598
// clang-format off
599599
using ONNXSplitOpShapeHelper = ONNXCommonSplitOpShapeHelper<mlir::ONNXSplitOp>;
600600
using ONNXSplitV11OpShapeHelper = ONNXCommonSplitOpShapeHelper<mlir::ONNXSplitV11Op>;
601+
using ONNXSplitV13OpShapeHelper = ONNXCommonSplitOpShapeHelper<mlir::ONNXSplitV13Op>;
601602
// clang-format on
602603

603604
//===----------------------------------------------------------------------===//

src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,25 @@ LogicalResult ONNXCommonSplitOpShapeHelper<OP_TYPE>::customComputeShape(
5959
} else {
6060
// If split parameter is not specified, the dimension is split to
6161
// equal-sized parts.
62+
bool hasNumOutputsAttr = std::is_same_v<OP_TYPE, ONNXSplitOp>;
63+
// TODO figure out how to handle when numResults is determined by
64+
// num_outputs attribute introduced for Split in opset 18
65+
// Currently the whole graph depends on the number of outputs being
66+
// determined at the ONNX to ONNX-MLIR ingestion stage.
6267
IndexExpr splitInputDim = createIE->getShapeAsDim(input, axisIndex);
6368
LiteralIndexExpr numOfPartitions(numOfResults);
6469
if (splitInputDim.isLiteral() &&
65-
(splitInputDim.getLiteral() % numOfResults != 0))
70+
(splitInputDim.getLiteral() % numOfResults != 0) && !hasNumOutputsAttr)
6671
return op->emitError("The dimension at the split axis is "
6772
"expected to be divisible by the number of results");
73+
74+
unsigned numBiggerChunks = splitInputDim.isLiteral()
75+
? splitInputDim.getLiteral() % numOfResults
76+
: numOfResults;
6877
for (unsigned int i = 0; i < numOfResults; ++i) {
69-
IndexExpr splitDim = splitInputDim.ceilDiv(numOfPartitions);
78+
IndexExpr splitDim = (i < numBiggerChunks)
79+
? splitInputDim.ceilDiv(numOfPartitions)
80+
: splitInputDim.floorDiv(numOfPartitions);
7081
splitDims.emplace_back(splitDim);
7182
}
7283
}
@@ -103,6 +114,22 @@ LogicalResult ONNXSplitOpShapeHelper::computeShape() {
103114
return customComputeShape(indexExprArray);
104115
}
105116

117+
// Code for SplitV13Op compute shape.
118+
template <>
119+
LogicalResult ONNXSplitV13OpShapeHelper::computeShape() {
120+
ONNXSplitOpAdaptor operandAdaptor(operands, op->getAttrDictionary());
121+
Value split = operandAdaptor.getSplit();
122+
SmallVector<IndexExpr, 4> indexExprArray;
123+
if (isNoneValue(split)) {
124+
// None is fine, indexExprArray will be empty.
125+
} else {
126+
createIE->getIntFromArrayAsSymbols(split, indexExprArray);
127+
assert(IndexExpr::isLiteral(indexExprArray) &&
128+
"dynamic split not yet supported");
129+
}
130+
return customComputeShape(indexExprArray);
131+
}
132+
106133
// Code for SplitV11Op compute shape.
107134
template <>
108135
LogicalResult ONNXSplitV11OpShapeHelper::computeShape() {
@@ -157,6 +184,19 @@ LogicalResult ONNXSplitOp::inferShapes(
157184
return shapeHelper.computeShapeAndUpdateType(elementType);
158185
}
159186

187+
LogicalResult ONNXSplitV13Op::inferShapes(
188+
std::function<void(Region &)> doShapeInference) {
189+
// Cannot infer the output shape if the input shape isn't known yet.
190+
if (!hasShapeAndRank(getInput()))
191+
return success();
192+
193+
auto inputType = getInput().getType().cast<ShapedType>();
194+
Type elementType = inputType.getElementType();
195+
ONNXSplitV13OpShapeHelper shapeHelper(getOperation(), {});
196+
// Same time for all results.
197+
return shapeHelper.computeShapeAndUpdateType(elementType);
198+
}
199+
160200
LogicalResult ONNXSplitV11Op::inferShapes(
161201
std::function<void(Region &)> doShapeInference) {
162202
// Cannot infer the output shape if the input shape isn't known yet.
@@ -180,5 +220,6 @@ LogicalResult ONNXSplitV11Op::inferShapes(
180220

181221
namespace onnx_mlir {
182222
template struct ONNXCommonSplitOpShapeHelper<ONNXSplitOp>;
223+
template struct ONNXCommonSplitOpShapeHelper<ONNXSplitV13Op>;
183224
template struct ONNXCommonSplitOpShapeHelper<ONNXSplitV11Op>;
184225
} // namespace onnx_mlir

src/Transform/ONNX/Decompose.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,13 @@ def ClipV12Pattern : Pat<
323323

324324
def SplitV11PatternNoAttr : Pat<
325325
(ONNXSplitV11Op $x, $axis, $split),
326-
(ONNXSplitOp $x, (CreateNoneValue), $axis),
326+
(ONNXSplitOp $x, (CreateNoneValue), $axis, (GetNullIntegerAttr)),
327327
[(AttributeIsNull:$split)], (addBenefit 1)
328328
>;
329329

330330
def SplitV11Pattern : Pat<
331331
(ONNXSplitV11Op $x, $axis, $split),
332-
(ONNXSplitOp $x, (ONNXConstantOpFromDenseAttr(createDenseArrayAttr $split)), $axis),
332+
(ONNXSplitOp $x, (ONNXConstantOpFromDenseAttr(createDenseArrayAttr $split)), $axis, (GetNullIntegerAttr)),
333333
[], (addBenefit 0)
334334
>;
335335

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,32 @@ func.func @test_split_5(%arg0 : tensor<16x?x64xf32>) -> tensor<*xf32> {
13631363

13641364
// -----
13651365

1366+
func.func @test_split_6(%arg0 : tensor<16x39x64xf32>) -> tensor<*xf32> {
1367+
%cst = "onnx.NoValue"() {value} : () -> none
1368+
%0, %1 = "onnx.Split"(%arg0, %cst) { axis = 1 : si64, num_outputs = 2 : si64} : (tensor<16x39x64xf32>, none) -> (tensor<*xf32>, tensor<*xf32>)
1369+
"func.return"(%0) : (tensor<*xf32>) -> ()
1370+
1371+
// CHECK-LABEL: test_split_6
1372+
// CHECK: [[CST:%.+]] = "onnx.NoValue"() {value} : () -> none
1373+
// CHECK-NEXT: [[RES:%.+]]:2 = "onnx.Split"(%arg0, [[CST]]) {axis = 1 : si64, num_outputs = 2 : si64} : (tensor<16x39x64xf32>, none) -> (tensor<16x20x64xf32>, tensor<16x19x64xf32>)
1374+
// CHECK: return [[RES]]#0 : tensor<16x20x64xf32>
1375+
}
1376+
1377+
// -----
1378+
1379+
func.func @test_split_7(%arg0 : tensor<16x38x64xf32>) -> tensor<*xf32> {
1380+
%cst = "onnx.NoValue"() {value} : () -> none
1381+
%0, %1, %2 = "onnx.Split"(%arg0, %cst) { axis = 1 : si64, num_outputs = 3 : si64} : (tensor<16x38x64xf32>, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
1382+
"func.return"(%0) : (tensor<*xf32>) -> ()
1383+
1384+
// CHECK-LABEL: test_split_7
1385+
// CHECK: [[CST:%.+]] = "onnx.NoValue"() {value} : () -> none
1386+
// CHECK-NEXT: [[RES:%.+]]:3 = "onnx.Split"(%arg0, [[CST]]) {axis = 1 : si64, num_outputs = 3 : si64} : (tensor<16x38x64xf32>, none) -> (tensor<16x13x64xf32>, tensor<16x13x64xf32>, tensor<16x12x64xf32>)
1387+
// CHECK: return [[RES]]#0 : tensor<16x13x64xf32>
1388+
}
1389+
1390+
// -----
1391+
13661392
func.func @test_splitv11_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
13671393
%0, %1 = "onnx.SplitV11"(%arg0) { axis = 1 : si64} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
13681394
"func.return"(%0) : (tensor<*xf32>) -> ()

utils/gen_onnx_mlir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@
256256
'Softplus': [1],
257257
'Softsign': [1],
258258
'SpaceToDepth': [13],
259-
'Split': [13, 11],
259+
'Split': [18, 13, 11],
260260
'SplitToSequence': [11],
261261
'Sqrt': [13],
262262
'Squeeze': [13, 11],
@@ -500,6 +500,7 @@
500500
'ReduceSumV11',
501501
'Softmax',
502502
'Split',
503+
'SplitV13',
503504
'Sqrt',
504505
'Squeeze',
505506
'SqueezeV11',

0 commit comments

Comments
 (0)