@@ -297,22 +297,32 @@ def __call__(
297297 if kv_shared_cache is not None :
298298 cache_positions = kv_shared_cache .get ('positions' )
299299 elif cache is not None :
300- end_index = cache ['end_index' ]
300+ # `cache['end_index']` contains the last valid index in the cache -- we
301+ # want to start writing at the next position. However, it is initialized
302+ # to 0. So in that case, we start writing at 0, which is the first empty
303+ # slot.
304+ cache_write_position_start = jnp .where (
305+ cache ['end_index' ] == 0 ,
306+ cache ['end_index' ],
307+ cache ['end_index' ] + 1 ,
308+ )
301309 cache_size = cache ['v' ].shape [1 ]
302310 seq_len = x .shape [1 ]
303311 # [batch_size, seq_len]
304- indices = (end_index [:, None ] + jnp .arange (seq_len )[None , :]) % cache_size
312+ new_indices = (
313+ cache_write_position_start [:, None ] + jnp .arange (seq_len )[None , :]
314+ ) % cache_size
305315 batch_indices = jnp .arange (x .shape [0 ])[:, None ]
306316
307317 # [batch_size, cache_size, num_heads, key_size]
308- value_proj = cache ['v' ].at [batch_indices , indices ].set (value_proj )
318+ value_proj = cache ['v' ].at [batch_indices , new_indices ].set (value_proj )
309319
310320 # [batch_size, cache_size, num_heads, key_size]
311- key_proj = cache ['k' ].at [batch_indices , indices ].set (key_proj )
321+ key_proj = cache ['k' ].at [batch_indices , new_indices ].set (key_proj )
312322
313323 # [batch_size, cache_size]
314324 cache_positions = (
315- cache ['positions' ].at [batch_indices , indices ].set (segment_pos )
325+ cache ['positions' ].at [batch_indices , new_indices ].set (segment_pos )
316326 )
317327 else :
318328 cache_positions = None
0 commit comments