@@ -219,11 +219,22 @@ def compare(a, b):
219219 upper_lim = ceil_log2 (size )
220220
221221 def get_merge_begin (source , base_idx , aCount , bCount , aStart , bStart , diag , step_count ):
222- first = ib .allocate ("int64" , (1 ,), name = "first" , scope = "local" )
223- mid = ib .allocate ("int64" , (1 ,), name = "mid" , scope = "local" )
224- last = ib .allocate ("int64" , (1 ,), name = "last" , scope = "local" )
225- first [0 ] = tvm .te .max (0 , diag - bCount )
226- last [0 ] = tvm .te .min (diag , aCount )
222+ target = tvm .target .Target .current ()
223+ is_webgpu = "webgpu" in str (target )
224+ target_dtype = "int32" if is_webgpu else "int64"
225+
226+ first = ib .allocate (target_dtype , (1 ,), name = "first" , scope = "local" )
227+ mid = ib .allocate (target_dtype , (1 ,), name = "mid" , scope = "local" )
228+ last = ib .allocate (target_dtype , (1 ,), name = "last" , scope = "local" )
229+ max_val = tvm .te .max (0 , diag - bCount )
230+ min_val = tvm .te .min (diag , aCount )
231+ if is_webgpu :
232+ first [0 ] = cast (max_val , target_dtype )
233+ last [0 ] = cast (min_val , target_dtype )
234+ else :
235+ first [0 ] = max_val
236+ last [0 ] = min_val
237+
227238 with ib .while_loop (first [0 ] < last [0 ]):
228239 mid = (first [0 ] + last [0 ]) >> 1
229240 a = source [base_idx + (aStart + mid )]
@@ -250,10 +261,20 @@ def serial_merge(
250261 first ,
251262 last ,
252263 ):
253- i = ib .allocate ("int64" , (1 ,), name = "i" , scope = "local" )
254- j = ib .allocate ("int64" , (1 ,), name = "j" , scope = "local" )
255- i [0 ] = aStart + first
256- j [0 ] = bStart + diag - last
264+ target = tvm .target .Target .current ()
265+ is_webgpu = "webgpu" in str (target )
266+ target_dtype = "int32" if is_webgpu else "int64"
267+ i = ib .allocate (target_dtype , (1 ,), name = "i" , scope = "local" )
268+ j = ib .allocate (target_dtype , (1 ,), name = "j" , scope = "local" )
269+ i_val = aStart + first
270+ j_val = bStart + diag - last
271+ if is_webgpu :
272+ i [0 ] = cast (i_val , target_dtype )
273+ j [0 ] = cast (j_val , target_dtype )
274+ else :
275+ i [0 ] = i_val
276+ j [0 ] = j_val
277+
257278 with ib .for_range (0 , tvm .te .min (aCount + bCount - diag , step_count )) as count :
258279 i_idx = base_idx + i [0 ]
259280 j_idx = base_idx + j [0 ]
@@ -287,7 +308,9 @@ def assign_j():
287308 with ib .else_scope ():
288309 assign_j ()
289310
290- with ib .for_range (0 , cast (upper_lim - lower_lim , "int64" ), dtype = "int64" ) as l2_width :
311+ target = tvm .target .Target .current ()
312+ target_dtype = "int32" if "webgpu" in str (target ) else "int64"
313+ with ib .for_range (0 , cast (upper_lim - lower_lim , target_dtype ), dtype = target_dtype ) as l2_width :
291314 width = 2 << (l2_width + lower_lim )
292315 # Define and launch the cuda kernel
293316 with ib .new_scope ():
@@ -359,8 +382,10 @@ def merge(source, dest, source_idx, dest_idx):
359382 def mergesort (source , dest , source_idx , dest_idx , size , width , even ):
360383 # calculate the start, mid, and end points of this section
361384 start = width * bz
362- middle = cast (tvm .te .min (start + tvm .tir .indexdiv (width , 2 ), size ), "int64" )
363- end = cast (tvm .te .min (start + width , size ), "int64" )
385+ target = tvm .target .Target .current ()
386+ target_dtype = "int32" if "webgpu" in str (target ) else "int64"
387+ middle = cast (tvm .te .min (start + tvm .tir .indexdiv (width , 2 ), size ), target_dtype )
388+ end = cast (tvm .te .min (start + width , size ), target_dtype )
364389 with ib .if_scope (start < size ):
365390 with ib .if_scope (nbx == 1 ):
366391 ## merge the start->middle and middle->end arrays
0 commit comments