Skip to content

Commit 0fd3fa2

Browse files
committed
feat(engine): add Muon optimizer support for FSDP and Megatron engines
Add Muon optimizer (Newton-Schulz orthogonalization) with distributed FSDP support, ported from samsja/muon_fsdp_2 v0.3.0. Core changes: - areal/utils/optimizer.py: Full Muon implementation with Work pipeline for async NCCL overlap (Fsdp1dWork, SingleDeviceWork), Newton-Schulz iteration, Moonlight RMS scaling option, and AdamW fallback for non-2D parameters (embeddings, norms, biases) - areal/api/cli_args.py: Add Muon-specific config fields (muon_momentum, muon_nesterov, muon_ns_steps, muon_backend_steps, muon_rms_scale) - areal/engine/fsdp_engine.py: Integrate Muon into FSDP optimizer creation - areal/experimental/engine/archon_engine.py + archon_utils.py: Integrate Muon into Archon engine optimizer creation - pyproject.toml / pyproject.vllm.toml: Add muon_fsdp_2 dependency - tests/test_muon_optimizer.py: Unit tests for Newton-Schulz, scaling, optimizer step, and config validation feat(megatron): enable Muon optimizer via Megatron-Core native dispatch Megatron-Core natively supports Muon via _get_megatron_emerging_optimizer when optimizer type is not in ('adam', 'sgd'). It handles TP-aware Newton-Schulz, QKV splitting, and ChainedOptimizer (Muon for 2D weights, AdamW for norms/biases/embeddings) out of the box. - Allow 'muon' in _create_optimizer assertion - Forward muon_momentum/muon_nesterov/muon_num_ns_steps from OptimizerConfig to MCoreOptimizerConfig (with hasattr guard for older Megatron-Core) - Requires the 'emerging-optimizers' package to be installed at runtime
1 parent 95ca870 commit 0fd3fa2

9 files changed

Lines changed: 834 additions & 18 deletions

File tree

areal/api/cli_args.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import asdict, dataclass, field, fields
88
from enum import Enum
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
10+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
1111

