Skip to content

Commit 802bd1a

Browse files
committed
feat: use exclude modules to loraconfig.
1 parent dd28509 commit 802bd1a

2 files changed

Lines changed: 43 additions & 11 deletions

File tree

src/diffusers/loaders/peft.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,20 @@ def load_lora_adapter(
243243
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
244244
}
245245

246-
# create LoraConfig
247-
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
248-
249246
# adapter_name
250247
if adapter_name is None:
251248
adapter_name = get_adapter_name(self)
252249

250+
# create LoraConfig
251+
lora_config = _create_lora_config(
252+
state_dict,
253+
network_alphas,
254+
metadata,
255+
rank,
256+
model_state_dict=self.state_dict(),
257+
adapter_name=adapter_name,
258+
)
259+
253260
# <Unsafe code
254261
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
255262
# Now we remove any existing hooks to `_pipeline`.

src/diffusers/utils/peft_utils.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
150150
module.set_scale(adapter_name, 1.0)
151151

152152

153-
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
153+
def get_peft_kwargs(
154+
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
155+
):
154156
rank_pattern = {}
155157
alpha_pattern = {}
156158
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,18 +182,23 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180182
else:
181183
lora_alpha = set(network_alpha_dict.values()).pop()
182184

183-
# layer names without the Diffusers specific
184185
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
185186
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
186187
# for now we know that the "bias" keys are only associated with `lora_B`.
187188
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
188189

190+
# Example: load FusionX LoRA into Wan VACE
191+
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
192+
if not exclude_modules:
193+
exclude_modules = None
194+
189195
lora_config_kwargs = {
190196
"r": r,
191197
"lora_alpha": lora_alpha,
192198
"rank_pattern": rank_pattern,
193199
"alpha_pattern": alpha_pattern,
194200
"target_modules": target_modules,
201+
"exclude_modules": exclude_modules,
195202
"use_dora": use_dora,
196203
"lora_bias": lora_bias,
197204
}
@@ -294,19 +301,20 @@ def check_peft_version(min_version: str) -> None:
294301

295302

296303
def _create_lora_config(
297-
state_dict,
298-
network_alphas,
299-
metadata,
300-
rank_pattern_dict,
301-
is_unet: bool = True,
304+
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
302305
):
303306
from peft import LoraConfig
304307

305308
if metadata is not None:
306309
lora_config_kwargs = metadata
307310
else:
308311
lora_config_kwargs = get_peft_kwargs(
309-
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
312+
rank_pattern_dict,
313+
network_alpha_dict=network_alphas,
314+
peft_state_dict=state_dict,
315+
is_unet=is_unet,
316+
model_state_dict=model_state_dict,
317+
adapter_name=adapter_name,
310318
)
311319

312320
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +379,20 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
371379

372380
if warn_msg:
373381
logger.warning(warn_msg)
382+
383+
384+
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
385+
all_modules = set()
386+
string_to_replace = f"{adapter_name}." if adapter_name else ""
387+
388+
for name in model_state_dict.keys():
389+
if string_to_replace:
390+
name = name.replace(string_to_replace, "")
391+
if "." in name:
392+
module_name = name.rsplit(".", 1)[0]
393+
all_modules.add(module_name)
394+
395+
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
396+
exclude_modules = list(all_modules - target_modules_set)
397+
398+
return exclude_modules

0 commit comments

Comments
 (0)