Description:
I’d like to propose adding the SOAP (Second-order Optimization with Alternating Projections) optimizer to optax.contrib. SOAP is a high-performance second-order method that improves upon Shampoo by running Adam steps in the preconditioner's eigenbasis.
Paper (ICLR 2025): SOAP: Improving and Stabilizing Shampoo using Adam
Idea
- Performance: ~40% fewer iterations and ~35% faster wall-clock time vs AdamW on LLMs.
- Stability: Much less performance drop when eigendecomposition is infrequent.
- Continuous Updates: Second-moment updates happen in the current basis, avoiding stale statistics.
Proposed Implementation:
- Core: Add scale_by_soap and soap alias to optax/contrib/_soap.py.
- State: Use a JAX-idiomatic NamedTuple to manage basis and momentum buffers.
- JIT: Use jax.lax.cond for periodic eigendecomposition to remain JIT-friendly.
- Generalization: Support nD parameter shapes via dimension numbers (Muon/Shampoo style).
- Tests: Full unit tests in optax/contrib/_soap_test.py.
Motivation:
I will be referencing haydn-jones/SOAP_JAX for the industrialization.
Would you be open to a PR for this?
Description:
I’d like to propose adding the SOAP (Second-order Optimization with Alternating Projections) optimizer to optax.contrib. SOAP is a high-performance second-order method that improves upon Shampoo by running Adam steps in the preconditioner's eigenbasis.
Paper (ICLR 2025): SOAP: Improving and Stabilizing Shampoo using Adam
Idea
Proposed Implementation:
Motivation:
I will be referencing haydn-jones/SOAP_JAX for the industrialization.
Would you be open to a PR for this?