Skip to content

Commit 35a61d3

Browse files
cjvolzkatungld
andauthored
Add a flag to turn on/off the lowering of scalar broadcasting binary ops to NNPA (#2778) (#2782)
* Add a flag to turn on/off scalar broadcasting binary op in NNPA Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]> (cherry picked from commit 08d4fed) Co-authored-by: Tung D. Le <[email protected]>
1 parent 41e755a commit 35a61d3

File tree

13 files changed

+77
-60
lines changed

13 files changed

+77
-60
lines changed

src/Accelerators/NNPA/Compiler/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
2-
31
add_onnx_mlir_library(OMNNPACompilerOptions
42
NNPACompilerOptions.cpp
53

@@ -12,7 +10,6 @@ add_onnx_mlir_library(OMNNPACompilerOptions
1210
${NNPA_ONNX_MLIR_BIN_ROOT}
1311

1412
LINK_LIBS PUBLIC
15-
${OMLibs}
1613
OMCompilerOptions
1714

1815
ACCEL_INCLUDE_DIRS PRIVATE
@@ -32,7 +29,6 @@ add_onnx_mlir_library(OMNNPACompilerUtils
3229
${NNPA_ONNX_MLIR_BIN_ROOT}
3330

3431
LINK_LIBS PUBLIC
35-
${OMLibs}
3632
OMNNPACompilerOptions
3733
OMCompilerPasses
3834

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick(
5555
"stick/unstick code. Default is false."),
5656
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));
5757

58+
llvm::cl::opt<bool> nnpaEnableScalarBcastBinary(
59+
"nnpa-enable-scalar-bcast-binary",
60+
llvm::cl::desc("Enable the lowering to NNPA the broadcasting binary ops "
61+
"whose one of the operands is scalar. Currently support "
62+
"ONNXDiv only. Default is false."),
63+
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
64+
5865
llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile{
5966
"nnpa-load-device-placement-file",
6067
llvm::cl::desc(

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ typedef enum {
4949
} NNPAPlacementHeuristic;
5050

5151
extern llvm::cl::OptionCategory OnnxMlirOptions;
52+
extern llvm::cl::OptionCategory OnnxMlirCommonOptions;
5253
extern llvm::cl::opt<onnx_mlir::NNPAEmissionTargetType> nnpaEmissionTarget;
5354
extern llvm::cl::opt<bool> nnpaClipToDLFloatRange;
5455
extern llvm::cl::opt<bool> nnpaEnableZHighToOnnx;
5556
extern llvm::cl::opt<bool> nnpaEnableZHighDecomposeStickUnstick;
5657
extern llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick;
58+
extern llvm::cl::opt<bool> nnpaEnableScalarBcastBinary;
5759
extern llvm::cl::opt<NNPAPlacementHeuristic> nnpaPlacementHeuristic;
5860
extern llvm::cl::opt<bool> profileZHighIR;
5961
extern llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile;

src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ add_onnx_mlir_library(OMONNXToZHigh
1111
libzdnn
1212

1313
LINK_LIBS PUBLIC
14-
OMCompilerOptions
14+
OMNNPACompilerOptions
1515
OMONNXOps
1616
OMONNXToKrnl
1717
OMZHighOps
@@ -32,7 +32,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh
3232
libzdnn
3333

3434
LINK_LIBS PUBLIC
35-
OMCompilerOptions
35+
OMNNPACompilerOptions
3636
OMONNXOps
3737
OMONNXToKrnl
3838
OMZHighOps

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,19 @@ bool isSuitableForZDNN<ONNXDivOp>(
324324
// Check NNPA level.
325325
if (!isCompatibleWithNNPALevel(NNPA_Z16))
326326
return false;
327-
if (!isF32ScalarConstantTensor(A) && !isValidElementTypeAndRank(A))
327+
// Broadcast with a scalar operand.
328+
if (isEnableScalarBcastBinary()) {
329+
if (isF32ScalarConstantTensor(A) && isValidElementTypeAndRank(B))
330+
return true;
331+
if (isF32ScalarConstantTensor(B) && isValidElementTypeAndRank(A))
332+
return true;
333+
}
334+
// Non-broadcast cases.
335+
if (!isValidElementTypeAndRank(A))
328336
return false;
329-
if (!isF32ScalarConstantTensor(B) && !isValidElementTypeAndRank(B))
337+
if (!isValidElementTypeAndRank(B))
330338
return false;
331-
return isF32ScalarConstantTensor(A) || isF32ScalarConstantTensor(B) ||
332-
dimAnalysis->sameShape(A, B);
339+
return dimAnalysis->sameShape(A, B);
333340
}
334341

335342
/// Check legality for ONNXSum.

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td"
2929
/// dag benefitsAdded = (addBenefit 0)
3030
/// >;
3131

32+
def IsEnableScalarBcastBinary: Constraint<CPred<"isEnableScalarBcastBinary()">>;
33+
3234
def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;
3335

3436
def IsNotNoneType : Constraint<CPred<"(!($_self).getType().isa<NoneType>())">>;
@@ -227,7 +229,7 @@ def replaceONNXDivBroadcastPattern1 : Pat<
227229
(GetScalarF32AttrFromConstant $y),
228230
(NoneLayoutAttr)),
229231
(returnType $s_x))),
230-
[(IsF32ScalarConstantTensor $y)], [],
232+
[(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $y)], [],
231233
(addBenefit 1)
232234
>;
233235

@@ -241,7 +243,7 @@ def replaceONNXDivBroadcastPattern2 : Pat<
241243
(NoneLayoutAttr)),
242244
(ZHighStickOp:$s_y $y, (NoneLayoutAttr)),
243245
(returnType $s_y))),
244-
[(IsF32ScalarConstantTensor $x)], [],
246+
[(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $x)], [],
245247
(addBenefit 1)
246248
>;
247249

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
//===----------------------------------------------------------------------===//
1515

