Skip to content

[PyTorch] Add distributed Muon optimizer#2920

Open
vcherepanov-nv wants to merge 6 commits intoNVIDIA:mainfrom
vcherepanov-nv:muon
Open

[PyTorch] Add distributed Muon optimizer#2920
vcherepanov-nv wants to merge 6 commits intoNVIDIA:mainfrom
vcherepanov-nv:muon

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

Add a distributed Muon optimizer, based on newton_schulz orthogonalization

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add an optimizer class and tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vcherepanov-nv and others added 2 commits April 23, 2026 18:50
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 23, 2026

Greptile Summary

This PR adds MuonOptimizer, a distributed Muon optimizer that applies SGD-momentum followed by distributed Newton-Schulz orthogonalization on tensor-parallel 2D parameter shards via an NCCL process group. The implementation is well-validated, mathematically correct relative to the full-matrix reference, and the closure/torch.enable_grad() interaction is handled properly.

  • eps ≤ 0 is never validated: passing eps=0 causes silent division by zero inside _distributed_normalize_p2_ when a gradient norm is exactly zero (clamp_min(0) does not prevent sqrt(0) / 0).
  • __del__ calling destroy() is a teardown ordering hazard: if dist.destroy_process_group() runs before the optimizer is garbage-collected, the finalizer will attempt to tear down a cuSolverMp context backed by a freed NCCL communicator.

Confidence Score: 5/5

Safe to merge; all findings are P2 suggestions with no blocking correctness or security issues.

All findings are P2 (style/hardening): missing eps validation, del teardown ordering, and import-time NUM_PROCS. The core optimizer math is correct, distributed normalization is equivalent to the full-matrix reference, and previously discussed issues (closure/enable_grad, global_shape scaling) are properly handled in this version.

transformer_engine/pytorch/optimizers/muon.py — eps validation gap and del teardown ordering.

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/muon.py New MuonOptimizer class: distributed Newton-Schulz orthogonalization with SGD-momentum; well-structured with solid validation, but missing eps > 0 guard and del teardown ordering hazard.
transformer_engine/pytorch/optimizers/init.py Adds MuonOptimizer and get_muon_scale_factor to the public optimizer exports; straightforward one-line change.
tests/pytorch/distributed/run_muon_optimizer.py torchrun worker that validates distributed optimizer against a full-matrix float32 reference; reference logic is correct, no global_shape double-scaling bug present in this version.
tests/pytorch/distributed/test_muon_optimizer.py pytest harness launching torchrun workers; covers dtype, partition_dim, and weight-decay modes; NUM_PROCS baked in at import time could be 0 on CPU-only hosts.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[step called] --> B{closure?}
    B -- yes --> C[enable_grad + call closure]
    B -- no --> D[iterate param groups]
    C --> D
    D --> E{p.grad is None?}
    E -- yes --> F[skip]
    E -- no --> G[resolve + validate partition_dim]
    G --> H{decoupled weight decay?}
    H -- yes --> I[p *= 1 - lr * wd]
    H -- no, wd != 0 --> J[grad += wd * p]
    I --> K[momentum_buffer.lerp_ grad]
    J --> K
    K --> L{nesterov?}
    L -- yes --> M[update = grad.lerp momentum_buffer momentum]
    L -- no --> N[update = momentum_buffer ref]
    M --> O[_orthogonalize update]
    N --> O
    O --> P[clone + maybe transpose]
    P --> Q[distributed_normalize_p2_ all_reduce norm]
    Q --> R[newton_schulz distributed]
    R --> S[maybe untranspose]
    S --> T[scale by get_muon_scale_factor * extra]
    T --> U[p += -lr * orth_update]
    U --> V[return loss]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +186 to +191
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
loss = closure()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Closure called inside @torch.no_grad(), preventing gradient computation

closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.

Suggested change
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
loss = closure()
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

Comment on lines +28 to +33
scale_mode: str,
extra_scale_factor: float,
eps: float,
) -> torch.Tensor:
global_shape = [grad.size(0), grad.size(1)]
global_shape[partition_dim] *= world_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Reference global_shape incorrectly scales an already-full tensor

_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.

The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.

