Skip to content

Feature request: Adahessian optimizer #1603

@viralvgupta

Description

@viralvgupta

I’d like to request adding AdaHessian to Optax, and I’m happy to implement it.

AdaHessian is a second‑order optimizer that uses a Hessian diagonal approximation for adaptive updates. I already have a working JAX version here and would adapt it to Optax’s optimizer structure:

Paper: https://arxiv.org/abs/2006.00719
PyTorch reference: https://github.com/amirgholami/adahessian
My JAX implementation: https://github.com/viralvgupta/adahessian-jax
If you’re open to it, I can start a PR and align with whatever Optax design conventions you prefer. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions