Skip to content

Commit 17cfeb7

Browse files
authored
[Enhancement] Improve error handling and assertion messages across runtime and argument binding (tile-ai#1356)
This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process.
1 parent 36a2b2f commit 17cfeb7

File tree

9 files changed

+459
-138
lines changed

9 files changed

+459
-138
lines changed

3rdparty/tvm

Submodule tvm updated from e3af400 to fc7ed0b

CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ file(GLOB TILE_LANG_SRCS
145145
src/target/intrin_rule*.cc
146146
)
147147

148+
# Always include CPU-safe runtime helpers
149+
list(APPEND TILE_LANG_SRCS
150+
src/runtime/error_helpers.cc
151+
)
152+
148153
# Track if the user explicitly selected a backend via cache options.
149154
set(TILELANG_BACKEND_USER_SELECTED OFF)
150155
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
@@ -206,7 +211,7 @@ elseif(USE_CUDA)
206211
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)
207212

208213
file(GLOB TILE_LANG_CUDA_SRCS
209-
src/runtime/*.cc
214+
src/runtime/runtime.cc
210215
src/target/ptx.cc
211216
src/target/codegen_cuda.cc
212217
src/target/rt_mod_cuda.cc

src/runtime/error_helpers.cc

Lines changed: 167 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
/*
22
* Helper functions for nicer runtime error messages.
33
*/
4+
#include "error_helpers.h"
5+
46
#include <tvm/ffi/c_api.h>
7+
#include <tvm/ffi/error.h>
8+
#include <tvm/ffi/function.h>
59
#include <tvm/ffi/reflection/registry.h>
610
#include <tvm/runtime/data_type.h>
11+
#include <tvm/runtime/device_api.h>
712

813
#include <sstream>
914
#include <string>
@@ -25,8 +30,9 @@ static int DTypeMismatch(const tvm::ffi::String &kernel_name,
2530
static_cast<int>(expect_bits),
2631
static_cast<int>(expect_lanes));
2732
std::ostringstream os;
28-
os << std::string(kernel_name) << ": dtype of " << std::string(buffer_name)
29-
<< " is expected to be " << expect << ", but got " << actual;
33+
os << "kernel " << std::string(kernel_name) << " input "
34+
<< std::string(buffer_name) << " dtype expected " << expect << ", but got "
35+
<< actual;
3036
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
3137
return -1;
3238
}
@@ -48,13 +54,169 @@ static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
4854
return -1;
4955
}
5056

