Skip to content

Commit 2bd2e67

Browse files
author
The gemma Authors
committed
Fix gemma4's cache indexing for non-empty cache, in the attention layer.
PiperOrigin-RevId: 907003276
1 parent 73a9ac5 commit 2bd2e67

1 file changed

Lines changed: 15 additions & 5 deletions

File tree

gemma/gm/nn/gemma4/_modules.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,22 +297,32 @@ def __call__(
297297
if kv_shared_cache is not None:
298298
cache_positions = kv_shared_cache.get('positions')
299299
elif cache is not None:
300-
end_index = cache['end_index']
300+
# `cache['end_index']` contains the last valid index in the cache -- we
301+
# want to start writing at the next position. However, it is initialized
302+
# to 0. So in that case, we start writing at 0, which is the first empty
303+
# slot.
304+
cache_write_position_start = jnp.where(
305+
cache['end_index'] == 0,
306+
cache['end_index'],
307+
cache['end_index'] + 1,
308+
)
301309
cache_size = cache['v'].shape[1]
302310
seq_len = x.shape[1]
303311
# [batch_size, seq_len]
304-
indices = (end_index[:, None] + jnp.arange(seq_len)[None, :]) % cache_size
312+
new_indices = (
313+
cache_write_position_start[:, None] + jnp.arange(seq_len)[None, :]
314+
) % cache_size
305315
batch_indices = jnp.arange(x.shape[0])[:, None]
306316

307317
# [batch_size, cache_size, num_heads, key_size]
308-
value_proj = cache['v'].at[batch_indices, indices].set(value_proj)
318+
value_proj = cache['v'].at[batch_indices, new_indices].set(value_proj)
309319

310320
# [batch_size, cache_size, num_heads, key_size]
311-
key_proj = cache['k'].at[batch_indices, indices].set(key_proj)
321+
key_proj = cache['k'].at[batch_indices, new_indices].set(key_proj)
312322

313323
# [batch_size, cache_size]
314324
cache_positions = (
315-
cache['positions'].at[batch_indices, indices].set(segment_pos)
325+
cache['positions'].at[batch_indices, new_indices].set(segment_pos)
316326
)
317327
else:
318328
cache_positions = None

0 commit comments

Comments
 (0)