Skip to content

Commit 37c0b35

Browse files
committed
[HWToLLVM][ArcToLLVM] Spill array values early
1 parent 120dd53 commit 37c0b35

File tree

3 files changed

+175
-24
lines changed

3 files changed

+175
-24
lines changed

include/circt/Conversion/HWToLLVM.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,22 @@ struct HWToLLVMEndianessConverter {
4242
StringRef fieldName);
4343
};
4444

45+
struct ArraySpillCache {
46+
void spillNonHWOps(mlir::OpBuilder &builder,
47+
mlir::LLVMTypeConverter &converter,
48+
Operation *containerOp);
49+
void map(mlir::Value arrayValue, mlir::Value bufferPtr);
50+
Value lookup(Value arrayValue);
51+
Value spillLLVMArrayValue(OpBuilder &builder, Location loc, Value llvmArray);
52+
Value spillHWArrayValue(OpBuilder &builder, Location loc,
53+
mlir::LLVMTypeConverter &converter, Value hwArray,
54+
bool replaceUses);
55+
56+
private:
57+
// Map LLVM Array values to pointers to constant (!) buffers
58+
llvm::DenseMap<Value, Value> spillMap;
59+
};
60+
4561
/// Get the HW to LLVM type conversions.
4662
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter);
4763

@@ -50,7 +66,8 @@ void populateHWToLLVMConversionPatterns(
5066
mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns,
5167
Namespace &globals,
5268
DenseMap<std::pair<Type, ArrayAttr>, mlir::LLVM::GlobalOp>
53-
&constAggregateGlobalsMap);
69+
&constAggregateGlobalsMap,
70+
ArraySpillCache &spillCache);
5471

5572
/// Create an HW to LLVM conversion pass.
5673
std::unique_ptr<OperationPass<ModuleOp>> createConvertHWToLLVMPass();

lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,9 +760,15 @@ void LowerArcToLLVMPass::runOnOperation() {
760760

761761
// CIRCT patterns.
762762
DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
763-
populateHWToLLVMConversionPatterns(converter, patterns, globals,
764-
constAggregateGlobalsMap);
765763
populateHWToLLVMTypeConversions(converter);
764+
ArraySpillCache spillCache;
765+
{
766+
OpBuilder spillBuilder(getOperation());
767+
spillCache.spillNonHWOps(spillBuilder, converter, getOperation());
768+
}
769+
populateHWToLLVMConversionPatterns(converter, patterns, globals,
770+
constAggregateGlobalsMap, spillCache);
771+
766772
populateCombToArithConversionPatterns(converter, patterns);
767773
populateCombToLLVMConversionPatterns(converter, patterns);
768774

lib/Conversion/HWToLLVM/HWToLLVM.cpp

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1818
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/IR/Iterators.h"
2021
#include "mlir/Pass/Pass.h"
2122
#include "mlir/Transforms/DialectConversion.h"
2223
#include "llvm/ADT/TypeSwitch.h"
@@ -79,6 +80,95 @@ static Value zextByOne(Location loc, ConversionPatternRewriter &rewriter,
7980
return LLVM::ZExtOp::create(rewriter, loc, zextTy, value);
8081
}
8182

83+
void ArraySpillCache::spillNonHWOps(OpBuilder &builder,
84+
LLVMTypeConverter &converter,
85+
Operation *containerOp) {
86+
OpBuilder::InsertionGuard g(builder);
87+
containerOp->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
88+
[&](Operation *op) {
89+
// Spill Block arguments
90+
for (auto &region : op->getRegions()) {
91+
for (auto &block : region.getBlocks()) {
92+
builder.setInsertionPointToStart(&block);
93+
for (auto &arg : block.getArguments()) {
94+
if (isa<hw::ArrayType>(arg.getType()))
95+
spillHWArrayValue(builder, arg.getLoc(), converter, arg,
96+
/*replaceUses*/ true);
97+
}
98+
}
99+
}
100+
101+
// Spill Op Results
102+
for (auto result : op->getResults()) {
103+
if (isa<hw::ArrayType>(result.getType())) {
104+
builder.setInsertionPointAfter(op);
105+
spillHWArrayValue(builder, op->getLoc(), converter, result,
106+
/*replaceUses*/ true);
107+
}
108+
}
109+
});
110+
}
111+
112+
void ArraySpillCache::map(Value arrayValue, Value bufferPtr) {
113+
assert(llvm::isa<LLVM::LLVMArrayType>(arrayValue.getType()) &&
114+
"Key is not an LLVM array.");
115+
assert(llvm::isa<LLVM::LLVMPointerType>(bufferPtr.getType()) &&
116+
"Value is not a pointer.");
117+
auto insert = spillMap.insert({arrayValue, bufferPtr});
118+
(void)insert;
119+
assert(insert.second && "Key already mapped");
120+
}
121+
122+
Value ArraySpillCache::lookup(Value arrayValue) {
123+
assert(isa<LLVM::LLVMArrayType>(arrayValue.getType()) ||
124+
isa<hw::ArrayType>(arrayValue.getType()) && "Not an array value");
125+
while (isa<LLVM::LLVMArrayType>(arrayValue.getType()) ||
126+
isa<hw::ArrayType>(arrayValue.getType())) {
127+
if (isa<LLVM::LLVMArrayType>(arrayValue.getType())) {
128+
auto mapVal = spillMap.lookup(arrayValue);
129+
if (mapVal)
130+
return mapVal;
131+
}
132+
if (auto castOp = arrayValue.getDefiningOp<UnrealizedConversionCastOp>())
133+
arrayValue = castOp.getOperand(0);
134+
else
135+
break;
136+
}
137+
return {};
138+
}
139+
140+
Value ArraySpillCache::spillLLVMArrayValue(OpBuilder &builder, Location loc,
141+
Value llvmArray) {
142+
assert(isa<LLVM::LLVMArrayType>(llvmArray.getType()) &&
143+
"Expected an LLVM array");
144+
auto oneC = LLVM::ConstantOp::create(builder, loc, builder.getI32Type(),
145+
builder.getI32IntegerAttr(1));
146+
auto spillBuffer = LLVM::AllocaOp::create(
147+
builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
148+
llvmArray.getType(), oneC,
149+
/*alignment=*/4);
150+
LLVM::StoreOp::create(builder, loc, llvmArray, spillBuffer);
151+
auto loadOp =
152+
LLVM::LoadOp::create(builder, loc, llvmArray.getType(), spillBuffer);
153+
map(loadOp.getResult(), spillBuffer);
154+
return loadOp.getResult();
155+
}
156+
157+
Value ArraySpillCache::spillHWArrayValue(OpBuilder &builder, Location loc,
158+
LLVMTypeConverter &converter,
159+
Value hwArray, bool replaceUses) {
160+
assert(isa<hw::ArrayType>(hwArray.getType()) && "Expected an HW array");
161+
auto targetType = converter.convertType(hwArray.getType());
162+
auto hwToLLVMCast =
163+
UnrealizedConversionCastOp::create(builder, loc, targetType, hwArray);
164+
auto spilled = spillLLVMArrayValue(builder, loc, hwToLLVMCast.getResult(0));
165+
auto llvmToHWCast = UnrealizedConversionCastOp::create(
166+
builder, loc, hwArray.getType(), spilled);
167+
if (replaceUses)
168+
hwArray.replaceAllUsesExcept(llvmToHWCast.getResult(0), hwToLLVMCast);
169+
return llvmToHWCast.getResult(0);
170+
}
171+
82172
//===----------------------------------------------------------------------===//
83173
// Extraction operation conversions
84174
//===----------------------------------------------------------------------===//
@@ -115,6 +205,7 @@ struct StructExplodeOpConversion
115205
} // namespace
116206