Comment on lines +33 to +34
if mode == "unit_rms_norm":
return (size_out / size_in) ** 0.5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 unit_rms_norm mode can divide by zero when size_in == 0

(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.

Comment on lines +218 to +221
if group["nesterov"]:
update = grad.lerp(momentum_buffer, group["momentum"])
else:
update = momentum_buffer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Non-Nesterov update is an alias to momentum_buffer, not a copy

update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@vcherepanov-nv vcherepanov-nv changed the title [Draft] [PyTorch] Add distributed Muon optimizer [PyTorch] Add distributed Muon optimizer Apr 27, 2026
@vcherepanov-nv vcherepanov-nv requested a review from cyanguwa April 27, 2026 18:12
Copy link
Copy Markdown

@skyw skyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd advice NOT to expose it in public API. Keeping it in test only if that is the purpose.

Having an optimizer with most code copied invites fragmentation.

Before this, all optimizer TE provides are more optimized fused version. I'd say a highly optimized Fused Muon with similar concept can be justified, but would need more consideration because it has more dependencies on other part of the training pipeline than elementwise optimizers.

on tensor-parallel parameter shards. The local parameter shard must represent a
partition of a logical 2D matrix across the provided NCCL process group.

Args:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does TE use numpy style docstring instead of Google style?


def __init__(
self,
params: Iterable[torch.nn.Parameter | dict],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The type here doesn't match PyTorch internal. Should be fine for the purpose of this class.

scale_mode: MuonScaleT = "spectral",
extra_scale_factor: float = 1.0,
process_group: Optional[dist.ProcessGroup] = None,
partition_dim: int = 1,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix: partition_dim is per parameter.

raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
if partition_dim not in (0, 1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does this class intend to support non-distributed case? partition_dim would be -1 in TE in such case.


if process_group is None:
if not dist.is_initialized():
raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question above regarding single GPU support.

if process_group is None:
if not dist.is_initialized():
raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.")
process_group = dist.group.WORLD
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: This silent behavior is dangerous. If user forgot to pass the correct TP group, wrong group will be used.

eps: float,
) -> torch.Tensor:
self._validate_param(grad, partition_dim)
world_size = dist.get_world_size(self.process_group)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestion as above. The silent behavior of None process group falling back to default is dangerous. (Understand it is from PyTorch for historical reasons)

global_shape[partition_dim] *= world_size

orth_grad = grad.clone()
transposed = partition_dim == 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attn: This is from common Row and Column wise tensor parallelism in most LLM. It would be sub optimal for anything other than that. Add comment if the assumption is made.

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator Author

Having an optimizer with most code copied invites fragmentation.

The idea was to give something to users, who use TE, but not Megatron-LM. By fragmentation you mean that we want to encourage everyone to use Megatron-LM? Or that the optimizer being relatively thin thing on top of newton_schulz call, and the users should have no trouble creating it themselves?

I don't think we gain anything by putting it into tests, since we already have tests for newton_schulz call. So we need to decide whether we want this PR, or should abandon it altogether. @cyanguwa

@skyw
Copy link
Copy Markdown

skyw commented Apr 28, 2026

Having an optimizer with most code copied invites fragmentation.

The idea was to give something to users, who use TE, but not Megatron-LM. By fragmentation you mean that we want to encourage everyone to use Megatron-LM? Or that the optimizer being relatively thin thing on top of newton_schulz call, and the users should have no trouble creating it themselves?

I don't think we gain anything by putting it into tests, since we already have tests for newton_schulz call. So we need to decide whether we want this PR, or should abandon it altogether. @cyanguwa

Fragmentation means there will be different flavor of muon in emerging optimizer and TE, also a lot of copied code. TE can have stalled feature when emerging optimizer updates. Megatron-LM will always have its own version because there are implementation specific things need to be hooked together. For example, how QKV is implemetned, or fused swighlu.
For TE, I think an example of how to build a version of emerging optimizer use TE NS backend would be good to have. But providing optimizer (not fusion optimized version) confuses customers.
Having said that, I would love for TE to have a more optimized version. similar idea as fusedAdam, etc.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move newton_schulz.py to this directory? Also, how do we expect Megatron to call us for this functionality? Thanks.

vcherepanov-nv and others added 3 commits May 1, 2026 07:27
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants