@@ -375,48 +375,33 @@ def fused_experts(
375375 else :
376376 out_hidden_states = torch .zeros_like (hidden_states )
377377
378- # import pdb; pdb.set_trace()
379- idxs = topk_ids .flatten ().argsort ()
380- counts = topk_ids .flatten ().to (torch .long ).bincount (minlength = E ).cpu ().numpy ()
381- tokens_per_expert = counts .cumsum ()
382- num_per_tok = TopK
383- token_idxs = idxs // num_per_tok
384- offset = []
378+ flat_topk = topk_ids .flatten ()
379+ idxs = flat_topk .argsort ()
380+ sorted_expert_ids = flat_topk [idxs ]
381+
382+ counts = torch .bincount (sorted_expert_ids , minlength = E ) # [E]
383+ token_idxs = idxs // TopK # [num_tokens * TopK]
385384 input_A = torch .empty (
386385 (num_tokens * TopK , K ), device = hidden_states .device , dtype = hidden_states .dtype
387386 )
388- for expert_id , end_idx in enumerate (tokens_per_expert ):
389- start_idx = 0 if expert_id == 0 else tokens_per_expert [expert_id - 1 ]
390- offset .append (end_idx - start_idx )
391- if start_idx == end_idx :
392- continue
393- exp_token_idxs = token_idxs [start_idx :end_idx ]
394- # expert_tokens = hidden_states[exp_token_idxs]
395- # grouped_input_A.append(expert_tokens)
396- input_A [start_idx :end_idx , :].copy_ (hidden_states [exp_token_idxs ].squeeze (1 ))
397- offset = torch .tensor (offset , device = "xpu" , dtype = torch .int32 )
387+ input_A = hidden_states [token_idxs ].squeeze (1 )
388+ offset = counts .to (torch .int32 )
398389
399- # import pdb; pdb.set_trace()
400390 torch .ops .sgl_kernel .moe_grouped_mm_nt (intermediate_cache1 , input_A , w1 , offset , E )
401391
402- gate , up_ = torch .split (intermediate_cache1 , N , dim = 1 )
403- act = torch .nn .SiLU ()
404- intermediate_cache2 = act (gate ) * up_
392+ torch .ops .sgl_kernel .silu_and_mul (intermediate_cache2 , intermediate_cache1 )
405393
406394 torch .ops .sgl_kernel .moe_grouped_mm_nt (
407395 intermediate_cache3 , intermediate_cache2 .contiguous (), w2 , offset , E
408396 )
409- for expert_id , end_idx in enumerate (tokens_per_expert ):
410- start_idx = 0 if expert_id == 0 else tokens_per_expert [expert_id - 1 ]
411- if start_idx == end_idx :
412- continue
413-
414- exp_token_idxs = token_idxs [start_idx :end_idx ]
415- expert_out = intermediate_cache3 [start_idx :end_idx ]
416- expert_out .mul_ (topk_weights .view (- 1 , 1 )[idxs [start_idx :end_idx ]])
417- # import pdb; pdb.set_trace()
418- out_hidden_states .scatter_reduce_ (
419- 0 , exp_token_idxs .view (- 1 , 1 ).repeat (1 , OutK ), expert_out , reduce = "sum"
420- )
397+
398+ flat_weights = topk_weights .to (intermediate_cache3 .dtype ).flatten ()[idxs ] # [N]
399+ intermediate_cache3 = intermediate_cache3 * flat_weights .unsqueeze (1 )
400+ out_hidden_states .scatter_reduce_ (
401+ 0 ,
402+ token_idxs .view (- 1 , 1 ).expand (- 1 , OutK ),
403+ intermediate_cache3 ,
404+ reduce = "sum" ,
405+ )
421406
422407 return out_hidden_states
0 commit comments