Skip to content

Commit a932fda

Browse files
committed
fix(peft_utils): raise actionable ValueError when rank_dict is empty in get_peft_kwargs
`get_peft_kwargs` is reached after callers (e.g. `_load_lora_into_text_encoder` and the various `load_lora_into_{unet,transformer}` flows) walk the model's `named_modules` and probe the state_dict for `{module}.lora_B.weight` keys. When the saved LoRA state_dict carries an extra or missing prefix versus the target model (e.g. `text_model.encoder.*` vs `encoder.*`), an adapter-name infix between `lora_B` and `weight` (e.g. `.default_0.`), or a non-diffusers serialization format that was not converted upstream, none of the probed keys are present in the state_dict and `rank_dict` arrives here empty. The first statement of the function was `r = lora_alpha = list(rank_dict.values())[0]`, which then raised an unhelpful `IndexError: list index out of range` with no signal about which mismatch was actually responsible (see huggingface/peft#3238 for a real-world report against SDXL + peft 0.19; the report was filed against peft but the failure path is in diffusers). Raise a `ValueError` whose message names the common causes and shows a sample of the state_dict keys, so the next user with a key-mismatch can diagnose their specific case from the exception alone rather than from a 5-frame deep IndexError. Adds focused tests for the new error path, the safe-on-None behavior, and the unchanged fast-path. Refs: huggingface/peft#3238
1 parent de5fcf6 commit a932fda

2 files changed

Lines changed: 110 additions & 0 deletions

File tree

src/diffusers/utils/peft_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,31 @@ def get_peft_kwargs(
155155
):
156156
rank_pattern = {}
157157
alpha_pattern = {}
158+
if not rank_dict:
159+
# rank_dict is populated by the caller by walking the model's named_modules
160+
# and probing the state_dict for `{name}.lora_B.weight` keys (see e.g.
161+
# `_load_lora_into_text_encoder` and `load_lora_into_unet`). When the
162+
# state_dict keys do not match that pattern (typically because of a
163+
# missing or extra prefix on the saved keys, an adapter-name infix such
164+
# as `.default_0.` between `lora_B` and `weight`, or a non-diffusers
165+
# serialization format that was not converted upstream), `rank_dict`
166+
# ends up empty and we would crash here with a cryptic IndexError on
167+
# `list(rank_dict.values())[0]`. Surface the actual problem instead so
168+
# the caller can debug the key mismatch. See issue #3238 on huggingface/peft
169+
# (the original report was filed against peft, but the failure path is
170+
# this function in diffusers).
171+
n_keys = len(peft_state_dict) if peft_state_dict is not None else 0
172+
sample_keys = list(peft_state_dict.keys())[:3] if peft_state_dict else []
173+
raise ValueError(
174+
"Could not extract LoRA rank: `rank_dict` is empty. This means none of the "
175+
"expected `{module_name}.lora_B.weight` keys were found in the state_dict. "
176+
"Usual causes: the saved keys carry an extra or missing prefix versus the "
177+
"target model (e.g. `text_model.encoder.*` vs `encoder.*`); the keys carry "
178+
"an adapter-name infix such as `.default_0.` between `lora_B` and `weight`; "
179+
"or the state_dict was saved in a format that diffusers does not yet "
180+
"convert. "
181+
f"State dict has {n_keys} keys; first 3: {sample_keys}."
182+
)
158183
r = lora_alpha = list(rank_dict.values())[0]
159184

160185
if len(set(rank_dict.values())) > 1:

tests/others/test_peft_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
from diffusers.utils.peft_utils import get_peft_kwargs
19+
20+
21+
class GetPeftKwargsTest(unittest.TestCase):
22+
"""Tests for diffusers.utils.peft_utils.get_peft_kwargs."""
23+
24+
def test_empty_rank_dict_raises_actionable_value_error(self):
25+
"""Regression for huggingface/peft#3238 (failure path is in diffusers,
26+
not peft). When the caller's rank-discovery loop produces an empty
27+
`rank_dict` (typical when state_dict keys carry an extra/missing
28+
prefix or an adapter-name infix that the loop did not match), we
29+
used to crash with a cryptic IndexError on
30+
`list(rank_dict.values())[0]`. Now we raise a `ValueError` whose
31+
message names the underlying mismatch and shows a few state_dict
32+
keys so the user can diagnose.
33+
"""
34+
peft_state_dict = {
35+
"encoder.layers.0.self_attn.k_proj.lora_A.default_0.weight": object(),
36+
"encoder.layers.0.self_attn.k_proj.lora_B.default_0.weight": object(),
37+
"encoder.layers.0.self_attn.q_proj.lora_A.default_0.weight": object(),
38+
}
39+
with self.assertRaises(ValueError) as cm:
40+
get_peft_kwargs(
41+
rank_dict={},
42+
network_alpha_dict=None,
43+
peft_state_dict=peft_state_dict,
44+
is_unet=False,
45+
)
46+
message = str(cm.exception)
47+
self.assertIn("`rank_dict` is empty", message)
48+
self.assertIn("lora_B.weight", message)
49+
# The message includes a sample of the state_dict so the user can spot
50+
# the prefix/infix mismatch from the error alone.
51+
self.assertIn("State dict has 3 keys", message)
52+
53+
def test_empty_rank_dict_with_none_state_dict_is_safe(self):
54+
"""The diagnostic message should not crash on a None peft_state_dict."""
55+
with self.assertRaises(ValueError) as cm:
56+
get_peft_kwargs(
57+
rank_dict={},
58+
network_alpha_dict=None,
59+
peft_state_dict=None,
60+
is_unet=True,
61+
)
62+
self.assertIn("State dict has 0 keys", str(cm.exception))
63+
64+
def test_non_empty_rank_dict_unchanged(self):
65+
"""The fast-path (rank_dict populated as before) must remain
66+
functionally identical. Smoke-test that get_peft_kwargs returns the
67+
expected keys for a minimal one-module rank_dict.
68+
"""
69+
rank_dict = {"q_proj.lora_B.weight": 4}
70+
peft_state_dict = {
71+
"q_proj.lora_A.weight": object(),
72+
"q_proj.lora_B.weight": object(),
73+
}
74+
kwargs = get_peft_kwargs(
75+
rank_dict=rank_dict,
76+
network_alpha_dict=None,
77+
peft_state_dict=peft_state_dict,
78+
is_unet=True,
79+
)
80+
self.assertEqual(kwargs["r"], 4)
81+
self.assertEqual(kwargs["lora_alpha"], 4)
82+
self.assertEqual(kwargs["rank_pattern"], {})
83+
self.assertIn("q_proj", kwargs["target_modules"])
84+
self.assertFalse(kwargs["use_dora"])
85+
self.assertFalse(kwargs["lora_bias"])

0 commit comments

Comments
 (0)