Skip to content

PrefixTuningConfig fails on google/gemma-4-e2b-it: tensor expand size mismatch in attention forward (peft 0.19.1, transformers 5.6.2) #3201

@stharrold

Description

@stharrold

Summary

Training with PrefixTuningConfig on google/gemma-4-e2b-it crashes in the
first forward pass with a tensor expand-size mismatch. The error is consistent
with #2881 (Qwen3 GQA
prefix tuning), which was fixed for Qwen3 in
#2883 (peft 0.18.0). We're on
peft 0.19.1, so the Qwen3 fix is in — but Gemma 4 has a different attention
implementation that the same fix doesn't cover.

Reproducer

import torch
from peft import PrefixTuningConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "google/gemma-4-e2b-it"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)
cfg = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,
    prefix_projection=False,
)
model = get_peft_model(model, cfg)

tok = AutoTokenizer.from_pretrained(model_id)
text = "Hello, world." * 100  # any non-trivial prompt
inputs = tok(text, return_tensors="pt", truncation=True, max_length=1632).to(model.device)
inputs["labels"] = inputs["input_ids"]

with torch.no_grad():
    out = model(**inputs)
print("loss:", out.loss.item())

Error

RuntimeError: The expanded size of the tensor (1632) must match the existing
size (1652) at non-singleton dimension 3.
Target sizes: [1, 8, 1632, 1632].  Tensor sizes: [1, 1, 1632, 1652]

File "/opt/venv/lib/python3.12/site-packages/peft/peft_model.py", line 1964, in forward
... (wrapped calls into Gemma-4 attention)

The 20-token gap between target K-dim (1632 = sequence length) and tensor
K-dim (1652 = sequence length + num_virtual_tokens) confirms peft has
correctly prefix-extended the K/V, but Gemma-4's attention implementation
is computing a square [B, H, Q, Q] target shape — i.e., not honoring the
prefix-extended K.

Environment

  • peft==0.19.1
  • transformers==5.6.2
  • torch==2.11.0
  • Python 3.12
  • Hardware: NVIDIA A100 80GB; attn_implementation="sdpa" (also expected to
    fail on other CUDA devices with sdpa)

Related issues

  • #2881 — same shape-mismatch
    pattern for Qwen3 GQA prefix tuning. Fixed in
    #2883, shipped in peft 0.18.0.
    Gemma-4 needs the same class of adaptation.
  • #1901 — original GQA
    support for prefix tuning (Qwen2-era).
  • #3129 — open Gemma-4 +
    peft compatibility tracking. Currently scoped to LoRA's Gemma4ClippableLinear,
    but Gemma-4 + PrefixTuning is the same ecosystem question.

Workaround attempt (in progress)

We suspect attn_implementation="eager" may bypass the bug (sdpa's fused
kernel doesn't honor manually-prepended K/V the way the eager implementation
does). Will update with results once verified. If eager works, that points the
fix toward the sdpa attention path in the Gemma-4 modeling code.

Suggested fix

Mirror the Qwen3 fix from #2883 for Gemma-4's attention implementation. The
fix likely lives near peft/peft_model.py:1964 where prefix-extended K/V is
fed to the base model's forward — the attention-mask construction needs to be
aware of the additional num_virtual_tokens for Gemma-4's specific attention
layout.

Happy to test patches.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions