This repository was archived by the owner on Apr 4, 2026. It is now read-only.
Commit 28e2dfe
Stacked cache for MLPerf (#154)
* Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested.
* Fixed the test_model_impl for llama, but test_llama_e2e is still failing.
* Adds lazy_cache_update and restructure the cache flags.
* Disable all the prints. Fix create engine.
* Fix typos and minor errors.
* Fixes create engine.
* Adds new_cache_stacked and fixes cache update.
* Fix cache update when new_cach_stacked is False.
* Fix the cache manager and make unit tests pass except for 1.
* Updates the exportable model to return cache.
* Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention.
* Try to use shard_map for cache update.
* Fix update single cache line in cache.finalize()
* Adds int8 support.
* Int8 left aligned lazy cache update working, performance still not good enough.
* Fix the stacked cache introduced in the previous couple of commits.
* Put original ragged attention back.
* Add the original ragged attention kernel.
* Fixes the bf16/int8 cache stack.
* Fix int8 stacked cache insertion in engine and finalization.
* Fixes int8 with lazy cache update.
* Updates the int8 test.
* Fix the int8 ragged attention output sharding.
* Fix group query attention broadcasting issue.
* Fix shard map input issue. Variables not listed as inputs are freezed into jit function.
* Fix the flash attention mask shape; Fix the update single cache line quant version
* Adds the kv cache test.
* Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test.
* Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust;
* Adds lazy cache update with generate cache stacked new cache unstacked for performance validation.
* Fix the shard map sharding for stacked generate cache and unstacked new cache.
* Using Jax API to slicing instead of Pytorch index slicing.
* Adds stacked cache support in ragged attention reference kernel.
* Adds stacked cache support for the modified ragged kernel.
* Llama2 70b int8 optimization done. Output not correct yet.
* Remove testing temp output files.
* Fix the llama 70b output accuracy resulting from gqa.
* Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc.
* Fix the pallas kernel OOB issue
* Fix tests; Fix lint issues;
* Fix the interactive script.
* Fix lint errors.
* Fix errors.
* Fix the comments.
* Fix based on comments; Fix all the unit tests.
* Fix the remaining pylint errors.
* Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run.
* Fix all the lint errors.
* Remove the deps/JetStream changes.
* Fix merge errors, fix lint errors.1 parent 50a6d10 commit 28e2dfe
17 files changed
Lines changed: 1604 additions & 266 deletions
File tree
- benchmarks
- jetstream_pt
- third_party
- gemma
- llama
- mixtral
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | | - | |
| 35 | + | |
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
56 | 60 | | |
57 | 61 | | |
58 | 62 | | |
59 | 63 | | |
60 | 64 | | |
61 | 65 | | |
62 | 66 | | |
| 67 | + | |
63 | 68 | | |
64 | | - | |
| 69 | + | |
65 | 70 | | |
66 | 71 | | |
67 | 72 | | |
| |||
86 | 91 | | |
87 | 92 | | |
88 | 93 | | |
| 94 | + | |
89 | 95 | | |
90 | | - | |
91 | | - | |
| 96 | + | |
| 97 | + | |
92 | 98 | | |
93 | 99 | | |
94 | 100 | | |
| |||
103 | 109 | | |
104 | 110 | | |
105 | 111 | | |
| 112 | + | |
106 | 113 | | |
107 | 114 | | |
108 | | - | |
| 115 | + | |
109 | 116 | | |
| 117 | + | |
110 | 118 | | |
111 | 119 | | |
112 | 120 | | |
| |||
116 | 124 | | |
117 | 125 | | |
118 | 126 | | |
119 | | - | |
| 127 | + | |
120 | 128 | | |
121 | 129 | | |
122 | 130 | | |
| |||
0 commit comments