Skip to content

Commit 0496697

Browse files
committed
refactor test tolerance
1 parent f47f0d4 commit 0496697

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

python/sgl_kernel/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def fused_experts(
355355
torch.ops.sgl_kernel.silu_and_mul(intermediate_cache2, intermediate_cache1)
356356

357357
torch.ops.sgl_kernel.moe_grouped_mm_nt(
358-
intermediate_cache3, intermediate_cache2.contiguous(), w2, offset, E
358+
intermediate_cache3, intermediate_cache2, w2, offset, E
359359
)
360360

361361
flat_weights = topk_weights.to(intermediate_cache3.dtype).flatten()[idxs] # [N]

tests/test_moe_gemm.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def create_random_xpu_tensor(shape, dtype, mean=0, std=0.01):
2323
Returns:
2424
torch.Tensor: Randomly initialized xpu tensor
2525
"""
26-
return torch.randn(shape, device="xpu").to(dtype)
26+
return torch.empty(shape, dtype=dtype, device="xpu").normal_(mean, std)
2727

2828

2929
def torch_naive_moe(
@@ -65,7 +65,7 @@ def torch_naive_moe(
6565
),
6666
)
6767
def test_moe_gemm(num_tokens, topk, num_experts, hidden_size, intermediate_size):
68-
rtol, atol = 2e-2, 2e-1
68+
rtol, atol = 1e-1, 1e-2
6969
a = create_random_xpu_tensor((num_tokens, hidden_size), torch.bfloat16)
7070
w1 = create_random_xpu_tensor(
7171
(num_experts, 2 * intermediate_size, hidden_size), torch.bfloat16
@@ -93,9 +93,7 @@ def test_moe_gemm(num_tokens, topk, num_experts, hidden_size, intermediate_size)
9393
topk_ids,
9494
)
9595
# import pdb; pdb.set_trace()
96-
assert torch.allclose(
97-
torch_output, sglang_output, rtol=rtol, atol=atol * hidden_size
98-
)
96+
torch.testing.assert_close(torch_output, sglang_output, rtol=rtol, atol=atol)
9997

10098

10199
if __name__ == "__main__":

0 commit comments

Comments
 (0)