Skip to content

[Feature] support ep and enhance step latency for muon#1716

Open
nil0x9 wants to merge 1 commit intoInternLM:mainfrom
nil0x9:linty/feat-muon-opt
Open

[Feature] support ep and enhance step latency for muon#1716
nil0x9 wants to merge 1 commit intoInternLM:mainfrom
nil0x9:linty/feat-muon-opt

Conversation

@nil0x9
Copy link
Copy Markdown
Collaborator

@nil0x9 nil0x9 commented Apr 24, 2026

Summary

Highlight:

  • Support Muon training when EP > 1
  • reduce latency and memory footprint caused by excessive padding for all2all.

Details:

  • Remove distributed_mesh parameter; extract device mesh and process group directly from DTensor metadata to support heterogeneous meshes (ViT and LM have different meshes).
  • Pre-compute adjust_lr ratios in __init__ based on global (unsharded) shape, avoiding incorrect shape references inside async tasks after communication.
  • Add MoE expert-parallel support: per-expert Newton-Schulz orthogonalization, requires n_experts % ep_size == 0 on the EP dimension. On the FSDP dimension, skip communication when n_experts % fsdp_size == 0 (each rank holds complete experts), use sub-group all-gather when fsdp_size % n_experts == 0, otherwise fall back to batched all-to-all.
  • Add AGRS (All-Gather + Reduce-Scatter) path for remainder batches to avoid zero-padding overhead, with even-sharding guard to prevent deadlock.
  • Refactor shared utilities (group_tensors_by_device_mesh_and_placements, cal_total_norm) from grad_norm.py to dtensor.py.
  • Remove # type: ignore from file head, fix lint, add full type annotations.

Communication Path Table

Path Condition Communication Computation Code Location
No communication MoE, n_experts % fsdp_size == 0 (each rank holds complete experts) None Local NS per expert skip_communication = True
Sub-group all-gather MoE, fsdp_size % n_experts == 0, batch < fsdp_size Small all-gather within sub-group Local NS on reconstructed expert, then slice back subgroup_process_group path
All-to-all (even) Batch fills world_size, global_dim % world_size == 0 Forward a2a + backward a2a Each rank orthogonalizes 1 full matrix Even sharding branch
All-to-all (uneven) Batch fills world_size, global_dim % world_size != 0 Padded forward a2a + padded backward a2a Each rank orthogonalizes 1 full matrix (after unpad) Uneven sharding branch
AGRS Batch < world_size, evenly sharded (dim % world_size == 0) All-gather + reduce-scatter Selective NS: rank r computes params where i % W == r use_agrs = True
Padded all-to-all Batch < world_size, unevenly sharded (dim % world_size != 0) Pad batch with zeros_like to world_size, then even/uneven a2a Padded zeros go through NS (wasted computation) Fallback when is_evenly_sharded = False

Constraints

  • EP dimension: n_experts % ep_size == 0 is required (assertion).
  • FSDP dimension: no hard constraint; the three MoE paths (no-comm, sub-group, fall-through to a2a) are optimizations, not requirements.

End Results

Experiments are run on a model of arch Qwen3.5. #GPU=128

Regression Test

Loss aligns almost perfectly when XTUNER_DETERMINISTIC is turned on.
image

Step Latency

Profile was collected at step 10 (rank0). Latencies might vary run-to-run.

VERSION EP_SIZE latency
Base 1 783ms
Base 8 Not Supported
This PR 1 236ms
This PR 8 232ms

The original implementation wastes a lot of time creating and transmitting padding tensors. This is wasteful when num_layers < world_size. The following graph showcases this behavior:
image

