diff --git a/tests/pytorch/distributed/run_muon_optimizer.py b/tests/pytorch/distributed/run_muon_optimizer.py new file mode 100644 index 0000000000..6063df6c35 --- /dev/null +++ b/tests/pytorch/distributed/run_muon_optimizer.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Muon optimizer test worker. + +Launched via torchrun from test_muon_optimizer.py. +""" + +import argparse +import sys + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.newton_schulz import get_coefficients +from transformer_engine.pytorch.optimizers.muon import get_muon_scale_factor + + +def _reference_orthogonalize( + grad: torch.Tensor, + *, + partition_dim: int, + coefficients: list[tuple[float, float, float]], + scale_mode: str, + extra_scale_factor: float, + eps: float, +) -> torch.Tensor: + global_shape = [grad.size(0), grad.size(1)] + + x = grad.clone() + if partition_dim == 0: + x = x.mT.contiguous() + + x = x / torch.sqrt((x.float() * x.float()).sum()).clamp_min(eps).to(dtype=x.dtype) + + for a, b, c in coefficients: + xxt = x @ x.mT + x = a * x + b * (xxt @ x) + c * ((xxt @ xxt) @ x) + + if partition_dim == 0: + x = x.mT.contiguous() + + scale = get_muon_scale_factor(global_shape[0], global_shape[1], mode=scale_mode) + return x * (scale * extra_scale_factor) + + +def _reference_step( + param: torch.Tensor, + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + *, + lr: float, + momentum: float, + nesterov: bool, + weight_decay: float, + use_decoupled_weight_decay: bool, + partition_dim: int, + coefficients: list[tuple[float, float, float]], + scale_mode: str, + extra_scale_factor: float, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + param = param.clone() + grad = grad.clone() + momentum_buffer = momentum_buffer.clone() + + if use_decoupled_weight_decay: + param = param * (1.0 - lr * weight_decay) + elif weight_decay != 0: + grad = grad + weight_decay * param + + momentum_buffer = momentum * momentum_buffer + (1.0 - momentum) * grad + if nesterov: + update = (1.0 - momentum) * grad + momentum * momentum_buffer + else: + update = momentum_buffer + + orth_update = _reference_orthogonalize( + update, + partition_dim=partition_dim, + coefficients=coefficients, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + eps=eps, + ) + param = param - lr * orth_update + return param, momentum_buffer + + +@record +def main(): + parser = argparse.ArgumentParser(description="Distributed Muon optimizer test") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) + parser.add_argument("--partition-dim", type=int, default=1, choices=[0, 1]) + parser.add_argument( + "--weight-decay-mode", type=str, default="decoupled", choices=["decoupled", "l2"] + ) + parser.add_argument("--num-steps", type=int, default=2) + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 + if args.partition_dim == 0: + full_shape = (world_size * 64, 96) + else: + full_shape = (96, world_size * 64) + + lr = 3e-4 + momentum = 0.95 + nesterov = True + weight_decay = 0.01 + use_decoupled_weight_decay = args.weight_decay_mode == "decoupled" + coefficient_type = "quintic" + num_ns_steps = 5 + scale_mode = "spectral" + extra_scale_factor = 1.0 + eps = 1e-7 + coefficients = get_coefficients(num_ns_steps, coefficient_type) + + if rank == 0: + torch.manual_seed(1234) + full_param = torch.randn(full_shape, device="cuda", dtype=dtype) + full_grads = [ + torch.randn(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps) + ] + else: + full_param = torch.empty(full_shape, device="cuda", dtype=dtype) + full_grads = [ + torch.empty(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps) + ] + + dist.broadcast(full_param, src=0) + for grad in full_grads: + dist.broadcast(grad, src=0) + + shard_size = full_shape[args.partition_dim] // world_size + shard_slice = slice(rank * shard_size, (rank + 1) * shard_size) + if args.partition_dim == 0: + local_param_init = full_param[shard_slice, :].contiguous() + else: + local_param_init = full_param[:, shard_slice].contiguous() + + param = torch.nn.Parameter(local_param_init.clone()) + param.partition_dim = args.partition_dim + optimizer = te.optimizers.MuonOptimizer( + [param], + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + use_decoupled_weight_decay=use_decoupled_weight_decay, + coefficient_type=coefficient_type, + num_ns_steps=num_ns_steps, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + process_group=dist.group.WORLD, + eps=eps, + ) + + ref_param = full_param.float() + ref_momentum = torch.zeros_like(ref_param) + for full_grad in full_grads: + if args.partition_dim == 0: + param.grad = full_grad[shard_slice, :].contiguous() + else: + param.grad = full_grad[:, shard_slice].contiguous() + optimizer.step() + + ref_param, ref_momentum = _reference_step( + ref_param, + full_grad.float(), + ref_momentum, + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + use_decoupled_weight_decay=use_decoupled_weight_decay, + partition_dim=args.partition_dim, + coefficients=coefficients, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + eps=eps, + ) + + gathered = [torch.empty_like(param) for _ in range(world_size)] + dist.all_gather(gathered, param) + if args.partition_dim == 0: + test_param = torch.cat(gathered, dim=0) + else: + test_param = torch.cat(gathered, dim=1) + + if rank == 0: + expected = ref_param.to(dtype) + atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2e-3, 2e-3) + if torch.allclose(test_param, expected, atol=atol, rtol=rtol): + print("MUON OPTIMIZER CHECK PASSED", flush=True) + else: + max_diff = (test_param - expected).abs().max().item() + print(f"Max |optimizer - reference|: {max_diff:.6e}", flush=True) + print("MUON OPTIMIZER CHECK FAILED", flush=True, file=sys.stderr) + sys.exit(1) + + optimizer.destroy() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/test_muon_optimizer.py b/tests/pytorch/distributed/test_muon_optimizer.py new file mode 100644 index 0000000000..24515d2766 --- /dev/null +++ b/tests/pytorch/distributed/test_muon_optimizer.py @@ -0,0 +1,84 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed Muon optimizer.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +from transformer_engine.pytorch.optimizers.muon import MuonOptimizer + +MULTI_GPU_AVAILABLE = torch.cuda.device_count() >= 2 +requires_multi_gpu = pytest.mark.skipif( + not MULTI_GPU_AVAILABLE, + reason="Muon optimizer distributed tests require at least 2 GPUs.", +) + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(dtype: str, partition_dim: int, weight_decay_mode: str) -> None: + test_path = TEST_ROOT / "run_muon_optimizer.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--dtype={dtype}", + f"--partition-dim={partition_dim}", + f"--weight-decay-mode={weight_decay_mode}", + ] + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300) + if ( + result.returncode != 0 + or "MUON OPTIMIZER CHECK FAILED" in result.stderr.decode() + or "MUON OPTIMIZER CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError( + "Muon optimizer test failed.\n" + f"stdout: {result.stdout.decode()}\n" + f"stderr: {result.stderr.decode()}" + ) + + +@requires_multi_gpu +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +@pytest.mark.parametrize("partition_dim", [0, 1]) +def test_muon_optimizer_matches_reference(dtype: str, partition_dim: int) -> None: + """Compare distributed Muon updates with a full-matrix reference.""" + _run_test(dtype, partition_dim, "decoupled") + + +@requires_multi_gpu +def test_muon_optimizer_l2_weight_decay() -> None: + """Exercise the L2 weight decay branch against the same reference.""" + _run_test("float32", 1, "l2") + + +def test_muon_optimizer_requires_explicit_process_group() -> None: + """Muon should not silently fall back to the world process group.""" + param = torch.nn.Parameter(torch.empty(2, 2)) + with pytest.raises(ValueError, match="explicit NCCL tensor-parallel process_group"): + MuonOptimizer([param], process_group=None, partition_dim=0) + + +def test_muon_optimizer_resolves_partition_dim_per_parameter() -> None: + """TE tensor-parallel metadata should provide per-parameter partition dims.""" + param = torch.empty(2, 2) + param.partition_dim = 0 + + assert MuonOptimizer._resolve_partition_dim(param, None) == 0 + + param_without_metadata = torch.empty(2, 2) + assert MuonOptimizer._resolve_partition_dim(param_without_metadata, 1) == 1 + + with pytest.raises(ValueError, match="Conflicting partition_dim"): + MuonOptimizer._resolve_partition_dim(param, 1) + + param.partition_dim = -1 + with pytest.raises(ValueError, match="Non-parallel parameters are not supported"): + MuonOptimizer._resolve_partition_dim(param, None) diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 7220f1924a..c643d32287 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -16,4 +16,5 @@ ) from .fused_adam import FusedAdam from .fused_sgd import FusedSGD +from .muon import MuonOptimizer, get_muon_scale_factor from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier diff --git a/transformer_engine/pytorch/optimizers/muon.py b/transformer_engine/pytorch/optimizers/muon.py new file mode 100644 index 0000000000..c9336597bf --- /dev/null +++ b/transformer_engine/pytorch/optimizers/muon.py @@ -0,0 +1,298 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Muon optimizer backed by distributed Newton-Schulz orthogonalization.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Literal, Optional + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from transformer_engine.pytorch.newton_schulz import ( + CusolverMpCtx, + NSCoeffT, + get_coefficients, + newton_schulz, +) + + +MuonScaleT = Literal["shape_scaling", "spectral", "unit_rms_norm"] +ParamsT = Iterable[torch.Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, torch.Tensor]] + + +def get_muon_scale_factor(size_out: int, size_in: int, mode: MuonScaleT = "spectral") -> float: + """Return the Muon update scale factor for the logical matrix shape.""" + if mode == "shape_scaling": + return max(1, size_out / size_in) ** 0.5 + if mode == "spectral": + return max(size_out, size_in) ** 0.5 + if mode == "unit_rms_norm": + return (size_out / size_in) ** 0.5 + raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") + + +class MuonOptimizer(Optimizer): + """Distributed Muon optimizer for 2D CUDA parameters. + + This optimizer applies SGD-momentum followed by Newton-Schulz orthogonalization + on tensor-parallel parameter shards. The local parameter shard must represent a + contiguous row or column partition of a logical 2D matrix across the provided + NCCL process group. Single-GPU, unsharded parameters and TE non-parallel + parameters with ``partition_dim == -1`` are not supported. + + Parameters + ---------- + params : iterable of torch.Tensor, dict, or tuple[str, torch.Tensor] + Parameters, parameter group dictionaries, or named parameters. The + optimizer delegates normalization of this input to ``torch.optim.Optimizer``. + lr : float, default = 3e-4 + Learning rate. + momentum : float, default = 0.95 + Momentum coefficient. + nesterov : bool, default = True + Whether to use Nesterov momentum. + weight_decay : float, default = 0.01 + Weight decay coefficient. + use_decoupled_weight_decay : bool, default = True + Whether to apply decoupled weight decay. + coefficient_type : str, default = "quintic" + Newton-Schulz coefficient schedule. + num_ns_steps : int, default = 5 + Number of Newton-Schulz iterations. + scale_mode : str, default = "spectral" + Muon update scale mode. + extra_scale_factor : float, default = 1.0 + Extra multiplicative scale applied after orthogonalization. + process_group : torch.distributed.ProcessGroup + Explicit NCCL tensor-parallel process group for distributed Newton-Schulz. + Pass ``dist.group.WORLD`` only when the world group is intentionally the + tensor-parallel group. + partition_dim : int, optional + Default partition dimension for parameters that do not carry TE + tensor-parallel metadata. If a parameter has a ``partition_dim`` attribute, + that per-parameter value is used instead. Must be 0 or 1 when provided. + eps : float, default = 1e-7 + Lower bound for the distributed normalization denominator. + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-4, + momentum: float = 0.95, + nesterov: bool = True, + weight_decay: float = 0.01, + *, + use_decoupled_weight_decay: bool = True, + coefficient_type: NSCoeffT = "quintic", + num_ns_steps: int = 5, + scale_mode: MuonScaleT = "spectral", + extra_scale_factor: float = 1.0, + process_group: dist.ProcessGroup, + partition_dim: Optional[int] = None, + eps: float = 1e-7, + ) -> None: + self._ns_ctx: CusolverMpCtx | None = None + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0 or momentum >= 1.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if num_ns_steps < 1: + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + if partition_dim is not None and partition_dim not in (0, 1): + raise ValueError(f"partition_dim must be 0 or 1, got {partition_dim}") + get_coefficients(num_ns_steps, coefficient_type) + + if process_group is None: + raise ValueError( + "MuonOptimizer requires an explicit NCCL tensor-parallel process_group. " + "Pass dist.group.WORLD explicitly only if it is the intended group." + ) + if not dist.is_initialized(): + raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") + if dist.get_backend(process_group) != "nccl": + raise RuntimeError("MuonOptimizer requires an NCCL process group.") + + defaults = { + "lr": lr, + "momentum": momentum, + "nesterov": nesterov, + "weight_decay": weight_decay, + "use_decoupled_weight_decay": use_decoupled_weight_decay, + "coefficient_type": coefficient_type, + "num_ns_steps": num_ns_steps, + "scale_mode": scale_mode, + "extra_scale_factor": extra_scale_factor, + "partition_dim": partition_dim, + "eps": eps, + } + super().__init__(params, defaults) + for group in self.param_groups: + group_partition_dim = group["partition_dim"] + if group_partition_dim is not None and group_partition_dim not in (0, 1): + raise ValueError(f"partition_dim must be 0 or 1, got {group_partition_dim}") + self.process_group = process_group + + def __del__(self) -> None: + self.destroy() + + def destroy(self) -> None: + """Release the underlying cuSolverMp context.""" + if self._ns_ctx is not None: + self._ns_ctx.destroy() + self._ns_ctx = None + + def _get_ctx(self) -> CusolverMpCtx: + if self._ns_ctx is None: + self._ns_ctx = CusolverMpCtx(self.process_group) + return self._ns_ctx + + @staticmethod + def _validate_param(param: torch.Tensor, partition_dim: int) -> None: + if param.ndim != 2: + raise ValueError("MuonOptimizer only supports 2D parameters.") + if not param.is_cuda: + raise ValueError("MuonOptimizer only supports CUDA parameters.") + if param.dtype not in (torch.float32, torch.bfloat16): + raise ValueError( + f"MuonOptimizer requires float32 or bfloat16 parameters, got {param.dtype}." + ) + if param.size(partition_dim) == 0: + raise ValueError("MuonOptimizer does not support empty tensor-parallel shards.") + + @staticmethod + def _resolve_partition_dim( + param: torch.Tensor, + group_partition_dim: Optional[int], + ) -> int: + param_partition_dim = getattr(param, "partition_dim", None) + if param_partition_dim is None: + if group_partition_dim is None: + raise ValueError( + "MuonOptimizer requires a partition_dim for each parameter. " + "Set TE tensor-parallel metadata on the parameter or provide " + "partition_dim in the optimizer defaults/parameter group." + ) + partition_dim = group_partition_dim + else: + partition_dim = param_partition_dim + if group_partition_dim is not None and group_partition_dim != partition_dim: + raise ValueError( + "Conflicting partition_dim values for MuonOptimizer parameter: " + f"parameter has {partition_dim}, parameter group has {group_partition_dim}." + ) + + if partition_dim not in (0, 1): + raise ValueError( + "MuonOptimizer only supports tensor-parallel parameters sharded along " + f"dimension 0 or 1, got partition_dim={partition_dim}. Non-parallel " + "parameters are not supported." + ) + return partition_dim + + def _distributed_normalize_p2_( + self, + x: torch.Tensor, + eps: float, + ) -> None: + norm_sq = (x.float() * x.float()).sum() + dist.all_reduce(norm_sq, op=dist.ReduceOp.SUM, group=self.process_group) + x.div_(torch.sqrt(norm_sq).clamp_min(eps).to(dtype=x.dtype)) + + def _orthogonalize( + self, + grad: torch.Tensor, + *, + partition_dim: int, + coefficient_type: NSCoeffT, + num_ns_steps: int, + scale_mode: MuonScaleT, + extra_scale_factor: float, + eps: float, + ) -> torch.Tensor: + self._validate_param(grad, partition_dim) + world_size = dist.get_world_size(self.process_group) + global_shape = [grad.size(0), grad.size(1)] + global_shape[partition_dim] *= world_size + + orth_grad = grad.clone() + # The cuSolverMp Newton-Schulz backend expects columns to be distributed. + # Row-parallel shards are transposed into that layout. This assumes the + # usual contiguous row/column TP sharding; strided or irregular layouts + # are outside this optimizer's contract. + transposed = partition_dim == 0 + if transposed: + orth_grad = orth_grad.mT.contiguous() + else: + orth_grad = orth_grad.contiguous() + + self._distributed_normalize_p2_(orth_grad, eps) + coefficients = get_coefficients(num_ns_steps, coefficient_type) + newton_schulz(orth_grad, self._get_ctx(), num_ns_steps, coefficients=coefficients) + + if transposed: + orth_grad = orth_grad.mT.contiguous() + + scale_factor = get_muon_scale_factor(global_shape[0], global_shape[1], mode=scale_mode) + orth_grad.mul_(scale_factor * extra_scale_factor) + return orth_grad + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + partition_dim = self._resolve_partition_dim(p, group["partition_dim"]) + self._validate_param(p, partition_dim) + grad = p.grad + if grad.dtype != p.dtype: + raise ValueError( + f"Gradient dtype {grad.dtype} must match parameter dtype {p.dtype}." + ) + if grad.shape != p.shape: + raise ValueError("Gradient shape must match parameter shape.") + + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + + if group["use_decoupled_weight_decay"]: + p.mul_(1.0 - group["lr"] * group["weight_decay"]) + elif group["weight_decay"] != 0: + grad = grad.add(p, alpha=group["weight_decay"]) + + momentum_buffer = state["momentum_buffer"] + momentum_buffer.lerp_(grad, 1.0 - group["momentum"]) + + if group["nesterov"]: + update = grad.lerp(momentum_buffer, group["momentum"]) + else: + update = momentum_buffer + + orth_update = self._orthogonalize( + update, + partition_dim=partition_dim, + coefficient_type=group["coefficient_type"], + num_ns_steps=group["num_ns_steps"], + scale_mode=group["scale_mode"], + extra_scale_factor=group["extra_scale_factor"], + eps=group["eps"], + ) + p.add_(orth_update, alpha=-group["lr"]) + + return loss