[Feature] support ep and enhance step latency for muon#1716
[Feature] support ep and enhance step latency for muon#1716nil0x9 wants to merge 1 commit intoInternLM:mainfrom
Conversation
4531e5d to
5710a74
Compare
|
@claude review |
| elif len(shard_placements) == 2: | ||
| fsdp_mesh_dim = self._find_fsdp_mesh_dim(device_mesh) | ||
|
|
||
| if fsdp_mesh_dim is None: |
There was a problem hiding this comment.
Claude: Warning — fragile cache key
id(device_mesh) is a memory address. If the DeviceMesh object is garbage-collected and a new one is allocated at the same address (unlikely during training, but possible on checkpoint reload), the cache could silently return stale process groups or miss a valid entry.
Consider keying on a stable identity instead, e.g. (device_mesh.mesh_dim_names, tuple(device_mesh.mesh.flatten().tolist()), sharded_mesh_dim, subgroup_size) — that way the cache is correct even if the Python object is re-created with the same logical mesh.
| # Case A: each rank holds complete experts → no sub-groups needed | ||
| if ns_num_experts % fsdp_size == 0: | ||
| continue | ||
|
|
||
| # Case B: sub-group all-gather | ||
| if fsdp_size % ns_num_experts == 0: | ||
| subgroup_size = fsdp_size // ns_num_experts | ||
| self._create_subgroup_pgs(device_mesh, sharded_mesh_dim, subgroup_size) | ||
|
|
||
| def _create_subgroup_pgs( |
There was a problem hiding this comment.
Claude: Nit — leaked process groups on non-matching ranks
In the inner loop, dist.new_group(sg_ranks) is called for every subgroup across all ranks (necessary for the collective), but only the rank's own group is stored in my_pg. The other ProcessGroup objects created on this rank are never assigned to a variable and become garbage. While PyTorch's NCCL backend should handle this gracefully, it would be cleaner to track that these are intentionally discarded — a brief comment here would suffice.
| group_world_size = 1 | ||
| group_device_rank = 0 | ||
|
|
||
| # Create batches within this mesh group | ||
| for params in create_param_batches(mesh_params, batch_size=group_world_size): | ||
| gradients: list[Tensor] = [g for p in params if (g := p.grad) is not None] | ||
|
|
There was a problem hiding this comment.
Claude: Warning — gradients length may diverge from params
The walrus-operator comprehension filters out None grads:
gradients: list[Tensor] = [g for p in params if (g := p.grad) is not None]While group_params was already filtered for non-None grads upstream, if a grad is somehow freed between filtering and batch creation (e.g. by another concurrent task), len(gradients) < len(params). This would pass mismatched X/G lists to muon_update_batch_async, causing assert len(X) == len(G) to fail with a confusing error.
Consider either keeping the simple [p.grad for p in params] (since grads are guaranteed non-None at this point) or adding an explicit check:
assert len(gradients) == len(params), "Some gradients became None after filtering"| # Move the W dimension to front for reduce_scatter_tensor | ||
| w_pos = shard_dim + 1 | ||
| perm = [w_pos] + list(range(0, w_pos)) + list(range(w_pos + 1, rs_input.ndim)) | ||
| rs_input = rs_input.permute(perm).contiguous() | ||
|
|
||
| # Flatten trailing dims: (W, R*shard_numel) for reduce_scatter | ||
| rs_flat_in = rs_input.reshape(W, -1) | ||
| rs_flat_out = torch.empty(rs_flat_in.shape[1], dtype=rs_flat_in.dtype, device=rs_flat_in.device) | ||
|
|
||
| work = dist.reduce_scatter_tensor( | ||
| rs_flat_out, | ||
| rs_flat_in, | ||
| op=dist.ReduceOp.SUM, | ||
| group=process_group, | ||
| async_op=True, |
There was a problem hiding this comment.
Claude: Nit — AGRS selective NS allocates zeros_like for skipped params
When R params exist and W ranks participate, each rank allocates R - ceil(R/W) zero tensors at full parameter size. For the intended use case (remainder batches, so R is small), this is fine. But it might be worth adding a brief note that this is the expected tradeoff (small R means minimal waste), since the zeros also flow through reduce-scatter (wasted bandwidth on zero contributions).
| """Distributed Muon optimizer for PyTorch FSDP2. | ||
|
|
||
| All parameters must be DTensors. The optimizer extracts device mesh and process group | ||
| information directly from DTensor metadata. |
There was a problem hiding this comment.
Claude: Nit — docstring says "All parameters must be DTensors"
The class now requires DTensors for Muon params, but AdamW params still go through batch_to_local which handles plain Tensors. The docstring should clarify that the DTensor requirement applies specifically to Muon parameter groups, not all params.
Also, since this removes DDP support (the old docstring said "Also compatible with DDP"), this is a breaking API change. The old TODO comment # TODO: 暂不支持 EP>1; maybe rm device_mesh dependency? shows this was intentional, but it would be good to mention the DDP compatibility drop in the PR description or a migration note.
| return dim | ||
| return None | ||
|
|
||
| def _init_moe_subgroups(self) -> None: |
There was a problem hiding this comment.
Claude: Warning — FSDP dimension detected by name heuristic
_find_fsdp_mesh_dim (defined above) checks if "fsdp" appears as a substring in the mesh dimension name. This is brittle — it will silently return None if the mesh dimension is named differently (e.g. "dp_shard", "shard", "sharding"), causing a RuntimeError here or incorrect fallback behavior in other call sites.
Consider also accepting common alternative names, or documenting the naming requirement prominently in the class docstring so users know their mesh dimensions must contain "fsdp".
|
Claude: ## Review Summary Solid work — the EP support, AGRS path, and pre-computed lr ratios are well-designed. The communication path table in the PR description is excellent documentation. The 3x latency improvement and the fix for OOM on large models are compelling. IssuesWarning
Nit
Positive
VerdictCOMMENT — no critical bugs found. The algorithm changes look correct across all communication paths. The warnings above are about robustness and API documentation, not correctness. |
5710a74 to
6af69e5
Compare
- Remove distributed_mesh parameter; extract device mesh and process group directly from DTensor metadata to support heterogeneous meshes (ViT 1D + LM 2D). - Pre-compute adjust_lr ratios in __init__ based on global (unsharded) shape, avoiding incorrect shape references inside async tasks after communication. - Add MoE expert-parallel support: per-expert Newton-Schulz orthogonalization, requires n_experts % ep_size == 0 on the EP dimension. On the FSDP dimension, skip communication when n_experts % fsdp_size == 0 (each rank holds complete experts), use sub-group all-gather when fsdp_size % n_experts == 0, otherwise fall back to batched all-to-all. - Add AGRS (All-Gather + Reduce-Scatter) path for remainder batches to avoid zero-padding overhead, with even-sharding guard to prevent deadlock. - Refactor shared utilities (group_tensors_by_device_mesh_and_placements, cal_total_norm) from grad_norm.py to dtensor.py. - Remove `# type: ignore` from file head, fix lint, add full type annotations.
6af69e5 to
7337959
Compare
Summary
Highlight:
Details:
__init__based on global (unsharded) shape, avoiding incorrect shape references inside async tasks after communication.# type: ignorefrom file head, fix lint, add full type annotations.Communication Path Table
n_experts % fsdp_size == 0(each rank holds complete experts)skip_communication = Truefsdp_size % n_experts == 0, batch < fsdp_sizesubgroup_process_grouppathworld_size,global_dim % world_size == 0world_size,global_dim % world_size != 0world_size, evenly sharded (dim % world_size == 0)rcomputes params wherei % W == ruse_agrs = Trueworld_size, unevenly sharded (dim % world_size != 0)zeros_liketoworld_size, then even/uneven a2ais_evenly_sharded = FalseConstraints
n_experts % ep_size == 0is required (assertion).End Results
Experiments are run on a model of arch Qwen3.5. #GPU=128
Regression Test
Loss aligns almost perfectly when

XTUNER_DETERMINISTICis turned on.Step Latency
Profile was collected at step 10 (rank0). Latencies might vary run-to-run.
The original implementation wastes a lot of time creating and transmitting padding tensors. This is wasteful when num_layers < world_size. The following graph showcases this behavior:

This is especially the case when the world_size grows. Typically the world_size might be somewhere between 64 and 2048. But the number of layers (which determines how many params of the same shape are batched and all2all'ed) is typically smaller than 100, meaning that for large-scale training, the memory footprint and latency caused by padding tensors grows linearly with world size. In my experiment, I encountered OOM when training a 397B MoE model (60 layers, 128GPU, see below) for this exact reason.
The following is the timeline when running the same experiment using code in this PR:

397B Model
Limitation
Padding fallback. When the number of experts satisfies neither n_experts % fsdp_size == 0 nor fsdp_size % n_experts == 0, the code falls back to creating large padding tensors and running all-to-all to distribute
expert parameters across ranks. This incurs significant memory and latency overhead. In practice, however, such model configurations are uncommon.
Dispatch efficiency. The current parameter dispatch follows a two-stage pattern: tensors are first grouped by device mesh and sharding placement, then routed within each group. This grouping-before-dispatching approach can introduce waste. A more general global parameter-to-device dispatcher could eliminate this overhead and is left for future work.