Skip to content

Commit 61b7ed6

Browse files
committed
gmm tuning
1 parent 1ad5223 commit 61b7ed6

File tree

1 file changed

+64
-35
lines changed
  • python/sgl_jax/srt/layers

1 file changed

+64
-35
lines changed

python/sgl_jax/srt/layers/moe.py

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,32 @@ def __call__(self, inputs, router_logits=None, gmm_tiling_configs=None):
207207
def _expert_parallel_forward_with_shard_map(
208208
self, inputs, router_logits, gmm_tiling_configs
209209
):
210+
# 预先计算静态 tiling 参数(在 shard_map 外部)
211+
total_tokens, hidden_dim = inputs.shape
212+
m, k = total_tokens, hidden_dim
213+
n_gate = self.intermediate_dim
214+
n_down = hidden_dim
215+
216+
# 获取最优 tiling 配置
217+
optimal_tiling_gate = self._get_tiling_from_configs(
218+
gmm_tiling_configs, m, k, n_gate, self.num_experts
219+
)
220+
optimal_tiling_down = self._get_tiling_from_configs(
221+
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
222+
)
223+
224+
# 转换为静态整数参数(在动态 m 值上使用最大值作为安全的静态值)
225+
static_tiling_gate = (
226+
min(optimal_tiling_gate[0], 16384), # 设置合理的最大值
227+
optimal_tiling_gate[1],
228+
optimal_tiling_gate[2],
229+
)
230+
static_tiling_down = (
231+
min(optimal_tiling_down[0], 16384),
232+
optimal_tiling_down[1],
233+
optimal_tiling_down[2],
234+
)
235+
210236
def _internal_moe_computation(
211237
hidden_states,
212238
router_logits,
@@ -253,15 +279,16 @@ def _internal_moe_computation(
253279
else:
254280
local_group_sizes = group_sizes
255281

256-
# GMM
257-
intermediate_output = self._gmm_compute_with_sharded_weights(
282+
# GMM (使用预先计算的静态 tiling)
283+
intermediate_output = self._gmm_compute_with_static_tiling(
258284
x,
259285
local_group_sizes,
260286
selected_experts,
261287
w0_weights,
262288
w1_weights,
263289
wo_weights,
264-
gmm_tiling_configs,
290+
static_tiling_gate,
291+
static_tiling_down,
265292
)
266293

267294
# EP Combine
@@ -301,59 +328,41 @@ def _internal_moe_computation(
301328
self.wo.value,
302329
)
303330

304-
def _gmm_compute_with_sharded_weights(
331+
def _gmm_compute_with_static_tiling(
305332
self,
306333
x,
307334
local_group_sizes,
308335
selected_experts,
309336
w0_kernel,
310337
w1_kernel,
311338
wo_kernel,
312-
gmm_tiling_configs,
339+
static_tiling_gate,
340+
static_tiling_down,
313341
):
314342
if x.shape[0] == 0:
315343
empty_output = jnp.zeros(
316344
(0, wo_kernel.shape[-1]), dtype=x.dtype
317345
) # (0, hidden_dim)
318346
return empty_output
319347

320-
m, k = x.shape[0], x.shape[1]
321-
n_gate = w0_kernel.shape[2]
322-
n_down = wo_kernel.shape[2]
323-
324-
optimal_tiling_gate = self._get_tiling_from_configs(
325-
gmm_tiling_configs, m, k, n_gate, self.num_experts
326-
)
327-
optimal_tiling_down = self._get_tiling_from_configs(
328-
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
329-
)
330-
331-
# Use JAX operations for tiling parameters (cannot use int() on tracers)
332-
# tiling_gate = (
333-
# jnp.minimum(optimal_tiling_gate[0], m),
334-
# jnp.minimum(optimal_tiling_gate[1], k),
335-
# jnp.minimum(optimal_tiling_gate[2], n_gate),
336-
# )
337-
# tiling_down = (
338-
# jnp.minimum(optimal_tiling_down[0], m),
339-
# jnp.minimum(optimal_tiling_down[1], n_gate),
340-
# jnp.minimum(optimal_tiling_down[2], n_down),
341-
# )
348+
# 直接使用预先计算好的静态 tiling 参数
349+
tiling_gate = static_tiling_gate
350+
tiling_down = static_tiling_down
342351
# gate
343352
layer_w0 = gmm(
344353
lhs=x,
345354
rhs=w0_kernel,
346355
group_sizes=local_group_sizes,
347356
preferred_element_type=self.dtype,
348-
tiling=optimal_tiling_gate,
357+
tiling=tiling_gate,
349358
)
350359
# up
351360
layer_w1 = gmm(
352361
lhs=x,
353362
rhs=w1_kernel,
354363
group_sizes=local_group_sizes,
355364
preferred_element_type=self.dtype,
356-
tiling=optimal_tiling_gate,
365+
tiling=tiling_gate,
357366
)
358367

359368
# activation
@@ -366,7 +375,7 @@ def _gmm_compute_with_sharded_weights(
366375
rhs=wo_kernel,
367376
group_sizes=local_group_sizes,
368377
preferred_element_type=self.dtype,
369-
tiling=optimal_tiling_down,
378+
tiling=tiling_down,
370379
)
371380

372381
return intermediate_output
@@ -381,13 +390,33 @@ def _single_device_forward(self, inputs, router_logits, gmm_tiling_configs):
381390

382391
top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True)
383392

384-
return self._single_device_forward_impl(
385-
inputs, top_k_indices, top_k_weights, gmm_tiling_configs
393+
# 为单设备也预先计算静态 tiling 参数
394+
total_tokens, hidden_dim = inputs.shape
395+
m, k = total_tokens, hidden_dim
396+
n_gate = self.intermediate_dim
397+
n_down = hidden_dim
398+
399+
optimal_tiling_gate = self._get_tiling_from_configs(
400+
gmm_tiling_configs, m, k, n_gate, self.num_experts
401+
)
402+
optimal_tiling_down = self._get_tiling_from_configs(
403+
gmm_tiling_configs, m, n_gate, n_down, self.num_experts
404+
)
405+
406+
static_tiling_gate = (
407+
min(optimal_tiling_gate[0], 16384),
408+
optimal_tiling_gate[1],
409+
optimal_tiling_gate[2],
410+
)
411+
static_tiling_down = (
412+
min(optimal_tiling_down[0], 16384),
413+
optimal_tiling_down[1],
414+
optimal_tiling_down[2],
386415
)
387416

388-
def _single_device_forward_impl(
389-
self, inputs, top_k_indices, top_k_weights, gmm_tiling_configs
390-
):
417+
return self._single_device_forward_impl(inputs, top_k_indices, top_k_weights)
418+
419+
def _single_device_forward_impl(self, inputs, top_k_indices, top_k_weights):
391420
num_tokens = inputs.shape[0] * (inputs.shape[1] if inputs.ndim > 1 else 1)
392421
inputs_flat = inputs.reshape(num_tokens, -1)
393422

0 commit comments

Comments
 (0)