@@ -363,6 +363,12 @@ def forward(
363363
364364 kv_seq_len = key_states .shape [- 2 ]
365365 if past_key_value is not None :
366+ if self .layer_idx is None :
367+ raise ValueError (
368+ f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
369+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
370+ "with a layer index."
371+ )
366372 kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
367373
368374 # Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -385,11 +391,16 @@ def forward(
385391
386392 if past_key_value is not None :
387393 # Activate slicing cache only if the config has a value `sliding_windows` attribute
388- if getattr (self .config , "sliding_window" , None ) is not None and kv_seq_len > self .config .sliding_window :
394+ cache_has_contents = past_key_value .get_seq_length (self .layer_idx ) > 0
395+ if (
396+ getattr (self .config , "sliding_window" , None ) is not None
397+ and kv_seq_len > self .config .sliding_window
398+ and cache_has_contents
399+ ):
389400 slicing_tokens = 1 - self .config .sliding_window
390401
391- past_key = past_key_value [0 ]
392- past_value = past_key_value [1 ]
402+ past_key = past_key_value [self . layer_idx ][ 0 ]
403+ past_value = past_key_value [self . layer_idx ][ 1 ]
393404
394405 past_key = past_key [:, :, slicing_tokens :, :].contiguous ()
395406 past_value = past_value [:, :, slicing_tokens :, :].contiguous ()
@@ -400,8 +411,6 @@ def forward(
400411 f" { past_key .shape } "
401412 )
402413
403- past_key_value = (past_key , past_value )
404-
405414 if attention_mask is not None :
406415 attention_mask = attention_mask [:, slicing_tokens :]
407416 attention_mask = torch .cat ([attention_mask , torch .ones_like (attention_mask [:, - 1 :])], dim = - 1 )
0 commit comments