This is especially the case when the world_size grows. Typically the world_size might be somewhere between 64 and 2048. But the number of layers (which determines how many params of the same shape are batched and all2all'ed) is typically smaller than 100, meaning that for large-scale training, the memory footprint and latency caused by padding tensors grows linearly with world size. In my experiment, I encountered OOM when training a 397B MoE model (60 layers, 128GPU, see below) for this exact reason.

The following is the timeline when running the same experiment using code in this PR:
image

397B Model

VERSION EP_SIZE latency
Base 1 OOM
Base 8 Not Supported
This PR 1 491ms
This PR 8 1019ms

Limitation

Padding fallback. When the number of experts satisfies neither n_experts % fsdp_size == 0 nor fsdp_size % n_experts == 0, the code falls back to creating large padding tensors and running all-to-all to distribute
expert parameters across ranks. This incurs significant memory and latency overhead. In practice, however, such model configurations are uncommon.

Dispatch efficiency. The current parameter dispatch follows a two-stage pattern: tensors are first grouped by device mesh and sharding placement, then routed within each group. This grouping-before-dispatching approach can introduce waste. A more general global parameter-to-device dispatcher could eliminate this overhead and is left for future work.

@nil0x9 nil0x9 force-pushed the linty/feat-muon-opt branch from 4531e5d to 5710a74 Compare April 24, 2026 13:25
@nil0x9
Copy link
Copy Markdown
Collaborator Author

nil0x9 commented Apr 24, 2026

@claude review

Comment thread xtuner/v1/optim/muon.py
elif len(shard_placements) == 2:
fsdp_mesh_dim = self._find_fsdp_mesh_dim(device_mesh)

if fsdp_mesh_dim is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — fragile cache key

id(device_mesh) is a memory address. If the DeviceMesh object is garbage-collected and a new one is allocated at the same address (unlikely during training, but possible on checkpoint reload), the cache could silently return stale process groups or miss a valid entry.

Consider keying on a stable identity instead, e.g. (device_mesh.mesh_dim_names, tuple(device_mesh.mesh.flatten().tolist()), sharded_mesh_dim, subgroup_size) — that way the cache is correct even if the Python object is re-created with the same logical mesh.

Comment thread xtuner/v1/optim/muon.py
Comment on lines +478 to +487
# Case A: each rank holds complete experts → no sub-groups needed
if ns_num_experts % fsdp_size == 0:
continue

# Case B: sub-group all-gather
if fsdp_size % ns_num_experts == 0:
subgroup_size = fsdp_size // ns_num_experts
self._create_subgroup_pgs(device_mesh, sharded_mesh_dim, subgroup_size)

def _create_subgroup_pgs(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — leaked process groups on non-matching ranks

In the inner loop, dist.new_group(sg_ranks) is called for every subgroup across all ranks (necessary for the collective), but only the rank's own group is stored in my_pg. The other ProcessGroup objects created on this rank are never assigned to a variable and become garbage. While PyTorch's NCCL backend should handle this gracefully, it would be cleaner to track that these are intentionally discarded — a brief comment here would suffice.

Comment thread xtuner/v1/optim/muon.py
Comment on lines +656 to +662
group_world_size = 1
group_device_rank = 0

# Create batches within this mesh group
for params in create_param_batches(mesh_params, batch_size=group_world_size):
gradients: list[Tensor] = [g for p in params if (g := p.grad) is not None]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — gradients length may diverge from params

The walrus-operator comprehension filters out None grads:

gradients: list[Tensor] = [g for p in params if (g := p.grad) is not None]

While group_params was already filtered for non-None grads upstream, if a grad is somehow freed between filtering and batch creation (e.g. by another concurrent task), len(gradients) < len(params). This would pass mismatched X/G lists to muon_update_batch_async, causing assert len(X) == len(G) to fail with a confusing error.

Consider either keeping the simple [p.grad for p in params] (since grads are guaranteed non-None at this point) or adding an explicit check:

assert len(gradients) == len(params), "Some gradients became None after filtering"

Comment thread xtuner/v1/optim/muon.py
Comment on lines +899 to +913
# Move the W dimension to front for reduce_scatter_tensor
w_pos = shard_dim + 1
perm = [w_pos] + list(range(0, w_pos)) + list(range(w_pos + 1, rs_input.ndim))
rs_input = rs_input.permute(perm).contiguous()

# Flatten trailing dims: (W, R*shard_numel) for reduce_scatter
rs_flat_in = rs_input.reshape(W, -1)
rs_flat_out = torch.empty(rs_flat_in.shape[1], dtype=rs_flat_in.dtype, device=rs_flat_in.device)

work = dist.reduce_scatter_tensor(
rs_flat_out,
rs_flat_in,
op=dist.ReduceOp.SUM,
group=process_group,
async_op=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — AGRS selective NS allocates zeros_like for skipped params

When R params exist and W ranks participate, each rank allocates R - ceil(R/W) zero tensors at full parameter size. For the intended use case (remainder batches, so R is small), this is fine. But it might be worth adding a brief note that this is the expected tradeoff (small R means minimal waste), since the zeros also flow through reduce-scatter (wasted bandwidth on zero contributions).

Comment thread xtuner/v1/optim/muon.py
Comment on lines +246 to +249
"""Distributed Muon optimizer for PyTorch FSDP2.

All parameters must be DTensors. The optimizer extracts device mesh and process group
information directly from DTensor metadata.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — docstring says "All parameters must be DTensors"

The class now requires DTensors for Muon params, but AdamW params still go through batch_to_local which handles plain Tensors. The docstring should clarify that the DTensor requirement applies specifically to Muon parameter groups, not all params.

Also, since this removes DDP support (the old docstring said "Also compatible with DDP"), this is a breaking API change. The old TODO comment # TODO: 暂不支持 EP>1; maybe rm device_mesh dependency? shows this was intentional, but it would be good to mention the DDP compatibility drop in the PR description or a migration note.

Comment thread xtuner/v1/optim/muon.py
return dim
return None

def _init_moe_subgroups(self) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — FSDP dimension detected by name heuristic

_find_fsdp_mesh_dim (defined above) checks if "fsdp" appears as a substring in the mesh dimension name. This is brittle — it will silently return None if the mesh dimension is named differently (e.g. "dp_shard", "shard", "sharding"), causing a RuntimeError here or incorrect fallback behavior in other call sites.

Consider also accepting common alternative names, or documenting the naming requirement prominently in the class docstring so users know their mesh dimensions must contain "fsdp".

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 24, 2026

Claude: ## Review Summary

Solid work — the EP support, AGRS path, and pre-computed lr ratios are well-designed. The communication path table in the PR description is excellent documentation. The 3x latency improvement and the fix for OOM on large models are compelling.

Issues

Warning

  • muon.py:462_subgroup_cache uses id(device_mesh) as cache key, which is a memory address and won't survive object recreation (e.g. checkpoint reload). Consider keying on stable mesh properties.
  • muon.py:422_find_fsdp_mesh_dim relies on "fsdp" substring in dimension names. Brittle if users name their mesh dimensions differently.
  • muon.py:246-249 — Docstring says "All parameters must be DTensors" which drops DDP compatibility (previously documented). This is a breaking API change worth noting.
  • muon.py:656-662 — Gradients list built with walrus-operator filter could silently diverge from params length; an explicit assert would be safer.

Nit

  • muon.py:478-487dist.new_group() returns are silently discarded for non-matching ranks; a brief comment would clarify this is intentional.
  • muon.py:899-913 — AGRS selective NS allocates zeros_like for skipped params which also flow through reduce-scatter; fine for small remainder batches but worth a note.

Positive

  • Pre-computing lr_ratio in __init__ avoids shape ambiguity in async tasks — good defensive design.
  • The _get_or_initialize_state change from if not state: to if "momentum" not in state: correctly handles the new lr_ratio state entries.
  • In-place torch._foreach_mul_ in muon_update_post_orthogonalize is a nice micro-optimization.
  • Refactoring shared utilities to dtensor.py is clean and the re-exports in grad_norm.py maintain backward compatibility.

Verdict

COMMENT — no critical bugs found. The algorithm changes look correct across all communication paths. The warnings above are about robustness and API documentation, not correctness.

@nil0x9 nil0x9 force-pushed the linty/feat-muon-opt branch from 5710a74 to 6af69e5 Compare April 24, 2026 19:39
- Remove distributed_mesh parameter; extract device mesh and process
group
  directly from DTensor metadata to support heterogeneous meshes (ViT 1D
+ LM 2D).
- Pre-compute adjust_lr ratios in __init__ based on global (unsharded)
shape,
  avoiding incorrect shape references inside async tasks after
communication.
- Add MoE expert-parallel support: per-expert Newton-Schulz
orthogonalization,
  requires n_experts % ep_size == 0 on the EP dimension. On the FSDP
dimension,
  skip communication when n_experts % fsdp_size == 0 (each rank holds
complete
  experts), use sub-group all-gather when fsdp_size % n_experts == 0,
otherwise
  fall back to batched all-to-all.
- Add AGRS (All-Gather + Reduce-Scatter) path for remainder batches to
avoid
  zero-padding overhead, with even-sharding guard to prevent deadlock.
- Refactor shared utilities
(group_tensors_by_device_mesh_and_placements,
  cal_total_norm) from grad_norm.py to dtensor.py.
- Remove `# type: ignore` from file head, fix lint, add full type
annotations.
@nil0x9 nil0x9 force-pushed the linty/feat-muon-opt branch from 6af69e5 to 7337959 Compare April 24, 2026 19:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant