feat: add jax-cache recipe using jax.linearize#577
Conversation
Benchmark ResultsBenchmarks use a no-op Tesseract to measure pure framework overhead. 🚀 0 faster, ✅ No significant performance changes detected. Full results
|
|
@jpbrodrick89 I believe you're missing a few commits here? :) |
|
should be there now sorry! 😅 Targetting your branch so you can see diff more clearly. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## dion/vjp-cache #577 +/- ##
==================================================
- Coverage 77.31% 76.12% -1.20%
==================================================
Files 34 34
Lines 4990 4519 -471
Branches 883 740 -143
==================================================
- Hits 3858 3440 -418
+ Misses 814 777 -37
+ Partials 318 302 -16 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| with self._lock: | ||
| if key in self._cache: | ||
| self._cache.move_to_end(key) | ||
| if len(self._cache) > 1 and next(reversed(self._cache)) != key: |
There was a problem hiding this comment.
Just a path to skip the mov_to_end, maybe not necessary. Btw, is there a reason you chose not to use functools.lrucache?
| # - Optimizers that only ever use gradient information and never inspect the | ||
| # loss value (e.g. Adam/SGD using jax.grad rather than jax.value_and_grad). | ||
| # No apply() call is issued, the cache is never populated, and every vjp | ||
| # pays a cold-miss linearize cost. Use --recipe jax instead. |
There was a problem hiding this comment.
Is there really a non-negligible cost?
| # interleave apply() calls before their corresponding VJP calls. | ||
| _vjp_cache = LRUCache(maxsize=1) | ||
| _cache = LRUCache(maxsize=_CACHE_SIZE) | ||
| _f_lin_ref: dict[tuple, Partial] = {} |
There was a problem hiding this comment.
Holds module-level dict of linearised function templates to avoid re-jitting. I'll add a comment once we center on design.
| @eqx.filter_jit | ||
| def _jvp_with_lin( | ||
| f_lin, | ||
| dynamic_in: dict, |
There was a problem hiding this comment.
Does this re-jit when f_lin changes?
There was a problem hiding this comment.
only when _shape_key changes, see above
There was a problem hiding this comment.
I think we'll need an e2e test to be confident that this isn't answer-changing in a non-trivial case (like silently returning VJPs for the wrong residuals if something goes wrong during cache invalidation).
|
|
||
| dynamic_primals, f_lin_fresh = jax.linearize(_apply_dynamic_only, dynamic_in) | ||
|
|
||
| ref = _f_lin_ref.setdefault(_shape_key(inputs_dict), f_lin_fresh) |
There was a problem hiding this comment.
What is this for? Looks like you never read from _f_lin_ref?
There was a problem hiding this comment.
Setdefault does more than it says on the tin
The setdefault() method returns the value of the item with the specified key. If the key does not exist, insert the key.
Thus if we receive inputs with the same shape and dtype (i.e. _shape_key) then we use the same function reference to avoid re-jitting. Closure values are update on the next line with jax's Partial. If _shape_key changes then a new entry is created which prompts re-jitting.
|
Testing this again this is not nearly as performant as it first seemed and might actually be slower I need to iterate a bit more, will report back. |
|
So, after some new benchmarking I've found that when applying val_and_grad optimisation loops this can be quite a big slow down. Interestingly, even your vjp cache seems to be a slow down for shallow nonlinearity (e.g. 120% slowdown for (2,1024) NN). My theory is that this is because the extra memory flow of residuals at the jit boundary is not worth it in these cases. For larger problems (e.g. NNs with greater than ~20M parameters) both begin to win but the win from your vjp approach is about twice as big (up to 23%) as the linearize approach. Having linear transpose on the vjp call adds real weight. Two options:
(or do these separately jvp-cache, vjp-cache) |
This seems to work pretty well, I wrote it so that it if you set
_CACHE_SIZE=0we recover the current jax recipe so we could consider replacing it entirely. Runtime for apply + vjp on a warm cache is reduced by 40%.