Skip to content

FIX Error when prefix tuning Gemma 4#3205

Merged
BenjaminBossan merged 2 commits into
huggingface:mainfrom
BenjaminBossan:fix-prefix-tuning-gemma4
May 5, 2026
Merged

FIX Error when prefix tuning Gemma 4#3205
BenjaminBossan merged 2 commits into
huggingface:mainfrom
BenjaminBossan:fix-prefix-tuning-gemma4

Conversation

@BenjaminBossan
Copy link
Copy Markdown
Member

There was an issue with applying prefix tuning to Gemma 4 because the model uses different head dimensions for layers that use sliding window attention. As prefix tuning only initializes a single projection matrix that is used for all layers, this would lead to a shape mismatch.

The solution is to "overprovision" the matrix and then slice the prefix down to size of the layer is smaller. This is not quite as parameter efficient as it could be, but the overhead shouldn't be too large.

For robustness, we also skip layers if the matrix is underprovisioned, but we warn about it and raise an error if all layers are skipped.

Alternatively, we could implement one project per layer, each with the right size, like in google-deepmind/gemma#631. However, this would be a big refactor and also very hard to make backwards compatible with existing checkpoints, so going with the less efficient solution is preferable.

This PR also contains an independent, single line fix to a prefix tuning test that was referencing a non-existing model.

There was an issue with applying prefix tuning to Gemma 4 because the
model uses different head dimensions for layers that use sliding window
attention. As prefix tuning only initializes a single projection matrix
that is used for all layers, this would lead to a shape mismatch.

The solution is to "overprovision" the matrix and then slice the prefix
down to size of the layer is smaller. This is not quite as parameter
efficient as it could be, but the overhead shouldn't be too large.

For robustness, we also skip layers if the matrix is underprovisioned,
but we warn about it and raise an error if all layers are skipped.

Alternatively, we could implement one project per layer, each with the
right size, like in google-deepmind/gemma#631.
However, this would be a big refactor and also very hard to make
backwards compatible with existing checkpoints, so going with the less
efficient solution is preferable.

This PR also contains an independent, single line fix to a prefix tuning
test that was referencing a non-existing model.
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Tests passed on Linux but not on Windows. Trying to guess tolerances
that could work.
@stharrold
Copy link
Copy Markdown

Tested peft#3205 (head c020c11c) against Gemma-4-E2B + SDPA + PrefixTuning on NVIDIA H100 80GB. The forward pass succeeds.

Environment:

  • peft @ git+https://github.com/huggingface/peft.git@c020c11c397cdf2d66a34dccceab4246517a28c1 (reports 0.19.2.dev0)
  • transformers==5.7.0
  • torch==2.5.1+cu124, CUDA available
  • attn_implementation="sdpa", torch_dtype=torch.bfloat16, NVIDIA H100 80GB
  • model_id="google/gemma-4-e2b-it", PrefixTuningConfig(num_virtual_tokens=20, prefix_projection=False)

Result:

  • Forward returned loss=15.6689 (finite, no NaN)
  • input shape: torch.Size([1, 400])
  • logits shape: torch.Size([1, 400, 262144])
  • The skip-when-KV-mismatch logic from this PR kicked in cleanly. From the run output:

    UserWarning: Prefix tuning injected into layers [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; skipped [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] due to KV shape mismatch or shared-KV layers.

Caveat on shape: our test input ("Hello, world." * 100 with max_length=1632) tokenized to 400 tokens, not the 1632 the reproducer in #3201 finding 1 used. So the exact expanded size (1632) must match the existing size (1652) error didn't surface at this shape — but the forward path that previously crashed for prefix-tuning + SDPA + Gemma-4 combinations now runs end-to-end, and the PR's layer-skip logic is observable. Happy to retest with a longer literal prompt that hits 1632+ tokens if you want exact-shape repro confirmation.

For the other two findings in #3201 (eager-attn position_ids overflow at modeling_gemma4.py:2262; P-Tuning v2 PLE 1.69 TB OOM), this PR doesn't change behavior — they remain WONTFIX per the upstream consensus, and we're tracking them on synavistra's side via internal triage.

Thanks for the partial-fix PR — happy to test follow-up shapes (prefix_projection=True, multi-batch, longer sequence length) if useful for the merge.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review May 4, 2026 09:02
@BenjaminBossan
Copy link
Copy Markdown
Member Author

@zucchini-nlp It would be great if you could review the PR.

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm if we want to support gemma4. I think there is no uniform standard way for a unique arch, so hardcoding specific config names is fine

Comment thread src/peft/peft_model.py
Comment on lines +72 to +79
def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None:
"""Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models.

Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside
the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint;
this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit.
"""
layer_types = getattr(base_config, "layer_types", None)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so ig we're supporting specifically gemma4 with hardcoded attr names

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if there is a more general approach, LMK, otherwise I'm okay with a Gemma-specific solution.

Comment thread src/peft/peft_model.py
Comment on lines +852 to +855
if num_kv_shared_layers > 0 and layer_idx >= first_kv_shared_layer_idx:
skipped_layers.append(layer_idx)
continue
key_states, value_states = layer_past_key_values
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, prev gemma3 also used to skip layers, so we shouldn't need a prefix cache for it

)
model = get_peft_model(model, config)
text_config = model.config.get_text_config()
text_config.num_key_value_heads = 999
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about this. If configs dim are changed, doesn't it mean that key/value cache will also be a larger tensor?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, but I think it works because the model is already initialized and the cache is already created at this point, so changing the config won't affect it. But I haven't checked the full code path.

Copy link
Copy Markdown
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for reviewing @zucchini-nlp. I replied to your comments. LMK if there is anything left, otherwise I'll go ahead and merge.

Comment thread src/peft/peft_model.py
Comment on lines +72 to +79
def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None:
"""Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models.

Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside
the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint;
this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit.
"""
layer_types = getattr(base_config, "layer_types", None)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if there is a more general approach, LMK, otherwise I'm okay with a Gemma-specific solution.

)
model = get_peft_model(model, config)
text_config = model.config.get_text_config()
text_config.num_key_value_heads = 999
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, but I think it works because the model is already initialized and the cache is already created at this point, so changing the config won't affect it. But I haven't checked the full code path.

@zucchini-nlp
Copy link
Copy Markdown
Member

Yep, all good for me. The thing about bigger head dim seems to be a specific edge case :)

@BenjaminBossan BenjaminBossan merged commit 17a7a16 into huggingface:main May 5, 2026
10 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-prefix-tuning-gemma4 branch May 5, 2026 10:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants