improve traceabaility to JAX#163
Conversation
|
Hi @n-gao, thanks for the PR and congrats on the tojax/kups release! I've made a few edits on top of your changes to support all our current models (direct and conservative). See this branch and this diff. Two things I wanted to ask about:
|
|
Hi @vsimkus, thanks for the detailed feedback! :) We did not include the wrapping code as part of the public API since it's more brittle than I would like it to be since some codebases are easier to trace through than other (yours is actually great!). Furthermore, it might change in the future and is not considered stable, so I'd prefer your way of reimplementing a small chunk of that logic here. |
When tracing the model to JAX with https://github.com/cusp-ai-oss/tojax, we ideally want to use symbolic shapes for the number of atoms, number of edges and number of systems. However,
__len__is required to be an instance of integer prohibiting the use of symbolic shapes for tracing a program in JAX. Here, we supply static shape information and replace__len__with.shape[0].