@@ -23,7 +23,7 @@ def create_random_xpu_tensor(shape, dtype, mean=0, std=0.01):
2323 Returns:
2424 torch.Tensor: Randomly initialized xpu tensor
2525 """
26- return torch .randn (shape , device = "xpu" ).to ( dtype )
26+ return torch .empty (shape , dtype = dtype , device = "xpu" ).normal_ ( mean , std )
2727
2828
2929def torch_naive_moe (
@@ -65,7 +65,7 @@ def torch_naive_moe(
6565 ),
6666)
6767def test_moe_gemm (num_tokens , topk , num_experts , hidden_size , intermediate_size ):
68- rtol , atol = 2e-2 , 2e-1
68+ rtol , atol = 1e-1 , 1e-2
6969 a = create_random_xpu_tensor ((num_tokens , hidden_size ), torch .bfloat16 )
7070 w1 = create_random_xpu_tensor (
7171 (num_experts , 2 * intermediate_size , hidden_size ), torch .bfloat16
@@ -93,9 +93,7 @@ def test_moe_gemm(num_tokens, topk, num_experts, hidden_size, intermediate_size)
9393 topk_ids ,
9494 )
9595 # import pdb; pdb.set_trace()
96- assert torch .allclose (
97- torch_output , sglang_output , rtol = rtol , atol = atol * hidden_size
98- )
96+ torch .testing .assert_close (torch_output , sglang_output , rtol = rtol , atol = atol )
9997
10098
10199if __name__ == "__main__" :
0 commit comments