Skip to content

Commit 973dd6b

Browse files
author
The gemma Authors
committed
Optionally allow models to keep the last cache item at prefill time.
PiperOrigin-RevId: 907003276
1 parent 73a9ac5 commit 973dd6b

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

gemma/gm/text/_prefill.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,12 @@ def prefill(
217217
# A cleaner implementation could be to have a per-batch cache index, to
218218
# remove padding. But I leave this to my future self (or to future Gemini).
219219

220-
new_used_cache_length = (
221-
prev_turns.used_cache_length + input.length_with_mm - 1
222-
)
220+
if hasattr(model, 'keep_last_prefill_kv') and model.keep_last_prefill_kv:
221+
new_used_cache_length = prev_turns.used_cache_length + input.length_with_mm
222+
else:
223+
new_used_cache_length = (
224+
prev_turns.used_cache_length + input.length_with_mm - 1
225+
)
223226
cache = cache.set_end_index(new_used_cache_length)
224227

225228
# TODO(epot): The first token was predicted, so could use this, but would

0 commit comments

Comments
 (0)