51-
} // namespace tl
52-
} // namespace tvm
53-
57+
// Register packed versions, following the design in runtime.cc
5458
TVM_FFI_STATIC_INIT_BLOCK() {
5559
namespace refl = tvm::ffi::reflection;
60+
61+
// Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name,
62+
// actual_code, actual_bits, actual_lanes,
63+
// expect_code, expect_bits, expect_lanes)
64+
refl::GlobalDef().def_packed(
65+
tl::tvm_error_dtype_mismatch,
66+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
67+
ICHECK(args.size() == 8) << "Expected 8 args: kernel, buffer, "
68+
"actual_code, actual_bits, actual_lanes, "
69+
<< "expect_code, expect_bits, expect_lanes";
70+
71+
auto kernel_name = args[0].cast<tvm::ffi::String>();
72+
auto buffer_name = args[1].cast<tvm::ffi::String>();
73+
int64_t actual_code = args[2].cast<int64_t>();
74+
int64_t actual_bits = args[3].cast<int64_t>();
75+
int64_t actual_lanes = args[4].cast<int64_t>();
76+
int64_t expect_code = args[5].cast<int64_t>();
77+
int64_t expect_bits = args[6].cast<int64_t>();
78+
int64_t expect_lanes = args[7].cast<int64_t>();
79+
80+
// Reuse the helper to format the message
81+
(void)DTypeMismatch(kernel_name, buffer_name, actual_code, actual_bits,
82+
actual_lanes, expect_code, expect_bits,
83+
expect_lanes);
84+
// Provide a return value for completeness, then signal the error
85+
*ret = -1;
86+
throw ::tvm::ffi::EnvErrorAlreadySet();
87+
});
88+
89+
// kernel, buffer, expect:int64, got:int64
90+
refl::GlobalDef().def_packed(
91+
tl::tvm_error_ndim_mismatch,
92+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
93+
ICHECK(args.size() == 4)
94+
<< "__tvm_error_ndim_mismatch(kernel, buffer, expect, got)";
95+
auto kernel = args[0].cast<tvm::ffi::String>();
96+
auto buffer = args[1].cast<tvm::ffi::String>();
97+
int64_t expect = args[2].cast<int64_t>();
98+
int64_t got = args[3].cast<int64_t>();
99+
std::ostringstream os;
100+
os << "kernel " << std::string(kernel) << " input "
101+
<< std::string(buffer) << " ndim expected " << expect << ", but got "
102+
<< got;
103+
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
104+
*ret = -1;
105+
throw ::tvm::ffi::EnvErrorAlreadySet();
106+
});
107+
108+
// kernel, buffer, expect:int64, got:int64
109+
refl::GlobalDef().def_packed(
110+
tl::tvm_error_byte_offset_mismatch,
111+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
112+
ICHECK(args.size() == 4)
113+
<< "__tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)";
114+
auto kernel = args[0].cast<tvm::ffi::String>();
115+
auto buffer = args[1].cast<tvm::ffi::String>();
116+
int64_t expect = args[2].cast<int64_t>();
117+
int64_t got = args[3].cast<int64_t>();
118+
std::ostringstream os;
119+
os << "kernel " << std::string(kernel) << " input "
120+
<< std::string(buffer) << " byte_offset expected " << expect
121+
<< ", but got " << got;
122+
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
123+
*ret = -1;
124+
throw ::tvm::ffi::EnvErrorAlreadySet();
125+
});
126+
127+
// kernel, buffer, expect:int64, got:int64
128+
refl::GlobalDef().def_packed(
129+
tl::tvm_error_device_type_mismatch,
130+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
131+
ICHECK(args.size() == 4)
132+
<< "__tvm_error_device_type_mismatch(kernel, buffer, expect, got)";
133+
auto kernel = args[0].cast<tvm::ffi::String>();
134+
auto buffer = args[1].cast<tvm::ffi::String>();
135+
int64_t expect = args[2].cast<int64_t>();
136+
int64_t got = args[3].cast<int64_t>();
137+
const char *expect_str =
138+
tvm::runtime::DLDeviceType2Str(static_cast<int>(expect));
139+
const char *got_str =
140+
tvm::runtime::DLDeviceType2Str(static_cast<int>(got));
141+
std::ostringstream os;
142+
os << "kernel " << std::string(kernel) << " input "
143+
<< std::string(buffer) << " device_type expected " << expect_str
144+
<< ", but got " << got_str;
145+
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
146+
*ret = -1;
147+
throw ::tvm::ffi::EnvErrorAlreadySet();
148+
});
149+
150+
// kernel, buffer, field:String
151+
refl::GlobalDef().def_packed(
152+
tl::tvm_error_null_ptr,
153+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
154+
ICHECK(args.size() == 3)
155+
<< "__tvm_error_null_ptr(kernel, buffer, field)";
156+
auto kernel = args[0].cast<tvm::ffi::String>();
157+
auto buffer = args[1].cast<tvm::ffi::String>();
158+
auto field = args[2].cast<tvm::ffi::String>();
159+
std::ostringstream os;
160+
os << "kernel " << std::string(kernel) << " input "
161+
<< std::string(buffer) << ' ' << std::string(field)
162+
<< " expected non-NULL, but got NULL";
163+
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
164+
*ret = -1;
165+
throw ::tvm::ffi::EnvErrorAlreadySet();
166+
});
167+
168+
// kernel, buffer, field:String, expect:int64, got:int64
169+
refl::GlobalDef().def_packed(
170+
tl::tvm_error_expect_eq,
171+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
172+
ICHECK(args.size() == 5)
173+
<< "__tvm_error_expect_eq(kernel, buffer, field, expect, got)";
174+
auto kernel = args[0].cast<tvm::ffi::String>();
175+
auto buffer = args[1].cast<tvm::ffi::String>();
176+
auto field = args[2].cast<tvm::ffi::String>();
177+
int64_t expect = args[3].cast<int64_t>();
178+
int64_t got = args[4].cast<int64_t>();
179+
std::ostringstream os;
180+
os << "kernel " << std::string(kernel) << " input "
181+
<< std::string(buffer) << ' ' << std::string(field) << " expected "
182+
<< expect << ", but got " << got;
183+
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
184+
*ret = -1;
185+
throw ::tvm::ffi::EnvErrorAlreadySet();
186+
});
187+
188+
// kernel, buffer, field:String [, reason:String]
189+
refl::GlobalDef().def_packed(
190+
tl::tvm_error_constraint_violation,
191+
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
192+
ICHECK(args.size() == 3 || args.size() == 4)
193+
<< "__tvm_error_constraint_violation(kernel, buffer, field[, "
194+
"reason])";
195+
auto kernel = args[0].cast<tvm::ffi::String>();
196+
auto buffer = args[1].cast<tvm::ffi::String>();
197+
auto field = args[2].cast<tvm::ffi::String>();
198+
std::string reason;
199+
if (args.size() == 4) {
200+
reason = args[3].cast<tvm::ffi::String>();
201+
}
202+
std::ostringstream os;
203+
os << "kernel " << std::string(kernel) << " input "
204+
<< std::string(buffer) << ' ' << std::string(field)
205+
<< " constraint not satisfied";
206+
if (!reason.empty()) {
207+
os << ": " << reason;
208+
}
209+
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
210+
*ret = -1;
211+
throw ::tvm::ffi::EnvErrorAlreadySet();
212+
});
213+
214+
// Legacy typed registrations for backward compatibility
56215
refl::GlobalDef().def("tilelang_error_dtype_mismatch",
57216
&tvm::tl::DTypeMismatch);
58217
refl::GlobalDef().def("tilelang_error_dtype_mismatch2",
59218
&tvm::tl::DTypeMismatchNoNames);
60219
}
220+
221+
} // namespace tl
222+
} // namespace tvm