1212
import uvloop
1313
import yaml
@@ -338,30 +338,46 @@ class OptimizerConfig:
338338
type: str = field(
339339
default="adam",
340340
metadata={
341-
"help": "Optimizer type. For FSDP Engine, adam_bf16 enables memory-efficient BF16 optimizer states. "
342-
"For Megatron Engine, adam_bf16 requires dtype=bfloat16 and is automatically converted to adam "
343-
"with precision-aware optimizer enabled.",
344-
"choices": ["adam", "sgd", "adam_bf16"],
341+
"help": "Optimizer type. 'adam': AdamW (default). 'adam_bf16': memory-efficient BF16 AdamW "
342+
"(FSDP: uses AnyPrecisionAdamW; Megatron: requires dtype=bfloat16, auto-converted to adam "
343+
"with precision-aware optimizer). 'sgd': plain SGD. 'muon': Muon optimizer for >=2D params "
344+
"with AdamW backend for <2D params (biases, norms, embeddings).",
345+
"choices": ["adam", "sgd", "adam_bf16", "muon"],
346+
},
347+
)
348+
lr: float = field(
349+
default=1e-3,
350+
metadata={
351+
"help": "Learning rate. When type='muon', this is the Muon lr for >=2D params "
352+
"(typical value: ~0.02). The AdamW backend lr is controlled by muon_backend_lr."
353+
},
354+
)
355+
weight_decay: float = field(
356+
default=0.01,
357+
metadata={
358+
"help": "Weight decay. Applied to all optimizer types including Muon (>=2D params) "
359+
"and AdamW backend (<2D params)."
345360
},
346361
)
347-
lr: float = field(default=1e-3, metadata={"help": "Learning rate"})
348-
weight_decay: float = field(default=0.01, metadata={"help": "Weight decay"})
349362
beta1: float = field(
350363
default=0.9,
351364
metadata={
352-
"help": "Adam beta1 parameter. Only effective when optimizer_type is adam/adam_bf16"
365+
"help": "Adam beta1 parameter. Used by adam/adam_bf16, and by the AdamW backend "
366+
"when type='muon'. Not used by the Muon sub-optimizer itself."
353367
},
354368
)
355369
beta2: float = field(
356370
default=0.999,
357371
metadata={
358-
"help": "Adam beta2 parameter. Only effective when optimizer_type is adam/adam_bf16"
372+
"help": "Adam beta2 parameter. Used by adam/adam_bf16, and by the AdamW backend "
373+
"when type='muon'. Not used by the Muon sub-optimizer itself."
359374
},
360375
)
361376
eps: float = field(
362377
default=1e-8,
363378
metadata={
364-
"help": "Adam epsilon parameter. Only effective when optimizer_type is adam/adam_bf16"
379+
"help": "Adam epsilon for numerical stability. Used by adam/adam_bf16, and by the "
380+
"AdamW backend when type='muon'. Not used by the Muon sub-optimizer itself."
365381
},
366382
)
367383
min_lr_ratio: float = field(
@@ -398,6 +414,45 @@ class OptimizerConfig:
398414
gradient_clipping: float = field(
399415
default=1.0, metadata={"help": "Gradient clipping threshold"}
400416
)
417+
muon_momentum: float = field(
418+
default=0.95,
419+
metadata={
420+
"help": "Muon momentum parameter. Only effective when optimizer_type is muon."
421+
},
422+
)
423+
muon_use_nesterov: bool = field(
424+
default=True,
425+
metadata={
426+
"help": "Whether to use Nesterov momentum in Muon. Only effective when type='muon'. "
427+
"Mirrors Megatron-Core OptimizerConfig.muon_use_nesterov."
428+
},
429+
)
430+
muon_num_ns_steps: int = field(
431+
default=5,
432+
metadata={
433+
"help": "Number of Newton-Schulz iteration steps in Muon. Only effective when type='muon'. "
434+
"Mirrors Megatron-Core OptimizerConfig.muon_num_ns_steps."
435+
},
436+
)
437+
muon_scale_mode: Literal["rms", "spectral"] = field(
438+
default="rms",
439+
metadata={
440+
"help": "Update-scaling mode for Muon. 'rms' (Moonlight-style) scales the update so its "
441+
"RMS matches Adam, allowing a single lr for all parameters (see https://arxiv.org/abs/2502.16982). "
442+
"'spectral' uses the Keller Jordan max(1, m/n)^0.5 spectral scaling. "
443+
"Only effective when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_scale_mode.",
444+
"choices": ["rms", "spectral"],
445+
},
446+
)
447+
muon_backend_lr: float | None = field(
448+
default=None,
449+
metadata={
450+
"help": "Learning rate for the AdamW backend optimizer in Muon (handles <2D params: "
451+
"biases, norms, embeddings). Typical value: ~3e-4. If None, falls back to the main lr "
452+
"with a warning (since Muon lr is typically ~100x larger). "
453+
"Only effective when type='muon'."
454+
},
455+
)
401456

402457

403458
@dataclass

areal/engine/fsdp_engine.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@
8383
)
8484
from areal.engine.fsdp_utils.checkpoint import DCPState
8585
from areal.engine.fsdp_utils.grad import fsdp2_clip_grad_norm
86-
from areal.engine.fsdp_utils.optimizer import AnyPrecisionAdamW, PerLayerOptimWrapper
86+
from areal.engine.fsdp_utils.muon import Muon as MuonOptimizer
87+
from areal.engine.fsdp_utils.optimizer import (
88+
AnyPrecisionAdamW,
89+
PerLayerOptimWrapper,
90+
)
8791
from areal.engine.fsdp_utils.parallel import ParallelHelper, parallelize_model
8892
from areal.infra.dist_rollout import DistRolloutCoordinator
8993
from areal.infra.platforms import current_platform
@@ -470,7 +474,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
470474
self._create_optimizer(ft_spec)
471475

