|
72 | 72 | .. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz. |
73 | 73 | https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin) |
74 | 74 | .. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates. |
75 | | - arXiv:2602.15322, 2025. |
| 75 | + arXiv:2602.15322, 2026. |
76 | 76 | https://arxiv.org/abs/2602.15322 |
77 | 77 | Implements block-wise momentum-gradient alignment scoring with EMA smoothing |
78 | 78 | and soft scaling for improved stability under heavy-tailed gradient noise. |
@@ -340,9 +340,7 @@ def _batched_newton_schulz_orth( |
340 | 340 | """ |
341 | 341 | # === Step 1. Validate and prepare matrix orientation === |
342 | 342 | if G.ndim != 3: |
343 | | - raise ValueError( |
344 | | - "Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)." |
345 | | - ) |
| 343 | + raise ValueError("Batched Newton-Schulz expects a 3D tensor (B, m, n).") |
346 | 344 |
|
347 | 345 | X = G.to(dtype=torch.bfloat16) |
348 | 346 | transposed = X.size(-2) > X.size(-1) |
@@ -474,9 +472,7 @@ def get_matrix_view_shape( |
474 | 472 | rows = int(effective_shape[-2]) |
475 | 473 | cols = int(effective_shape[-1]) |
476 | 474 | return (batch_size, rows, cols) |
477 | | - raise ValueError( |
478 | | - f"Unsupported muon_mode '{muon_mode}'. Expected one of ['2d', 'flat', 'slice']." |
479 | | - ) |
| 475 | + raise ValueError(f"Invalid muon_mode '{muon_mode}'. Use '2d', 'flat', or 'slice'.") |
480 | 476 |
|
481 | 477 |
|
482 | 478 | class HybridMuonOptimizer(Optimizer): |
@@ -601,7 +597,9 @@ def __init__( |
601 | 597 | # === Step 1. Validate routing mode === |
602 | 598 | muon_mode = str(muon_mode).lower() |
603 | 599 | if muon_mode not in {"2d", "flat", "slice"}: |
604 | | - raise ValueError("muon_mode must be one of ['2d', 'flat', 'slice'].") |
| 600 | + raise ValueError( |
| 601 | + f"Invalid muon_mode '{muon_mode}'. Use '2d', 'flat', or 'slice'." |
| 602 | + ) |
605 | 603 |
|
606 | 604 | # === Step 2. Register optimizer defaults === |
607 | 605 | defaults = { |
@@ -1024,10 +1022,12 @@ def step( |
1024 | 1022 |
|
1025 | 1023 | # exp_avg = beta1 * exp_avg + (1 - beta1) * grad |
1026 | 1024 | # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 |
1027 | | - for ea, g in zip(adam_no_decay_exp_avgs, adam_no_decay_grads_fp32): |
| 1025 | + for ea, g in zip( |
| 1026 | + adam_no_decay_exp_avgs, adam_no_decay_grads_fp32, strict=True |
| 1027 | + ): |
1028 | 1028 | ea.lerp_(g, 1 - adam_betas[0]) |
1029 | 1029 | grad_sq = [g * g for g in adam_no_decay_grads_fp32] |
1030 | | - for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq): |
| 1030 | + for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq, strict=True): |
1031 | 1031 | eas.lerp_(gsq, 1 - adam_betas[1]) |
1032 | 1032 |
|
1033 | 1033 | # === Step 1.3. Bias correction and parameter update === |
@@ -1083,10 +1083,12 @@ def step( |
1083 | 1083 |
|
1084 | 1084 | # exp_avg = beta1 * exp_avg + (1 - beta1) * grad |
1085 | 1085 | # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 |
1086 | | - for ea, g in zip(adam_decay_exp_avgs, adam_decay_grads_fp32): |
| 1086 | + for ea, g in zip( |
| 1087 | + adam_decay_exp_avgs, adam_decay_grads_fp32, strict=True |
| 1088 | + ): |
1087 | 1089 | ea.lerp_(g, 1 - adam_betas[0]) |
1088 | 1090 | grad_sq = [g * g for g in adam_decay_grads_fp32] |
1089 | | - for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq): |
| 1091 | + for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq, strict=True): |
1090 | 1092 | eas.lerp_(gsq, 1 - adam_betas[1]) |
1091 | 1093 |
|
1092 | 1094 | # === Step 2.3. Bias correction and parameter update === |
|
0 commit comments