117207
namespace {
208+
118209
/// Convert a StructExtractOp to LLVM dialect.
119210
/// Pattern: struct_extract(input, fieldname) =>
120211
/// extractvalue(input, fieldname_to_index(fieldname))
@@ -200,25 +291,54 @@ struct ArrayInjectOpConversion
200291
};
201292
} // namespace
202293

294+
namespace {
295+
template <typename SourceOp>
296+
struct HWArrayOpToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
297+
298+
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
299+
HWArrayOpToLLVMPattern(LLVMTypeConverter &converter,
300+
ArraySpillCache &spillCache)
301+
: ConvertOpToLLVMPattern<SourceOp>(converter), spillCache(spillCache) {}
302+
303+
ArraySpillCache &spillCache;
304+
};
305+
306+
} // namespace
307+
203308
namespace {
204309
/// Convert an ArrayGetOp to the LLVM dialect.
205310
/// Pattern: array_get(input, index) =>
206311
/// load(gep(store(input, alloca), zext(index)))
207-
struct ArrayGetOpConversion : public ConvertOpToLLVMPattern<hw::ArrayGetOp> {
208-
using ConvertOpToLLVMPattern<hw::ArrayGetOp>::ConvertOpToLLVMPattern;
312+
struct ArrayGetOpConversion : public HWArrayOpToLLVMPattern<hw::ArrayGetOp> {
313+
using HWArrayOpToLLVMPattern<hw::ArrayGetOp>::HWArrayOpToLLVMPattern;
209314

210315
LogicalResult
211316
matchAndRewrite(hw::ArrayGetOp op, OpAdaptor adaptor,
212317
ConversionPatternRewriter &rewriter) const override {
213-
auto oneC = LLVM::ConstantOp::create(
214-
rewriter, op->getLoc(), IntegerType::get(rewriter.getContext(), 32),
215-
rewriter.getI32IntegerAttr(1));
216-
Value arrPtr = LLVM::AllocaOp::create(
217-
rewriter, op->getLoc(),
218-
LLVM::LLVMPointerType::get(rewriter.getContext()),
219-
adaptor.getInput().getType(), oneC,
220-
/*alignment=*/4);
221-
LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
318+
319+
Value arrPtr = spillCache.lookup(adaptor.getInput());
320+
321+
/*
322+
if (arrPtr) {
323+
llvm::dbgs() << ">>> Reused spilled array\n";
324+
} else {
325+
llvm::dbgs() << ">>> Should have been spilled: " << op << "\n DefOp: \n";
326+
adaptor.getInput().getDefiningOp()->dumpPretty();
327+
assert(false && "Should have been spilled");
328+
}
329+
*/
330+
331+
if (!arrPtr) {
332+
auto oneC = LLVM::ConstantOp::create(
333+
rewriter, op->getLoc(), IntegerType::get(rewriter.getContext(), 32),
334+
rewriter.getI32IntegerAttr(1));
335+
arrPtr = LLVM::AllocaOp::create(
336+
rewriter, op->getLoc(),
337+
LLVM::LLVMPointerType::get(rewriter.getContext()),
338+
adaptor.getInput().getType(), oneC,
339+
/*alignment=*/4);
340+
LLVM::StoreOp::create(rewriter, op->getLoc(), adaptor.getInput(), arrPtr);
341+
}
222342

223343
auto arrTy = typeConverter->convertType(op.getInput().getType());
224344
auto elemTy = typeConverter->convertType(op.getResult().getType());
@@ -466,9 +586,10 @@ class AggregateConstantOpConversion
466586
LLVMTypeConverter &typeConverter,
467587
DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
468588
&constAggregateGlobalsMap,
469-
Namespace &globals)
589+
Namespace &globals, ArraySpillCache &spillCache)
470590
: ConvertOpToLLVMPattern(typeConverter),
471-
constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals) {}
591+
constAggregateGlobalsMap(constAggregateGlobalsMap), globals(globals),
592+
spillCache(spillCache) {}
472593

473594
LogicalResult
474595
matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
@@ -478,6 +599,7 @@ class AggregateConstantOpConversion
478599
DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
479600
&constAggregateGlobalsMap;
480601
Namespace &globals;
602+
ArraySpillCache &spillCache;
481603
};
482604
} // namespace
483605

