Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/apidocs/orthogonalized-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ emerging_optimizers.orthogonalized_optimizers
.. autoclass:: MuonHyperball
:members:

:hidden:`Spectron`
~~~~~~~~~~~~~~~~~~~

.. autoclass:: Spectron
:members:


:hidden:`Newton-Schulz`
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.scion import *
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
from emerging_optimizers.orthogonalized_optimizers.spectron import *
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing trailing newline

The file is missing a trailing newline after the new import line. This is flagged by most linters and POSIX standards, and the previous version of the file had one.

Suggested change
from emerging_optimizers.orthogonalized_optimizers.spectron import *
from emerging_optimizers.orthogonalized_optimizers.spectron import *

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

251 changes: 251 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/spectron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, overload, override

import torch
import torch.optim as optim
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import registry, utils
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.utils.eig import power_iteration


__all__ = ["Spectron"]


@registry.register_optimizer("spectron")
class Spectron(opt_mixin.WeightDecayMixin, optim.Optimizer):
"""Spectron: Low-rank spectral optimizer with orthogonalized momentum.

Spectron maintains each 2D weight matrix W as a low-rank factorization W = A @ B^T,
where A ∈ R^(m×r) and B ∈ R^(n×r). It applies momentum, orthogonalizes the updates
using Newton-Schulz iteration, and scales the learning rate by the spectral radii
of both factors.

The algorithm:
1. Compute gradients with respect to A and B from parameter gradients
2. Apply momentum to both factors
3. Orthogonalize momentum buffers using Newton-Schulz iteration
4. Estimate spectral radius of A and B using power iteration
5. Update with scaled learning rate: η / (σ_A + σ_B + 1)
6. Reconstruct full weight matrix W = A @ B^T

References:
- Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper (https://arxiv.org/abs/2602.12429).
Low-rank spectral optimization with orthogonalized momentum.

Warning:
- This optimizer requires that all parameters passed in are 2D.
- Low-rank factorization may not be suitable for all parameter types.

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate (η in the algorithm). Default: 3e-4
rank: The rank of the low-rank factorization. Default: 64
momentum_beta: The momentum decay coefficient (β). Default: 0.9
weight_decay: The weight decay coefficient. Default: 0.01
weight_decay_method: Method to apply weight decay. Default: "decoupled"
fp32_matmul_prec: Precision of matmul operations. Default: "medium"
num_ns_steps: Number of Newton-Schulz iteration steps. Default: 5
num_power_iter: Number of power iteration steps for spectral radius. Default: 1
coefficient_type: Type of coefficient set for Newton-Schulz. Default: "quintic"
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
rank: int = 64,
momentum_beta: float = 0.9,
weight_decay: float = 0.01,
*,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
fp32_matmul_prec: FP32MatmulPrecT = "medium",
num_ns_steps: int = 5,
num_power_iter: int = 1,
coefficient_type: NSCoeffT = "quintic",
) -> None:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if rank < 1:
raise ValueError(f"Invalid rank: {rank}")
if not 0.0 <= momentum_beta < 1.0:
raise ValueError(f"Invalid momentum_beta: {momentum_beta}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay: {weight_decay}")
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
if num_power_iter < 1:
raise ValueError(f"num_power_iter must be at least 1, got {num_power_iter}")

self.fp32_matmul_prec = fp32_matmul_prec
self.weight_decay_method = weight_decay_method
self.rank = rank
self.num_power_iter = num_power_iter

# Create orthogonalization function following OrthogonalizedOptimizer pattern
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
logging.debug(f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient")
return muon_utils.newton_schulz(
grad,
steps=num_ns_steps,
coefficient_type=coefficient_type,
)

self.scaled_orthogonalize_fn = scaled_orthogonalize_fn

defaults = dict(
lr=lr,
momentum_beta=momentum_beta,
weight_decay=weight_decay,
)

super().__init__(params, defaults)

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.

Args:
closure: A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

if p.ndim != 2:
raise ValueError(f"Spectron only supports 2D parameters, got shape {p.shape}")

grad = p.grad
state = self.state[p]

# Initialize low-rank factors and momentum buffers
if "factor_A" not in state:
self._initialize_state(p, state)

# Get state variables
factor_A = state["factor_A"]
factor_B = state["factor_B"]
momentum_A = state["momentum_A"]
momentum_B = state["momentum_B"]
u_A = state["u_A"]
u_B = state["u_B"]

# Compute gradients for A and B from parameter gradient
# Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A
grad_A = grad @ factor_B # shape: (m, r)
grad_B = grad.mT @ factor_A # shape: (n, r)

Comment on lines +178 to +181
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient dtype mismatch with non-fp32 parameters

grad = p.grad inherits p's dtype, but factor_B is always float32 (initialized from torch.linalg.svd(p.float(), ...)). When the parameter is bfloat16 — the standard dtype for LLM pretraining, which is the stated use case — the line grad @ factor_B will raise a RuntimeError at runtime:

RuntimeError: expected scalar type Float but found BFloat16

Even if PyTorch silently promotes the dtype in some contexts, momentum_A.lerp_(grad_A, ...) on line 187 will then fail because momentum_A is float32 but grad_A would be bfloat16.

The gradient should be explicitly cast to float32 before the matmul:

Suggested change
with utils.fp32_matmul_precision("highest"):
grad_A = grad @ factor_B # shape: (m, r)
grad_B = grad.mT @ factor_A # shape: (n, r)
with utils.fp32_matmul_precision("highest"):
grad_A = grad.float() @ factor_B # shape: (m, r)
grad_B = grad.float().mT @ factor_A # shape: (n, r)

# Apply weight decay
self._apply_weight_decay_inplace(factor_A, grad_A, group["lr"], group["weight_decay"])
self._apply_weight_decay_inplace(factor_B, grad_B, group["lr"], group["weight_decay"])
Comment thread
mkhona-nvidia marked this conversation as resolved.

# Update momentum buffers (EMA of gradients)
momentum_A.lerp_(grad_A, 1 - group["momentum_beta"])
momentum_B.lerp_(grad_B, 1 - group["momentum_beta"])

# Orthogonalize momentum using Newton-Schulz
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A)
orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B)
Comment thread
mkhona-nvidia marked this conversation as resolved.

# Estimate spectral radius using power iteration (Algorithm 3)
sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter)
sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter)

# Update power iteration vectors
state["u_A"] = u_A
state["u_B"] = u_B

# Compute scaled learning rate
scaled_lr = group["lr"] / (sigma_A + sigma_B + 1.0)
Comment thread
mkhona-nvidia marked this conversation as resolved.

# Update low-rank factors
factor_A.add_(orth_momentum_A, alpha=-scaled_lr)
factor_B.add_(orth_momentum_B, alpha=-scaled_lr)

# Reconstruct full weight matrix: W = A @ B^T
p.copy_(factor_A @ factor_B.mT)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing this reconstruction is for the compatibility with the rest of the library. Otherwise the whole implementation looks correct.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave the weights of the model as a single matrix, but do the low-rank decomposition as optimizer states (rather than having the low-rank factored weights as 2 separate matrices in the model, which make it harder to access them inside the optimizer). This is functionally identical but makes the SW easier to use

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree


return loss

def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> None:
"""Initialize low-rank factors and state for a parameter.

