11/*
22 * Helper functions for nicer runtime error messages.
33 */
4+ #include " error_helpers.h"
5+
46#include < tvm/ffi/c_api.h>
7+ #include < tvm/ffi/error.h>
8+ #include < tvm/ffi/function.h>
59#include < tvm/ffi/reflection/registry.h>
610#include < tvm/runtime/data_type.h>
11+ #include < tvm/runtime/device_api.h>
712
813#include < sstream>
914#include < string>
@@ -25,8 +30,9 @@ static int DTypeMismatch(const tvm::ffi::String &kernel_name,
2530 static_cast <int >(expect_bits),
2631 static_cast <int >(expect_lanes));
2732 std::ostringstream os;
28- os << std::string (kernel_name) << " : dtype of " << std::string (buffer_name)
29- << " is expected to be " << expect << " , but got " << actual;
33+ os << " kernel " << std::string (kernel_name) << " input "
34+ << std::string (buffer_name) << " dtype expected " << expect << " , but got "
35+ << actual;
3036 TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
3137 return -1 ;
3238}
@@ -48,13 +54,169 @@ static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
4854 return -1 ;
4955}
5056
51- } // namespace tl
52- } // namespace tvm
53-
57+ // Register packed versions, following the design in runtime.cc
5458TVM_FFI_STATIC_INIT_BLOCK () {
5559 namespace refl = tvm::ffi::reflection;
60+
61+ // Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name,
62+ // actual_code, actual_bits, actual_lanes,
63+ // expect_code, expect_bits, expect_lanes)
64+ refl::GlobalDef ().def_packed (
65+ tl::tvm_error_dtype_mismatch,
66+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
67+ ICHECK (args.size () == 8 ) << " Expected 8 args: kernel, buffer, "
68+ " actual_code, actual_bits, actual_lanes, "
69+ << " expect_code, expect_bits, expect_lanes" ;
70+
71+ auto kernel_name = args[0 ].cast <tvm::ffi::String>();
72+ auto buffer_name = args[1 ].cast <tvm::ffi::String>();
73+ int64_t actual_code = args[2 ].cast <int64_t >();
74+ int64_t actual_bits = args[3 ].cast <int64_t >();
75+ int64_t actual_lanes = args[4 ].cast <int64_t >();
76+ int64_t expect_code = args[5 ].cast <int64_t >();
77+ int64_t expect_bits = args[6 ].cast <int64_t >();
78+ int64_t expect_lanes = args[7 ].cast <int64_t >();
79+
80+ // Reuse the helper to format the message
81+ (void )DTypeMismatch (kernel_name, buffer_name, actual_code, actual_bits,
82+ actual_lanes, expect_code, expect_bits,
83+ expect_lanes);
84+ // Provide a return value for completeness, then signal the error
85+ *ret = -1 ;
86+ throw ::tvm::ffi::EnvErrorAlreadySet ();
87+ });
88+
89+ // kernel, buffer, expect:int64, got:int64
90+ refl::GlobalDef ().def_packed (
91+ tl::tvm_error_ndim_mismatch,
92+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
93+ ICHECK (args.size () == 4 )
94+ << " __tvm_error_ndim_mismatch(kernel, buffer, expect, got)" ;
95+ auto kernel = args[0 ].cast <tvm::ffi::String>();
96+ auto buffer = args[1 ].cast <tvm::ffi::String>();
97+ int64_t expect = args[2 ].cast <int64_t >();
98+ int64_t got = args[3 ].cast <int64_t >();
99+ std::ostringstream os;
100+ os << " kernel " << std::string (kernel) << " input "
101+ << std::string (buffer) << " ndim expected " << expect << " , but got "
102+ << got;
103+ TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
104+ *ret = -1 ;
105+ throw ::tvm::ffi::EnvErrorAlreadySet ();
106+ });
107+
108+ // kernel, buffer, expect:int64, got:int64
109+ refl::GlobalDef ().def_packed (
110+ tl::tvm_error_byte_offset_mismatch,
111+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
112+ ICHECK (args.size () == 4 )
113+ << " __tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)" ;
114+ auto kernel = args[0 ].cast <tvm::ffi::String>();
115+ auto buffer = args[1 ].cast <tvm::ffi::String>();
116+ int64_t expect = args[2 ].cast <int64_t >();
117+ int64_t got = args[3 ].cast <int64_t >();
118+ std::ostringstream os;
119+ os << " kernel " << std::string (kernel) << " input "
120+ << std::string (buffer) << " byte_offset expected " << expect
121+ << " , but got " << got;
122+ TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
123+ *ret = -1 ;
124+ throw ::tvm::ffi::EnvErrorAlreadySet ();
125+ });
126+
127+ // kernel, buffer, expect:int64, got:int64
128+ refl::GlobalDef ().def_packed (
129+ tl::tvm_error_device_type_mismatch,
130+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
131+ ICHECK (args.size () == 4 )
132+ << " __tvm_error_device_type_mismatch(kernel, buffer, expect, got)" ;
133+ auto kernel = args[0 ].cast <tvm::ffi::String>();
134+ auto buffer = args[1 ].cast <tvm::ffi::String>();
135+ int64_t expect = args[2 ].cast <int64_t >();
136+ int64_t got = args[3 ].cast <int64_t >();
137+ const char *expect_str =
138+ tvm::runtime::DLDeviceType2Str (static_cast <int >(expect));
139+ const char *got_str =
140+ tvm::runtime::DLDeviceType2Str (static_cast <int >(got));
141+ std::ostringstream os;
142+ os << " kernel " << std::string (kernel) << " input "
143+ << std::string (buffer) << " device_type expected " << expect_str
144+ << " , but got " << got_str;
145+ TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
146+ *ret = -1 ;
147+ throw ::tvm::ffi::EnvErrorAlreadySet ();
148+ });
149+
150+ // kernel, buffer, field:String
151+ refl::GlobalDef ().def_packed (
152+ tl::tvm_error_null_ptr,
153+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
154+ ICHECK (args.size () == 3 )
155+ << " __tvm_error_null_ptr(kernel, buffer, field)" ;
156+ auto kernel = args[0 ].cast <tvm::ffi::String>();
157+ auto buffer = args[1 ].cast <tvm::ffi::String>();
158+ auto field = args[2 ].cast <tvm::ffi::String>();
159+ std::ostringstream os;
160+ os << " kernel " << std::string (kernel) << " input "
161+ << std::string (buffer) << ' ' << std::string (field)
162+ << " expected non-NULL, but got NULL" ;
163+ TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
164+ *ret = -1 ;
165+ throw ::tvm::ffi::EnvErrorAlreadySet ();
166+ });
167+
168+ // kernel, buffer, field:String, expect:int64, got:int64
169+ refl::GlobalDef ().def_packed (
170+ tl::tvm_error_expect_eq,
171+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
172+ ICHECK (args.size () == 5 )
173+ << " __tvm_error_expect_eq(kernel, buffer, field, expect, got)" ;
174+ auto kernel = args[0 ].cast <tvm::ffi::String>();
175+ auto buffer = args[1 ].cast <tvm::ffi::String>();
176+ auto field = args[2 ].cast <tvm::ffi::String>();
177+ int64_t expect = args[3 ].cast <int64_t >();
178+ int64_t got = args[4 ].cast <int64_t >();
179+ std::ostringstream os;
180+ os << " kernel " << std::string (kernel) << " input "
181+ << std::string (buffer) << ' ' << std::string (field) << " expected "
182+ << expect << " , but got " << got;
183+ TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
184+ *ret = -1 ;
185+ throw ::tvm::ffi::EnvErrorAlreadySet ();
186+ });
187+
188+ // kernel, buffer, field:String [, reason:String]
189+ refl::GlobalDef ().def_packed (
190+ tl::tvm_error_constraint_violation,
191+ [](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
192+ ICHECK (args.size () == 3 || args.size () == 4 )
193+ << " __tvm_error_constraint_violation(kernel, buffer, field[, "
194+ " reason])" ;
195+ auto kernel = args[0 ].cast <tvm::ffi::String>();
196+ auto buffer = args[1 ].cast <tvm::ffi::String>();
197+ auto field = args[2 ].cast <tvm::ffi::String>();
198+ std::string reason;
199+ if (args.size () == 4 ) {
200+ reason = args[3 ].cast <tvm::ffi::String>();
201+ }
202+ std::ostringstream os;
203+ os << " kernel " << std::string (kernel) << " input "
204+ << std::string (buffer) << ' ' << std::string (field)
205+ << " constraint not satisfied" ;
206+ if (!reason.empty ()) {
207+ os << " : " << reason;
208+ }
209+ TVMFFIErrorSetRaisedFromCStr (" RuntimeError" , os.str ().c_str ());
210+ *ret = -1 ;
211+ throw ::tvm::ffi::EnvErrorAlreadySet ();
212+ });
213+
214+ // Legacy typed registrations for backward compatibility
56215 refl::GlobalDef ().def (" tilelang_error_dtype_mismatch" ,
57216 &tvm::tl::DTypeMismatch);
58217 refl::GlobalDef ().def (" tilelang_error_dtype_mismatch2" ,
59218 &tvm::tl::DTypeMismatchNoNames);
60219}
220+
221+ } // namespace tl
222+ } // namespace tvm
0 commit comments