@@ -8556,6 +8556,53 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
85568556}
85578557
85588558def ONNXSplitOp:ONNX_Op<"Split",
8559+ [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
8560+ let summary = "ONNX Split operation";
8561+ let description = [{
8562+ Split a tensor into a list of tensors, along the specified 'axis'.
8563+ Either input 'split' or the attribute 'num_outputs' should be specified, but not both.
8564+ If the attribute 'num_outputs' is specified, then the tensor is split into equal sized parts.
8565+ If the tensor is not evenly splittable into `num_outputs`, the last chunk will be smaller.
8566+ If the input 'split' is specified, it indicates the sizes of each output in the split.
8567+ }];
8568+ let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$input,
8569+ AnyTypeOf<[TensorOf<[I64]>, NoneType]>:$split,
8570+ DefaultValuedAttr<SI64Attr, "0">:$axis,
8571+ OptionalAttr<SI64Attr>:$num_outputs);
8572+ let results = (outs Variadic<AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>>:$outputs);
8573+ let builders = [
8574+ OpBuilder<(ins "Value":$input, "Value":$split, "IntegerAttr":$axis, "IntegerAttr":$num_outputs), [{
8575+ auto resultType = UnrankedTensorType::get(input.getType().cast<ShapedType>().getElementType());
8576+ build($_builder, $_state, resultType, input, split, axis, num_outputs);
8577+ }]>,
8578+ OpBuilder<(ins "ValueRange":$operands, "ArrayRef<NamedAttribute>":$attributes), [{
8579+ auto resultType = UnrankedTensorType::get(operands[0].getType().cast<ShapedType>().getElementType());
8580+ build($_builder, $_state, {resultType}, operands, attributes);
8581+ }]>
8582+ ];
8583+ let extraClassDeclaration = [{
8584+ static int getNumberOfOperands() {
8585+ return 2;
8586+ }
8587+ static int getNumberOfResults() {
8588+ return -1;
8589+ }
8590+ static std::vector<int> getTypeMap() {
8591+ return {20};
8592+ }
8593+ }];
8594+ let extraClassDefinition = [{
8595+ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
8596+ onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
8597+ onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXSplitOpShapeHelper(op, oper, ieb, scope);
8598+ assert(sh && "failed to allocate shape helper");
8599+ return sh;
8600+ }
8601+ }];
8602+ let hasVerifier = 1;
8603+ }
8604+
8605+ def ONNXSplitV13Op:ONNX_Op<"SplitV13",
85598606 [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
85608607 let summary = "ONNX Split operation";
85618608 let description = [{
@@ -8591,12 +8638,11 @@ def ONNXSplitOp:ONNX_Op<"Split",
85918638 let extraClassDefinition = [{
85928639 onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
85938640 onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
8594- onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXSplitOpShapeHelper (op, oper, ieb, scope);
8641+ onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXSplitV13OpShapeHelper (op, oper, ieb, scope);
85958642 assert(sh && "failed to allocate shape helper");
85968643 return sh;
85978644 }
85988645 }];
8599- let hasVerifier = 1;
86008646}
86018647
86028648def ONNXSplitV11Op:ONNX_Op<"SplitV11",
0 commit comments