55import triton
66import triton .language as tl
77from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
8- from vllm import _custom_ops as ops
98
109USE_RANDOM_PERM = False
1110
@@ -143,102 +142,63 @@ def moe_align_block_size_triton(
143142def calculate_diff (num_tokens , num_experts = 256 , block_size = 128 , topk = 8 ):
144143 topk_ids = torch .stack (
145144 [
146- torch .randperm (num_experts , dtype = torch .int32 , device = "cuda " )[:topk ]
145+ torch .randperm (num_experts , dtype = torch .int32 , device = "xpu " )[:topk ]
147146 for _ in range (num_tokens )
148147 ]
149148 )
150149
151150 max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
152- sorted_ids_cuda = torch .empty (
151+ sorted_ids_xpu = torch .empty (
153152 (max_num_tokens_padded ,), dtype = torch .int32 , device = topk_ids .device
154153 )
155- sorted_ids_cuda .fill_ (topk_ids .numel ())
154+ sorted_ids_xpu .fill_ (topk_ids .numel ())
156155 max_num_m_blocks = max_num_tokens_padded // block_size
157- expert_ids_cuda = torch .zeros (
156+ expert_ids_xpu = torch .zeros (
158157 (max_num_m_blocks ,), dtype = torch .int32 , device = topk_ids .device
159158 )
160- num_tokens_post_pad_cuda = torch .empty (
159+ num_tokens_post_pad_xpu = torch .empty (
161160 (1 ), dtype = torch .int32 , device = topk_ids .device
162161 )
163- token_cnts_buffer = torch .zeros (
164- (num_experts + 1 ) * num_experts , dtype = torch .int32 , device = topk_ids .device
165- )
166162 cumsum_buffer = torch .zeros (
167- num_experts + 1 , dtype = torch .int32 , device = topk_ids .device
163+ num_experts + 2 , dtype = torch .int32 , device = topk_ids .device
168164 )
169165
170- sorted_ids_triton = torch .empty_like (sorted_ids_cuda )
166+ sorted_ids_triton = torch .empty_like (sorted_ids_xpu )
171167 sorted_ids_triton .fill_ (topk_ids .numel ())
172- expert_ids_triton = torch .zeros_like (expert_ids_cuda )
173- num_tokens_post_pad_triton = torch .empty_like (num_tokens_post_pad_cuda )
174-
175- sorted_ids_vllm = torch .empty_like (sorted_ids_cuda )
176- sorted_ids_vllm .fill_ (topk_ids .numel ())
177- expert_ids_vllm = torch .zeros_like (expert_ids_cuda )
178- num_tokens_post_pad_vllm = torch .empty_like (num_tokens_post_pad_cuda )
168+ expert_ids_triton = torch .zeros_like (expert_ids_xpu )
169+ num_tokens_post_pad_triton = torch .empty_like (num_tokens_post_pad_xpu )
179170
180- # compare the performance of cuda, triton and vllm implementation
171+ # compare the performance of xpu and triton implementation
181172 sgl_moe_align_block_size (
182173 topk_ids ,
183- num_experts ,
174+ num_experts + 1 ,
184175 block_size ,
185- sorted_ids_cuda ,
186- expert_ids_cuda ,
187- num_tokens_post_pad_cuda ,
188- token_cnts_buffer ,
176+ sorted_ids_xpu ,
177+ expert_ids_xpu ,
178+ num_tokens_post_pad_xpu ,
189179 cumsum_buffer ,
180+ False ,
190181 )
191182 moe_align_block_size_triton (
192183 topk_ids ,
193- num_experts ,
184+ num_experts + 1 ,
194185 block_size ,
195186 sorted_ids_triton ,
196187 expert_ids_triton ,
197188 num_tokens_post_pad_triton ,
198189 )
199190
200- try :
201- ops .moe_align_block_size (
202- topk_ids ,
203- num_experts ,
204- block_size ,
205- sorted_ids_vllm ,
206- expert_ids_vllm ,
207- num_tokens_post_pad_vllm ,
208- )
209- print (f"✅ VLLM implementation works with { num_experts } experts!" )
210- vllm_works = True
211- except RuntimeError as e :
212- print (f"❌ VLLM implementation failed with { num_experts } experts: { e } " )
213- vllm_works = False
214-
215- if torch .allclose (expert_ids_cuda , expert_ids_triton ) and torch .allclose (
216- num_tokens_post_pad_cuda , num_tokens_post_pad_triton
191+ if torch .allclose (expert_ids_xpu , expert_ids_triton ) and torch .allclose (
192+ num_tokens_post_pad_xpu , num_tokens_post_pad_triton
217193 ):
218194 print ("✅ SGL and Triton implementations match" )
219195 else :
220196 print ("❌ SGL and Triton implementations do not match" )
221- print ("SGL expert_ids:" , expert_ids_cuda )
197+ print ("SGL expert_ids:" , expert_ids_xpu )
222198 print ("Triton expert_ids:" , expert_ids_triton )
223- print ("SGL num_tokens_post_pad:" , num_tokens_post_pad_cuda )
199+ print ("SGL num_tokens_post_pad:" , num_tokens_post_pad_xpu )
224200 print ("Triton num_tokens_post_pad:" , num_tokens_post_pad_triton )
225201
226- if (
227- vllm_works
228- and torch .allclose (expert_ids_cuda , expert_ids_vllm )
229- and torch .allclose (num_tokens_post_pad_cuda , num_tokens_post_pad_vllm )
230- ):
231- print ("✅ SGL and VLLM implementations match" )
232- else :
233- if not vllm_works :
234- print ("⚠️ VLLM comparison skipped due to failure" )
235- else :
236- print ("❌ SGL and VLLM implementations do not match" )
237- print ("SGL expert_ids:" , expert_ids_cuda )
238- print ("VLLM expert_ids:" , expert_ids_vllm )
239- print ("SGL num_tokens_post_pad:" , num_tokens_post_pad_cuda )
240- print ("VLLM num_tokens_post_pad:" , num_tokens_post_pad_vllm )
241-
242202
243203# Test range
244204num_tokens_range = [1 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 ]
@@ -249,9 +209,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
249209
250210
251211def get_topk_ids (num_tokens : int , num_experts : int , topk : int ) -> torch .Tensor :
252- topk_ids = torch .zeros ((num_tokens , topk ), dtype = torch .int32 , device = "cuda " )
212+ topk_ids = torch .zeros ((num_tokens , topk ), dtype = torch .int32 , device = "xpu " )
253213 for i in range (num_tokens ):
254- topk_ids [i , :] = torch .randperm (num_experts , dtype = torch .int32 , device = "cuda " )[
214+ topk_ids [i , :] = torch .randperm (num_experts , dtype = torch .int32 , device = "xpu " )[
255215 :topk
256216 ]
257217 return topk_ids
@@ -262,8 +222,8 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
262222 x_names = ["num_tokens" , "num_experts" , "topk" ],
263223 x_vals = configs ,
264224 line_arg = "provider" ,
265- line_vals = ["sgl" , "triton" , "vllm" ],
266- line_names = ["SGL" , "Triton" , "VLLM" ],
225+ line_vals = ["sgl" , "triton" ],
226+ line_names = ["SGL" , "Triton" ],
267227 styles = [("blue" , "-" ), ("red" , "-" ), ("green" , "-" )],
268228 ylabel = "us" ,
269229 plot_name = "moe-align-block-size-performance" ,
@@ -281,7 +241,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
281241 num_experts ,
282242 (num_tokens , topk ),
283243 dtype = torch .int32 ,
284- device = "cuda " ,
244+ device = "xpu " ,
285245 )
286246
287247 max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
@@ -306,11 +266,6 @@ def sgl_moe_align_block_size_with_empty(
306266 expert_ids ,
307267 num_tokens_post_pad ,
308268 ):
309- token_cnts_buffer = torch .empty (
310- (num_experts + 1 ) * num_experts ,
311- dtype = torch .int32 ,
312- device = topk_ids .device ,
313- )
314269 cumsum_buffer = torch .empty (
315270 num_experts + 1 , dtype = torch .int32 , device = topk_ids .device
316271 )
@@ -322,8 +277,8 @@ def sgl_moe_align_block_size_with_empty(
322277 sorted_ids .clone (),
323278 expert_ids .clone (),
324279 num_tokens_post_pad .clone (),
325- token_cnts_buffer ,
326280 cumsum_buffer ,
281+ False ,
327282 )
328283
329284 ms , min_ms , max_ms = triton .testing .do_bench (
@@ -349,23 +304,6 @@ def sgl_moe_align_block_size_with_empty(
349304 ),
350305 quantiles = quantiles ,
351306 )
352- else : # vllm
353- try :
354- ms , min_ms , max_ms = triton .testing .do_bench (
355- lambda : ops .moe_align_block_size (
356- topk_ids ,
357- num_experts ,
358- block_size ,
359- sorted_ids .clone (),
360- expert_ids .clone (),
361- num_tokens_post_pad .clone (),
362- ),
363- quantiles = quantiles ,
364- )
365- except RuntimeError as e :
366- print (f"❌ VLLM benchmark failed with { num_experts } experts: { e } " )
367- # Return extreme values to indicate failure in the chart
368- return float ("inf" ), float ("inf" ), float ("inf" )
369307
370308 return 1000 * ms , 1000 * max_ms , 1000 * min_ms
371309
0 commit comments