Skip to content

Commit d2f4797

Browse files
authored
Use opaque pointers when lowering to LLVMIR (#2190)
* Use opaque pointers when lowering to LLVMIR Signed-off-by: Tung D. Le <[email protected]>
1 parent 38b16b0 commit d2f4797

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2281
-574
lines changed

docs/Testing.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ Again, these debug statements can then be activated by adding the `--debug-only=
377377

378378
## ONNX Model Zoo
379379

380-
We provide a Python script [RunONNXModelZoo.py](../utils/RunONNXModelZoo.py) to check inference accuracy with models in the [ONNX model zoo](https://github.com/onnx/models). [RunONNXModelZoo.py](../utils/RunONNXModelZoo.py) requires [RunONNXModel.py](../utils/RunONNXModel.py) to be in the same folder. For example,
380+
We provide a Python script [RunONNXModelZoo.py](../utils/RunONNXModelZoo.py) to check inference accuracy with models in the [ONNX model zoo](https://github.com/onnx/models). [RunONNXModelZoo.py](../utils/RunONNXModelZoo.py) requires [RunONNXModel.py](../utils/RunONNXModel.py) to be in the same folder. For example, to check inference accuracy with mnist-8:
381381

382382
```bash
383383
$ mkdir test && cd test
@@ -388,3 +388,5 @@ $ ONNX_MLIR_HOME=/onnx-mlir/build/Release/ python RunONNXModelZoo.py -m mnist-8
388388
Run the script with `-h` to see all the options. In addition to the `-m` flag to specify a model and `-c` flag to specify the compile options, useful options are the `-k` flag to leave the onnx model in the current directory as a `.tgz` file, and the `-l debug` flag to print lots of debugging info.
389389

390390
To find out which models are available, run the script with `-p` to print the list of available models; or `-m` followed by an incomplete name, and the script will suggest the exact names.
391+
392+
Without specifying a model using `-m`, the script will check all models in the ONNX model zoo.

src/Accelerators/NNPA/Conversion/ZLowToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_onnx_mlir_library(OMZLowToLLVM
77

88
LINK_LIBS PUBLIC
99
MLIRLLVMCommonConversion
10+
OMKrnlToLLVM
1011
OMLayoutHelper
1112
OMZLowOps
1213
OMMlirDialects

src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp

Lines changed: 85 additions & 115 deletions
Large diffs are not rendered by default.

src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2020

2121
#include "src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp"
22+
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
2223
#include "src/Dialect/Mlir/DialectBuilder.hpp"
2324
#include "zdnn.h"
2425

@@ -29,7 +30,7 @@ namespace zlow {
2930

3031
ApiRegistry RegisterAllApis(MLIRContext *context) {
3132
auto voidTy = LLVM::LLVMVoidType::get(context);
32-
auto opaquePtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
33+
auto opaquePtrTy = krnl::getI8PointerType(context);
3334
auto int32Ty = IntegerType::get(context, 32);
3435
auto int64Ty = IntegerType::get(context, 64);
3536

@@ -105,7 +106,8 @@ Value ZTensorHelper::getPreTransformedDescPtr(zdnn_data_types zDNNDataType,
105106
Value one = create.llvm.constant(llvmI64Ty, (int64_t)1);
106107

107108
Value preTransformedDescPtr = create.llvm._alloca(
108-
LLVM::LLVMPointerType::get(llvmZTensorDescStructTy), one,
109+
krnl::getPointerType(context, llvmZTensorDescStructTy),
110+
llvmZTensorDescStructTy, one,
109111
/*alignment=*/0);
110112

111113
// Prepare operands for calling the function that initializes the zTensor
@@ -145,7 +147,8 @@ Value ZTensorHelper::getTransformedDescPtr(
145147
Value one = create.llvm.constant(llvmI64Ty, (int64_t)1);
146148

147149
Value transformedDescPtr = create.llvm._alloca(
148-
LLVM::LLVMPointerType::get(llvmZTensorDescStructTy), one,
150+
krnl::getPointerType(context, llvmZTensorDescStructTy),
151+
llvmZTensorDescStructTy, one,
149152
/*alignment=*/0);
150153

151154
if (isConcat) {
@@ -165,10 +168,11 @@ Value ZTensorHelper::getTransformedDescPtr(
165168

166169
// Get the pointer to memref.
167170
Value ZTensorHelper::getAlignedI8Ptr(Value memRef) {
171+
MLIRContext *context = rewriter.getContext();
168172
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
169173
MemRefDescriptor descriptor(memRef);
170174
Value alignedPtr = descriptor.alignedPtr(rewriter, loc);
171-
return create.llvm.bitcastI8Ptr(alignedPtr);
175+
return create.llvm.bitcast(krnl::getI8PointerType(context), alignedPtr);
172176
}
173177

174178
// Get buffer size from a transformed descriptor.
@@ -202,7 +206,8 @@ ZTensor ZTensorHelper::getZTensor(Value bufferPtr, zdnn_data_types dataType,
202206
getTransformedDescPtr(preTransformedDescPtr, isConcat, concatInfo);
203207
// Create the input zTensor.
204208
Value alloc =
205-
create.llvm._alloca(LLVM::LLVMPointerType::get(llvmZTensorStructTy), one,
209+
create.llvm._alloca(krnl::getPointerType(context, llvmZTensorStructTy),
210+
llvmZTensorStructTy, one,
206211
/*alignment=*/0);
207212
// Buffer size.
208213
Value bufferSize = getBufferSize(transformedDescPtr);
@@ -235,7 +240,8 @@ ZTensor ZTensorHelper::getZTensor(Value preTransformedDescPtr,
235240
Type llvmZTensorStructTy = getZTensorStructTy(context);
236241
Value one = create.llvm.constant(rewriter.getI64Type(), (int64_t)1);
237242
Value alloc =
238-
create.llvm._alloca(LLVM::LLVMPointerType::get(llvmZTensorStructTy), one,
243+
create.llvm._alloca(krnl::getPointerType(context, llvmZTensorStructTy),
244+
llvmZTensorStructTy, one,
239245
/*alignment=*/0);
240246
// clang-format off
241247
fillInZTensor(rewriter, loc, module, alloc,
@@ -370,10 +376,10 @@ std::vector<Value> getDimsFromShapeMemRefBySize(PatternRewriter &rewriter,
370376
Value alignedPtr = inputMRD.alignedPtr(rewriter, loc);
371377
Type int64Ty = IntegerType::get(context, 64);
372378
for (int64_t i = 0; i < size; ++i) {
373-
Value index = create.llvm.constant(int64Ty, i);
374-
Value alignedGep = create.llvm.getElemPtr(
375-
LLVM::LLVMPointerType::get(int64Ty), alignedPtr, {index});
376-
Value dimI64 = create.llvm.load(alignedGep);
379+
Value alignedGep =
380+
create.llvm.getElemPtr(krnl::getPointerType(context, int64Ty), int64Ty,
381+
alignedPtr, ArrayRef<LLVM::GEPArg>{(int32_t)i});
382+
Value dimI64 = create.llvm.load(int64Ty, alignedGep);
377383
dims.emplace_back(dimI64);
378384
}
379385
return dims;
@@ -462,16 +468,16 @@ Type getZTensorStructTy(MLIRContext *context) {
462468
Type llvmI1Ty = IntegerType::get(context, 1);
463469
Type llvmI8Ty = IntegerType::get(context, 8);
464470
Type llvmArrayI8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 32);
465-
Type llvmI8PtrTy = LLVM::LLVMPointerType::get(llvmI8Ty);
471+
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
466472
Type llvmZTensorDescStructTy = getZTensorDescStructTy(context);
467473

468474
SmallVector<Type, 4> zTensorTypeElements;
469475
// A pointer to pre-transformed descriptor struct type
470476
zTensorTypeElements.emplace_back(
471-
LLVM::LLVMPointerType::get(llvmZTensorDescStructTy));
477+
krnl::getPointerType(context, llvmZTensorDescStructTy));
472478
// A pointer to transformed descriptor struct type
473479
zTensorTypeElements.emplace_back(
474-
LLVM::LLVMPointerType::get(llvmZTensorDescStructTy));
480+
krnl::getPointerType(context, llvmZTensorDescStructTy));
475481
// zTensor size in bytes
476482
zTensorTypeElements.emplace_back(llvmI64Ty);
477483
// pointer to the zTensor in memory
@@ -490,8 +496,9 @@ Type getZTensorStructTy(MLIRContext *context) {
490496
/// Function to cast an LLVM pointer to an opaque LLVM pointer.
491497
Value toOpaquePtr(
492498
PatternRewriter &rewriter, Location loc, ModuleOp module, Value ptr) {
499+
MLIRContext *context = rewriter.getContext();
493500
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
494-
return create.llvm.bitcastI8Ptr(ptr);
501+
return create.llvm.bitcast(krnl::getI8PointerType(context), ptr);
495502
}
496503

497504
void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
@@ -501,48 +508,35 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
501508
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
502509

503510
Type llvmI1Ty = IntegerType::get(context, 1);
504-
Type llvmI8Ty = IntegerType::get(context, 8);
505-
Type llvmI8PtrTy = LLVM::LLVMPointerType::get(llvmI8Ty);
506-
Type llvmI32Ty = IntegerType::get(context, 32);
507-
Type llvmI64Ty = IntegerType::get(context, 64);
508-
Type llvmZTensorDescTy =
509-
LLVM::LLVMPointerType::get(getZTensorDescStructTy(context));
510-
511-
// Got runtime error if using i64 as index to access zTensor. It looks
512-
// like an error in MLIR. So use i32 here, which does not affect the
513-
// correctness of the generated program.
514-
Value zero = create.llvm.constant(llvmI32Ty, (int64_t)0);
515-
Value one = create.llvm.constant(llvmI32Ty, (int64_t)1);
516-
Value two = create.llvm.constant(llvmI32Ty, (int64_t)2);
517-
Value three = create.llvm.constant(llvmI32Ty, (int64_t)3);
518-
Value four = create.llvm.constant(llvmI32Ty, (int64_t)4);
511+
Type llvmZTensorTy = getZTensorStructTy(context);
512+
Type llvmZTensorPtrTy = krnl::getPointerType(context, llvmZTensorTy);
519513

520514
// 1. Set pre-transformed descriptor.
521515
Value zTensorPreTransformedDescPtr = create.llvm.getElemPtr(
522-
LLVM::LLVMPointerType::get(llvmZTensorDescTy), zTensor, {zero, zero});
516+
llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{0, 0});
523517
create.llvm.store(preTransformedDescPtr, zTensorPreTransformedDescPtr);
524518

525519
// 2. Set transformed descriptor.
526520
Value zTensorTransformedDescPtr = create.llvm.getElemPtr(
527-
LLVM::LLVMPointerType::get(llvmZTensorDescTy), zTensor, {zero, one});
521+
llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{0, 1});
528522
create.llvm.store(transformedDescPtr, zTensorTransformedDescPtr);
529523

530524
// 3. Set buffer_size.
531525
Value bufferSizePtr = create.llvm.getElemPtr(
532-
LLVM::LLVMPointerType::get(llvmI64Ty), zTensor, {zero, two});
526+
llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{0, 2});
533527
create.llvm.store(bufferSize, bufferSizePtr);
534528

535529
// 4. Set buffer. Buffer was allocated in advance by the stickified memref.
536530
// So get the pointer from the stickified memref and set it to the zTensor.
537531
Value bufferPtr = create.llvm.getElemPtr(
538-
LLVM::LLVMPointerType::get(llvmI8PtrTy), zTensor, {zero, three});
532+
llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{0, 3});
539533
create.llvm.store(alignedBuffer, bufferPtr);
540534

541535
// 5. Set is_transformed.
542536
Value isTransformedVal =
543537
create.llvm.constant(llvmI1Ty, (int64_t)((isTransformed) ? 1 : 0));
544538
Value isTransformedDescPtr = create.llvm.getElemPtr(
545-
LLVM::LLVMPointerType::get(llvmI1Ty), zTensor, {zero, four});
539+
llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{0, 4});
546540
create.llvm.store(isTransformedVal, isTransformedDescPtr);
547541

548542
// 6. Set reserved (not currently used), not touch

src/Compiler/CompilerPasses.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ void addKrnlToLLVMPasses(
188188
pm.addNestedPass<func::FuncOp>(mlir::createConvertSCFToCFPass());
189189

190190
pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass());
191-
pm.addPass(krnl::createConvertKrnlToLLVMPass(verifyInputTensors));
191+
pm.addPass(krnl::createConvertKrnlToLLVMPass(
192+
verifyInputTensors, /*useOpaquePointers=*/true));
192193
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
193194
pm.addPass(mlir::createCanonicalizerPass());
194195
}

src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ using namespace mlir;
6363
namespace onnx_mlir {
6464
namespace krnl {
6565

66+
bool LLVM_USE_OPAQUE_POINTER = true;
67+
6668
uint64_t KRNL_ENTRY_POINT_ID = 0;
6769

6870
// Return true if the value owns the storge. A value defined by memref.alloc
@@ -275,9 +277,9 @@ void genSignatureFunction(ModuleOp &module,
275277
Type i8Type = IntegerType::get(context, 8);
276278
Type i32Type = IntegerType::get(context, 32);
277279
Type i64Type = IntegerType::get(context, 64);
278-
Type i64PtrTy = LLVM::LLVMPointerType::get(i64Type);
279-
Type i8PtrTy = LLVM::LLVMPointerType::get(i8Type);
280-
Type i8PtrPtrTy = LLVM::LLVMPointerType::get(i8PtrTy);
280+
Type i64PtrTy = getPointerType(context, i64Type);
281+
Type i8PtrTy = getPointerType(context, i8Type);
282+
Type i8PtrPtrTy = getPointerType(context, i8PtrTy);
281283

282284
uint64_t numOfEntryPoints = entryGlobalOps.size();
283285

@@ -300,16 +302,13 @@ void genSignatureFunction(ModuleOp &module,
300302
uint32_t index = 0;
301303
Value lastValue = array;
302304
for (const LLVM::GlobalOp &globalOp : entryGlobalOps) {
303-
Value address = create.llvm.addressOf(globalOp);
304-
Value zeroI64 = create.llvm.constant(i64Type, (int64_t)0);
305-
Value strAddr =
306-
create.llvm.getElemPtr(i8PtrTy, address, {zeroI64, zeroI64});
305+
Value strAddr = krnl::getPtrToGlobalString(globalOp, loc, b);
307306
lastValue =
308307
create.llvm.insertValue(arrayType, lastValue, strAddr, {index++});
309308
}
310309

311310
// The last element of the array is NULL.
312-
Value nullPtr = create.llvm.nullI8Ptr();
311+
Value nullPtr = create.llvm.null(getI8PointerType(context));
313312
lastValue =
314313
create.llvm.insertValue(arrayType, lastValue, nullPtr, {index++});
315314
create.llvm._return(lastValue);
@@ -339,16 +338,15 @@ void genSignatureFunction(ModuleOp &module,
339338
LLVM::ICmpPredicate::ne, numOfEntryPoints, nullPtr);
340339
}, /*then=*/
341340
[&](LLVMBuilder &createLLVM) {
342-
Value zero = createLLVM.constant(i64Type, (int64_t)0);
343-
Value numOfEntryPointsPtr =
344-
createLLVM.getElemPtr(i64PtrTy, numOfEntryPoints, {zero});
341+
Value numOfEntryPointsPtr = createLLVM.getElemPtr(
342+
i64PtrTy, i64Type, numOfEntryPoints, ArrayRef<LLVM::GEPArg>{0});
345343
Value noep =
346344
createLLVM.constant(i64Type, (int64_t)entryGlobalOps.size());
347345
createLLVM.store(noep, numOfEntryPointsPtr);
348346
});
349347
// Emit code to return the entry point array.
350348
Value entryAddr = create.llvm.addressOf(entryArrayOp);
351-
Value entryI8Ptr = create.llvm.bitcastI8PtrPtr(entryAddr);
349+
Value entryI8Ptr = create.llvm.bitcast(i8PtrPtrTy, entryAddr);
352350
create.llvm._return(entryI8Ptr);
353351
}
354352

@@ -388,10 +386,8 @@ void genSignatureFunction(ModuleOp &module,
388386
create.llvm.ifThenElse(/*cond=*/
389387
[&](LLVMBuilder &createLLVM) {
390388
// Read an entry point name.
391-
Value address = createLLVM.addressOf(globalEntryPoint);
392-
Value zeroI64 = createLLVM.constant(i64Type, (int64_t)0);
393389
Value entryI8Ptr =
394-
createLLVM.getElemPtr(i8PtrTy, address, {zeroI64, zeroI64});
390+
krnl::getPtrToGlobalString(globalEntryPoint, loc, b);
395391
// Compare it with the user's entry point name.
396392
FlatSymbolRefAttr StrncmpRef = krnl::getOrInsertStrncmp(b, module);
397393
Value length = createLLVM.constant(
@@ -404,13 +400,13 @@ void genSignatureFunction(ModuleOp &module,
404400
}, /*then=*/
405401
[&](LLVMBuilder &createLLVM) {
406402
Value sigAddr = createLLVM.addressOf(globalSignature);
407-
Value sigI8Ptr = createLLVM.bitcastI8Ptr(sigAddr);
403+
Value sigI8Ptr = createLLVM.bitcast(i8PtrTy, sigAddr);
408404
createLLVM._return(sigI8Ptr);
409405
});
410406
}
411407

412408
// Return NULL if not found.
413-
create.llvm._return(create.llvm.nullI8Ptr());
409+
create.llvm._return(create.llvm.null(getI8PointerType(context)));
414410
}
415411
}
416412

@@ -427,8 +423,9 @@ struct ConvertKrnlToLLVMPass
427423
ConvertKrnlToLLVMPass() = default;
428424
ConvertKrnlToLLVMPass(const ConvertKrnlToLLVMPass &pass)
429425
: PassWrapper<ConvertKrnlToLLVMPass, OperationPass<ModuleOp>>() {}
430-
ConvertKrnlToLLVMPass(bool verifyInputTensors) {
426+
ConvertKrnlToLLVMPass(bool verifyInputTensors, bool useOpaquePointers) {
431427
this->verifyInputTensors = verifyInputTensors;
428+
this->useOpaquePointers = useOpaquePointers;
432429
}
433430

434431
StringRef getArgument() const override { return "convert-krnl-to-llvm"; }
@@ -439,6 +436,11 @@ struct ConvertKrnlToLLVMPass
439436

440437
void runOnOperation() final;
441438

439+
Option<bool> useOpaquePointers{*this, "use-opaque-pointers",
440+
llvm::cl::desc("Whether to use opaque pointers instead of typed pointers "
441+
"when lowering to LLVM. Default: true"),
442+
llvm::cl::init(true)};
443+
442444
Option<bool> verifyInputTensors{*this, "verify-input-tensors",
443445
llvm::cl::desc(
444446
"Verify input tensors whenever the entry point function is called.\n"
@@ -453,9 +455,10 @@ void ConvertKrnlToLLVMPass::runOnOperation() {
453455
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
454456
LowerToLLVMOptions options(ctx, dataLayoutAnalysis.getAtOrAbove(module));
455457

456-
// There are many places where we still rely on non-opaque pointers. Disable
457-
// opaque-pointers until we migrated the affected code parts
458-
options.useOpaquePointers = false;
458+
// MLIR/LLVM is moving to using opaque pointers instead of typed pointers.
459+
// Remove this once MLIR/LLVM completely uses opaque pointers.
460+
options.useOpaquePointers = useOpaquePointers; // for LLVMTypeConverter.
461+
LLVM_USE_OPAQUE_POINTER = useOpaquePointers; // for onnx-mlir util functions.
459462

460463
KRNL_ENTRY_POINT_ID = 0;
461464

@@ -533,8 +536,10 @@ void ConvertKrnlToLLVMPass::runOnOperation() {
533536
std::unique_ptr<Pass> createConvertKrnlToLLVMPass() {
534537
return std::make_unique<ConvertKrnlToLLVMPass>();
535538
}
536-
std::unique_ptr<Pass> createConvertKrnlToLLVMPass(bool verifyInputTensors) {
537-
return std::make_unique<ConvertKrnlToLLVMPass>(verifyInputTensors);
539+
std::unique_ptr<Pass> createConvertKrnlToLLVMPass(
540+
bool verifyInputTensors, bool useOpaquePointers) {
541+
return std::make_unique<ConvertKrnlToLLVMPass>(
542+
verifyInputTensors, useOpaquePointers);
538543
}
539544

540545
void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,

0 commit comments

Comments
 (0)