Skip to content

Commit 388fd31

Browse files
authored
Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window (#28037)
1 parent 0ede762 commit 388fd31

2 files changed

Lines changed: 28 additions & 10 deletions

File tree

src/transformers/models/mistral/modeling_mistral.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,12 @@ def forward(
363363

364364
kv_seq_len = key_states.shape[-2]
365365
if past_key_value is not None:
366+
if self.layer_idx is None:
367+
raise ValueError(
368+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
369+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
370+
"with a layer index."
371+
)
366372
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
367373

368374
# Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -385,11 +391,16 @@ def forward(
385391

386392
if past_key_value is not None:
387393
# Activate slicing cache only if the config has a value `sliding_windows` attribute
388-
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
394+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
395+
if (
396+
getattr(self.config, "sliding_window", None) is not None
397+
and kv_seq_len > self.config.sliding_window
398+
and cache_has_contents
399+
):
389400
slicing_tokens = 1 - self.config.sliding_window
390401

391-
past_key = past_key_value[0]
392-
past_value = past_key_value[1]
402+
past_key = past_key_value[self.layer_idx][0]
403+
past_value = past_key_value[self.layer_idx][1]
393404

394405
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
395406
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
@@ -400,8 +411,6 @@ def forward(
400411
f" {past_key.shape}"
401412
)
402413

403-
past_key_value = (past_key, past_value)
404-
405414
if attention_mask is not None:
406415
attention_mask = attention_mask[:, slicing_tokens:]
407416
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ def forward(
414414

415415
kv_seq_len = key_states.shape[-2]
416416
if past_key_value is not None:
417+
if self.layer_idx is None:
418+
raise ValueError(
419+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
420+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
421+
"with a layer index."
422+
)
417423
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
418424

419425
# Because the input can be padded, the absolute sequence length depends on the max position id.
@@ -436,11 +442,16 @@ def forward(
436442

437443
if past_key_value is not None:
438444
# Activate slicing cache only if the config has a value `sliding_windows` attribute
439-
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
445+
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
446+
if (
447+
getattr(self.config, "sliding_window", None) is not None
448+
and kv_seq_len > self.config.sliding_window
449+
and cache_has_contents
450+
):
440451
slicing_tokens = 1 - self.config.sliding_window
441452

442-
past_key = past_key_value[0]
443-
past_value = past_key_value[1]
453+
past_key = past_key_value[self.layer_idx][0]
454+
past_value = past_key_value[self.layer_idx][1]
444455

445456
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
446457
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
@@ -451,8 +462,6 @@ def forward(
451462
f" {past_key.shape}"
452463
)
453464

454-
past_key_value = (past_key, past_value)
455-
456465
if attention_mask is not None:
457466
attention_mask = attention_mask[:, slicing_tokens:]
458467
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

0 commit comments

Comments
 (0)