@@ -207,6 +207,32 @@ def __call__(self, inputs, router_logits=None, gmm_tiling_configs=None):
207207 def _expert_parallel_forward_with_shard_map (
208208 self , inputs , router_logits , gmm_tiling_configs
209209 ):
210+ # 预先计算静态 tiling 参数(在 shard_map 外部)
211+ total_tokens , hidden_dim = inputs .shape
212+ m , k = total_tokens , hidden_dim
213+ n_gate = self .intermediate_dim
214+ n_down = hidden_dim
215+
216+ # 获取最优 tiling 配置
217+ optimal_tiling_gate = self ._get_tiling_from_configs (
218+ gmm_tiling_configs , m , k , n_gate , self .num_experts
219+ )
220+ optimal_tiling_down = self ._get_tiling_from_configs (
221+ gmm_tiling_configs , m , n_gate , n_down , self .num_experts
222+ )
223+
224+ # 转换为静态整数参数(在动态 m 值上使用最大值作为安全的静态值)
225+ static_tiling_gate = (
226+ min (optimal_tiling_gate [0 ], 16384 ), # 设置合理的最大值
227+ optimal_tiling_gate [1 ],
228+ optimal_tiling_gate [2 ],
229+ )
230+ static_tiling_down = (
231+ min (optimal_tiling_down [0 ], 16384 ),
232+ optimal_tiling_down [1 ],
233+ optimal_tiling_down [2 ],
234+ )
235+
210236 def _internal_moe_computation (
211237 hidden_states ,
212238 router_logits ,
@@ -253,15 +279,16 @@ def _internal_moe_computation(
253279 else :
254280 local_group_sizes = group_sizes
255281
256- # GMM
257- intermediate_output = self ._gmm_compute_with_sharded_weights (
282+ # GMM (使用预先计算的静态 tiling)
283+ intermediate_output = self ._gmm_compute_with_static_tiling (
258284 x ,
259285 local_group_sizes ,
260286 selected_experts ,
261287 w0_weights ,
262288 w1_weights ,
263289 wo_weights ,
264- gmm_tiling_configs ,
290+ static_tiling_gate ,
291+ static_tiling_down ,
265292 )
266293
267294 # EP Combine
@@ -301,59 +328,41 @@ def _internal_moe_computation(
301328 self .wo .value ,
302329 )
303330
304- def _gmm_compute_with_sharded_weights (
331+ def _gmm_compute_with_static_tiling (
305332 self ,
306333 x ,
307334 local_group_sizes ,
308335 selected_experts ,
309336 w0_kernel ,
310337 w1_kernel ,
311338 wo_kernel ,
312- gmm_tiling_configs ,
339+ static_tiling_gate ,
340+ static_tiling_down ,
313341 ):
314342 if x .shape [0 ] == 0 :
315343 empty_output = jnp .zeros (
316344 (0 , wo_kernel .shape [- 1 ]), dtype = x .dtype
317345 ) # (0, hidden_dim)
318346 return empty_output
319347
320- m , k = x .shape [0 ], x .shape [1 ]
321- n_gate = w0_kernel .shape [2 ]
322- n_down = wo_kernel .shape [2 ]
323-
324- optimal_tiling_gate = self ._get_tiling_from_configs (
325- gmm_tiling_configs , m , k , n_gate , self .num_experts
326- )
327- optimal_tiling_down = self ._get_tiling_from_configs (
328- gmm_tiling_configs , m , n_gate , n_down , self .num_experts
329- )
330-
331- # Use JAX operations for tiling parameters (cannot use int() on tracers)
332- # tiling_gate = (
333- # jnp.minimum(optimal_tiling_gate[0], m),
334- # jnp.minimum(optimal_tiling_gate[1], k),
335- # jnp.minimum(optimal_tiling_gate[2], n_gate),
336- # )
337- # tiling_down = (
338- # jnp.minimum(optimal_tiling_down[0], m),
339- # jnp.minimum(optimal_tiling_down[1], n_gate),
340- # jnp.minimum(optimal_tiling_down[2], n_down),
341- # )
348+ # 直接使用预先计算好的静态 tiling 参数
349+ tiling_gate = static_tiling_gate
350+ tiling_down = static_tiling_down
342351 # gate
343352 layer_w0 = gmm (
344353 lhs = x ,
345354 rhs = w0_kernel ,
346355 group_sizes = local_group_sizes ,
347356 preferred_element_type = self .dtype ,
348- tiling = optimal_tiling_gate ,
357+ tiling = tiling_gate ,
349358 )
350359 # up
351360 layer_w1 = gmm (
352361 lhs = x ,
353362 rhs = w1_kernel ,
354363 group_sizes = local_group_sizes ,
355364 preferred_element_type = self .dtype ,
356- tiling = optimal_tiling_gate ,
365+ tiling = tiling_gate ,
357366 )
358367
359368 # activation
@@ -366,7 +375,7 @@ def _gmm_compute_with_sharded_weights(
366375 rhs = wo_kernel ,
367376 group_sizes = local_group_sizes ,
368377 preferred_element_type = self .dtype ,
369- tiling = optimal_tiling_down ,
378+ tiling = tiling_down ,
370379 )
371380
372381 return intermediate_output
@@ -381,13 +390,33 @@ def _single_device_forward(self, inputs, router_logits, gmm_tiling_configs):
381390
382391 top_k_weights = top_k_weights / jnp .sum (top_k_weights , axis = - 1 , keepdims = True )
383392
384- return self ._single_device_forward_impl (
385- inputs , top_k_indices , top_k_weights , gmm_tiling_configs
393+ # 为单设备也预先计算静态 tiling 参数
394+ total_tokens , hidden_dim = inputs .shape
395+ m , k = total_tokens , hidden_dim
396+ n_gate = self .intermediate_dim
397+ n_down = hidden_dim
398+
399+ optimal_tiling_gate = self ._get_tiling_from_configs (
400+ gmm_tiling_configs , m , k , n_gate , self .num_experts
401+ )
402+ optimal_tiling_down = self ._get_tiling_from_configs (
403+ gmm_tiling_configs , m , n_gate , n_down , self .num_experts
404+ )
405+
406+ static_tiling_gate = (
407+ min (optimal_tiling_gate [0 ], 16384 ),
408+ optimal_tiling_gate [1 ],
409+ optimal_tiling_gate [2 ],
410+ )
411+ static_tiling_down = (
412+ min (optimal_tiling_down [0 ], 16384 ),
413+ optimal_tiling_down [1 ],
414+ optimal_tiling_down [2 ],
386415 )
387416
388- def _single_device_forward_impl (
389- self , inputs , top_k_indices , top_k_weights , gmm_tiling_configs
390- ):
417+ return self . _single_device_forward_impl (inputs , top_k_indices , top_k_weights )
418+
419+ def _single_device_forward_impl ( self , inputs , top_k_indices , top_k_weights ):
391420 num_tokens = inputs .shape [0 ] * (inputs .shape [1 ] if inputs .ndim > 1 else 1 )
392421 inputs_flat = inputs .reshape (num_tokens , - 1 )
393422
0 commit comments