1717#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
1818#include " mlir/Conversion/LLVMCommon/Pattern.h"
1919#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
20+ #include " mlir/IR/Iterators.h"
2021#include " mlir/Pass/Pass.h"
2122#include " mlir/Transforms/DialectConversion.h"
2223#include " llvm/ADT/TypeSwitch.h"
@@ -79,6 +80,95 @@ static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter,
7980 return LLVM::ZExtOp::create (rewriter, loc, zextTy, value);
8081}
8182
83+ void ArraySpillCache::spillNonHWOps (OpBuilder &builder,
84+ LLVMTypeConverter &converter,
85+ Operation *containerOp) {
86+ OpBuilder::InsertionGuard g (builder);
87+ containerOp->walk <mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
88+ [&](Operation *op) {
89+ // Spill Block arguments
90+ for (auto ®ion : op->getRegions ()) {
91+ for (auto &block : region.getBlocks ()) {
92+ builder.setInsertionPointToStart (&block);
93+ for (auto &arg : block.getArguments ()) {
94+ if (isa<hw::ArrayType>(arg.getType ()))
95+ spillHWArrayValue (builder, arg.getLoc (), converter, arg,
96+ /* replaceUses*/ true );
97+ }
98+ }
99+ }
100+
101+ // Spill Op Results
102+ for (auto result : op->getResults ()) {
103+ if (isa<hw::ArrayType>(result.getType ())) {
104+ builder.setInsertionPointAfter (op);
105+ spillHWArrayValue (builder, op->getLoc (), converter, result,
106+ /* replaceUses*/ true );
107+ }
108+ }
109+ });
110+ }
111+
112+ void ArraySpillCache::map (Value arrayValue, Value bufferPtr) {
113+ assert (llvm::isa<LLVM::LLVMArrayType>(arrayValue.getType ()) &&
114+ " Key is not an LLVM array." );
115+ assert (llvm::isa<LLVM::LLVMPointerType>(bufferPtr.getType ()) &&
116+ " Value is not a pointer." );
117+ auto insert = spillMap.insert ({arrayValue, bufferPtr});
118+ (void )insert;
119+ assert (insert.second && " Key already mapped" );
120+ }
121+
122+ Value ArraySpillCache::lookup (Value arrayValue) {
123+ assert (isa<LLVM::LLVMArrayType>(arrayValue.getType ()) ||
124+ isa<hw::ArrayType>(arrayValue.getType ()) && " Not an array value" );
125+ while (isa<LLVM::LLVMArrayType>(arrayValue.getType ()) ||
126+ isa<hw::ArrayType>(arrayValue.getType ())) {
127+ if (isa<LLVM::LLVMArrayType>(arrayValue.getType ())) {
128+ auto mapVal = spillMap.lookup (arrayValue);
129+ if (mapVal)
130+ return mapVal;
131+ }
132+ if (auto castOp = arrayValue.getDefiningOp <UnrealizedConversionCastOp>())
133+ arrayValue = castOp.getOperand (0 );
134+ else
135+ break ;
136+ }
137+ return {};
138+ }
139+
140+ Value ArraySpillCache::spillLLVMArrayValue (OpBuilder &builder, Location loc,
141+ Value llvmArray) {
142+ assert (isa<LLVM::LLVMArrayType>(llvmArray.getType ()) &&
143+ " Expected an LLVM array" );
144+ auto oneC = LLVM::ConstantOp::create (builder, loc, builder.getI32Type (),
145+ builder.getI32IntegerAttr (1 ));
146+ auto spillBuffer = LLVM::AllocaOp::create (
147+ builder, loc, LLVM::LLVMPointerType::get (builder.getContext ()),
148+ llvmArray.getType (), oneC,
149+ /* alignment=*/ 4 );
150+ LLVM::StoreOp::create (builder, loc, llvmArray, spillBuffer);
151+ auto loadOp =
152+ LLVM::LoadOp::create (builder, loc, llvmArray.getType (), spillBuffer);
153+ map (loadOp.getResult (), spillBuffer);
154+ return loadOp.getResult ();
155+ }
156+
157+ Value ArraySpillCache::spillHWArrayValue (OpBuilder &builder, Location loc,
158+ LLVMTypeConverter &converter,
159+ Value hwArray, bool replaceUses) {
160+ assert (isa<hw::ArrayType>(hwArray.getType ()) && " Expected an HW array" );
161+ auto targetType = converter.convertType (hwArray.getType ());
162+ auto hwToLLVMCast =
163+ UnrealizedConversionCastOp::create (builder, loc, targetType, hwArray);
164+ auto spilled = spillLLVMArrayValue (builder, loc, hwToLLVMCast.getResult (0 ));
165+ auto llvmToHWCast = UnrealizedConversionCastOp::create (
166+ builder, loc, hwArray.getType (), spilled);
167+ if (replaceUses)
168+ hwArray.replaceAllUsesExcept (llvmToHWCast.getResult (0 ), hwToLLVMCast);
169+ return llvmToHWCast.getResult (0 );
170+ }
171+
82172// ===----------------------------------------------------------------------===//
83173// Extraction operation conversions
84174// ===----------------------------------------------------------------------===//
@@ -115,6 +205,7 @@ struct StructExplodeOpConversion
115205} // namespace
116206
117207namespace {
208+
118209// / Convert a StructExtractOp to LLVM dialect.
119210// / Pattern: struct_extract(input, fieldname) =>
120211// / extractvalue(input, fieldname_to_index(fieldname))
@@ -200,25 +291,54 @@ struct ArrayInjectOpConversion
200291};
201292} // namespace
202293
294+ namespace {
295+ template <typename SourceOp>
296+ struct HWArrayOpToLLVMPattern : public ConvertOpToLLVMPattern <SourceOp> {
297+
298+ using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
299+ HWArrayOpToLLVMPattern (LLVMTypeConverter &converter,
300+ ArraySpillCache &spillCache)
301+ : ConvertOpToLLVMPattern<SourceOp>(converter), spillCache(spillCache) {}
302+
303+ ArraySpillCache &spillCache;
304+ };
305+
306+ } // namespace
307+
203308namespace {
204309// / Convert an ArrayGetOp to the LLVM dialect.
205310// / Pattern: array_get(input, index) =>
206311// / load(gep(store(input, alloca), zext(index)))
207- struct ArrayGetOpConversion : public ConvertOpToLLVMPattern <hw::ArrayGetOp> {
208- using ConvertOpToLLVMPattern <hw::ArrayGetOp>::ConvertOpToLLVMPattern ;
312+ struct ArrayGetOpConversion : public HWArrayOpToLLVMPattern <hw::ArrayGetOp> {
313+ using HWArrayOpToLLVMPattern <hw::ArrayGetOp>::HWArrayOpToLLVMPattern ;
209314
210315 LogicalResult
211316 matchAndRewrite (hw::ArrayGetOp op, OpAdaptor adaptor,
212317 ConversionPatternRewriter &rewriter) const override {
213- auto oneC = LLVM::ConstantOp::create (
214- rewriter, op->getLoc (), IntegerType::get (rewriter.getContext (), 32 ),
215- rewriter.getI32IntegerAttr (1 ));
216- Value arrPtr = LLVM::AllocaOp::create (
217- rewriter, op->getLoc (),
218- LLVM::LLVMPointerType::get (rewriter.getContext ()),
219- adaptor.getInput ().getType (), oneC,
220- /* alignment=*/ 4 );
221- LLVM::StoreOp::create (rewriter, op->getLoc (), adaptor.getInput (), arrPtr);
318+
319+ Value arrPtr = spillCache.lookup (adaptor.getInput ());
320+
321+ /*
322+ if (arrPtr) {
323+ llvm::dbgs() << ">>> Reused spilled array\n";
324+ } else {
325+ llvm::dbgs() << ">>> Should have been spilled: " << op << "\n DefOp: \n";
326+ adaptor.getInput().getDefiningOp()->dumpPretty();
327+ assert(false && "Should have been spilled");
328+ }
329+ */
330+
331+ if (!arrPtr) {
332+ auto oneC = LLVM::ConstantOp::create (
333+ rewriter, op->getLoc (), IntegerType::get (rewriter.getContext (), 32 ),
334+ rewriter.getI32IntegerAttr (1 ));
335+ arrPtr = LLVM::AllocaOp::create (
336+ rewriter, op->getLoc (),
337+ LLVM::LLVMPointerType::get (rewriter.getContext ()),
338+ adaptor.getInput ().getType (), oneC,
339+ /* alignment=*/ 4 );
340+ LLVM::StoreOp::create (rewriter, op->getLoc (), adaptor.getInput (), arrPtr);
341+ }
222342
223343 auto arrTy = typeConverter->convertType (op.getInput ().getType ());
224344 auto elemTy = typeConverter->convertType (op.getResult ().getType ());
@@ -466,9 +586,10 @@ class AggregateConstantOpConversion
466586 LLVMTypeConverter &typeConverter,
467587 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
468588 &constAggregateGlobalsMap,
469- Namespace &globals)
589+ Namespace &globals, ArraySpillCache &spillCache )
470590 : ConvertOpToLLVMPattern(typeConverter),
471- constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
591+ constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals),
592+ spillCache(spillCache) {}
472593
473594 LogicalResult
474595 matchAndRewrite (hw::AggregateConstantOp op, OpAdaptor adaptor,
@@ -478,6 +599,7 @@ class AggregateConstantOpConversion
478599 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
479600 &constAggregateGlobalsMap;
480601 Namespace &globals;
602+ ArraySpillCache &spillCache;
481603};
482604} // namespace
483605
@@ -661,7 +783,10 @@ LogicalResult AggregateConstantOpConversion::matchAndRewrite(
661783 // Get the global array address and load it to return an array value.
662784 auto addr = LLVM::AddressOfOp::create (rewriter, op->getLoc (),
663785 constAggregateGlobalsMap[typeAttrPair]);
664- rewriter.replaceOpWithNewOp <LLVM::LoadOp>(op, llvmTy, addr);
786+ auto newOp = rewriter.replaceOpWithNewOp <LLVM::LoadOp>(op, llvmTy, addr);
787+
788+ if (llvm::isa<hw::ArrayType>(aggregateType))
789+ spillCache.map (newOp.getResult (), addr);
665790
666791 return success ();
667792}
@@ -703,24 +828,26 @@ void circt::populateHWToLLVMConversionPatterns(
703828 LLVMTypeConverter &converter, RewritePatternSet &patterns,
704829 Namespace &globals,
705830 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
706- &constAggregateGlobalsMap) {
831+ &constAggregateGlobalsMap,
832+ ArraySpillCache &spillCache) {
707833 MLIRContext *ctx = converter.getDialect ()->getContext ();
708834
709835 // Value creation conversion patterns.
710836 patterns.add <HWConstantOpConversion>(ctx, converter);
711837 patterns.add <HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
712838 converter);
713839 patterns.add <AggregateConstantOpConversion>(
714- converter, constAggregateGlobalsMap, globals);
840+ converter, constAggregateGlobalsMap, globals, spillCache );
715841
716842 // Bitwise conversion patterns.
717843 patterns.add <BitcastOpConversion>(converter);
718844
719845 // Extraction operation conversion patterns.
720- patterns.add <ArrayInjectOpConversion, ArrayGetOpConversion,
721- ArraySliceOpConversion, ArrayConcatOpConversion,
722- StructExplodeOpConversion, StructExtractOpConversion,
723- StructInjectOpConversion>(converter);
846+ patterns.add <ArrayInjectOpConversion, ArraySliceOpConversion,
847+ ArrayConcatOpConversion, StructExplodeOpConversion,
848+ StructExtractOpConversion, StructInjectOpConversion>(converter);
849+
850+ patterns.add <ArrayGetOpConversion>(converter, spillCache);
724851}
725852
726853void circt::populateHWToLLVMTypeConversions (LLVMTypeConverter &converter) {
@@ -732,6 +859,7 @@ void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {
732859
733860void HWToLLVMLoweringPass::runOnOperation () {
734861 DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
862+ ArraySpillCache spillCache;
735863 Namespace globals;
736864 SymbolCache cache;
737865 cache.addDefinitions (getOperation ());
@@ -746,7 +874,7 @@ void HWToLLVMLoweringPass::runOnOperation() {
746874
747875 // Setup the conversion.
748876 populateHWToLLVMConversionPatterns (converter, patterns, globals,
749- constAggregateGlobalsMap);
877+ constAggregateGlobalsMap, spillCache );
750878
751879 // Apply the partial conversion.
752880 ConversionConfig config;
0 commit comments