@@ -661,7 +783,10 @@ LogicalResult AggregateConstantOpConversion::matchAndRewrite(
661783
// Get the global array address and load it to return an array value.
662784
auto addr = LLVM::AddressOfOp::create(rewriter, op->getLoc(),
663785
constAggregateGlobalsMap[typeAttrPair]);
664-
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy, addr);
786+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmTy, addr);
787+
788+
if (llvm::isa<hw::ArrayType>(aggregateType))
789+
spillCache.map(newOp.getResult(), addr);
665790

666791
return success();
667792
}
@@ -703,24 +828,26 @@ void circt::populateHWToLLVMConversionPatterns(
703828
LLVMTypeConverter &converter, RewritePatternSet &patterns,
704829
Namespace &globals,
705830
DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp>
706-
&constAggregateGlobalsMap) {
831+
&constAggregateGlobalsMap,
832+
ArraySpillCache &spillCache) {
707833
MLIRContext *ctx = converter.getDialect()->getContext();
708834

709835
// Value creation conversion patterns.
710836
patterns.add<HWConstantOpConversion>(ctx, converter);
711837
patterns.add<HWDynamicArrayCreateOpConversion, HWStructCreateOpConversion>(
712838
converter);
713839
patterns.add<AggregateConstantOpConversion>(
714-
converter, constAggregateGlobalsMap, globals);
840+
converter, constAggregateGlobalsMap, globals, spillCache);
715841

716842
// Bitwise conversion patterns.
717843
patterns.add<BitcastOpConversion>(converter);
718844

719845
// Extraction operation conversion patterns.
720-
patterns.add<ArrayInjectOpConversion, ArrayGetOpConversion,
721-
ArraySliceOpConversion, ArrayConcatOpConversion,
722-
StructExplodeOpConversion, StructExtractOpConversion,
723-
StructInjectOpConversion>(converter);
846+
patterns.add<ArrayInjectOpConversion, ArraySliceOpConversion,
847+
ArrayConcatOpConversion, StructExplodeOpConversion,
848+
StructExtractOpConversion, StructInjectOpConversion>(converter);
849+
850+
patterns.add<ArrayGetOpConversion>(converter, spillCache);
724851
}
725852

726853
void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {
@@ -732,6 +859,7 @@ void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {
732859

733860
void HWToLLVMLoweringPass::runOnOperation() {
734861
DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
862+
ArraySpillCache spillCache;
735863
Namespace globals;
736864
SymbolCache cache;
737865
cache.addDefinitions(getOperation());
@@ -746,7 +874,7 @@ void HWToLLVMLoweringPass::runOnOperation() {
746874

747875
// Setup the conversion.
748876
populateHWToLLVMConversionPatterns(converter, patterns, globals,
749-
constAggregateGlobalsMap);
877+
constAggregateGlobalsMap, spillCache);
750878

751879
// Apply the partial conversion.
752880
ConversionConfig config;

0 commit comments

Comments
 (0)