1- commit 3ea3d4bed57c4f6a35bed044bca8c1277fa2bb17
2- Author: lipracer <
[email protected] >
3- Date: Fri Mar 29 23:25:07 2024 +0800
1+ From a2113b34ed4c5bebfc2d86187cc8d8272e3bd8ef Mon Sep 17 00:00:00 2001
2+ From: lipracer <
[email protected] >
3+ Date: Fri, 29 Mar 2024 23:25:07 +0800
4+ Subject: [PATCH] [mlir] fix Undefined behavior in CastInfo::castFailed with
5+ From=<MLIR interface>
46
5- [mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR interface>
6-
7- Fixes https://github.com/llvm/llvm-project/issues/86647
7+ Fixes https://github.com/llvm/llvm-project/issues/86647
88
9+ add CastInfo to support cast Interface to Op
10+ ---
11+ config.sh | 10 +++
12+ mlir/include/mlir/IR/OpDefinition.h | 71 ++++++++++++++++++++++
13+ mlir/include/mlir/TableGen/Class.h | 2 +
14+ mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++
15+ mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +-
16+ mlir/unittests/IR/InterfaceTest.cpp | 48 +++++++++++++++
17+ 6 files changed, 142 insertions(+), 1 deletion(-)
18+ create mode 100644 config.sh
19+
20+ diff --git a/config.sh b/config.sh
21+ new file mode 100644
22+ index 000000000000..55ab08224a32
23+ --- /dev/null
24+ +++ b/config.sh
25+ @@ -0,0 +1,10 @@
26+ + cmake -G Ninja llvm -B build \
27+ + -DCMAKE_C_COMPILER=clang \
28+ + -DCMAKE_CXX_COMPILER=clang++ \
29+ + -DLLVM_ENABLE_LLD=OFF \
30+ + -DLLVM_ENABLE_PROJECTS=mlir \
31+ + -DLLVM_BUILD_EXAMPLES=ON \
32+ + -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \
33+ + -DCMAKE_BUILD_TYPE=Release \
34+ + -DLLVM_ENABLE_ASSERTIONS=ON \
35+ + -DMLIR_INCLUDE_INTEGRATION_TESTS=ON
936diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
10- index bd68c2744574..5ba39b80b513 100644
37+ index 59f094d66909..52aac19289cf 100644
1138--- a/mlir/include/mlir/IR/OpDefinition.h
1239+++ b/mlir/include/mlir/IR/OpDefinition.h
1340@@ -22,6 +22,7 @@
@@ -16,9 +43,9 @@ index bd68c2744574..5ba39b80b513 100644
1643 #include "mlir/IR/Operation.h"
1744+ #include "llvm/Support/Casting.h"
1845 #include "llvm/Support/PointerLikeTypeTraits.h"
19-
46+
2047 #include <optional>
21- @@ -2110 ,6 +2111,34 @@ struct DenseMapInfo<T,
48+ @@ -2142 ,6 +2143,76 @@ struct DenseMapInfo<T,
2249 }
2350 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
2451 };
@@ -36,7 +63,7 @@ index bd68c2744574..5ba39b80b513 100644
3663+ void>> : NullableValueCastFailed<To>,
3764+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
3865+
39- + static bool isPossible(From &val) {
66+ + static inline bool isPossible(From &val) {
4067+ if constexpr (std::is_same_v<To, From>)
4168+ return true;
4269+ else
@@ -45,16 +72,92 @@ index bd68c2744574..5ba39b80b513 100644
4572+ const_cast<std::remove_const_t<From> &>(val).getOperation());
4673+ }
4774+
48- + static To doCast(From &val) {
75+ + static inline To doCast(From &val) {
76+ + return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
77+ + }
78+ + };
79+ +
80+ + template <typename OpT, typename = void>
81+ + struct is_concrete_op_type : public std::false_type {};
82+ +
83+ + template <typename OpT, template <typename T> typename... Traits>
84+ + constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
85+ + return mlir::Op<OpT, Traits...>(nullptr);
86+ + }
87+ +
88+ + template <typename OpT>
89+ + using concrete_op_base_type =
90+ + decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));
91+ +
92+ + template <typename OpT>
93+ + struct is_concrete_op_type<
94+ + OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
95+ + : public std::true_type {};
96+ +
97+ + template <typename To, typename From>
98+ + struct CastInfo<
99+ + To, From,
100+ + std::enable_if_t<
101+ + is_concrete_op_type<To>() &&
102+ + std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
103+ + typename std::remove_const_t<
104+ + From>::InterfaceTraits>,
105+ + std::remove_const_t<From>>>>
106+ + : NullableValueCastFailed<To>,
107+ + DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
108+ +
109+ + static inline bool isPossible(From &val) {
110+ + if constexpr (std::is_same_v<To, From>)
111+ + return true;
112+ + else
113+ + return isa<To>(
114+ + const_cast<std::remove_const_t<From> &>(val).getOperation());
115+ + }
116+ +
117+ + static inline To doCast(From &val) {
49118+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
50119+ }
51120+ };
52121+
53122 } // namespace llvm
54-
123+
55124 #endif
125+ diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
126+ index 92fec6a3b11d..7616f56aa2e3 100644
127+ --- a/mlir/include/mlir/TableGen/Class.h
128+ +++ b/mlir/include/mlir/TableGen/Class.h
129+ @@ -520,6 +520,8 @@ public:
130+ /// Write the parent class declaration.
131+ void writeTo(raw_indented_ostream &os) const;
132+
133+ + friend class OpClass;
134+ +
135+ private:
136+ /// The fully resolved C++ name of the parent class.
137+ std::string name;
138+ diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
139+ index 60fa1833ce62..5426302dfed3 100644
140+ --- a/mlir/tools/mlir-tblgen/OpClass.cpp
141+ +++ b/mlir/tools/mlir-tblgen/OpClass.cpp
142+ @@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
143+ }
144+
145+ void OpClass::finalize() {
146+ + std::string traitList;
147+ + llvm::raw_string_ostream os(traitList);
148+ + iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
149+ + std::end(parent.templateParams));
150+ + llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
151+ + os << trait << "<" << getClassName().str() << ">";
152+ + });
153+ + declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
154+ Class::finalize();
155+ +
156+ declare<VisibilityDeclaration>(Visibility::Public);
157+ declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
158+ }
56159diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
57- index 2a7406f42f34..c6409e9ec30e 100644
160+ index 4b06b92fbc8a..a1cae23c1df9 100644
58161--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
59162+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
60163@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
@@ -66,27 +169,28 @@ index 2a7406f42f34..c6409e9ec30e 100644
66169+ " using InterfaceTraits = detail::{2};\n",
67170 interfaceName, interfaceName, interfaceTraitsName,
68171 interfaceBaseType);
69-
172+
70173diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
71- index 5ab4d9a10623..7012da669248 100644
174+ index 42196b003e7d..c9ae6938e8b4 100644
72175--- a/mlir/unittests/IR/InterfaceTest.cpp
73176+++ b/mlir/unittests/IR/InterfaceTest.cpp
74- @@ -16,6 +16,9 @@
75- #include "../../test/lib/Dialect/Test/TestAttributes.h"
177+ @@ -17,6 +17,10 @@
76178 #include "../../test/lib/Dialect/Test/TestDialect.h"
179+ #include "../../test/lib/Dialect/Test/TestOps.h"
77180 #include "../../test/lib/Dialect/Test/TestTypes.h"
78181+ #include "mlir/Dialect/Arith/IR/Arith.h"
182+ + #include "mlir/Dialect/SCF/IR/SCF.h"
79183+ #include "mlir/Parser/Parser.h"
80184+ #include "llvm/ADT/TypeSwitch.h"
81-
185+
82186 using namespace mlir;
83187 using namespace test;
84- @@ -83 ,3 +86,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
188+ @@ -84 ,3 +88,47 @@ TEST(InterfaceTest, TestImplicitConversion) {
85189 typeA = typeB;
86190 EXPECT_EQ(typeA, typeB);
87191 }
88192+
89- + TEST(OperationInterfaceTest, CastOpToInterface ) {
193+ + TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface ) {
90194+ DialectRegistry registry;
91195+ MLIRContext ctx;
92196+
@@ -103,13 +207,20 @@ index 5ab4d9a10623..7012da669248 100644
103207+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
104208+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
105209+
210+ + static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
211+ + arith::AddIOp>,
212+ + "");
213+ + static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
214+ + static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");
215+ +
106216+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
107217+
108- + bool constantOp =
109- + llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
110- + .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
111- + return std::is_same_v<decltype(op), arith::ConstantOp>;
112- + });
218+ + bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
219+ + .Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
220+ + bool is_same =
221+ + std::is_same_v<decltype(op), arith::ConstantOp>;
222+ + return is_same;
223+ + });
113224+
114225+ EXPECT_TRUE(constantOp);
115226+
@@ -122,3 +233,5 @@ index 5ab4d9a10623..7012da669248 100644
122233+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
123234+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
124235+ }
236+ - -
237+ 2.25.1
0 commit comments