Skip to content

Commit 5c545e6

Browse files
committed
docs: add doc
1 parent 0fd3fa2 commit 5c545e6

9 files changed

Lines changed: 275 additions & 48 deletions

File tree

areal/api/cli_args.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import asdict, dataclass, field, fields
88
from enum import Enum
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
10+
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
1111

1212
import uvloop
1313
import yaml
@@ -348,8 +348,10 @@ class OptimizerConfig:
348348
lr: float = field(
349349
default=1e-3,
350350
metadata={
351-
"help": "Learning rate. When type='muon', this is the Muon lr for >=2D params "
352-
"(typical value: ~0.02). The AdamW backend lr is controlled by muon_backend_lr."
351+
"help": "Learning rate. When type='muon', this is shared by both the Muon sub-optimizer "
352+
"(>=2D params) and the AdamW backend (<2D params). Pair "
353+
"muon_scale_mode='spectral' with muon_extra_scale_factor=0.2 (Moonlight-style) to "
354+
"make Muon's update RMS match AdamW so a single lr works for both."
353355
},
354356
)
355357
weight_decay: float = field(
@@ -434,26 +436,31 @@ class OptimizerConfig:
434436
"Mirrors Megatron-Core OptimizerConfig.muon_num_ns_steps."
435437
},
436438
)
437-
muon_scale_mode: Literal["rms", "spectral"] = field(
438-
default="rms",
439+
muon_scale_mode: str = field(
440+
default="spectral",
439441
metadata={
440-
"help": "Update-scaling mode for Muon. 'rms' (Moonlight-style) scales the update so its "
441-
"RMS matches Adam, allowing a single lr for all parameters (see https://arxiv.org/abs/2502.16982). "
442-
"'spectral' uses the Keller Jordan max(1, m/n)^0.5 spectral scaling. "
443-
"Only effective when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_scale_mode.",
444-
"choices": ["rms", "spectral"],
442+
"help": "Muon update scaling mode (final scale = mode_factor * muon_extra_scale_factor):"
443+
"Only used when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_scale_mode.",
444+
"choices": ["spectral", "unit_rms_norm", "shape_scaling"],
445445
},
446446
)
447-
muon_backend_lr: float | None = field(
448-
default=None,
447+
muon_extra_scale_factor: float = field(
448+
default=1.0,
449449
metadata={
450-
"help": "Learning rate for the AdamW backend optimizer in Muon (handles <2D params: "
451-
"biases, norms, embeddings). Typical value: ~3e-4. If None, falls back to the main lr "
452-
"with a warning (since Muon lr is typically ~100x larger). "
453-
"Only effective when type='muon'."
450+
"help": "Extra multiplier on top of muon_scale_mode. Use 0.2 with "
451+
"scale_mode='spectral' for Moonlight-style RMS-matched scaling. "
452+
"Only used when type='muon'. Mirrors Megatron-Core OptimizerConfig.muon_extra_scale_factor."
454453
},
455454
)
456455

456+
def __post_init__(self):
457+
"""Validate optimizer configuration."""
458+
valid_muon_scale_modes = {"spectral", "unit_rms_norm", "shape_scaling"}
459+
if self.muon_scale_mode not in valid_muon_scale_modes:
460+
raise ValueError(
461+
f"muon_scale_mode must be one of {valid_muon_scale_modes}, got {self.muon_scale_mode!r}. "
462+
)
463+
457464

458465
@dataclass
459466
class FSDPWrapPolicy:

areal/engine/fsdp_engine.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,31 +1136,22 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
11361136
muon_params.append(p)
11371137
else:
11381138
backend_params.append(p)
1139-
if self.optimizer_config.muon_backend_lr is not None:
1140-
backend_lr = self.optimizer_config.muon_backend_lr
1141-
else:
1142-
backend_lr = lr
1143-
self.logger.warning(
1144-
"muon_backend_lr is not set; falling back to main lr (%.2e) for AdamW backend. "
1145-
"Typical Muon setups use a much smaller backend lr (e.g. 3e-4). "
1146-
"Set muon_backend_lr explicitly to suppress this warning.",
1147-
lr,
1148-
)
11491139
self.optimizer = MuonOptimizer(
11501140
[
11511141
dict(
11521142
params=muon_params,
11531143
lr=lr,
11541144
momentum=self.optimizer_config.muon_momentum,
11551145
weight_decay=weight_decay,
1156-
rms_scale=self.optimizer_config.muon_scale_mode == "rms",
1146+
scale_mode=self.optimizer_config.muon_scale_mode,
1147+
extra_scale_factor=self.optimizer_config.muon_extra_scale_factor,
11571148
nesterov=self.optimizer_config.muon_use_nesterov,
11581149
ns_steps=self.optimizer_config.muon_num_ns_steps,
11591150
use_muon=True,
11601151
),
11611152
dict(
11621153
params=backend_params,
1163-
lr=backend_lr,
1154+
lr=lr,
11641155
betas=(beta1, beta2),
11651156
eps=eps,
11661157
weight_decay=weight_decay,

areal/engine/fsdp_utils/muon.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,44 @@ def apply_momentum(
8080
return update
8181

8282

83-
def apply_scaling(grad: Tensor, rms_scale: bool = False) -> Tensor:
84-
"""Post-NS scaling: either Moonlight RMS or Keller Jordan max(1, m/n)^0.5."""
85-
if rms_scale:
86-
# https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d05865fe90d9f39abcbcbd/examples/toy_train.py#L146
87-
grad *= 0.2 * math.sqrt(max(grad.shape[1], grad.shape[0]))
83+
def apply_scaling(
84+
grad: Tensor,
85+
mode: str = "spectral",
86+
extra_scale_factor: float = 1.0,
87+
) -> Tensor:
88+
"""Post-Newton-Schulz update scaling.
89+
90+
Naming aligned with Megatron-Core / emerging_optimizers (NVIDIA-NeMo).
91+
92+
Final scale = scale_factor(mode) * extra_scale_factor, where:
93+
- 'spectral' : sqrt(max(m, n))
94+
Kimi/Moonlight (arXiv:2502.16982); emerging_optimizers default.
95+
- 'unit_rms_norm' : sqrt(m / n)
96+
Scion (arXiv:2502.07529) / Bernstein
97+
(https://jeremybernste.in/writing/deriving-muon).
98+
- 'shape_scaling' : max(1, m / n)**0.5
99+
Keller Jordan original (https://kellerjordan.github.io/posts/muon).
100+
101+
Set extra_scale_factor=0.2 with mode='spectral' to reproduce the legacy
102+
Moonlight `https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d05865fe90d9f39abcbcbd/examples/toy_train.py#L146`
103+
setting (= 0.2 * sqrt(max(m, n))), which
104+
approximately matches AdamW's update RMS norm so a single lr works for
105+
both Muon and the AdamW backend.
106+
"""
107+
m = grad.size(-2)
108+
n = grad.size(-1)
109+
if mode == "spectral":
110+
scale = math.sqrt(max(m, n))
111+
elif mode == "unit_rms_norm":
112+
scale = math.sqrt(m / n)
113+
elif mode == "shape_scaling":
114+
scale = max(1, m / n) ** 0.5
88115
else:
89-
# https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L40
90-
grad *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
116+
raise ValueError(
117+
f"Invalid muon_scale_mode {mode!r}. Valid: "
118+
"{'spectral', 'unit_rms_norm', 'shape_scaling'}."
119+
)
120+
grad *= scale * extra_scale_factor
91121
return grad
92122

93123

@@ -194,7 +224,11 @@ def finish(self):
194224
else:
195225
scatter(grad.to_local(), None, src=dest_rank, group=pg, async_op=False)
196226

197-
update = apply_scaling(grad, self.group["rms_scale"])
227+
update = apply_scaling(
228+
grad,
229+
self.group["scale_mode"],
230+
self.group["extra_scale_factor"],
231+
)
198232

199233
self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"])
200234
self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"])
@@ -272,7 +306,11 @@ def finish(self):
272306
new_local = distribute_tensor(g_full, mesh, placements).to_local()
273307
grad.to_local().copy_(new_local)
274308

275-
update = apply_scaling(grad, self.group["rms_scale"])
309+
update = apply_scaling(
310+
grad,
311+
self.group["scale_mode"],
312+
self.group["extra_scale_factor"],
313+
)
276314

277315
self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"])
278316
self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"])
@@ -309,7 +347,11 @@ def start(self):
309347
)
310348
update = zeropower_via_newtonschulz5(update, self.group["ns_steps"])
311349
update = update.to(self.param.grad.dtype)
312-
update = apply_scaling(update, self.group["rms_scale"])
350+
update = apply_scaling(
351+
update,
352+
self.group["scale_mode"],
353+
self.group["extra_scale_factor"],
354+
)
313355
self.param.mul_(1 - self.group["lr"] * self.group["weight_decay"])
314356
self.param.add_(update.reshape(self.param.shape), alpha=-self.group["lr"])
315357

@@ -330,7 +372,8 @@ class Muon(torch.optim.Optimizer):
330372
331373
Notable changes:
332374
- DTensor/FSDP2 native: uses gather/scatter for distributed NS instead of DDP.
333-
- ``rms_scale`` argument following the Moonlight paper (https://arxiv.org/abs/2502.16982).
375+
- ``scale_mode`` / ``extra_scale_factor`` arguments aligned with Megatron-Core /
376+
emerging_optimizers (NVIDIA-NeMo). See :func:`apply_scaling` for details.
334377
335378
Example::
336379
@@ -340,7 +383,7 @@ class Muon(torch.optim.Optimizer):
340383
])
341384
342385
Param group args (``use_muon=True``):
343-
lr, momentum, weight_decay, rms_scale, nesterov, ns_steps
386+
lr, momentum, weight_decay, scale_mode, extra_scale_factor, nesterov, ns_steps
344387
345388
Param group args (``use_muon=False``):
346389
lr, betas, eps, weight_decay
@@ -353,7 +396,8 @@ def __init__(self, param_groups):
353396
group.setdefault("lr", 0.02)
354397
group.setdefault("momentum", 0.95)
355398
group.setdefault("weight_decay", 0)
356-
group.setdefault("rms_scale", True)
399+
group.setdefault("scale_mode", "spectral")
400+
group.setdefault("extra_scale_factor", 1.0)
357401
group.setdefault("nesterov", True)
358402
group.setdefault("ns_steps", 5)
359403
assert set(group.keys()) == {
@@ -362,7 +406,8 @@ def __init__(self, param_groups):
362406
"momentum",
363407
"weight_decay",
364408
"use_muon",
365-
"rms_scale",
409+
"scale_mode",
410+
"extra_scale_factor",
366411
"nesterov",
367412
"ns_steps",
368413
}

areal/engine/megatron_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,14 +1327,15 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
13271327
# Forward Muon-specific hyperparameters onto Megatron-Core's OptimizerConfig.
13281328
# AReaL's muon_* fields are 1:1 aligned with Megatron-Core >= 0.17, so no
13291329
# translation is required. Fields not exposed by AReaL (muon_coefficient_type,
1330-
# muon_split_qkv, muon_tp_mode, muon_extra_scale_factor, muon_fp32_matmul_prec)
1331-
# keep their Megatron defaults.
1330+
# muon_split_qkv, muon_tp_mode, muon_fp32_matmul_prec) keep their Megatron
1331+
# defaults.
13321332
if self.optimizer_config.type == "muon":
13331333
muon_passthrough_fields = (
13341334
"muon_momentum",
13351335
"muon_use_nesterov",
13361336
"muon_num_ns_steps",
13371337
"muon_scale_mode",
1338+
"muon_extra_scale_factor",
13381339
)
13391340
for attr in muon_passthrough_fields:
13401341
if hasattr(mcore_opt_config, attr):

areal/experimental/engine/archon_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def create_optimizer(
6060
eps=eps,
6161
fused=True,
6262
)
63-
elif optimizer_config.type == "muon":
64-
raise NotImplementedError(
65-
"Muon optimizer is not yet supported under ArchonEngine. "
66-
)
6763
elif optimizer_config.type == "sgd":
6864
return torch.optim.SGD(
6965
params,

docs/en/_toc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ parts:
4040
- file: algorithms/prox_approx
4141
- caption: Reference
4242
chapters:
43+
- file: reference/optimizer
4344
- file: reference/checkpointing
4445
- file: reference/metrics_tracking
4546
- file: reference/alloc_mode

docs/en/reference/optimizer.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
(section-optimizer-guide)=
2+
3+
# Optimizer Configuration Guide
4+
5+
AReaL supports multiple optimizer types, configurable via the `optimizer.type` field.
6+
This document covers the support matrix across training backends and the implementation
7+
differences of the Muon optimizer.
8+
9+
## Supported Optimizer Types
10+
11+
| Type | Description |
12+
| ----------- | -------------------------------------------------------------------------------------------------- |
13+
| `adam` | AdamW optimizer (default) |
14+
| `adam_bf16` | BF16-precision AdamW, reduces optimizer state memory |
15+
| `sgd` | Standard SGD |
16+
| `muon` | Muon optimizer: Newton-Schulz orthogonalized updates for ≥2D params, AdamW backend for \<2D params |
17+
18+
## Engine Support Matrix
19+
20+
| Optimizer | FSDP Engine | Megatron Engine | Archon Engine |
21+
| ----------- | :--------------------: | :----------------------------: | :------------------: |
22+
| `adam` ||||
23+
| `adam_bf16` | ✅ (AnyPrecisionAdamW) | ✅ (precision-aware optimizer) ||
24+
| `sgd` ||||
25+
| `muon` || ✅ (Megatron-Core ≥ 0.17) | ❌ (not implemented) |
26+
27+
### Notes
28+
29+
- **FSDP Engine**: `adam_bf16` uses `AnyPrecisionAdamW`, storing momentum and variance
30+
in BF16.
31+
- **Megatron Engine**: `adam_bf16` requires model dtype to be bfloat16; it is
32+
auto-converted to adam with precision-aware optimizer enabled.
33+
- **Archon Engine**: Currently only supports `adam` and `sgd`. Muon support is under
34+
development.
35+
36+
## Muon Optimizer
37+
38+
### Overview
39+
40+
Muon (MomentUm Orthogonalized by Newton-schulz) is an optimizer that applies approximate
41+
orthogonalization to gradient momentum via Newton-Schulz iteration. The core idea is to
42+
impose an orthogonal constraint on weight matrix gradients, making update directions
43+
more "uniform" in parameter space and accelerating convergence.
44+
45+
### Reference Implementations and Papers
46+
47+
| Resource | Link |
48+
| ---------------------------------------- | -------------------------------------------------- |
49+
| Original implementation (Keller Jordan) | https://github.com/KellerJordan/Muon |
50+
| Moonlight paper (RMS scaling) | https://arxiv.org/abs/2502.16982 |
51+
| AReaL FSDP implementation | `areal/engine/fsdp_utils/muon.py` |
52+
| Emerging-Optimizers (Megatron-Core Muon) | https://github.com/NVIDIA-NeMo/Emerging-Optimizers |
53+
54+
### FSDP vs Megatron Implementation Differences
55+
56+
The FSDP Engine and Megatron Engine differ significantly in how they partition
57+
parameters for Muon:
58+
59+
| Dimension | FSDP Engine | Megatron Engine |
60+
| --------------------------------- | ----------------------------------------------------------------- | ------------------------------------------------------------------------------------- |
61+
| **Muon parameter scope** | **All** ≥2D parameters (including embedding weight matrices) | **Linear layer weights** |
62+
| **AdamW backend parameters** | All \<2D parameters (bias, LayerNorm weight/bias) | Embeddings, biases, norms, and non-Linear 2D parameters |
63+
| **Distributed NS implementation** | DTensor gather/scatter (FSDP2 native) | TP-aware `TensorParallelMuon` (distributed Newton-Schulz over TP communication group) |
64+
| **TP + EP support** | TP + FSDP 2D mesh ✅; TP + EP + FSDP 3D mesh ❌ (not implemented) | Full TP / EP / PP support |
65+
66+
### Configuration Example
67+
68+
```yaml
69+
optimizer:
70+
type: muon
71+
lr: 2e-3 # Shared lr (Muon and AdamW backend)
72+
muon_momentum: 0.95
73+
muon_use_nesterov: true
74+
muon_num_ns_steps: 5
75+
muon_scale_mode: spectral # spectral / unit_rms_norm / shape_scaling
76+
muon_extra_scale_factor: 0.2 # 0.2 + spectral = Moonlight-style RMS-matched scaling
77+
weight_decay: 0.05
78+
beta1: 0.9 # AdamW backend params
79+
beta2: 0.95
80+
eps: 1e-5
81+
lr_scheduler_type: cosine
82+
warmup_steps_proportion: 0.03
83+
```
84+
85+
### Configuration Parameters
86+
87+
| Parameter | Type | Default | Description |
88+
| ------------------------- | ----- | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
89+
| `lr` | float | 0.001 | Shared learning rate for both Muon (≥2D params) and AdamW backend (\<2D params). A single lr works well when pairing `muon_scale_mode=spectral` with `muon_extra_scale_factor=0.2` (Moonlight-style) |
90+
| `muon_momentum` | float | 0.95 | Muon momentum coefficient |
91+
| `muon_use_nesterov` | bool | true | Whether to use Nesterov momentum |
92+
| `muon_num_ns_steps` | int | 5 | Number of Newton-Schulz iteration steps |
93+
| `muon_scale_mode` | str | "spectral" | Update scaling mode. `spectral`: `sqrt(max(m, n))` (Kimi/Moonlight, emerging_optimizers default). `unit_rms_norm`: `sqrt(m / n)` (Scion / Bernstein). `shape_scaling`: `max(1, m/n)**0.5` (Keller Jordan original) |
94+
| `muon_extra_scale_factor` | float | 1.0 | Extra multiplicative scale; final scale = `scale_factor(mode) * muon_extra_scale_factor`. Use `0.2` with `spectral` to reproduce Moonlight-style RMS-matched scaling |
95+
| `weight_decay` | float | 0.01 | Weight decay, applied to both Muon and AdamW backend |
96+
| `beta1` / `beta2` / `eps` | float | 0.9 / 0.999 / 1e-8 | AdamW backend hyperparameters |

docs/zh/_toc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ parts:
4040
- file: algorithms/prox_approx
4141
- caption: 参考
4242
chapters:
43+
- file: reference/optimizer
4344
- file: reference/checkpointing
4445
- file: reference/metrics_tracking
4546
- file: reference/alloc_mode

0 commit comments

Comments
 (0)