33import torch
44import triton
55from sgl_kernel import topk_softmax
6+ from utils import get_model_config , parse_args
67
78
89def vllm_topk_softmax (gating_output , topk ):
@@ -23,7 +24,35 @@ def vllm_topk_softmax(gating_output, topk):
2324 return topk_weights , topk_indices
2425
2526
26- def sglang_topk_softmax (gating_output , topk ):
27+ def navtive_topk_softmax (
28+ gating_output : torch .Tensor ,
29+ topk : int ,
30+ renormalize : bool ,
31+ ):
32+ num_tokens , num_experts = gating_output .shape
33+
34+ import torch .nn .functional as F
35+
36+ topk_weights = torch .empty (
37+ (num_tokens , topk ), device = gating_output .device , dtype = torch .float32
38+ )
39+ topk_indices = torch .empty (
40+ (num_tokens , topk ), dtype = torch .int32 , device = gating_output .device
41+ )
42+ topk_weights = F .softmax (gating_output .float (), dim = - 1 )
43+ topk_weights , topk_indices = torch .topk (topk_weights , topk , dim = - 1 )
44+
45+ if renormalize :
46+ topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
47+
48+ return topk_weights , topk_indices
49+
50+
51+ def sglang_topk_softmax (
52+ gating_output : torch .Tensor ,
53+ topk : int ,
54+ renormalize : bool ,
55+ ):
2756 num_tokens , num_experts = gating_output .shape
2857
2958 topk_weights = torch .empty (
@@ -37,18 +66,18 @@ def sglang_topk_softmax(gating_output, topk):
3766 )
3867
3968 topk_softmax (
40- topk_weights = topk_weights ,
41- topk_ids = topk_indices ,
42- token_expert_indices = token_expert_indices ,
43- gating_output = gating_output ,
69+ topk_weights ,
70+ topk_indices ,
71+ gating_output ,
72+ renormalize = renormalize ,
4473 )
4574
4675 return topk_weights , topk_indices
4776
4877
4978def calculate_diff (num_tokens , num_experts , topk ):
5079 gating_output = torch .randn (
51- (num_tokens , num_experts ), device = "cuda" , dtype = torch .float32
80+ (num_tokens , num_experts ), device = gating_output . device , dtype = torch .float32
5281 )
5382 weights_vllm , indices_vllm = vllm_topk_softmax (gating_output .clone (), topk )
5483 weights_sglang , indices_sglang = sglang_topk_softmax (gating_output .clone (), topk )
@@ -67,52 +96,67 @@ def calculate_diff(num_tokens, num_experts, topk):
6796 )
6897
6998
70- num_tokens_range = [128 , 512 , 1024 , 2048 , 4096 , 8192 , 16384 , 32768 ]
71- num_experts_range = [32 , 64 , 128 , 256 , 12 , 512 ]
72- topk_range = [1 , 2 , 4 , 8 ]
99+ def get_benchmark (device = "xpu" ):
100+ @triton .testing .perf_report (
101+ triton .testing .Benchmark (
102+ x_names = ["num_tokens" , "num_experts" , "topk" , "dtype" , "renormalize" ],
103+ x_vals = configs ,
104+ line_arg = "provider" ,
105+ line_vals = ["sglang" , "native" ],
106+ line_names = ["SGLang" , "native" ],
107+ styles = [("blue" , "-" ), ("green" , "-" )],
108+ ylabel = "Latency (us)" ,
109+ plot_name = "topk-softmax-performance" ,
110+ args = {},
111+ )
112+ )
113+ def benchmark (num_tokens , num_experts , topk , dtype , renormalize , provider ):
73114
74- configs = list (itertools .product (num_tokens_range , num_experts_range , topk_range ))
115+ gating_output = torch .randn (
116+ (num_tokens , num_experts ), device = device , dtype = dtype
117+ )
75118
119+ if provider == "sglang" or provider == "sglang1" :
120+ fn = lambda : sglang_topk_softmax (gating_output , topk , renormalize )
121+ elif provider == "native" :
122+ fn = lambda : navtive_topk_softmax (gating_output , topk , renormalize )
76123
77- @triton .testing .perf_report (
78- triton .testing .Benchmark (
79- x_names = ["num_tokens" , "num_experts" , "topk" ],
80- x_vals = configs ,
81- line_arg = "provider" ,
82- line_vals = ["sglang" , "vllm" ],
83- line_names = ["SGLang" , "VLLM" ],
84- styles = [("blue" , "-" ), ("green" , "-" )],
85- ylabel = "Latency (us)" ,
86- plot_name = "topk-softmax-performance" ,
87- args = {},
88- )
89- )
90- def benchmark (num_tokens , num_experts , topk , provider ):
124+ quantiles = [0.5 , 0.2 , 0.8 ]
125+ ms , min_ms , max_ms = triton .testing .do_bench (fn , quantiles = quantiles )
91126
92- gating_output = torch . randn (
93- ( num_tokens , num_experts ), device = "cuda" , dtype = torch . float32
94- )
127+ return 1000 * ms , 1000 * max_ms , 1000 * min_ms
128+
129+ return benchmark
95130
96- if provider == "vllm" or provider == "vllm1" :
97- fn = lambda : vllm_topk_softmax (gating_output , topk )
98- elif provider == "sglang" or provider == "sglang1" :
99- fn = lambda : sglang_topk_softmax (gating_output , topk )
100131
101- quantiles = [0.5 , 0.2 , 0.8 ]
102- ms , min_ms , max_ms = triton .testing .do_bench (fn , quantiles = quantiles )
132+ if __name__ == "__main__" :
133+ # Run correctness test on small configs if not using a real model
134+ args = parse_args ()
135+ params = get_model_config (args )
136+
137+ sweep_params = {
138+ "num_tokens" : args .num_tokens ,
139+ "num_experts" : params ["num_experts" ] or [64 ],
140+ "top_k" : params ["top_k" ] or [2 , 4 ],
141+ "dtype" : [torch .bfloat16 ],
142+ "renormalize" : [False ],
143+ }
144+
145+ keys = sweep_params .keys ()
146+ configs = list (itertools .product (* sweep_params .values ()))
147+ print (f"Testing { len (configs )} configurations..." )
148+ for config in configs :
149+ num_tokens , num_experts , topk , dtype , renormalize = config
150+ print (
151+ f"Config: num_tokens={ num_tokens } , num_experts={ num_experts } , topk={ topk } , dtype={ dtype } , renormalize={ renormalize } "
152+ )
103153
104- return 1000 * ms , 1000 * max_ms , 1000 * min_ms
154+ # calculate_diff(num_tokens, num_experts, topk)
105155
156+ global benchmark_configs
157+ benchmark_configs = configs
106158
107- if __name__ == "__main__" :
108- configs = [
109- (20 , 256 , 4 ),
110- (20 , 256 , 8 ),
111- (20 , 12 , 4 ),
112- (20 , 12 , 1 ),
113- (20 , 512 , 4 ),
114- (20 , 512 , 1 ),
115- ]
116- for num_tokens , num_experts , topk in configs :
117- calculate_diff (num_tokens , num_experts , topk )
118- benchmark .run (print_data = True )
159+ # Run benchmark
160+ print ("Starting performance benchmark..." )
161+ benchmark = get_benchmark ()
162+ benchmark .run (print_data = True , show_plots = False , save_path = "." )
0 commit comments