Skip to content

Commit f47f0d4

Browse files
committed
fused_moe and moe_align_block_size benchmarks
1 parent 16b3f49 commit f47f0d4

File tree

6 files changed

+33
-131
lines changed

6 files changed

+33
-131
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
timeout-minutes: 20
5656
run: |
5757
docker exec -w /root/sglang ci_sglang_xpu \
58-
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py && python3 bench_fused_moe.py "
58+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py && python3 benchmark_fused_moe.py "
5959
6060
- name: Run E2E Bfloat16 tests
6161
timeout-minutes: 20

benchmark/bench_moe_align_block_size.py

Lines changed: 26 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import triton
66
import triton.language as tl
77
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
8-
from vllm import _custom_ops as ops
98

109
USE_RANDOM_PERM = False
1110

@@ -143,102 +142,63 @@ def moe_align_block_size_triton(
143142
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
144143
topk_ids = torch.stack(
145144
[
146-
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
145+
torch.randperm(num_experts, dtype=torch.int32, device="xpu")[:topk]
147146
for _ in range(num_tokens)
148147
]
149148
)
150149

151150
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
152-
sorted_ids_cuda = torch.empty(
151+
sorted_ids_xpu = torch.empty(
153152
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
154153
)
155-
sorted_ids_cuda.fill_(topk_ids.numel())
154+
sorted_ids_xpu.fill_(topk_ids.numel())
156155
max_num_m_blocks = max_num_tokens_padded // block_size
157-
expert_ids_cuda = torch.zeros(
156+
expert_ids_xpu = torch.zeros(
158157
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
159158
)
160-
num_tokens_post_pad_cuda = torch.empty(
159+
num_tokens_post_pad_xpu = torch.empty(
161160
(1), dtype=torch.int32, device=topk_ids.device
162161
)
163-
token_cnts_buffer = torch.zeros(
164-
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
165-
)
166162
cumsum_buffer = torch.zeros(
167-
num_experts + 1, dtype=torch.int32, device=topk_ids.device
163+
num_experts + 2, dtype=torch.int32, device=topk_ids.device
168164
)
169165

170-
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
166+
sorted_ids_triton = torch.empty_like(sorted_ids_xpu)
171167
sorted_ids_triton.fill_(topk_ids.numel())
172-
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
173-
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
174-
175-
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
176-
sorted_ids_vllm.fill_(topk_ids.numel())
177-
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
178-
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
168+
expert_ids_triton = torch.zeros_like(expert_ids_xpu)
169+
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_xpu)
179170

180-
# compare the performance of cuda, triton and vllm implementation
171+
# compare the performance of xpu and triton implementation
181172
sgl_moe_align_block_size(
182173
topk_ids,
183-
num_experts,
174+
num_experts + 1,
184175
block_size,
185-
sorted_ids_cuda,
186-
expert_ids_cuda,
187-
num_tokens_post_pad_cuda,
188-
token_cnts_buffer,
176+
sorted_ids_xpu,
177+
expert_ids_xpu,
178+
num_tokens_post_pad_xpu,
189179
cumsum_buffer,
180+
False,
190181
)
191182
moe_align_block_size_triton(
192183
topk_ids,
193-
num_experts,
184+
num_experts + 1,
194185
block_size,
195186
sorted_ids_triton,
196187
expert_ids_triton,
197188
num_tokens_post_pad_triton,
198189
)
199190

200-
try:
201-
ops.moe_align_block_size(
202-
topk_ids,
203-
num_experts,
204-
block_size,
205-
sorted_ids_vllm,
206-
expert_ids_vllm,
207-
num_tokens_post_pad_vllm,
208-
)
209-
print(f"✅ VLLM implementation works with {num_experts} experts!")
210-
vllm_works = True
211-
except RuntimeError as e:
212-
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
213-
vllm_works = False
214-
215-
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
216-
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
191+
if torch.allclose(expert_ids_xpu, expert_ids_triton) and torch.allclose(
192+
num_tokens_post_pad_xpu, num_tokens_post_pad_triton
217193
):
218194
print("✅ SGL and Triton implementations match")
219195
else:
220196
print("❌ SGL and Triton implementations do not match")
221-
print("SGL expert_ids:", expert_ids_cuda)
197+
print("SGL expert_ids:", expert_ids_xpu)
222198
print("Triton expert_ids:", expert_ids_triton)
223-
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
199+
print("SGL num_tokens_post_pad:", num_tokens_post_pad_xpu)
224200
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
225201

226-
if (
227-
vllm_works
228-
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
229-
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
230-
):
231-
print("✅ SGL and VLLM implementations match")
232-
else:
233-
if not vllm_works:
234-
print("⚠️ VLLM comparison skipped due to failure")
235-
else:
236-
print("❌ SGL and VLLM implementations do not match")
237-
print("SGL expert_ids:", expert_ids_cuda)
238-
print("VLLM expert_ids:", expert_ids_vllm)
239-
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
240-
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
241-
242202

243203
# Test range
244204
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
@@ -249,9 +209,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
249209

250210

251211
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
252-
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
212+
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="xpu")
253213
for i in range(num_tokens):
254-
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
214+
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="xpu")[
255215
:topk
256216
]
257217
return topk_ids
@@ -262,8 +222,8 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
262222
x_names=["num_tokens", "num_experts", "topk"],
263223
x_vals=configs,
264224
line_arg="provider",
265-
line_vals=["sgl", "triton", "vllm"],
266-
line_names=["SGL", "Triton", "VLLM"],
225+
line_vals=["sgl", "triton"],
226+
line_names=["SGL", "Triton"],
267227
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
268228
ylabel="us",
269229
plot_name="moe-align-block-size-performance",
@@ -281,7 +241,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
281241
num_experts,
282242
(num_tokens, topk),
283243
dtype=torch.int32,
284-
device="cuda",
244+
device="xpu",
285245
)
286246

