Skip to content

feat: add jax-cache recipe using jax.linearize#577

Open
jpbrodrick89 wants to merge 2 commits into
dion/vjp-cachefrom
jpb/jax-cache
Open

feat: add jax-cache recipe using jax.linearize#577
jpbrodrick89 wants to merge 2 commits into
dion/vjp-cachefrom
jpb/jax-cache

Conversation

@jpbrodrick89
Copy link
Copy Markdown
Contributor

This seems to work pretty well, I wrote it so that it if you set _CACHE_SIZE=0 we recover the current jax recipe so we could consider replacing it entirely. Runtime for apply + vjp on a warm cache is reduced by 40%.

@PasteurBot
Copy link
Copy Markdown
Contributor

PasteurBot commented Apr 29, 2026

Benchmark Results

Benchmarks use a no-op Tesseract to measure pure framework overhead.

🚀 0 faster, ⚠️ 0 slower, ✅ 36 unchanged

✅ No significant performance changes detected.

Full results
Benchmark Baseline Current Change Status
api/apply_1,000 0.590ms 0.581ms -1.5%
api/apply_100,000 0.593ms 0.580ms -2.3%
api/apply_10,000,000 0.590ms 0.588ms -0.2%
cli/apply_1,000 1690.101ms 1662.733ms -1.6%
cli/apply_100,000 1677.679ms 1675.031ms -0.2%
cli/apply_10,000,000 1731.726ms 1733.926ms +0.1%
decoding/base64_1,000 0.036ms 0.036ms +1.3%
decoding/base64_100,000 0.851ms 0.857ms +0.7%
decoding/base64_10,000,000 96.140ms 96.163ms +0.0%
decoding/binref_1,000 0.200ms 0.199ms -0.5%
decoding/binref_100,000 0.242ms 0.243ms +0.3%
decoding/binref_10,000,000 10.745ms 10.681ms -0.6%
decoding/json_1,000 0.107ms 0.107ms +0.2%
decoding/json_100,000 8.772ms 9.134ms +4.1%
decoding/json_10,000,000 1068.862ms 1088.958ms +1.9%
encoding/base64_1,000 0.042ms 0.040ms -3.2%
encoding/base64_100,000 0.146ms 0.145ms -0.7%
encoding/base64_10,000,000 25.348ms 26.043ms +2.7%
encoding/binref_1,000 0.304ms 0.301ms -1.2%
encoding/binref_100,000 0.480ms 0.479ms -0.2%
encoding/binref_10,000,000 18.989ms 18.807ms -1.0%
encoding/json_1,000 0.153ms 0.150ms -2.1%
encoding/json_100,000 12.955ms 12.913ms -0.3%
encoding/json_10,000,000 1393.614ms 1379.632ms -1.0%
http/apply_1,000 3.153ms 3.121ms -1.0%
http/apply_100,000 8.908ms 9.097ms +2.1%
http/apply_10,000,000 773.099ms 770.493ms -0.3%
roundtrip/base64_1,000 0.089ms 0.088ms -1.9%
roundtrip/base64_100,000 1.010ms 1.012ms +0.2%
roundtrip/base64_10,000,000 122.415ms 122.530ms +0.1%
roundtrip/binref_1,000 0.524ms 0.519ms -1.0%
roundtrip/binref_100,000 0.725ms 0.721ms -0.5%
roundtrip/binref_10,000,000 29.781ms 30.490ms +2.4%
roundtrip/json_1,000 0.274ms 0.273ms -0.3%
roundtrip/json_100,000 19.692ms 19.686ms -0.0%
roundtrip/json_10,000,000 2466.023ms 2461.123ms -0.2%
  • Runner: Linux 6.17.0-1010-azure x86_64

@dionhaefner
Copy link
Copy Markdown
Contributor

@jpbrodrick89 I believe you're missing a few commits here? :)

@jpbrodrick89 jpbrodrick89 changed the base branch from main to dion/vjp-cache April 29, 2026 11:37
@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

should be there now sorry! 😅 Targetting your branch so you can see diff more clearly.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 29, 2026

Codecov Report

❌ Patch coverage is 64.28571% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.12%. Comparing base (aa48ac2) to head (7cf3423).

Files with missing lines Patch % Lines
...ract_core/sdk/templates/jax-cache/tesseract_api.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@                Coverage Diff                 @@
##           dion/vjp-cache     #577      +/-   ##
==================================================
- Coverage           77.31%   76.12%   -1.20%     
==================================================
  Files                  34       34              
  Lines                4990     4519     -471     
  Branches              883      740     -143     
==================================================
- Hits                 3858     3440     -418     
+ Misses                814      777      -37     
+ Partials              318      302      -16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

with self._lock:
if key in self._cache:
self._cache.move_to_end(key)
if len(self._cache) > 1 and next(reversed(self._cache)) != key:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a path to skip the mov_to_end, maybe not necessary. Btw, is there a reason you chose not to use functools.lrucache?

# - Optimizers that only ever use gradient information and never inspect the
# loss value (e.g. Adam/SGD using jax.grad rather than jax.value_and_grad).
# No apply() call is issued, the cache is never populated, and every vjp
# pays a cold-miss linearize cost. Use --recipe jax instead.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there really a non-negligible cost?

# interleave apply() calls before their corresponding VJP calls.
_vjp_cache = LRUCache(maxsize=1)
_cache = LRUCache(maxsize=_CACHE_SIZE)
_f_lin_ref: dict[tuple, Partial] = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Holds module-level dict of linearised function templates to avoid re-jitting. I'll add a comment once we center on design.

Comment on lines +257 to +260
@eqx.filter_jit
def _jvp_with_lin(
f_lin,
dynamic_in: dict,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this re-jit when f_lin changes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only when _shape_key changes, see above

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need an e2e test to be confident that this isn't answer-changing in a non-trivial case (like silently returning VJPs for the wrong residuals if something goes wrong during cache invalidation).


dynamic_primals, f_lin_fresh = jax.linearize(_apply_dynamic_only, dynamic_in)

ref = _f_lin_ref.setdefault(_shape_key(inputs_dict), f_lin_fresh)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? Looks like you never read from _f_lin_ref?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setdefault does more than it says on the tin

The setdefault() method returns the value of the item with the specified key. If the key does not exist, insert the key.

Thus if we receive inputs with the same shape and dtype (i.e. _shape_key) then we use the same function reference to avoid re-jitting. Closure values are update on the next line with jax's Partial. If _shape_key changes then a new entry is created which prompts re-jitting.

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

Testing this again this is not nearly as performant as it first seemed and might actually be slower I need to iterate a bit more, will report back.

@jpbrodrick89
Copy link
Copy Markdown
Contributor Author

So, after some new benchmarking I've found that when applying val_and_grad optimisation loops this can be quite a big slow down. Interestingly, even your vjp cache seems to be a slow down for shallow nonlinearity (e.g. 120% slowdown for (2,1024) NN). My theory is that this is because the extra memory flow of residuals at the jit boundary is not worth it in these cases. For larger problems (e.g. NNs with greater than ~20M parameters) both begin to win but the win from your vjp approach is about twice as big (up to 23%) as the linearize approach. Having linear transpose on the vjp call adds real weight.

Two options:

  1. Ship a neatened up version of your template
  2. Use a two cache approach which uses jax.linearize only on jax.jvp calls (for use in iterative solvers) and does the vjp cache otherwise.

(or do these separately jvp-cache, vjp-cache)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants