Skip to content

Commit 3757e7a

Browse files
committed
[lumen] allow LLVMDialect to inherit LLVMContext
By allowing the LLVMDialect to be created with a reference to a non-owned LLVMContext, we can create one LLVMContext per thread and use it for everything. This prevents the issue where MLIR-generated modules were created with their own context and cannot be linked together or worked with using the thread-global LLVMContext we create for that purpose in Lumen.
1 parent 2fa4f5c commit 3757e7a

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def LLVM_Dialect : Dialect {
2121
let cppNamespace = "LLVM";
2222
let hasRegionArgAttrVerify = 1;
2323
let extraClassDeclaration = [{
24+
LLVMDialect(mlir::MLIRContext *, llvm::LLVMContext *);
2425
~LLVMDialect();
2526
llvm::LLVMContext &getLLVMContext();
2627
llvm::Module &getLLVMModule();

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,12 +1664,32 @@ static LogicalResult verify(FenceOp &op) {
16641664
namespace mlir {
16651665
namespace LLVM {
16661666
namespace detail {
1667+
struct LLVMContextHandle {
1668+
bool owned;
1669+
llvm::LLVMContext *context;
1670+
1671+
LLVMContextHandle() :
1672+
owned(true), context(new llvm::LLVMContext()) {}
1673+
LLVMContextHandle(llvm::LLVMContext *ctx) :
1674+
owned(false), context(ctx) {}
1675+
1676+
~LLVMContextHandle() {
1677+
if (owned)
1678+
delete context;
1679+
}
1680+
};
1681+
16671682
struct LLVMDialectImpl {
1668-
LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
1683+
LLVMDialectImpl()
1684+
: module("LLVMDialectModule", *llvmContext.context) {}
1685+
LLVMDialectImpl(llvm::LLVMContext *ctx)
1686+
: llvmContext(ctx), module("LLVMDialectModule", *ctx) {}
16691687

1670-
llvm::LLVMContext llvmContext;
1688+
LLVMContextHandle llvmContext;
16711689
llvm::Module module;
16721690

1691+
bool ownsContext;
1692+
16731693
/// A set of LLVMTypes that are cached on construction to avoid any lookups or
16741694
/// locking.
16751695
LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
@@ -1684,6 +1704,38 @@ struct LLVMDialectImpl {
16841704
} // end namespace LLVM
16851705
} // end namespace mlir
16861706

1707+
LLVMDialect::LLVMDialect(MLIRContext *context, llvm::LLVMContext *llvmCtx)
1708+
: Dialect(getDialectNamespace(), context),
1709+
impl(new detail::LLVMDialectImpl(llvmCtx)) {
1710+
addTypes<LLVMType>();
1711+
addOperations<
1712+
#define GET_OP_LIST
1713+
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1714+
>();
1715+
1716+
// Support unknown operations because not all LLVM operations are registered.
1717+
allowUnknownOperations();
1718+
1719+
// Cache some of the common LLVM types to avoid the need for lookups/locking.
1720+
auto &llvmContext = impl->module.getContext();
1721+
/// Integer Types.
1722+
impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
1723+
impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
1724+
impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext));
1725+
impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext));
1726+
impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext));
1727+
impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext));
1728+
/// Float Types.
1729+
impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
1730+
impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
1731+
impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
1732+
impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext));
1733+
impl->x86_fp80Ty =
1734+
LLVMType::get(context, llvm::Type::getX86_FP80Ty(llvmContext));
1735+
/// Other Types.
1736+
impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext));
1737+
}
1738+
16871739
LLVMDialect::LLVMDialect(MLIRContext *context)
16881740
: Dialect(getDialectNamespace(), context),
16891741
impl(new detail::LLVMDialectImpl()) {
@@ -1697,7 +1749,7 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
16971749
allowUnknownOperations();
16981750

16991751
// Cache some of the common LLVM types to avoid the need for lookups/locking.
1700-
auto &llvmContext = impl->llvmContext;
1752+
auto &llvmContext = impl->module.getContext();
17011753
/// Integer Types.
17021754
impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
17031755
impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
@@ -1722,7 +1774,7 @@ LLVMDialect::~LLVMDialect() {}
17221774
#define GET_OP_CLASSES
17231775
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
17241776

1725-
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
1777+
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->module.getContext(); }
17261778
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
17271779
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
17281780
return impl->mutex;

0 commit comments

Comments
 (0)