src/runtime/error_helpers.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*!
2+
* \file tl/runtime/error_helpers.h
3+
* \brief Error helper FFI names for TileLang runtime.
4+
*/
5+
6+
#ifndef TVM_TL_RUNTIME_ERROR_HELPERS_H_
7+
#define TVM_TL_RUNTIME_ERROR_HELPERS_H_
8+
9+
namespace tvm {
10+
namespace tl {
11+
12+
// Error helper packed functions
13+
constexpr const char *tvm_error_dtype_mismatch = "__tvm_error_dtype_mismatch";
14+
constexpr const char *tvm_error_ndim_mismatch = "__tvm_error_ndim_mismatch";
15+
constexpr const char *tvm_error_byte_offset_mismatch =
16+
"__tvm_error_byte_offset_mismatch";
17+
constexpr const char *tvm_error_device_type_mismatch =
18+
"__tvm_error_device_type_mismatch";
19+
constexpr const char *tvm_error_null_ptr = "__tvm_error_null_ptr";
20+
constexpr const char *tvm_error_expect_eq = "__tvm_error_expect_eq";
21+
constexpr const char *tvm_error_constraint_violation =
22+
"__tvm_error_constraint_violation";
23+
24+
} // namespace tl
25+
} // namespace tvm
26+
27+
#endif // TVM_TL_RUNTIME_ERROR_HELPERS_H_

src/target/codegen_c_host.cc

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -354,32 +354,44 @@ void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
354354
stream << "if (!(" << cond << ")) {\n";
355355
int assert_if_scope = this->BeginScope();
356356
{
357-
// Prepare the base error message
357+
// Prepare the base error message: allow StringImm or general PrimExpr
358358
const auto *msg_node = op->message.as<tvm::tir::StringImmNode>();
359-
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm";
360-
const std::string &raw_msg = msg_node->value;
361-
const std::string esc_msg = tvm::support::StrEscape(
362-
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
363-
/*escape_whitespace_special_chars=*/true);
364-
365-
// If the assertion is an equality check, append the actual LHS/RHS values
366-
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
367-
std::string lhs = PrintExpr(eq->a);
368-
std::string rhs = PrintExpr(eq->b);
369-
PrintIndent();
370-
stream << "char __tvm_assert_msg_buf[512];\n";
371-
PrintIndent();
372-
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
373-
"got: %lld\", \""
374-
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
375-
<< rhs << "));\n";
376-
PrintIndent();
377-
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
378-
"__tvm_assert_msg_buf);\n";
359+
bool msg_is_literal = (msg_node != nullptr);
360+
std::string esc_msg;
361+
std::string msg_expr;
362+
if (msg_is_literal) {
363+
const std::string &raw_msg = msg_node->value;
364+
esc_msg = tvm::support::StrEscape(
365+
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
366+
/*escape_whitespace_special_chars=*/true);
367+
} else {
368+
msg_expr = PrintExpr(op->message);
369+
}
370+
371+
// Only print expected/got values for equality when message is StringImm
372+
if (msg_is_literal) {
373+
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
374+
std::string lhs = PrintExpr(eq->a);
375+
std::string rhs = PrintExpr(eq->b);
376+
PrintIndent();
377+
stream << "char __tvm_assert_msg_buf[512];\n";
378+
PrintIndent();
379+
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
380+
"got: %lld\", \""
381+
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
382+
<< rhs << "));\n";
383+
PrintIndent();
384+
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
385+
"__tvm_assert_msg_buf);\n";
386+
} else {
387+
PrintIndent();
388+
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \""
389+
<< esc_msg << "\");\n";
390+
}
379391
} else {
380392
PrintIndent();
381-
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg
382-
<< "\");\n";
393+
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " << msg_expr
394+
<< ");\n";
383395
}
384396
}
385397
PrintIndent();

0 commit comments

Comments
 (0)