-
Notifications
You must be signed in to change notification settings - Fork 32.7k
[Gemma 4] use_cache=False corrupts attention computation, producing garbage logits #45242
Description
Gemma 4 has a bug where use_cache=False corrupts the attention computation, producing garbage logits. Every QLoRA tutorial sets model.config.use_cache = False, but this breaks Gemma 4 specifically.
When fine-tuning Gemma 4 (E2B-it in this situation) using standard QLoRA/LoRA workflows, the model produces garbage logits during the forward pass, resulting in extremely high training loss (~10-15, near random chance for a 262K vocab). Generation via model.generate() works perfectly because it internally uses use_cache=True, while the training forward pass uses use_cache=False.
Root cause
model.config.use_cache = False
A standard step in every LoRA/QLoRA fine-tuning tutorial triggers a different attention code path in Gemma 4 that corrupts the output logits.
Setting use_cache=True immediately fixes the issue.
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
MODEL_ID = "google/gemma-4-E2B-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float32,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
torch_dtype=torch.float32,
device_map={"": 0},
)
# Prepare a simple prompt
messages = [{"role": "user", "content": "Say hello in a formal way."}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
# use_cache=True (DEFAULT): Correct
with torch.no_grad():
out_cached = model(**inputs, use_cache=True)
probs_cached = torch.softmax(out_cached.logits[0, -1], dim=-1)
top1_cached = torch.argmax(probs_cached)
print(f"use_cache=True → Top-1: '{tokenizer.decode([top1_cached])}', prob={probs_cached[top1_cached].item():.4f}")
# Output: use_cache=True → Top-1: 'Greetings', prob=0.9890
# use_cache=False: BROKEN
with torch.no_grad():
out_uncached = model(**inputs, use_cache=False)
probs_uncached = torch.softmax(out_uncached.logits[0, -1], dim=-1)
top1_uncached = torch.argmax(probs_uncached)
print(f"use_cache=False → Top-1: '{tokenizer.decode([top1_uncached])}', prob={probs_uncached[top1_uncached].item():.4f}")
# Output: use_cache=False → Top-1: '//', prob=0.7385
# model.generate() always works (uses use_cache=True internally)
gen = model.generate(**inputs, max_new_tokens=1, do_sample=False)
print(f"generate() → '{tokenizer.decode(gen[0, inputs['input_ids'].shape[1]:])}'")
# Output: generate() → 'Greetings'
Diagnostics
| Condition | Top 1 Prediction | Probability | Loss |
|---|---|---|---|
use_cache=True |
'Greetings' ✅ |
98.9% | ~2-6 |
use_cache=False |
'//' ❌ |
73.8% | ~10-15 |
model.generate() |
'Greetings' ✅ |
98.9% | N/A |
When use_cache=False, the model's top-5 predictions include garbage tokens from other languages and scripts:
Examples from my situation:
⬇- 27%пол- 10%**- 5.6%差点- 2.3%
Workaround
# Do NOT set model.config.use_cache = False
# Do NOT use gradient_checkpointing=True (it forces use_cache=False)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, ...)
# model.config.use_cache = False # <-- Remove this line
training_args = TrainingArguments(
gradient_checkpointing=False, # <-- Must be False for Gemma 4
)Issue Environment
- Model:
google/gemma-4-E2B-it - GPU: NVIDIA T4 (Kaggle)
- transformers: 5.5.0
- PyTorch: 2.10.0+cu128
- PEFT: latest
- bitsandbytes: latest
Expected Behavior
model(**inputs, use_cache=False) should produce identical logits to model(**inputs, use_cache=True). The use_cache flag should only affect whether KV cache tensors are returned, not the numerical result of the forward pass.