287247
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
@@ -306,11 +266,6 @@ def sgl_moe_align_block_size_with_empty(
306266
expert_ids,
307267
num_tokens_post_pad,
308268
):
309-
token_cnts_buffer = torch.empty(
310-
(num_experts + 1) * num_experts,
311-
dtype=torch.int32,
312-
device=topk_ids.device,
313-
)
314269
cumsum_buffer = torch.empty(
315270
num_experts + 1, dtype=torch.int32, device=topk_ids.device
316271
)
@@ -322,8 +277,8 @@ def sgl_moe_align_block_size_with_empty(
322277
sorted_ids.clone(),
323278
expert_ids.clone(),
324279
num_tokens_post_pad.clone(),
325-
token_cnts_buffer,
326280
cumsum_buffer,
281+
False,
327282
)
328283

329284
ms, min_ms, max_ms = triton.testing.do_bench(
@@ -349,23 +304,6 @@ def sgl_moe_align_block_size_with_empty(
349304
),
350305
quantiles=quantiles,
351306
)
352-
else: # vllm
353-
try:
354-
ms, min_ms, max_ms = triton.testing.do_bench(
355-
lambda: ops.moe_align_block_size(
356-
topk_ids,
357-
num_experts,
358-
block_size,
359-
sorted_ids.clone(),
360-
expert_ids.clone(),
361-
num_tokens_post_pad.clone(),
362-
),
363-
quantiles=quantiles,
364-
)
365-
except RuntimeError as e:
366-
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
367-
# Return extreme values to indicate failure in the chart
368-
return float("inf"), float("inf"), float("inf")
369307

370308
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
371309

python/sgl_kernel/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
fp8_blockwise_scaled_grouped_mm,
5454
fused_experts,
5555
moe_align_block_size,
56-
moe_align_block_size_impl,
5756
moe_fused_gate,
5857
moe_sum,
5958
moe_sum_reduce,

python/sgl_kernel/moe.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55

6-
def moe_align_block_size_impl(
6+
def moe_align_block_size(
77
topk_ids,
88
num_experts,
99
block_size,
@@ -25,43 +25,6 @@ def moe_align_block_size_impl(
2525
)
2626

2727

28-
def moe_align_block_size(
29-
topk_ids,
30-
num_experts,
31-
block_size,
32-
pad_sorted_token_ids=False,
33-
):
34-
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
35-
36-
sorted_ids_xpu = torch.empty(
37-
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
38-
)
39-
if not pad_sorted_token_ids:
40-
sorted_ids_xpu.fill_(topk_ids.numel())
41-
max_num_m_blocks = max_num_tokens_padded // block_size
42-
expert_ids_xpu = torch.zeros(
43-
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
44-
)
45-
num_tokens_post_pad_xpu = torch.empty(
46-
(1), dtype=torch.int32, device=topk_ids.device
47-
)
48-
cumsum_buffer = torch.empty(
49-
num_experts + 2, dtype=torch.int32, device=topk_ids.device
50-
)
51-
moe_align_block_size_impl(
52-
topk_ids,
53-
num_experts + 1,
54-
block_size,
55-
sorted_ids_xpu,
56-
expert_ids_xpu,
57-
num_tokens_post_pad_xpu,
58-
cumsum_buffer,
59-
pad_sorted_token_ids,
60-
)
61-
62-
return sorted_ids_xpu, expert_ids_xpu, num_tokens_post_pad_xpu
63-
64-
6528
def topk_softmax(
6629
topk_weights: torch.Tensor,
6730
topk_ids: torch.Tensor,

tests/test_moe_align.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import triton
66
import triton.language as tl
7-
from sgl_kernel import moe_align_block_size_impl, moe_sum
7+
from sgl_kernel import moe_align_block_size, moe_sum
88

99

1010
def ceil_div(a, b):
@@ -180,7 +180,7 @@ def test_moe_align_block_size_compare_implementations(
180180
expert_ids_triton = torch.zeros_like(expert_ids_xpu)
181181
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_xpu)
182182

183-
moe_align_block_size_impl(
183+
moe_align_block_size(
184184
topk_ids,
185185
num_experts + 1,
186186
block_size,

tests/test_moe_gemm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def test_moe_gemm(num_tokens, topk, num_experts, hidden_size, intermediate_size)
9393
topk_ids,
9494
)
9595
# import pdb; pdb.set_trace()
96-
assert torch.allclose(torch_output, sglang_output, rtol=rtol, atol=atol * hidden_size)
96+
assert torch.allclose(
97+
torch_output, sglang_output, rtol=rtol, atol=atol * hidden_size
98+
)
9799

98100

99101
if __name__ == "__main__":

0 commit comments

Comments
 (0)