Skip to content

Commit 1dcb225

Browse files
zyzhou5akoumpa
authored andcommitted
fix: baichuan dynamic cache (#1865)
fix: Handle DynamicCache in Baichuan model for generation compatibility Baichuan's forward() and prepare_inputs_for_generation() assumed past_key_values is always a legacy tuple-of-tuples, but transformers 5.x passes DynamicCache objects during model.generate(). This caused TypeError/AttributeError in the baichuan_2_7b_squad_vllm_deploy and baichuan_2_7b_squad_peft_vllm_deploy CI tests. - Convert DynamicCache to legacy tuples in BaichuanModel.forward() - Treat empty DynamicCache as None in prepare_inputs_for_generation() - Fix position_ids truncation for transformers 5.x which passes position_ids via kwargs instead of letting the model compute them Signed-off-by: Zeyu Zhou <zezhou@nvidia.com> Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
1 parent a53b52c commit 1dcb225

2 files changed

Lines changed: 40 additions & 2 deletions

File tree

nemo_automodel/components/models/baichuan/model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torch.nn import functional as F
4343
from transformers import GenerationMixin, PreTrainedModel
4444
from transformers.activations import ACT2FN
45+
from transformers.cache_utils import DynamicCache
4546
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4647
from transformers.utils import logging
4748

@@ -364,6 +365,12 @@ def forward(
364365

365366
seq_length_with_past = seq_length
366367
past_key_values_length = 0
368+
if past_key_values is not None:
369+
if isinstance(past_key_values, DynamicCache):
370+
if past_key_values.get_seq_length() > 0:
371+
past_key_values = tuple((layer.keys, layer.values) for layer in past_key_values.layers)
372+
else:
373+
past_key_values = None
367374
if past_key_values is not None:
368375
past_key_values_length = past_key_values[0][0].shape[2]
369376
seq_length_with_past = seq_length_with_past + past_key_values_length
@@ -558,15 +565,18 @@ def prepare_inputs_for_generation(
558565
inputs_embeds=None,
559566
**kwargs,
560567
):
568+
# Treat empty DynamicCache as no cache so inputs stay consistent with forward()
569+
if isinstance(past_key_values, DynamicCache) and past_key_values.get_seq_length() == 0:
570+
past_key_values = None
561571
if past_key_values:
562572
input_ids = input_ids[:, -1:]
563573

564574
position_ids = kwargs.get("position_ids", None)
565575
if attention_mask is not None and position_ids is None:
566576
position_ids = attention_mask.long().cumsum(-1) - 1
567577
position_ids.masked_fill_(attention_mask == 0, 1)
568-
if past_key_values:
569-
position_ids = position_ids[:, -1].unsqueeze(-1)
578+
if past_key_values and position_ids is not None:
579+
position_ids = position_ids[:, -1].unsqueeze(-1)
570580

571581
if inputs_embeds is not None and past_key_values is None:
572582
model_inputs = {"inputs_embeds": inputs_embeds}

tests/unit_tests/models/baichuan/test_baichuan_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytest
1717
import torch
18+
from transformers.cache_utils import DynamicCache
1819

1920
from nemo_automodel.components.models.baichuan.configuration import BaichuanConfig
2021
from nemo_automodel.components.models.baichuan.model import (
@@ -483,6 +484,33 @@ def test_with_inputs_embeds_no_past(self):
483484
assert "input_ids" not in inputs
484485

485486

487+
class TestDynamicCacheCompat:
488+
"""Regression test for DynamicCache incompatibility (baichuan_2_7b_squad_vllm_deploy)."""
489+
490+
def test_forward_with_dynamic_cache(self):
491+
cfg = _tiny_config(use_cache=True)
492+
model = BaichuanModel(cfg)
493+
model.eval()
494+
bsz, seq_len = 1, 4
495+
input_ids = torch.randint(0, cfg.vocab_size, (bsz, seq_len))
496+
497+
# First forward to populate cache
498+
with torch.no_grad():
499+
out = model(input_ids=input_ids, use_cache=True)
500+
legacy_cache = out.past_key_values
501+
502+
# Convert legacy cache to DynamicCache (simulates what GenerationMixin does)
503+
dynamic_cache = DynamicCache()
504+
for layer_idx, (key, value) in enumerate(legacy_cache):
505+
dynamic_cache.update(key, value, layer_idx)
506+
507+
# Second forward with DynamicCache — this was the failing path
508+
next_token = torch.randint(0, cfg.vocab_size, (bsz, 1))
509+
with torch.no_grad():
510+
out2 = model(input_ids=next_token, past_key_values=dynamic_cache, use_cache=True)
511+
assert out2.last_hidden_state.shape == (bsz, 1, cfg.hidden_size)
512+
513+
486514
class TestReorderCache:
487515
def test_reorders_correctly(self):
488516
past = tuple((torch.randn(3, 2, 4, 8), torch.randn(3, 2, 4, 8)) for _ in range(2))

0 commit comments

Comments
 (0)