@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
150150 module .set_scale (adapter_name , 1.0 )
151151
152152
153- def get_peft_kwargs (rank_dict , network_alpha_dict , peft_state_dict , is_unet = True ):
153+ def get_peft_kwargs (
154+ rank_dict , network_alpha_dict , peft_state_dict , is_unet = True , model_state_dict = None , adapter_name = None
155+ ):
154156 rank_pattern = {}
155157 alpha_pattern = {}
156158 r = lora_alpha = list (rank_dict .values ())[0 ]
@@ -180,18 +182,23 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180182 else :
181183 lora_alpha = set (network_alpha_dict .values ()).pop ()
182184
183- # layer names without the Diffusers specific
184185 target_modules = list ({name .split (".lora" )[0 ] for name in peft_state_dict .keys ()})
185186 use_dora = any ("lora_magnitude_vector" in k for k in peft_state_dict )
186187 # for now we know that the "bias" keys are only associated with `lora_B`.
187188 lora_bias = any ("lora_B" in k and k .endswith (".bias" ) for k in peft_state_dict )
188189
190+ # Example: load FusionX LoRA into Wan VACE
191+ exclude_modules = _derive_exclude_modules (model_state_dict , peft_state_dict , adapter_name )
192+ if not exclude_modules :
193+ exclude_modules = None
194+
189195 lora_config_kwargs = {
190196 "r" : r ,
191197 "lora_alpha" : lora_alpha ,
192198 "rank_pattern" : rank_pattern ,
193199 "alpha_pattern" : alpha_pattern ,
194200 "target_modules" : target_modules ,
201+ "exclude_modules" : exclude_modules ,
195202 "use_dora" : use_dora ,
196203 "lora_bias" : lora_bias ,
197204 }
@@ -294,19 +301,20 @@ def check_peft_version(min_version: str) -> None:
294301
295302
296303def _create_lora_config (
297- state_dict ,
298- network_alphas ,
299- metadata ,
300- rank_pattern_dict ,
301- is_unet : bool = True ,
304+ state_dict , network_alphas , metadata , rank_pattern_dict , is_unet = True , model_state_dict = None , adapter_name = None
302305):
303306 from peft import LoraConfig
304307
305308 if metadata is not None :
306309 lora_config_kwargs = metadata
307310 else :
308311 lora_config_kwargs = get_peft_kwargs (
309- rank_pattern_dict , network_alpha_dict = network_alphas , peft_state_dict = state_dict , is_unet = is_unet
312+ rank_pattern_dict ,
313+ network_alpha_dict = network_alphas ,
314+ peft_state_dict = state_dict ,
315+ is_unet = is_unet ,
316+ model_state_dict = model_state_dict ,
317+ adapter_name = adapter_name ,
310318 )
311319
312320 _maybe_raise_error_for_ambiguous_keys (lora_config_kwargs )
@@ -371,3 +379,20 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
371379
372380 if warn_msg :
373381 logger .warning (warn_msg )
382+
383+
384+ def _derive_exclude_modules (model_state_dict , peft_state_dict , adapter_name = None ):
385+ all_modules = set ()
386+ string_to_replace = f"{ adapter_name } ." if adapter_name else ""
387+
388+ for name in model_state_dict .keys ():
389+ if string_to_replace :
390+ name = name .replace (string_to_replace , "" )
391+ if "." in name :
392+ module_name = name .rsplit ("." , 1 )[0 ]
393+ all_modules .add (module_name )
394+
395+ target_modules_set = {name .split (".lora" )[0 ] for name in peft_state_dict .keys ()}
396+ exclude_modules = list (all_modules - target_modules_set )
397+
398+ return exclude_modules
0 commit comments