Skip to content

[Fix] generate() prefills the full prompt instead of only the last token#962

Open
Sunt-ing wants to merge 1 commit into
fla-org:mainfrom
Sunt-ing:fla-01
Open

[Fix] generate() prefills the full prompt instead of only the last token#962
Sunt-ing wants to merge 1 commit into
fla-org:mainfrom
Sunt-ing:fla-01

Conversation

@Sunt-ing

Copy link
Copy Markdown
Contributor

Summary

On transformers 5.x, model.generate() on any FLA causal LM silently drops the prompt. On the first decoding step it forwards only the last prompt token with an empty recurrent state, so generation is conditioned on a single token rather than on the prompt. Every model that uses FLAGenerationMixin (GatedDeltaNet, GLA, RetNet, GSA, HGRN2, Comba, RWKV7, ...) is affected.

Root cause is in FLAGenerationMixin.prepare_inputs_for_generation. When cache_position is not supplied (which is the case under transformers 5.x for these models), it falls back to the old "keep only the last token if the cache is not empty" behavior, but tests emptiness with len(past_key_values) > 0. For FLACache, __len__ returns the number of layer slots, and those slots are allocated before any token is processed, so it is already non-zero at the prefill step. The prefill step is therefore mistaken for a decode step and the prompt is truncated to its last token.

The fix tests emptiness with past_key_values.get_seq_length() > 0 (seen tokens) instead of len(...) > 0 (layer count), in both the transformers 4.56+ branch and the legacy branch. At prefill the seen-token count is 0 so the full prompt is kept; during decode it is positive so only the last token is forwarded, which is exactly what the surrounding comment intends ("only last token if the past is not empty").

Minimal repro:

import torch
from fla.models import GatedDeltaNetConfig
from transformers import AutoModelForCausalLM

torch.manual_seed(0)
config = GatedDeltaNetConfig(hidden_size=256, num_hidden_layers=2, num_heads=4)
model = AutoModelForCausalLM.from_config(config).eval().to(torch.float32).cuda()

torch.manual_seed(1)
ids = torch.randint(1, config.vocab_size, (1, 16), device="cuda")
with torch.no_grad():
    # correct greedy first token = argmax of a full-prompt forward
    ref = model(input_ids=ids, use_cache=True).logits[:, -1].argmax(-1).item()
    gen = model.generate(ids, attention_mask=torch.ones_like(ids),
                         max_new_tokens=1, do_sample=False)[0, -1].item()
print("full-prompt argmax:", ref, "| generate first token:", gen)
# before this fix: 826 != 3675   (generate ignores the prompt)
# after  this fix: 826 == 826    (generate prefills the prompt)

Test plan

A regression test was added (run_test_generate_matches_forward in tests/models/test_modeling_base.py, wired in as test_generate_prefill for GatedDeltaNet): a greedy generate() must match a manual greedy decode that prefills the full prompt. It fails before the fix and passes after.

  • pytest tests/models/test_modeling_gated_deltanet.py::test_generate_prefill -q passes with the fix, fails without it.
  • pytest tests/models/test_modeling_gated_deltanet.py::test_generation -q still passes (no regression to the existing cache path).
  • end-to-end check across GatedDeltaNet, GLA, RetNet, GSA, HGRN2, Comba and RWKV7: greedy generate() now matches a manual greedy decode token for token; a standard transformers model under the same transformers version already behaved correctly, confirming the issue was FLA specific.
  • ruff check / ruff format --check on the touched files.

Breaking changes

None.

On transformers 5.x, FLAGenerationMixin.prepare_inputs_for_generation tested
cache emptiness with `len(past_key_values) > 0`. For FLACache, `__len__` returns
the number of preallocated layer slots, which is already nonzero before any token
is processed, so the prefill step was mistaken for a decode step and the prompt
was truncated to its last token. Every model using FLAGenerationMixin then
generated conditioned on a single token instead of the prompt.

Test emptiness with `past_key_values.get_seq_length() > 0` (seen tokens) in both
the transformers 4.56+ and legacy branches: at prefill the count is 0 so the full
prompt is kept; during decode it is positive so only the last token is forwarded.

Signed-off-by: Ting Sun <suntcrick@gmail.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the input preparation logic for generation in fla/models/utils.py by replacing legacy length checks on past_key_values with calls to get_seq_length(). It also introduces a regression test to ensure that the generation process correctly conditions on the entire prompt during prefill, and integrates this test for the Gated DeltaNet model. The review feedback correctly points out that calling get_seq_length() directly on past_key_values without checking for its existence can lead to AttributeError exceptions when legacy cache formats (like tuples or lists) are used. It is recommended to apply the suggested fallback checks to maintain backward compatibility.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread fla/models/utils.py
if input_ids is not None and input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
elif hasattr(past_key_values, '__len__') and len(past_key_values) > 0:
elif past_key_values.get_seq_length() > 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling past_key_values.get_seq_length() directly will raise an AttributeError if past_key_values is a legacy cache format (such as a standard tuple or list of tuples/dicts) which does not implement this method. To ensure backward compatibility and prevent runtime crashes, we should check if the method exists before calling it, falling back to checking the length of the cache.

Suggested change
elif past_key_values.get_seq_length() > 0:
elif (past_key_values.get_seq_length() if hasattr(past_key_values, 'get_seq_length') else len(past_key_values)) > 0:

Comment thread fla/models/utils.py
model_inputs = {}
# only last token for `inputs_ids` if the `past_key_values` is not empty.
if past_key_values is not None and hasattr(past_key_values, '__len__') and len(past_key_values) > 0:
if past_key_values is not None and past_key_values.get_seq_length() > 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling past_key_values.get_seq_length() directly will raise an AttributeError if past_key_values is a legacy cache format (such as a standard tuple or list of tuples/dicts) which does not implement this method. To ensure backward compatibility and prevent runtime crashes, we should check if the method exists before calling it, falling back to checking the length of the cache.

Suggested change
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_key_values is not None and (past_key_values.get_seq_length() if hasattr(past_key_values, 'get_seq_length') else len(past_key_values)) > 0:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the 4.56+/5.x path this branch only runs with a non-None past_key_values, which generate() always supplies as a Cache (here FLACache) implementing get_seq_length(). A legacy tuple can't run on transformers 5.x anyway: non-Cache inputs go through Cache.from_legacy_cache(...), removed in 5.x, so forward would raise first, and the len() fallback just reintroduces the layer-count check this PR fixes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant