We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 73a9ac5 commit 973dd6bCopy full SHA for 973dd6b
1 file changed
gemma/gm/text/_prefill.py
@@ -217,9 +217,12 @@ def prefill(
217
# A cleaner implementation could be to have a per-batch cache index, to
218
# remove padding. But I leave this to my future self (or to future Gemini).
219
220
- new_used_cache_length = (
221
- prev_turns.used_cache_length + input.length_with_mm - 1
222
- )
+ if hasattr(model, 'keep_last_prefill_kv') and model.keep_last_prefill_kv:
+ new_used_cache_length = prev_turns.used_cache_length + input.length_with_mm
+ else:
223
+ new_used_cache_length = (
224
+ prev_turns.used_cache_length + input.length_with_mm - 1
225
+ )
226
cache = cache.set_end_index(new_used_cache_length)
227
228
# TODO(epot): The first token was predicted, so could use this, but would
0 commit comments