Skip to content

Implement SOAP optimizer #1580

@Rahul-2k4

Description

@Rahul-2k4

Description:

I’d like to propose adding the SOAP (Second-order Optimization with Alternating Projections) optimizer to optax.contrib. SOAP is a high-performance second-order method that improves upon Shampoo by running Adam steps in the preconditioner's eigenbasis.

Paper (ICLR 2025): SOAP: Improving and Stabilizing Shampoo using Adam

Idea

  • Performance: ~40% fewer iterations and ~35% faster wall-clock time vs AdamW on LLMs.
  • Stability: Much less performance drop when eigendecomposition is infrequent.
  • Continuous Updates: Second-moment updates happen in the current basis, avoiding stale statistics.
    Proposed Implementation:
  • Core: Add scale_by_soap and soap alias to optax/contrib/_soap.py.
  • State: Use a JAX-idiomatic NamedTuple to manage basis and momentum buffers.
  • JIT: Use jax.lax.cond for periodic eigendecomposition to remain JIT-friendly.
  • Generalization: Support nD parameter shapes via dimension numbers (Muon/Shampoo style).
  • Tests: Full unit tests in optax/contrib/_soap_test.py.

Motivation:

I will be referencing haydn-jones/SOAP_JAX for the industrialization.

Would you be open to a PR for this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions