Skip to content

improve traceabaility to JAX#163

Open
n-gao wants to merge 1 commit into
orbital-materials:mainfrom
n-gao:ng/traceable
Open

improve traceabaility to JAX#163
n-gao wants to merge 1 commit into
orbital-materials:mainfrom
n-gao:ng/traceable

Conversation

@n-gao
Copy link
Copy Markdown

@n-gao n-gao commented May 12, 2026

When tracing the model to JAX with https://github.com/cusp-ai-oss/tojax, we ideally want to use symbolic shapes for the number of atoms, number of edges and number of systems. However, __len__ is required to be an instance of integer prohibiting the use of symbolic shapes for tracing a program in JAX. Here, we supply static shape information and replace __len__ with .shape[0].

@vsimkus
Copy link
Copy Markdown
Contributor

vsimkus commented May 16, 2026

Hi @n-gao, thanks for the PR and congrats on the tojax/kups release!

I've made a few edits on top of your changes to support all our current models (direct and conservative). See this branch and this diff.

Two things I wanted to ask about:

  1. Right now wrapping a model directly with tojax(model) doesn't work for conservative models because tojax can't translate torch.autograd.grad. Is there any plan to support this? It would be the most general way to convert models - no need to disassemble them into energy-only functions and then wrap in an ad-hoc jax.grad-friendly wrapper.

  2. So instead of tojax(model), I'm currently following the pattern from your export_orb.py - extracting the energy function, then using jax.grad for forces and stress. I'd like to have tests on our side to make sure our models stay compatible with tojax/kups, but since the code in export_common is not included in the package, I'm essentially reimplementing it here - but that logic shouldn't really live in orb-models.

    Would you consider making that wrapping code part of the public tojax (or kups) API? Something that takes a prediction function and returns a function that computes energy, forces, and stress via either jax.grad (for conservative models) or direct predictions (for direct models). So we could just call your wrapper in our tests and not have to reimplement it here.

@n-gao
Copy link
Copy Markdown
Author

n-gao commented May 26, 2026

Hi @vsimkus, thanks for the detailed feedback! :)
Currently, we use the tojax scripts as a one-time export to run the final models in kUPS (there we use jax.vjp internally for all gradient computations).

We did not include the wrapping code as part of the public API since it's more brittle than I would like it to be since some codebases are easier to trace through than other (yours is actually great!). Furthermore, it might change in the future and is not considered stable, so I'd prefer your way of reimplementing a small chunk of that logic here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants