Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,31 @@ def get_peft_kwargs(
):
rank_pattern = {}
alpha_pattern = {}
if not rank_dict:
# rank_dict is populated by the caller by walking the model's named_modules
# and probing the state_dict for `{name}.lora_B.weight` keys (see e.g.
# `_load_lora_into_text_encoder` and `load_lora_into_unet`). When the
# state_dict keys do not match that pattern (typically because of a
# missing or extra prefix on the saved keys, an adapter-name infix such
# as `.default_0.` between `lora_B` and `weight`, or a non-diffusers
# serialization format that was not converted upstream), `rank_dict`
# ends up empty and we would crash here with a cryptic IndexError on
# `list(rank_dict.values())[0]`. Surface the actual problem instead so
# the caller can debug the key mismatch. See issue #3238 on huggingface/peft
# (the original report was filed against peft, but the failure path is
# this function in diffusers).
n_keys = len(peft_state_dict) if peft_state_dict is not None else 0
sample_keys = list(peft_state_dict.keys())[:3] if peft_state_dict else []
raise ValueError(
"Could not extract LoRA rank: `rank_dict` is empty. This means none of the "
"expected `{module_name}.lora_B.weight` keys were found in the state_dict. "
"Usual causes: the saved keys carry an extra or missing prefix versus the "
"target model (e.g. `text_model.encoder.*` vs `encoder.*`); the keys carry "
"an adapter-name infix such as `.default_0.` between `lora_B` and `weight`; "
"or the state_dict was saved in a format that diffusers does not yet "
"convert. "
f"State dict has {n_keys} keys; first 3: {sample_keys}."
)
r = lora_alpha = list(rank_dict.values())[0]

if len(set(rank_dict.values())) > 1:
Expand Down
85 changes: 85 additions & 0 deletions tests/others/test_peft_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from diffusers.utils.peft_utils import get_peft_kwargs


class GetPeftKwargsTest(unittest.TestCase):
"""Tests for diffusers.utils.peft_utils.get_peft_kwargs."""

def test_empty_rank_dict_raises_actionable_value_error(self):
"""Regression for huggingface/peft#3238 (failure path is in diffusers,
not peft). When the caller's rank-discovery loop produces an empty
`rank_dict` (typical when state_dict keys carry an extra/missing
prefix or an adapter-name infix that the loop did not match), we
used to crash with a cryptic IndexError on
`list(rank_dict.values())[0]`. Now we raise a `ValueError` whose
message names the underlying mismatch and shows a few state_dict
keys so the user can diagnose.
"""
peft_state_dict = {
"encoder.layers.0.self_attn.k_proj.lora_A.default_0.weight": object(),
"encoder.layers.0.self_attn.k_proj.lora_B.default_0.weight": object(),
"encoder.layers.0.self_attn.q_proj.lora_A.default_0.weight": object(),
}
with self.assertRaises(ValueError) as cm:
get_peft_kwargs(
rank_dict={},
network_alpha_dict=None,
peft_state_dict=peft_state_dict,
is_unet=False,
)
message = str(cm.exception)
self.assertIn("`rank_dict` is empty", message)
self.assertIn("lora_B.weight", message)
# The message includes a sample of the state_dict so the user can spot
# the prefix/infix mismatch from the error alone.
self.assertIn("State dict has 3 keys", message)

def test_empty_rank_dict_with_none_state_dict_is_safe(self):
"""The diagnostic message should not crash on a None peft_state_dict."""
with self.assertRaises(ValueError) as cm:
get_peft_kwargs(
rank_dict={},
network_alpha_dict=None,
peft_state_dict=None,
is_unet=True,
)
self.assertIn("State dict has 0 keys", str(cm.exception))

def test_non_empty_rank_dict_unchanged(self):
"""The fast-path (rank_dict populated as before) must remain
functionally identical. Smoke-test that get_peft_kwargs returns the
expected keys for a minimal one-module rank_dict.
"""
rank_dict = {"q_proj.lora_B.weight": 4}
peft_state_dict = {
"q_proj.lora_A.weight": object(),
"q_proj.lora_B.weight": object(),
}
kwargs = get_peft_kwargs(
rank_dict=rank_dict,
network_alpha_dict=None,
peft_state_dict=peft_state_dict,
is_unet=True,
)
self.assertEqual(kwargs["r"], 4)
self.assertEqual(kwargs["lora_alpha"], 4)
self.assertEqual(kwargs["rank_pattern"], {})
self.assertIn("q_proj", kwargs["target_modules"])
self.assertFalse(kwargs["use_dora"])
self.assertFalse(kwargs["lora_bias"])
Loading