Skip to content

Commit 17a7a16

Browse files
FIX Error when prefix tuning Gemma 4 (#3205)
There was an issue with applying prefix tuning to Gemma 4 because the model uses different head dimensions for layers that use sliding window attention. As prefix tuning only initializes a single projection matrix that is used for all layers, this would lead to a shape mismatch. The solution is to "overprovision" the matrix and then slice the prefix down to size of the layer is smaller. This is not quite as parameter efficient as it could be, but the overhead shouldn't be too large. For robustness, we also skip layers if the matrix is underprovisioned, but we warn about it and raise an error if all layers are skipped. Alternatively, we could implement one project per layer, each with the right size, like in google-deepmind/gemma#631. However, this would be a big refactor and also very hard to make backwards compatible with existing checkpoints, so going with the less efficient solution is preferable. This PR also contains an independent, single line fix to a prefix tuning test that was referencing a non-existing model.
1 parent 9cda9e3 commit 17a7a16

4 files changed

Lines changed: 160 additions & 6 deletions

File tree

src/peft/peft_model.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@
6969
)
7070

7171

72+
def _get_layer_kv_target_shape(base_config, layer_idx: int) -> tuple[int, int] | None:
73+
"""Per-layer (num_kv_heads, head_dim) for prefix-tuning injection, or None for uniform models.
74+
75+
Models with heterogeneous attention (e.g. Gemma4) expose `global_head_dim` / `num_global_key_value_heads` alongside
76+
the sliding-layer `head_dim` / `num_key_value_heads`. The provisioned prefix is sized for the global footprint;
77+
this returns the shape each layer actually expects so the caller can slice down or skip layers that don't fit.
78+
"""
79+
layer_types = getattr(base_config, "layer_types", None)
80+
global_head_dim = getattr(base_config, "global_head_dim", None)
81+
if not layer_types or global_head_dim is None:
82+
return None
83+
84+
is_sliding = layer_types[layer_idx] == "sliding_attention"
85+
head_dim = base_config.head_dim if is_sliding else global_head_dim
86+
num_global_kv = getattr(base_config, "num_global_key_value_heads", None)
87+
if not is_sliding and num_global_kv is not None:
88+
num_kv_heads = num_global_kv
89+
else:
90+
num_kv_heads = base_config.num_key_value_heads
91+
return num_kv_heads, head_dim
92+
93+
7294
class PeftModel(PushToHubMixin, torch.nn.Module):
7395
"""
7496
Base model encompassing various Peft methods.
@@ -785,7 +807,7 @@ def get_prompt(
785807
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
786808
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
787809
past_key_values = post_process_fn(past_key_values)
788-
elif ("gemma2" in model_type) or ("gemma3_text" in model_type):
810+
elif ("gemma2" in model_type) or ("gemma3_text" in model_type) or ("gemma4" in model_type):
789811
# TODO: remove this logic once transformers < 4.56 is dropped
790812
transformers_lt_4_56 = packaging.version.parse(transformers.__version__) < packaging.version.parse(
791813
"4.56.0.dev0"
@@ -815,12 +837,54 @@ def get_prompt(
815837
# transformers 4.56+ uses DynamicCache for gemma
816838
new_cache = DynamicCache(config=base_config)
817839
cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device)
818-
for layer_idx in range(peft_config.num_layers):
819-
key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx]
840+
# Layers from `num_hidden_layers - num_kv_shared_layers` onward share KV with an earlier layer (no own
841+
# k_proj/v_proj) and never call `cache.update`; the prefix reaches them transitively via the source
842+
# layer.
843+
num_kv_shared_layers = getattr(base_config, "num_kv_shared_layers", 0)
844+
first_kv_shared_layer_idx = (
845+
getattr(base_config, "num_hidden_layers", peft_config.num_layers) - num_kv_shared_layers
846+
)
847+
injected_layers: list[int] = []
848+
skipped_layers: list[int] = []
849+
# past_key_values is a tuple of `num_layers` per-layer tensors each shaped
850+
# [2, batch, num_heads, num_virtual_tokens, head_dim], where dim 0 stacks K and V.
851+
for layer_idx, layer_past_key_values in enumerate(past_key_values):
852+
if num_kv_shared_layers > 0 and layer_idx >= first_kv_shared_layer_idx:
853+
skipped_layers.append(layer_idx)
854+
continue
855+
key_states, value_states = layer_past_key_values
856+
shape_or_none = _get_layer_kv_target_shape(base_config, layer_idx)
857+
if shape_or_none is not None: # e.g. gemma 4
858+
n_h, d = shape_or_none
859+
# Provisioned shape: [batch, num_heads, num_virtual_tokens, head_dim]. If a layer's KV is wider
860+
# than what we provisioned, we cannot slice up; skip rather than silently truncating to a shape
861+
# the model won't accept.
862+
if n_h > key_states.shape[1] or d > key_states.shape[3]:
863+
skipped_layers.append(layer_idx)
864+
continue
865+
key_states = key_states[:, :n_h, :, :d]
866+
value_states = value_states[:, :n_h, :, :d]
820867
new_cache.update(
821868
key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position}
822869
)
870+
injected_layers.append(layer_idx)
823871
past_key_values = new_cache
872+
873+
if not injected_layers:
874+
# raise if no layer was matched; similar logic as in target_modules not matching any layer
875+
raise ValueError(
876+
"Prefix tuning skipped every layer because no layer's KV shape matched the provisioned prefix "
877+
f"(num_attention_heads={peft_config.num_attention_heads}, "
878+
f"head_dim={peft_config.token_dim // peft_config.num_attention_heads}). Override `token_dim` "
879+
"and `num_attention_heads` in `PrefixTuningConfig` to match a layer that should receive the "
880+
"prefix."
881+
)
882+
if skipped_layers:
883+
warnings.warn(
884+
f"Prefix tuning injected into layers {injected_layers}; skipped {skipped_layers} due to KV "
885+
"shape mismatch or shared-KV layers."
886+
)
887+
824888
elif peft_config.num_transformer_submodules == 1:
825889
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
826890
# TODO: remove from_legacy_cache once transformers < 4.56 is dropped

src/peft/utils/other.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,11 +1161,19 @@ def _prepare_prompt_learning_config(peft_config, model_config):
11611161

11621162
# For grouped-query attention, see #1901.
11631163
if (peft_config.peft_type in {"PREFIX_TUNING", "CARTRIDGE"}) and ("num_key_value_heads" in model_config):
1164-
num_key_value_heads = model_config["num_key_value_heads"]
1165-
if model_config.get("head_dim", None) is not None:
1164+
# Models with heterogeneous attention (e.g. Gemma4) expose distinct shapes for global vs. sliding layers via
1165+
# `global_head_dim` / `num_global_key_value_heads`. Provision the prefix for the global-layer footprint; sliding
1166+
# layers whose KV shape doesn't match are skipped per-layer at injection time. Matches the default in
1167+
# google-deepmind/gemma#631.
1168+
if model_config.get("global_head_dim") is not None:
1169+
head_dim = model_config["global_head_dim"]
1170+
num_key_value_heads = model_config.get("num_global_key_value_heads") or model_config["num_key_value_heads"]
1171+
elif model_config.get("head_dim", None) is not None:
11661172
head_dim = model_config["head_dim"]
1173+
num_key_value_heads = model_config["num_key_value_heads"]
11671174
else:
11681175
head_dim = peft_config.token_dim // peft_config.num_attention_heads
1176+
num_key_value_heads = model_config["num_key_value_heads"]
11691177
peft_config.token_dim = head_dim * num_key_value_heads
11701178
peft_config.num_attention_heads = num_key_value_heads
11711179

tests/test_decoder_models.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,3 +1106,85 @@ def test_merge_and_unload_fixes_tie_word_embeddings_config(self):
11061106
assert not merged.config.tie_word_embeddings
11071107
assert merged.lm_head.weight is not merged.model.embed_tokens.weight
11081108
assert merged.lm_head.weight.data_ptr() != merged.model.embed_tokens.weight.data_ptr()
1109+
1110+
def test_prefix_tuning_gemma4_works(self):
1111+
# see #3201
1112+
# The issue was that head dim differs depending on whether sliding window attention is being used or not:
1113+
# https://github.com/huggingface/transformers/blob/223fe5231b783fbfb25296bb0a243dad5d158cde/src/transformers/models/gemma4/modeling_gemma4.py#L1147
1114+
# Prefix tuning could deal with different sizes, resulting in a size error
1115+
1116+
model_id = "google/gemma-4-E2B"
1117+
with hub_online_once(model_id):
1118+
model = AutoModelForCausalLM.from_pretrained(
1119+
model_id,
1120+
dtype=torch.bfloat16,
1121+
).to(self.torch_device)
1122+
config = PrefixTuningConfig(
1123+
task_type=TaskType.CAUSAL_LM,
1124+
num_virtual_tokens=20,
1125+
prefix_projection=False,
1126+
)
1127+
model = get_peft_model(model, config)
1128+
1129+
inputs = torch.arange(10).view(1, -1).to(self.torch_device)
1130+
model(inputs) # does not raise
1131+
1132+
# do mini training run
1133+
optim = torch.optim.SGD(model.parameters(), lr=0.001)
1134+
losses = []
1135+
for _ in range(5):
1136+
optim.zero_grad()
1137+
outputs = model(inputs)
1138+
label = torch.zeros_like(outputs.logits)
1139+
label[:, :, 1] = 1
1140+
loss = torch.nn.functional.cross_entropy(outputs.logits, label)
1141+
loss.backward()
1142+
optim.step()
1143+
losses.append(loss)
1144+
1145+
assert torch.isfinite(loss)
1146+
assert not torch.isclose(losses[0], losses[-1], atol=1e-6, rtol=1e-3)
1147+
1148+
def test_prefix_tuning_gemma4_warns_if_some_layers_skipped(self):
1149+
# See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted
1150+
# by prefix tuning, raise an error
1151+
model_id = "google/gemma-4-E2B"
1152+
with hub_online_once(model_id):
1153+
model = AutoModelForCausalLM.from_pretrained(
1154+
model_id,
1155+
dtype=torch.bfloat16,
1156+
).to(self.torch_device)
1157+
config = PrefixTuningConfig(
1158+
task_type=TaskType.CAUSAL_LM,
1159+
num_virtual_tokens=20,
1160+
prefix_projection=False,
1161+
)
1162+
text_config = model.config.get_text_config()
1163+
text_config.num_kv_shared_layers = 1 # set to lower value (was 2)
1164+
model = get_peft_model(model, config)
1165+
1166+
inputs = torch.arange(10).view(1, -1).to(self.torch_device)
1167+
with pytest.warns(UserWarning, match=r"skipped \[.*\] due to KV shape"):
1168+
model(inputs)
1169+
1170+
def test_prefix_tuning_gemma4_raises_if_all_layers_skipped(self):
1171+
# See previous test_prefix_tuning_gemma4_works. When the embedding matrix is too small to fit any layer targeted
1172+
# by prefix tuning, raise an error
1173+
model_id = "google/gemma-4-E2B"
1174+
with hub_online_once(model_id):
1175+
model = AutoModelForCausalLM.from_pretrained(
1176+
model_id,
1177+
dtype=torch.bfloat16,
1178+
).to(self.torch_device)
1179+
config = PrefixTuningConfig(
1180+
task_type=TaskType.CAUSAL_LM,
1181+
num_virtual_tokens=20,
1182+
prefix_projection=False,
1183+
)
1184+
model = get_peft_model(model, config)
1185+
text_config = model.config.get_text_config()
1186+
text_config.num_key_value_heads = 999
1187+
1188+
inputs = torch.arange(10).view(1, -1).to(self.torch_device)
1189+
with pytest.raises(ValueError, match="skipped every layer"):
1190+
model(inputs)

tests/test_gpu_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5401,7 +5401,7 @@ def test_prefix_tuning_multiple_devices_decoder_model(self):
54015401
@require_torch_multi_accelerator
54025402
def test_prefix_tuning_multiple_devices_encoder_decoder_model(self):
54035403
# See issue 2134
5404-
model_id = "peft-internal-testing/tiny-random-T5Model"
5404+
model_id = "peft-internal-testing/tiny-random-t5"
54055405
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
54065406
inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to(self.device)
54075407
device_map = {

0 commit comments

Comments
 (0)