472476
if self.config.fsdp.per_layer_optim_step:
473-
if self.optimizer_config.type != "adam":
477+
if self.optimizer_config.type not in ("adam",):
474478
raise ValueError(
475479
f"per_layer_optim_step only supports 'adam' optimizer, got '{self.optimizer_config.type}'."
476480
)
@@ -1111,7 +1115,8 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
11111115
"adam",
11121116
"adam_bf16",
11131117
"sgd",
1114-
], "Only adam/adam_bf16/sgd optimizer is supported in this engine."
1118+
"muon",
1119+
], "Only adam/adam_bf16/sgd/muon optimizer is supported in this engine."
11151120
if self.optimizer_config.type in ["sgd", "adam_bf16"]:
11161121
self.logger.warning(
11171122
f"Using the '{self.optimizer_config.type}' optimizer with FSDP may be less stable. Consider using the 'adam' (AdamW) optimizer for improved stability and performance."
@@ -1121,7 +1126,53 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
11211126
beta1 = self.optimizer_config.beta1
11221127
beta2 = self.optimizer_config.beta2
11231128
eps = self.optimizer_config.eps
1124-
if self.optimizer_config.type == "adam":
1129+
if self.optimizer_config.type == "muon":
1130+
muon_params: list[torch.nn.Parameter] = []
1131+
backend_params: list[torch.nn.Parameter] = []
1132+
for p in self.model.parameters():
1133+
if not p.requires_grad:
1134+
continue
1135+
if p.ndim >= 2:
1136+
muon_params.append(p)
1137+
else:
1138+
backend_params.append(p)
1139+
if self.optimizer_config.muon_backend_lr is not None:
1140+
backend_lr = self.optimizer_config.muon_backend_lr
1141+
else:
1142+
backend_lr = lr
1143+
self.logger.warning(
1144+
"muon_backend_lr is not set; falling back to main lr (%.2e) for AdamW backend. "
1145+
"Typical Muon setups use a much smaller backend lr (e.g. 3e-4). "
1146+
"Set muon_backend_lr explicitly to suppress this warning.",
1147+
lr,
1148+
)
1149+
self.optimizer = MuonOptimizer(
1150+
[
1151+
dict(
1152+
params=muon_params,
1153+
lr=lr,
1154+
momentum=self.optimizer_config.muon_momentum,
1155+
weight_decay=weight_decay,
1156+
rms_scale=self.optimizer_config.muon_scale_mode == "rms",
1157+
nesterov=self.optimizer_config.muon_use_nesterov,
1158+
ns_steps=self.optimizer_config.muon_num_ns_steps,
1159+
use_muon=True,
1160+
),
1161+
dict(
1162+
params=backend_params,
1163+
lr=backend_lr,
1164+
betas=(beta1, beta2),
1165+
eps=eps,
1166+
weight_decay=weight_decay,
1167+
use_muon=False,
1168+
),
1169+
]
1170+
)
1171+
self.logger.info(
1172+
f"Muon optimizer: {len(muon_params)} params (>=2D), "
1173+
f"AdamW backend: {len(backend_params)} params (<2D)"
1174+
)
1175+
elif self.optimizer_config.type == "adam":
11251176
self.optimizer = torch.optim.AdamW(
11261177
self.model.parameters(),
11271178
lr=lr,

areal/engine/fsdp_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from transformers import PreTrainedModel
1515

16+
from areal.engine.fsdp_utils.muon import Muon
1617
from areal.engine.fsdp_utils.optimizer import (
1718
AdamKernel,
1819
OptimKernel,
@@ -33,6 +34,7 @@
3334
"apply_fsdp2",
3435
"fsdp2_load_full_state_dict",
3536
"get_cosine_schedule_with_warmup",
37+
"Muon",
3638
"PerLayerOptimWrapper",
3739
"OptimKernel",
3840
"AdamKernel",

0 commit comments

Comments
 (0)