Skip to content

Commit b238664

Browse files
authored
Add bitwise op support (#2043)
* Add e2e support for bitwise ops Signed-off-by: philass <[email protected]> * Fix docs Signed-off-by: philass <[email protected]> * Add lit tests Signed-off-by: philass <[email protected]> --------- Signed-off-by: philass <[email protected]>
1 parent a70c43a commit b238664

File tree

4 files changed

+91
-3
lines changed

4 files changed

+91
-3
lines changed

docs/SupportedONNXOps-cpu.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 18. Limitatio
3030
| **Bernoulli** | |unsupported | |
3131
| **Binarizer** | |unsupported | |
3232
| **BitShift** | |unsupported | |
33-
| **BitwiseAnd** | |unsupported | |
33+
| **BitwiseAnd** |18 | | |
3434
| **BitwiseNot** | |unsupported | |
35-
| **BitwiseOr** | |unsupported | |
36-
| **BitwiseXor** | |unsupported | |
35+
| **BitwiseOr** |18 | | |
36+
| **BitwiseXor** |18 | | |
3737
| **BlackmanWindow** | |unsupported | |
3838
| **Cast** |13 |Cast only between float and double types. | |
3939
| **CastLike** | |unsupported | |

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ struct ScalarOp<ONNXXorOp> {
160160
using IOp = arith::XOrIOp;
161161
};
162162

163+
template <>
164+
struct ScalarOp<ONNXBitwiseAndOp> {
165+
using FOp = arith::AndIOp; // Not used.
166+
using IOp = arith::AndIOp;
167+
};
168+
169+
template <>
170+
struct ScalarOp<ONNXBitwiseOrOp> {
171+
using FOp = arith::OrIOp; // Not used.
172+
using IOp = arith::OrIOp;
173+
};
174+
175+
template <>
176+
struct ScalarOp<ONNXBitwiseXorOp> {
177+
using FOp = arith::XOrIOp; // Not used.
178+
using IOp = arith::XOrIOp;
179+
};
180+
163181
template <>
164182
struct ScalarOp<ONNXExpOp> {
165183
using FOp = math::ExpOp;
@@ -2294,6 +2312,9 @@ void populateLoweringONNXElementwiseOpPattern(RewritePatternSet &patterns,
22942312
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
22952313
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
22962314
ONNXElementwiseUnaryOpLowering<mlir::ONNXAtanOp>,
2315+
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseAndOp>,
2316+
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseOrOp>,
2317+
ONNXElementwiseBinaryOpLowering<mlir::ONNXBitwiseXorOp>,
22972318
ONNXElementwiseUnaryOpLowering<mlir::ONNXCastOp>,
22982319
ONNXElementwiseUnaryOpLowering<mlir::ONNXCeilOp>,
22992320
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,

test/backend/inference_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ def get_test_models():
164164

165165
# Bitshift
166166

167+
# ==OP== BitwiseAnd
168+
"test_bitwise_and_i32_2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
169+
"test_bitwise_and_i16_3d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
170+
171+
# ==OP== BitwiseOr
172+
"test_bitwise_or_i32_2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
173+
"test_bitwise_or_i16_4d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
174+
175+
# ==OP== BitwiseXor
176+
"test_bitwise_xor_i32_2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
177+
"test_bitwise_xor_i16_3d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
178+
179+
167180
# ==OP== Cast
168181
# ==LIM== Cast only between float and double types
169182
"test_cast_FLOAT_to_DOUBLE_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},

test/mlir/onnx/onnx_lowering.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,60 @@ func.func private @test_xor(%arg0 : tensor<10x10xi1>, %arg1 : tensor<10x10xi1>)
187187

188188
// -----
189189

190+
func.func private @test_bitwise_and(%arg0 : tensor<10x10xi8>, %arg1 : tensor<10x10xi8>) -> tensor<*xi8> {
191+
%0 = "onnx.BitwiseAnd"(%arg0, %arg1) : (tensor<10x10xi8>, tensor<10x10xi8>) -> tensor<*xi8>
192+
"func.return"(%0) : (tensor<*xi8>) -> ()
193+
194+
// CHECK-LABEL: test_bitwise_and
195+
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<10x10xi8>
196+
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
197+
// CHECK: krnl.iterate([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10){
198+
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
199+
// CHECK: [[LOAD1:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xi8>
200+
// CHECK: [[LOAD2:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1] : memref<10x10xi8>
201+
// CHECK: [[AND:%.+]] = arith.andi [[LOAD1]], [[LOAD2]] : i8
202+
// CHECK: krnl.store [[AND]], [[RES]][[[IV]]#0, [[IV]]#1] : memref<10x10xi8>
203+
// CHECK: return [[RES]] : memref<10x10xi8>
204+
}
205+
206+
// -----
207+
208+
func.func private @test_bitwise_or(%arg0 : tensor<10x10xi16>, %arg1 : tensor<10x10xi16>) -> tensor<*xi16> {
209+
%0 = "onnx.BitwiseOr"(%arg0, %arg1) : (tensor<10x10xi16>, tensor<10x10xi16>) -> tensor<*xi16>
210+
"func.return"(%0) : (tensor<*xi16>) -> ()
211+
212+
// CHECK-LABEL: test_bitwise_or
213+
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<10x10xi16>
214+
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
215+
// CHECK: krnl.iterate([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10){
216+
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
217+
// CHECK: [[LOAD1:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xi16>
218+
// CHECK: [[LOAD2:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1] : memref<10x10xi16>
219+
// CHECK: [[OR:%.+]] = arith.ori [[LOAD1]], [[LOAD2]] : i16
220+
// CHECK: krnl.store [[OR]], [[RES]][[[IV]]#0, [[IV]]#1] : memref<10x10xi16>
221+
// CHECK: return [[RES]] : memref<10x10xi16>
222+
}
223+
224+
// -----
225+
226+
func.func private @test_bitwise_xor(%arg0 : tensor<10x10xi32>, %arg1 : tensor<10x10xi32>) -> tensor<*xi32> {
227+
%0 = "onnx.BitwiseXor"(%arg0, %arg1) : (tensor<10x10xi32>, tensor<10x10xi32>) -> tensor<*xi32>
228+
"func.return"(%0) : (tensor<*xi32>) -> ()
229+
230+
// CHECK-LABEL: test_bitwise_xor
231+
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<10x10xi32>
232+
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
233+
// CHECK: krnl.iterate([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10){
234+
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
235+
// CHECK: [[LOAD1:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xi32>
236+
// CHECK: [[LOAD2:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1] : memref<10x10xi32>
237+
// CHECK: [[XOR:%.+]] = arith.xori [[LOAD1]], [[LOAD2]] : i32
238+
// CHECK: krnl.store [[XOR]], [[RES]][[[IV]]#0, [[IV]]#1] : memref<10x10xi32>
239+
// CHECK: return [[RES]] : memref<10x10xi32>
240+
}
241+
242+
// -----
243+
190244
func.func private @test_exp(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
191245
%0 = "onnx.Exp"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
192246
"func.return"(%0) : (tensor<*xf32>) -> ()

0 commit comments

Comments
 (0)