1616
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
17+
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
1718
#include "src/Dialect/ONNX/DialectBuilder.hpp"
1819

1920
using namespace mlir;
2021
namespace onnx_mlir {
2122

23+
bool isEnableScalarBcastBinary() { return nnpaEnableScalarBcastBinary; }
24+
2225
/// Get transposed tensor by using a permutation array.
2326
Value emitONNXTranspose(
2427
Location loc, PatternRewriter &rewriter, Value x, ArrayRef<int64_t> perms) {

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ const std::string DEVICE_ATTRIBUTE = "device";
2727
const std::string CPU_DEVICE = "cpu";
2828
const std::string NNPA_DEVICE = "nnpa";
2929

30+
bool isEnableScalarBcastBinary();
31+
3032
template <typename OP_TYPE>
3133
void addDynamicallyLegalOpFor(mlir::ConversionTarget *target,
3234
const onnx_mlir::DimAnalysis *dimAnalysis,

src/Accelerators/NNPA/Pass/NNPAPasses.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ std::unique_ptr<mlir::Pass> createDevicePlacementPass(
2929

3030
/// Add pass for lowering ONNX ops to ZHigh ops.
3131
std::unique_ptr<mlir::Pass> createONNXToZHighPass();
32-
std::unique_ptr<mlir::Pass> createONNXToZHighPass();
3332

3433
/// Add pass for rewriting ONNX ops for ZHigh.
3534
std::unique_ptr<mlir::Pass> createRewriteONNXForZHighPass();
36-
std::unique_ptr<mlir::Pass> createRewriteONNXForZHighPass();
3735

3836
/// Add pass for re-construct ONNX ops from ZHigh ops.
3937
std::unique_ptr<mlir::Pass> createZHighToONNXPass();

src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class KrnlGetLinearOffsetIndexLowering : public ConversionPattern {
5353

5454
auto memrefTy = llvm::dyn_cast<MemRefType>(memref.getType());
5555
int64_t rank = memrefTy.getRank();
56-
assert(mapResults.value().size() == rank && "Invalid indices");
56+
assert((int64_t)mapResults.value().size() == rank && "Invalid indices");
5757

5858
// Only lower this op after the memref is normalized.
5959
if (!memrefTy.getLayout().isIdentity())

0 commit comments

Comments
 (0)