I’d like to request adding AdaHessian to Optax, and I’m happy to implement it.
AdaHessian is a second‑order optimizer that uses a Hessian diagonal approximation for adaptive updates. I already have a working JAX version here and would adapt it to Optax’s optimizer structure:
Paper: https://arxiv.org/abs/2006.00719
PyTorch reference: https://github.com/amirgholami/adahessian
My JAX implementation: https://github.com/viralvgupta/adahessian-jax
If you’re open to it, I can start a PR and align with whatever Optax design conventions you prefer. Thanks!
I’d like to request adding AdaHessian to Optax, and I’m happy to implement it.
AdaHessian is a second‑order optimizer that uses a Hessian diagonal approximation for adaptive updates. I already have a working JAX version here and would adapt it to Optax’s optimizer structure:
Paper: https://arxiv.org/abs/2006.00719
PyTorch reference: https://github.com/amirgholami/adahessian
My JAX implementation: https://github.com/viralvgupta/adahessian-jax
If you’re open to it, I can start a PR and align with whatever Optax design conventions you prefer. Thanks!