Skip to content

Commit 462eeb7

Browse files
authored
[WebLLM] Replace int64s with int32s in WebGPU kernels (#18361)
This PR replaces int64s with int32s in the argsort and parallel_sampling_from_prob kernels when the target is WebGPU (since WGSL does not currently support i64)
1 parent 31a24a4 commit 462eeb7

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

python/tvm/relax/backend/gpu_generic/sampling.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import math
2121
from typing import Callable, Optional
22+
import tvm
2223
from tvm.script import tir as T
2324
from tvm.tir import PrimFunc
2425

@@ -69,6 +70,9 @@ def gpu_multinomial_from_uniform(
6970
The generated function
7071
"""
7172

73+
target = tvm.target.Target.current()
74+
target_dtype = "int32" if "webgpu" in str(target) else "int64"
75+
7276
TX = T.int64(tx_len) # threadIdx.x
7377
TY = T.int64(ty_len) # threadIdx.y
7478

@@ -282,15 +286,16 @@ def parallel_sampling_from_prob(
282286
# at least one iteration
283287
while T.tvm_thread_invariant(
284288
(step_iter[()] == 0 or aggregate[()] < u - eps)
285-
and T.Cast("int64", step_iter[()]) < T.ceildiv(vocab_size, block_elem)
289+
and T.Cast(target_dtype, step_iter[()])
290+
< T.Cast(target_dtype, T.ceildiv(vocab_size, block_elem))
286291
):
287292
single_batch_sampling(
288293
prob,
289294
row_idx,
290295
vocab_size,
291296
ty,
292297
tx,
293-
T.Cast("int64", step_iter[()]),
298+
T.Cast(target_dtype, step_iter[()]),
294299
0.0,
295300
aggregate,
296301
u,

python/tvm/topi/gpu/sort.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,22 @@ def compare(a, b):
219219
upper_lim = ceil_log2(size)
220220

221221
def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count):
222-
first = ib.allocate("int64", (1,), name="first", scope="local")
223-
mid = ib.allocate("int64", (1,), name="mid", scope="local")
224-
last = ib.allocate("int64", (1,), name="last", scope="local")
225-
first[0] = tvm.te.max(0, diag - bCount)
226-
last[0] = tvm.te.min(diag, aCount)
222+
target = tvm.target.Target.current()
223+
is_webgpu = "webgpu" in str(target)
224+
target_dtype = "int32" if is_webgpu else "int64"
225+
226+
first = ib.allocate(target_dtype, (1,), name="first", scope="local")
227+
mid = ib.allocate(target_dtype, (1,), name="mid", scope="local")
228+
last = ib.allocate(target_dtype, (1,), name="last", scope="local")
229+
max_val = tvm.te.max(0, diag - bCount)
230+
min_val = tvm.te.min(diag, aCount)
231+
if is_webgpu:
232+
first[0] = cast(max_val, target_dtype)
233+
last[0] = cast(min_val, target_dtype)
234+
else:
235+
first[0] = max_val
236+
last[0] = min_val
237+
227238
with ib.while_loop(first[0] < last[0]):
228239
mid = (first[0] + last[0]) >> 1
229240
a = source[base_idx + (aStart + mid)]
@@ -250,10 +261,20 @@ def serial_merge(
250261
first,
251262
last,
252263
):
253-
i = ib.allocate("int64", (1,), name="i", scope="local")
254-
j = ib.allocate("int64", (1,), name="j", scope="local")
255-
i[0] = aStart + first
256-
j[0] = bStart + diag - last
264+
target = tvm.target.Target.current()
265+
is_webgpu = "webgpu" in str(target)
266+
target_dtype = "int32" if is_webgpu else "int64"
267+
i = ib.allocate(target_dtype, (1,), name="i", scope="local")
268+
j = ib.allocate(target_dtype, (1,), name="j", scope="local")
269+
i_val = aStart + first
270+
j_val = bStart + diag - last
271+
if is_webgpu:
272+
i[0] = cast(i_val, target_dtype)
273+
j[0] = cast(j_val, target_dtype)
274+
else:
275+
i[0] = i_val
276+
j[0] = j_val
277+
257278
with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) as count:
258279
i_idx = base_idx + i[0]
259280
j_idx = base_idx + j[0]
@@ -287,7 +308,9 @@ def assign_j():
287308
with ib.else_scope():
288309
assign_j()
289310

290-
with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") as l2_width:
311+
target = tvm.target.Target.current()
312+
target_dtype = "int32" if "webgpu" in str(target) else "int64"
313+
with ib.for_range(0, cast(upper_lim - lower_lim, target_dtype), dtype=target_dtype) as l2_width:
291314
width = 2 << (l2_width + lower_lim)
292315
# Define and launch the cuda kernel
293316
with ib.new_scope():
@@ -359,8 +382,10 @@ def merge(source, dest, source_idx, dest_idx):
359382
def mergesort(source, dest, source_idx, dest_idx, size, width, even):
360383
# calculate the start, mid, and end points of this section
361384
start = width * bz
362-
middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
363-
end = cast(tvm.te.min(start + width, size), "int64")
385+
target = tvm.target.Target.current()
386+
target_dtype = "int32" if "webgpu" in str(target) else "int64"
387+
middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), target_dtype)
388+
end = cast(tvm.te.min(start + width, size), target_dtype)
364389
with ib.if_scope(start < size):
365390
with ib.if_scope(nbx == 1):
366391
## merge the start->middle and middle->end arrays

tests/python/relax/test_backend_dispatch_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handl
103103
u: T.float32 = uniform_samples[bx, 0]
104104
aggregate[()] = T.Cast("float32", 0)
105105
step_iter[()] = 0
106-
while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
106+
while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < T.Cast("int64", (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512))):
107107
with T.block(""):
108108
T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()])
109109
T.writes(sample_id_local[()], aggregate[()])

0 commit comments

Comments
 (0)