Summary
Training with PrefixTuningConfig on google/gemma-4-e2b-it crashes in the
first forward pass with a tensor expand-size mismatch. The error is consistent
with #2881 (Qwen3 GQA
prefix tuning), which was fixed for Qwen3 in
#2883 (peft 0.18.0). We're on
peft 0.19.1, so the Qwen3 fix is in — but Gemma 4 has a different attention
implementation that the same fix doesn't cover.
Reproducer
import torch
from peft import PrefixTuningConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "google/gemma-4-e2b-it"
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
cfg = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=20,
prefix_projection=False,
)
model = get_peft_model(model, cfg)
tok = AutoTokenizer.from_pretrained(model_id)
text = "Hello, world." * 100 # any non-trivial prompt
inputs = tok(text, return_tensors="pt", truncation=True, max_length=1632).to(model.device)
inputs["labels"] = inputs["input_ids"]
with torch.no_grad():
out = model(**inputs)
print("loss:", out.loss.item())
Error
RuntimeError: The expanded size of the tensor (1632) must match the existing
size (1652) at non-singleton dimension 3.
Target sizes: [1, 8, 1632, 1632]. Tensor sizes: [1, 1, 1632, 1652]
File "/opt/venv/lib/python3.12/site-packages/peft/peft_model.py", line 1964, in forward
... (wrapped calls into Gemma-4 attention)
The 20-token gap between target K-dim (1632 = sequence length) and tensor
K-dim (1652 = sequence length + num_virtual_tokens) confirms peft has
correctly prefix-extended the K/V, but Gemma-4's attention implementation
is computing a square [B, H, Q, Q] target shape — i.e., not honoring the
prefix-extended K.
Environment
peft==0.19.1
transformers==5.6.2
torch==2.11.0
- Python 3.12
- Hardware: NVIDIA A100 80GB;
attn_implementation="sdpa" (also expected to
fail on other CUDA devices with sdpa)
Related issues
- #2881 — same shape-mismatch
pattern for Qwen3 GQA prefix tuning. Fixed in
#2883, shipped in peft 0.18.0.
Gemma-4 needs the same class of adaptation.
- #1901 — original GQA
support for prefix tuning (Qwen2-era).
- #3129 — open Gemma-4 +
peft compatibility tracking. Currently scoped to LoRA's Gemma4ClippableLinear,
but Gemma-4 + PrefixTuning is the same ecosystem question.
Workaround attempt (in progress)
We suspect attn_implementation="eager" may bypass the bug (sdpa's fused
kernel doesn't honor manually-prepended K/V the way the eager implementation
does). Will update with results once verified. If eager works, that points the
fix toward the sdpa attention path in the Gemma-4 modeling code.
Suggested fix
Mirror the Qwen3 fix from #2883 for Gemma-4's attention implementation. The
fix likely lives near peft/peft_model.py:1964 where prefix-extended K/V is
fed to the base model's forward — the attention-mask construction needs to be
aware of the additional num_virtual_tokens for Gemma-4's specific attention
layout.
Happy to test patches.
Summary
Training with
PrefixTuningConfigongoogle/gemma-4-e2b-itcrashes in thefirst forward pass with a tensor expand-size mismatch. The error is consistent
with #2881 (Qwen3 GQA
prefix tuning), which was fixed for Qwen3 in
#2883 (peft 0.18.0). We're on
peft 0.19.1, so the Qwen3 fix is in — but Gemma 4 has a different attention
implementation that the same fix doesn't cover.
Reproducer
Error
The 20-token gap between target K-dim (1632 = sequence length) and tensor
K-dim (1652 = sequence length +
num_virtual_tokens) confirms peft hascorrectly prefix-extended the K/V, but Gemma-4's attention implementation
is computing a square
[B, H, Q, Q]target shape — i.e., not honoring theprefix-extended K.
Environment
peft==0.19.1transformers==5.6.2torch==2.11.0attn_implementation="sdpa"(also expected tofail on other CUDA devices with sdpa)
Related issues
pattern for Qwen3 GQA prefix tuning. Fixed in
#2883, shipped in peft 0.18.0.
Gemma-4 needs the same class of adaptation.
support for prefix tuning (Qwen2-era).
peft compatibility tracking. Currently scoped to LoRA's
Gemma4ClippableLinear,but Gemma-4 + PrefixTuning is the same ecosystem question.
Workaround attempt (in progress)
We suspect
attn_implementation="eager"may bypass the bug (sdpa's fusedkernel doesn't honor manually-prepended K/V the way the eager implementation
does). Will update with results once verified. If eager works, that points the
fix toward the sdpa attention path in the Gemma-4 modeling code.
Suggested fix
Mirror the Qwen3 fix from #2883 for Gemma-4's attention implementation. The
fix likely lives near
peft/peft_model.py:1964where prefix-extended K/V isfed to the base model's forward — the attention-mask construction needs to be
aware of the additional
num_virtual_tokensfor Gemma-4's specific attentionlayout.
Happy to test patches.