Skip to content

Commit 1f1ad2b

Browse files
committed
update upstream mlir patch
Signed-off-by: lipracer <[email protected]>
1 parent ab7289a commit 1f1ad2b

File tree

1 file changed

+138
-25
lines changed

1 file changed

+138
-25
lines changed

cast.patch

Lines changed: 138 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,40 @@
1-
commit 3ea3d4bed57c4f6a35bed044bca8c1277fa2bb17
2-
Author: lipracer <[email protected]>
3-
Date: Fri Mar 29 23:25:07 2024 +0800
1+
From a2113b34ed4c5bebfc2d86187cc8d8272e3bd8ef Mon Sep 17 00:00:00 2001
2+
From: lipracer <[email protected]>
3+
Date: Fri, 29 Mar 2024 23:25:07 +0800
4+
Subject: [PATCH] [mlir] fix Undefined behavior in CastInfo::castFailed with
5+
From=<MLIR interface>
46

5-
[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR interface>
6-
7-
Fixes https://github.com/llvm/llvm-project/issues/86647
7+
Fixes https://github.com/llvm/llvm-project/issues/86647
88

9+
add CastInfo to support cast Interface to Op
10+
---
11+
config.sh | 10 +++
12+
mlir/include/mlir/IR/OpDefinition.h | 71 ++++++++++++++++++++++
13+
mlir/include/mlir/TableGen/Class.h | 2 +
14+
mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++
15+
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +-
16+
mlir/unittests/IR/InterfaceTest.cpp | 48 +++++++++++++++
17+
6 files changed, 142 insertions(+), 1 deletion(-)
18+
create mode 100644 config.sh
19+
20+
diff --git a/config.sh b/config.sh
21+
new file mode 100644
22+
index 000000000000..55ab08224a32
23+
--- /dev/null
24+
+++ b/config.sh
25+
@@ -0,0 +1,10 @@
26+
+cmake -G Ninja llvm -B build \
27+
+ -DCMAKE_C_COMPILER=clang \
28+
+ -DCMAKE_CXX_COMPILER=clang++ \
29+
+ -DLLVM_ENABLE_LLD=OFF \
30+
+ -DLLVM_ENABLE_PROJECTS=mlir \
31+
+ -DLLVM_BUILD_EXAMPLES=ON \
32+
+ -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \
33+
+ -DCMAKE_BUILD_TYPE=Release \
34+
+ -DLLVM_ENABLE_ASSERTIONS=ON \
35+
+ -DMLIR_INCLUDE_INTEGRATION_TESTS=ON
936
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
10-
index bd68c2744574..5ba39b80b513 100644
37+
index 59f094d66909..52aac19289cf 100644
1138
--- a/mlir/include/mlir/IR/OpDefinition.h
1239
+++ b/mlir/include/mlir/IR/OpDefinition.h
1340
@@ -22,6 +22,7 @@
@@ -16,9 +43,9 @@ index bd68c2744574..5ba39b80b513 100644
1643
#include "mlir/IR/Operation.h"
1744
+#include "llvm/Support/Casting.h"
1845
#include "llvm/Support/PointerLikeTypeTraits.h"
19-
46+
2047
#include <optional>
21-
@@ -2110,6 +2111,34 @@ struct DenseMapInfo<T,
48+
@@ -2142,6 +2143,76 @@ struct DenseMapInfo<T,
2249
}
2350
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
2451
};
@@ -36,7 +63,7 @@ index bd68c2744574..5ba39b80b513 100644
3663
+ void>> : NullableValueCastFailed<To>,
3764
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
3865
+
39-
+ static bool isPossible(From &val) {
66+
+ static inline bool isPossible(From &val) {
4067
+ if constexpr (std::is_same_v<To, From>)
4168
+ return true;
4269
+ else
@@ -45,16 +72,92 @@ index bd68c2744574..5ba39b80b513 100644
4572
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
4673
+ }
4774
+
48-
+ static To doCast(From &val) {
75+
+ static inline To doCast(From &val) {
76+
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
77+
+ }
78+
+};
79+
+
80+
+template <typename OpT, typename = void>
81+
+struct is_concrete_op_type : public std::false_type {};
82+
+
83+
+template <typename OpT, template <typename T> typename... Traits>
84+
+constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
85+
+ return mlir::Op<OpT, Traits...>(nullptr);
86+
+}
87+
+
88+
+template <typename OpT>
89+
+using concrete_op_base_type =
90+
+ decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));
91+
+
92+
+template <typename OpT>
93+
+struct is_concrete_op_type<
94+
+ OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
95+
+ : public std::true_type {};
96+
+
97+
+template <typename To, typename From>
98+
+struct CastInfo<
99+
+ To, From,
100+
+ std::enable_if_t<
101+
+ is_concrete_op_type<To>() &&
102+
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
103+
+ typename std::remove_const_t<
104+
+ From>::InterfaceTraits>,
105+
+ std::remove_const_t<From>>>>
106+
+ : NullableValueCastFailed<To>,
107+
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
108+
+
109+
+ static inline bool isPossible(From &val) {
110+
+ if constexpr (std::is_same_v<To, From>)
111+
+ return true;
112+
+ else
113+
+ return isa<To>(
114+
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
115+
+ }
116+
+
117+
+ static inline To doCast(From &val) {
49118
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
50119
+ }
51120
+};
52121
+
53122
} // namespace llvm
54-
123+
55124
#endif
125+
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
126+
index 92fec6a3b11d..7616f56aa2e3 100644
127+
--- a/mlir/include/mlir/TableGen/Class.h
128+
+++ b/mlir/include/mlir/TableGen/Class.h
129+
@@ -520,6 +520,8 @@ public:
130+
/// Write the parent class declaration.
131+
void writeTo(raw_indented_ostream &os) const;
132+
133+
+ friend class OpClass;
134+
+
135+
private:
136+
/// The fully resolved C++ name of the parent class.
137+
std::string name;
138+
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
139+
index 60fa1833ce62..5426302dfed3 100644
140+
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
141+
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
142+
@@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
143+
}
144+
145+
void OpClass::finalize() {
146+
+ std::string traitList;
147+
+ llvm::raw_string_ostream os(traitList);
148+
+ iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
149+
+ std::end(parent.templateParams));
150+
+ llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
151+
+ os << trait << "<" << getClassName().str() << ">";
152+
+ });
153+
+ declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
154+
Class::finalize();
155+
+
156+
declare<VisibilityDeclaration>(Visibility::Public);
157+
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
158+
}
56159
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
57-
index 2a7406f42f34..c6409e9ec30e 100644
160+
index 4b06b92fbc8a..a1cae23c1df9 100644
58161
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
59162
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
60163
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
@@ -66,27 +169,28 @@ index 2a7406f42f34..c6409e9ec30e 100644
66169
+ " using InterfaceTraits = detail::{2};\n",
67170
interfaceName, interfaceName, interfaceTraitsName,
68171
interfaceBaseType);
69-
172+
70173
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
71-
index 5ab4d9a10623..7012da669248 100644
174+
index 42196b003e7d..c9ae6938e8b4 100644
72175
--- a/mlir/unittests/IR/InterfaceTest.cpp
73176
+++ b/mlir/unittests/IR/InterfaceTest.cpp
74-
@@ -16,6 +16,9 @@
75-
#include "../../test/lib/Dialect/Test/TestAttributes.h"
177+
@@ -17,6 +17,10 @@
76178
#include "../../test/lib/Dialect/Test/TestDialect.h"
179+
#include "../../test/lib/Dialect/Test/TestOps.h"
77180
#include "../../test/lib/Dialect/Test/TestTypes.h"
78181
+#include "mlir/Dialect/Arith/IR/Arith.h"
182+
+#include "mlir/Dialect/SCF/IR/SCF.h"
79183
+#include "mlir/Parser/Parser.h"
80184
+#include "llvm/ADT/TypeSwitch.h"
81-
185+
82186
using namespace mlir;
83187
using namespace test;
84-
@@ -83,3 +86,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
188+
@@ -84,3 +88,47 @@ TEST(InterfaceTest, TestImplicitConversion) {
85189
typeA = typeB;
86190
EXPECT_EQ(typeA, typeB);
87191
}
88192
+
89-
+TEST(OperationInterfaceTest, CastOpToInterface) {
193+
+TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
90194
+ DialectRegistry registry;
91195
+ MLIRContext ctx;
92196
+
@@ -103,13 +207,20 @@ index 5ab4d9a10623..7012da669248 100644
103207
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
104208
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
105209
+
210+
+ static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
211+
+ arith::AddIOp>,
212+
+ "");
213+
+ static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
214+
+ static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");
215+
+
106216
+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
107217
+
108-
+ bool constantOp =
109-
+ llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
110-
+ .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
111-
+ return std::is_same_v<decltype(op), arith::ConstantOp>;
112-
+ });
218+
+ bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
219+
+ .Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
220+
+ bool is_same =
221+
+ std::is_same_v<decltype(op), arith::ConstantOp>;
222+
+ return is_same;
223+
+ });
113224
+
114225
+ EXPECT_TRUE(constantOp);
115226
+
@@ -122,3 +233,5 @@ index 5ab4d9a10623..7012da669248 100644
122233
+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
123234
+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
124235
+}
236+
--
237+
2.25.1

0 commit comments

Comments
 (0)