Thanks for your great work and corresponding implementation of baselines! It really benefits the future work a lot!
I have a question about the implementation of KV cache selection in Quest.
It looks like in this repo, the Quest cache will select all the generated tokens (see
|
self.selected_key_cache[layer_idx] = torch.cat([self.selected_key_cache[layer_idx], key_states], dim=-2) |
),
The selection of tokens is limited to the prompt tokens because the KV page is built only on prefill.
It looks like the original Quest implementation (https://github.com/mit-han-lab/Quest/blob/main/evaluation/quest_attention.py ) will dynamically update the KV pages during the decoding.
Will you take the implementation of dynamic KV page updating into consideration? I implemented a very simple but not perfect version for this here: https://github.com/Monstertail/MagicPIG/blob/b635d06ae2c68c1d2949f2e95f358fb5746f6108/RULER/RULER/scripts/pred/quest_cache.py#L253 . If you are interested, we can think about how to make it better together.
Thanks for your great work and corresponding implementation of baselines! It really benefits the future work a lot!
I have a question about the implementation of KV cache selection in Quest.
It looks like in this repo, the Quest cache will select all the generated tokens (see
MagicPIG/evaluations/RULER/pred/quest_cache.py
Line 127 in ac9aa36
The selection of tokens is limited to the prompt tokens because the KV page is built only on prefill.
It looks like the original Quest implementation (https://github.com/mit-han-lab/Quest/blob/main/evaluation/quest_attention.py ) will dynamically update the KV pages during the decoding.
Will you take the implementation of dynamic KV page updating into consideration? I implemented a very simple but not perfect version for this here: https://github.com/Monstertail/MagicPIG/blob/b635d06ae2c68c1d2949f2e95f358fb5746f6108/RULER/RULER/scripts/pred/quest_cache.py#L253 . If you are interested, we can think about how to make it better together.