Skip to content

Fix Gemma4 use_cache=False producing bad logits#45253

Open
Charly21r wants to merge 1 commit intohuggingface:mainfrom
Charly21r:fix-gemma4-use-cache-false-kv-sharing
Open

Fix Gemma4 use_cache=False producing bad logits#45253
Charly21r wants to merge 1 commit intohuggingface:mainfrom
Charly21r:fix-gemma4-use-cache-false-kv-sharing

Conversation

@Charly21r
Copy link
Copy Markdown

What does this PR do?

Fixes a bug where use_cache=False produces garbage logits in Gemma 4 models due to broken KV sharing between layers.

Fixes #45242

Root cause of the issue

Gemma 4 introduces two architectural features not present in Gemma 3:

  1. KV sharing (num_kv_shared_layers): The last N decoder layers ("receiver" layers) don't compute their own keys/values, instead they reuse K/V states from earlier "donor" layers.

  2. K=V attention (attention_k_eq_v): On global attention layers, keys and values share the same projection weights, so v_proj is set to None.

Both mechanisms were implemented by piggybacking on past_key_values (the KV cache object):

  • Donor layers store their K/V into past_key_values.shared_layers[layer_idx]
  • Receiver layers retrieve shared K/V from past_key_values.shared_layers[donor_layer_idx]

Both code paths are guarded by if past_key_values is not None.

When use_cache=True, a DynamicCache is created, past_key_values is not None, and everything works correctly.

When use_cache=False, past_key_values remains None, so:

  1. Donor layers never store their K/V, thus the past_key_values is not None guard fails
  2. Receiver layers can't retrieve shared K/V, thus the same guard fails, so they fall into the else branch and try to compute their own K/V
  3. But receiver layers with attention_k_eq_v=True have v_proj = None, so the fallback value_states = self.v_proj(hidden_states) if self.v_proj is not None else key_states uses keys as values
  4. The attention now computes with corrupted value states, producing garbage output

The current fix

The fix creates a DynamicCache even when use_cache=False, but after causal mask creation to avoid affecting mask computation (which also depends on past_key_values). The cache is then available internally for KV sharing between layers. The return value is set to None when use_cache=False to preserve the expected API behavior.

# Before (broken): cache only created when use_cache=True
if use_cache and past_key_values is None:
    past_key_values = DynamicCache(config=self.config)
# ... mask creation uses past_key_values ...
# ... decoder layers use past_key_values for KV sharing ...

# After (fixed): cache always created, but after masks and only returned when use_cache=True
if use_cache and past_key_values is None:
    past_key_values = DynamicCache(config=self.config)
# ... mask creation uses past_key_values (unchanged) ...

# NEW: ensure cache exists for KV sharing even when use_cache=False
if past_key_values is None:
    past_key_values = DynamicCache(config=self.config)
# ... decoder layers use past_key_values for KV sharing ...

return BaseModelOutputWithPast(
    last_hidden_state=hidden_states,
    past_key_values=past_key_values if use_cache else None,  # don't leak internal cache
)

Alternative approach considered

A more memory-efficient approach would decouple KV sharing from past_key_values entirely by introducing
a lightweight shared_kv dict passed through kwargs. This would avoid allocating a full DynamicCache
when use_cache=False, preserving the memory savings that users expect during training. However, this
changes the gradient flow: KV-shared receiver layers would no longer exercise their own k_proj/k_norm/
v_proj/v_norm params (since they correctly skip computing their own K/V), causing
test_training_gradient_checkpointing to fail. Those params are architecturally unused on receiver layers,
so the missing gradients are semantically correct, but it requires updating the test expectations.

The current fix (always creating a DynamicCache after mask creation) was chosen for simplicity and
because it passes all existing tests without modification. Happy to switch to the shared_kv approach
if reviewers prefer it.

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @Cyrilvallez @zucchini-nlp

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 5, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

@Charly21r Charly21r changed the title Fix Gemma4 use_cache=False producing bad logits Fix Gemma4 use_cache=False producing bad logits Apr 5, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 5, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45253&sha=bc35f6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Gemma 4] use_cache=False corrupts attention computation, producing garbage logits

1 participant