FIX Error when prefix tuning Gemma 4#3205
Conversation
There was an issue with applying prefix tuning to Gemma 4 because the model uses different head dimensions for layers that use sliding window attention. As prefix tuning only initializes a single projection matrix that is used for all layers, this would lead to a shape mismatch. The solution is to "overprovision" the matrix and then slice the prefix down to size of the layer is smaller. This is not quite as parameter efficient as it could be, but the overhead shouldn't be too large. For robustness, we also skip layers if the matrix is underprovisioned, but we warn about it and raise an error if all layers are skipped. Alternatively, we could implement one project per layer, each with the right size, like in google-deepmind/gemma#631. However, this would be a big refactor and also very hard to make backwards compatible with existing checkpoints, so going with the less efficient solution is preferable. This PR also contains an independent, single line fix to a prefix tuning test that was referencing a non-existing model.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Tests passed on Linux but not on Windows. Trying to guess tolerances that could work.
|
Tested peft#3205 (head Environment:
Result:
Caveat on shape: our test input ( For the other two findings in #3201 (eager-attn Thanks for the partial-fix PR — happy to test follow-up shapes ( |
|
@zucchini-nlp It would be great if you could review the PR. |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Lgtm if we want to support gemma4. I think there is no uniform standard way for a unique arch, so hardcoding specific config names is fine
| def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None: | ||
| """Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models. | ||
|
|
||
| Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside | ||
| the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint; | ||
| this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit. | ||
| """ | ||
| layer_types = getattr(base_config, "layer_types", None) |
There was a problem hiding this comment.
so ig we're supporting specifically gemma4 with hardcoded attr names
There was a problem hiding this comment.
Yeah, if there is a more general approach, LMK, otherwise I'm okay with a Gemma-specific solution.
| if num_kv_shared_layers > 0 and layer_idx >= first_kv_shared_layer_idx: | ||
| skipped_layers.append(layer_idx) | ||
| continue | ||
| key_states, value_states = layer_past_key_values |
There was a problem hiding this comment.
nice, prev gemma3 also used to skip layers, so we shouldn't need a prefix cache for it
| ) | ||
| model = get_peft_model(model, config) | ||
| text_config = model.config.get_text_config() | ||
| text_config.num_key_value_heads = 999 |
There was a problem hiding this comment.
curious about this. If configs dim are changed, doesn't it mean that key/value cache will also be a larger tensor?
There was a problem hiding this comment.
I'm not 100% sure, but I think it works because the model is already initialized and the cache is already created at this point, so changing the config won't affect it. But I haven't checked the full code path.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks a lot for reviewing @zucchini-nlp. I replied to your comments. LMK if there is anything left, otherwise I'll go ahead and merge.
| def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None: | ||
| """Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models. | ||
|
|
||
| Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside | ||
| the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint; | ||
| this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit. | ||
| """ | ||
| layer_types = getattr(base_config, "layer_types", None) |
There was a problem hiding this comment.
Yeah, if there is a more general approach, LMK, otherwise I'm okay with a Gemma-specific solution.
| ) | ||
| model = get_peft_model(model, config) | ||
| text_config = model.config.get_text_config() | ||
| text_config.num_key_value_heads = 999 |
There was a problem hiding this comment.
I'm not 100% sure, but I think it works because the model is already initialized and the cache is already created at this point, so changing the config won't affect it. But I haven't checked the full code path.
|
Yep, all good for me. The thing about bigger head dim seems to be a specific edge case :) |
There was an issue with applying prefix tuning to Gemma 4 because the model uses different head dimensions for layers that use sliding window attention. As prefix tuning only initializes a single projection matrix that is used for all layers, this would lead to a shape mismatch.
The solution is to "overprovision" the matrix and then slice the prefix down to size of the layer is smaller. This is not quite as parameter efficient as it could be, but the overhead shouldn't be too large.
For robustness, we also skip layers if the matrix is underprovisioned, but we warn about it and raise an error if all layers are skipped.
Alternatively, we could implement one project per layer, each with the right size, like in google-deepmind/gemma#631. However, this would be a big refactor and also very hard to make backwards compatible with existing checkpoints, so going with the less efficient solution is preferable.
This PR also contains an independent, single line fix to a prefix tuning test that was referencing a non-existing model.