Args:
p: The parameter tensor (shape: m × n)
state: The state dictionary for this parameter
"""
m, n = p.shape
r = min(self.rank, m, n) # Ensure rank doesn't exceed dimensions

# Initialize A and B using SVD of the parameter
# This provides a good initialization close to the original weights
with torch.no_grad():
U, S, Vh = torch.linalg.svd(p.float(), full_matrices=False)
# Keep only top r singular values/vectors
sqrt_S = torch.sqrt(S[:r])
factor_A = (U[:, :r] * sqrt_S).to(p.dtype)
factor_B = (Vh[:r, :].mT * sqrt_S).to(p.dtype)

state["factor_A"] = factor_A.clone()
state["factor_B"] = factor_B.clone()
state["momentum_A"] = torch.zeros_like(factor_A)
state["momentum_B"] = torch.zeros_like(factor_B)

# Initialize power iteration vectors (normalized random vectors)
u_A = torch.randn(m, dtype=p.dtype, device=p.device)
u_A = u_A / u_A.norm()
u_B = torch.randn(n, dtype=p.dtype, device=p.device)
u_B = u_B / u_B.norm()

state["u_A"] = u_A
state["u_B"] = u_B

def _power_iteration(
self, matrix: torch.Tensor, u: torch.Tensor, num_iters: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Estimate the largest singular value using power iteration.

Args:
matrix: The matrix to estimate largest singular value for (shape: p × q)
u: The current approximation of the dominant left singular vector
num_iters: Number of power iteration steps

Returns:
Tuple of (largest singular value, updated_u)
"""
# power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector)
sigma, u, _v = power_iteration(matrix, u, k=num_iters)
return sigma, u
51 changes: 51 additions & 0 deletions emerging_optimizers/utils/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,60 @@
"met_approx_eigvals_criteria",
"conjugate",
"orthogonal_iteration",
"power_iteration",
]


def power_iteration(
W: torch.Tensor,
u: torch.Tensor,
k: int = 1,
eps: float = 1e-8,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Approximate largest singular value and left/right singular vectors using power iteration.

Implements Algorithm 3 from the Spectron paper (https://arxiv.org/abs/2602.12429). This method iteratively refines
estimates of the dominant singular value and corresponding left and right singular vectors
of a matrix W.

Args:
W: Matrix of shape (p, q) to analyze
u: Initial left singular vector of shape (p,), should be normalized
k: Number of power iteration steps. Default: 1
eps: Small constant for numerical stability. Default: 1e-8

Returns:
Tuple of (sigma, u, v) where:
- sigma: Approximation of the largest singular value (scalar tensor)
- u: Updated left singular vector of shape (p,)
- v: Updated right singular vector of shape (q,)
"""
# Ensure initial normalization
u = u / u.norm(p=2).clamp_min(eps)

# Power iteration loop
for _ in range(k):
# v ← W^T u (right vector)
v = W.mT @ u

# v ← v / ||v||_2 (normalize right vector)
v = v / v.norm(p=2).clamp_min(eps)

# u ← W v (left vector)
u = W @ v

# u ← u / ||u||_2 (normalize left vector)
u = u / u.norm(p=2).clamp_min(eps)

# σ ← u^T W v (Rayleigh quotient approximation)
v = W.mT @ u
v = v / v.norm(p=2).clamp_min(eps)
sigma = u @ (W @ v)

# Return σ, u, and v
return sigma, u, v


def eigh_with_fallback(
x: Tensor,
force_double: bool = False,
Expand Down
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ error=0
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1

exit "${error}"
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ error=0
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cuda -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1
Expand Down
Loading
Loading