1919#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
2020
2121#include " src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp"
22+ #include " src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
2223#include " src/Dialect/Mlir/DialectBuilder.hpp"
2324#include " zdnn.h"
2425
@@ -29,7 +30,7 @@ namespace zlow {
2930
3031ApiRegistry RegisterAllApis (MLIRContext *context) {
3132 auto voidTy = LLVM::LLVMVoidType::get (context);
32- auto opaquePtrTy = LLVM::LLVMPointerType::get ( IntegerType::get ( context, 8 ) );
33+ auto opaquePtrTy = krnl::getI8PointerType ( context);
3334 auto int32Ty = IntegerType::get (context, 32 );
3435 auto int64Ty = IntegerType::get (context, 64 );
3536
@@ -105,7 +106,8 @@ Value ZTensorHelper::getPreTransformedDescPtr(zdnn_data_types zDNNDataType,
105106 Value one = create.llvm .constant (llvmI64Ty, (int64_t )1 );
106107
107108 Value preTransformedDescPtr = create.llvm ._alloca (
108- LLVM::LLVMPointerType::get (llvmZTensorDescStructTy), one,
109+ krnl::getPointerType (context, llvmZTensorDescStructTy),
110+ llvmZTensorDescStructTy, one,
109111 /* alignment=*/ 0 );
110112
111113 // Prepare operands for calling the function that initializes the zTensor
@@ -145,7 +147,8 @@ Value ZTensorHelper::getTransformedDescPtr(
145147 Value one = create.llvm .constant (llvmI64Ty, (int64_t )1 );
146148
147149 Value transformedDescPtr = create.llvm ._alloca (
148- LLVM::LLVMPointerType::get (llvmZTensorDescStructTy), one,
150+ krnl::getPointerType (context, llvmZTensorDescStructTy),
151+ llvmZTensorDescStructTy, one,
149152 /* alignment=*/ 0 );
150153
151154 if (isConcat) {
@@ -165,10 +168,11 @@ Value ZTensorHelper::getTransformedDescPtr(
165168
166169// Get the pointer to memref.
167170Value ZTensorHelper::getAlignedI8Ptr (Value memRef) {
171+ MLIRContext *context = rewriter.getContext ();
168172 MultiDialectBuilder<LLVMBuilder> create (rewriter, loc);
169173 MemRefDescriptor descriptor (memRef);
170174 Value alignedPtr = descriptor.alignedPtr (rewriter, loc);
171- return create.llvm .bitcastI8Ptr ( alignedPtr);
175+ return create.llvm .bitcast ( krnl::getI8PointerType (context), alignedPtr);
172176}
173177
174178// Get buffer size from a transformed descriptor.
@@ -202,7 +206,8 @@ ZTensor ZTensorHelper::getZTensor(Value bufferPtr, zdnn_data_types dataType,
202206 getTransformedDescPtr (preTransformedDescPtr, isConcat, concatInfo);
203207 // Create the input zTensor.
204208 Value alloc =
205- create.llvm ._alloca (LLVM::LLVMPointerType::get (llvmZTensorStructTy), one,
209+ create.llvm ._alloca (krnl::getPointerType (context, llvmZTensorStructTy),
210+ llvmZTensorStructTy, one,
206211 /* alignment=*/ 0 );
207212 // Buffer size.
208213 Value bufferSize = getBufferSize (transformedDescPtr);
@@ -235,7 +240,8 @@ ZTensor ZTensorHelper::getZTensor(Value preTransformedDescPtr,
235240 Type llvmZTensorStructTy = getZTensorStructTy (context);
236241 Value one = create.llvm .constant (rewriter.getI64Type (), (int64_t )1 );
237242 Value alloc =
238- create.llvm ._alloca (LLVM::LLVMPointerType::get (llvmZTensorStructTy), one,
243+ create.llvm ._alloca (krnl::getPointerType (context, llvmZTensorStructTy),
244+ llvmZTensorStructTy, one,
239245 /* alignment=*/ 0 );
240246 // clang-format off
241247 fillInZTensor (rewriter, loc, module , alloc,
@@ -370,10 +376,10 @@ std::vector<Value> getDimsFromShapeMemRefBySize(PatternRewriter &rewriter,
370376 Value alignedPtr = inputMRD.alignedPtr (rewriter, loc);
371377 Type int64Ty = IntegerType::get (context, 64 );
372378 for (int64_t i = 0 ; i < size; ++i) {
373- Value index = create. llvm . constant (int64Ty, i);
374- Value alignedGep = create.llvm .getElemPtr (
375- LLVM::LLVMPointerType::get (int64Ty), alignedPtr, {index });
376- Value dimI64 = create.llvm .load (alignedGep);
379+ Value alignedGep =
380+ create.llvm .getElemPtr (krnl::getPointerType (context, int64Ty), int64Ty,
381+ alignedPtr, ArrayRef<LLVM::GEPArg>{( int32_t )i });
382+ Value dimI64 = create.llvm .load (int64Ty, alignedGep);
377383 dims.emplace_back (dimI64);
378384 }
379385 return dims;
@@ -462,16 +468,16 @@ Type getZTensorStructTy(MLIRContext *context) {
462468 Type llvmI1Ty = IntegerType::get (context, 1 );
463469 Type llvmI8Ty = IntegerType::get (context, 8 );
464470 Type llvmArrayI8Ty = LLVM::LLVMArrayType::get (llvmI8Ty, 32 );
465- Type llvmI8PtrTy = LLVM::LLVMPointerType::get ( llvmI8Ty);
471+ Type llvmI8PtrTy = krnl::getPointerType (context, llvmI8Ty);
466472 Type llvmZTensorDescStructTy = getZTensorDescStructTy (context);
467473
468474 SmallVector<Type, 4 > zTensorTypeElements;
469475 // A pointer to pre-transformed descriptor struct type
470476 zTensorTypeElements.emplace_back (
471- LLVM::LLVMPointerType::get ( llvmZTensorDescStructTy));
477+ krnl::getPointerType (context, llvmZTensorDescStructTy));
472478 // A pointer to transformed descriptor struct type
473479 zTensorTypeElements.emplace_back (
474- LLVM::LLVMPointerType::get ( llvmZTensorDescStructTy));
480+ krnl::getPointerType (context, llvmZTensorDescStructTy));
475481 // zTensor size in bytes
476482 zTensorTypeElements.emplace_back (llvmI64Ty);
477483 // pointer to the zTensor in memory
@@ -490,8 +496,9 @@ Type getZTensorStructTy(MLIRContext *context) {
490496// / Function to cast an LLVM pointer to an opaque LLVM pointer.
491497Value toOpaquePtr (
492498 PatternRewriter &rewriter, Location loc, ModuleOp module , Value ptr) {
499+ MLIRContext *context = rewriter.getContext ();
493500 MultiDialectBuilder<LLVMBuilder> create (rewriter, loc);
494- return create.llvm .bitcastI8Ptr ( ptr);
501+ return create.llvm .bitcast ( krnl::getI8PointerType (context), ptr);
495502}
496503
497504void fillInZTensor (PatternRewriter &rewriter, Location loc, ModuleOp module ,
@@ -501,48 +508,35 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
501508 MultiDialectBuilder<LLVMBuilder> create (rewriter, loc);
502509
503510 Type llvmI1Ty = IntegerType::get (context, 1 );
504- Type llvmI8Ty = IntegerType::get (context, 8 );
505- Type llvmI8PtrTy = LLVM::LLVMPointerType::get (llvmI8Ty);
506- Type llvmI32Ty = IntegerType::get (context, 32 );
507- Type llvmI64Ty = IntegerType::get (context, 64 );
508- Type llvmZTensorDescTy =
509- LLVM::LLVMPointerType::get (getZTensorDescStructTy (context));
510-
511- // Got runtime error if using i64 as index to access zTensor. It looks
512- // like an error in MLIR. So use i32 here, which does not affect the
513- // correctness of the generated program.
514- Value zero = create.llvm .constant (llvmI32Ty, (int64_t )0 );
515- Value one = create.llvm .constant (llvmI32Ty, (int64_t )1 );
516- Value two = create.llvm .constant (llvmI32Ty, (int64_t )2 );
517- Value three = create.llvm .constant (llvmI32Ty, (int64_t )3 );
518- Value four = create.llvm .constant (llvmI32Ty, (int64_t )4 );
511+ Type llvmZTensorTy = getZTensorStructTy (context);
512+ Type llvmZTensorPtrTy = krnl::getPointerType (context, llvmZTensorTy);
519513
520514 // 1. Set pre-transformed descriptor.
521515 Value zTensorPreTransformedDescPtr = create.llvm .getElemPtr (
522- LLVM::LLVMPointerType::get (llvmZTensorDescTy), zTensor, {zero, zero });
516+ llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{ 0 , 0 });
523517 create.llvm .store (preTransformedDescPtr, zTensorPreTransformedDescPtr);
524518
525519 // 2. Set transformed descriptor.
526520 Value zTensorTransformedDescPtr = create.llvm .getElemPtr (
527- LLVM::LLVMPointerType::get (llvmZTensorDescTy), zTensor, {zero, one });
521+ llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{ 0 , 1 });
528522 create.llvm .store (transformedDescPtr, zTensorTransformedDescPtr);
529523
530524 // 3. Set buffer_size.
531525 Value bufferSizePtr = create.llvm .getElemPtr (
532- LLVM::LLVMPointerType::get (llvmI64Ty), zTensor, {zero, two });
526+ llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{ 0 , 2 });
533527 create.llvm .store (bufferSize, bufferSizePtr);
534528
535529 // 4. Set buffer. Buffer was allocated in advance by the stickified memref.
536530 // So get the pointer from the stickified memref and set it to the zTensor.
537531 Value bufferPtr = create.llvm .getElemPtr (
538- LLVM::LLVMPointerType::get (llvmI8PtrTy), zTensor, {zero, three });
532+ llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{ 0 , 3 });
539533 create.llvm .store (alignedBuffer, bufferPtr);
540534
541535 // 5. Set is_transformed.
542536 Value isTransformedVal =
543537 create.llvm .constant (llvmI1Ty, (int64_t )((isTransformed) ? 1 : 0 ));
544538 Value isTransformedDescPtr = create.llvm .getElemPtr (
545- LLVM::LLVMPointerType::get (llvmI1Ty), zTensor, {zero, four });
539+ llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef<LLVM::GEPArg>{ 0 , 4 });
546540 create.llvm .store (isTransformedVal, isTransformedDescPtr);
547541
548542 // 6. Set reserved (not currently used), not touch
0 commit comments