[Fix] generate() prefills the full prompt instead of only the last token#962
[Fix] generate() prefills the full prompt instead of only the last token#962Sunt-ing wants to merge 1 commit into
Conversation
On transformers 5.x, FLAGenerationMixin.prepare_inputs_for_generation tested cache emptiness with `len(past_key_values) > 0`. For FLACache, `__len__` returns the number of preallocated layer slots, which is already nonzero before any token is processed, so the prefill step was mistaken for a decode step and the prompt was truncated to its last token. Every model using FLAGenerationMixin then generated conditioned on a single token instead of the prompt. Test emptiness with `past_key_values.get_seq_length() > 0` (seen tokens) in both the transformers 4.56+ and legacy branches: at prefill the count is 0 so the full prompt is kept; during decode it is positive so only the last token is forwarded. Signed-off-by: Ting Sun <suntcrick@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request updates the input preparation logic for generation in fla/models/utils.py by replacing legacy length checks on past_key_values with calls to get_seq_length(). It also introduces a regression test to ensure that the generation process correctly conditions on the entire prompt during prefill, and integrates this test for the Gated DeltaNet model. The review feedback correctly points out that calling get_seq_length() directly on past_key_values without checking for its existence can lead to AttributeError exceptions when legacy cache formats (like tuples or lists) are used. It is recommended to apply the suggested fallback checks to maintain backward compatibility.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if input_ids is not None and input_ids.shape[1] != cache_position.shape[0]: | ||
| input_ids = input_ids[:, cache_position] | ||
| elif hasattr(past_key_values, '__len__') and len(past_key_values) > 0: | ||
| elif past_key_values.get_seq_length() > 0: |
There was a problem hiding this comment.
Calling past_key_values.get_seq_length() directly will raise an AttributeError if past_key_values is a legacy cache format (such as a standard tuple or list of tuples/dicts) which does not implement this method. To ensure backward compatibility and prevent runtime crashes, we should check if the method exists before calling it, falling back to checking the length of the cache.
| elif past_key_values.get_seq_length() > 0: | |
| elif (past_key_values.get_seq_length() if hasattr(past_key_values, 'get_seq_length') else len(past_key_values)) > 0: |
| model_inputs = {} | ||
| # only last token for `inputs_ids` if the `past_key_values` is not empty. | ||
| if past_key_values is not None and hasattr(past_key_values, '__len__') and len(past_key_values) > 0: | ||
| if past_key_values is not None and past_key_values.get_seq_length() > 0: |
There was a problem hiding this comment.
Calling past_key_values.get_seq_length() directly will raise an AttributeError if past_key_values is a legacy cache format (such as a standard tuple or list of tuples/dicts) which does not implement this method. To ensure backward compatibility and prevent runtime crashes, we should check if the method exists before calling it, falling back to checking the length of the cache.
| if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
| if past_key_values is not None and (past_key_values.get_seq_length() if hasattr(past_key_values, 'get_seq_length') else len(past_key_values)) > 0: |
There was a problem hiding this comment.
On the 4.56+/5.x path this branch only runs with a non-None past_key_values, which generate() always supplies as a Cache (here FLACache) implementing get_seq_length(). A legacy tuple can't run on transformers 5.x anyway: non-Cache inputs go through Cache.from_legacy_cache(...), removed in 5.x, so forward would raise first, and the len() fallback just reintroduces the layer-count check this PR fixes.
Summary
On transformers 5.x,
model.generate()on any FLA causal LM silently drops the prompt. On the first decoding step it forwards only the last prompt token with an empty recurrent state, so generation is conditioned on a single token rather than on the prompt. Every model that usesFLAGenerationMixin(GatedDeltaNet, GLA, RetNet, GSA, HGRN2, Comba, RWKV7, ...) is affected.Root cause is in
FLAGenerationMixin.prepare_inputs_for_generation. Whencache_positionis not supplied (which is the case under transformers 5.x for these models), it falls back to the old "keep only the last token if the cache is not empty" behavior, but tests emptiness withlen(past_key_values) > 0. ForFLACache,__len__returns the number of layer slots, and those slots are allocated before any token is processed, so it is already non-zero at the prefill step. The prefill step is therefore mistaken for a decode step and the prompt is truncated to its last token.The fix tests emptiness with
past_key_values.get_seq_length() > 0(seen tokens) instead oflen(...) > 0(layer count), in both the transformers 4.56+ branch and the legacy branch. At prefill the seen-token count is 0 so the full prompt is kept; during decode it is positive so only the last token is forwarded, which is exactly what the surrounding comment intends ("only last token if the past is not empty").Minimal repro:
Test plan
A regression test was added (
run_test_generate_matches_forwardintests/models/test_modeling_base.py, wired in astest_generate_prefillfor GatedDeltaNet): a greedygenerate()must match a manual greedy decode that prefills the full prompt. It fails before the fix and passes after.pytest tests/models/test_modeling_gated_deltanet.py::test_generate_prefill -qpasses with the fix, fails without it.pytest tests/models/test_modeling_gated_deltanet.py::test_generation -qstill passes (no regression to the existing cache path).generate()now matches a manual greedy decode token for token; a standard transformers model under the same transformers version already behaved correctly, confirming the issue was FLA specific.ruff check/ruff format --checkon the touched files.Breaking changes
None.