Skip to content

[Gemma 4] use_cache=False corrupts attention computation, producing garbage logits #45242

@siwoolol

Description

@siwoolol

Gemma 4 has a bug where use_cache=False corrupts the attention computation, producing garbage logits. Every QLoRA tutorial sets model.config.use_cache = False, but this breaks Gemma 4 specifically.

When fine-tuning Gemma 4 (E2B-it in this situation) using standard QLoRA/LoRA workflows, the model produces garbage logits during the forward pass, resulting in extremely high training loss (~10-15, near random chance for a 262K vocab). Generation via model.generate() works perfectly because it internally uses use_cache=True, while the training forward pass uses use_cache=False.

Root cause

model.config.use_cache = False
A standard step in every LoRA/QLoRA fine-tuning tutorial triggers a different attention code path in Gemma 4 that corrupts the output logits.
Setting use_cache=True immediately fixes the issue.

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

MODEL_ID = "google/gemma-4-E2B-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float32,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=torch.float32,
    device_map={"": 0},
)

# Prepare a simple prompt
messages = [{"role": "user", "content": "Say hello in a formal way."}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

# use_cache=True (DEFAULT): Correct
with torch.no_grad():
    out_cached = model(**inputs, use_cache=True)
probs_cached = torch.softmax(out_cached.logits[0, -1], dim=-1)
top1_cached = torch.argmax(probs_cached)
print(f"use_cache=True  → Top-1: '{tokenizer.decode([top1_cached])}', prob={probs_cached[top1_cached].item():.4f}")
# Output: use_cache=True  → Top-1: 'Greetings', prob=0.9890

# use_cache=False: BROKEN
with torch.no_grad():
    out_uncached = model(**inputs, use_cache=False)
probs_uncached = torch.softmax(out_uncached.logits[0, -1], dim=-1)
top1_uncached = torch.argmax(probs_uncached)
print(f"use_cache=False → Top-1: '{tokenizer.decode([top1_uncached])}', prob={probs_uncached[top1_uncached].item():.4f}")
# Output: use_cache=False → Top-1: '//', prob=0.7385

# model.generate() always works (uses use_cache=True internally)
gen = model.generate(**inputs, max_new_tokens=1, do_sample=False)
print(f"generate()      → '{tokenizer.decode(gen[0, inputs['input_ids'].shape[1]:])}'")
# Output: generate()      → 'Greetings'

Diagnostics

Condition Top 1 Prediction Probability Loss
use_cache=True 'Greetings' 98.9% ~2-6
use_cache=False '//' 73.8% ~10-15
model.generate() 'Greetings' 98.9% N/A

When use_cache=False, the model's top-5 predictions include garbage tokens from other languages and scripts:
Examples from my situation:

  • - 27%
  • пол - 10%
  • ** - 5.6%
  • 差点 - 2.3%

Workaround

# Do NOT set model.config.use_cache = False
# Do NOT use gradient_checkpointing=True (it forces use_cache=False)

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, ...)
# model.config.use_cache = False  # <-- Remove  this line

training_args = TrainingArguments(
    gradient_checkpointing=False,  # <-- Must be False for Gemma 4
)

Issue Environment

  • Model: google/gemma-4-E2B-it
  • GPU: NVIDIA T4 (Kaggle)
  • transformers: 5.5.0
  • PyTorch: 2.10.0+cu128
  • PEFT: latest
  • bitsandbytes: latest

Expected Behavior

model(**inputs, use_cache=False) should produce identical logits to model(**inputs, use_cache=True). The use_cache flag should only affect whether KV cache tensors are returned, not the numerical result of the forward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions