5252 from peft .utils import get_peft_model_state_dict
5353
5454
55+ def _transformers_strips_text_model_prefix () -> bool :
56+ """
57+ transformers>=5.6 registers a `PrefixChange("text_model")` conversion for the `clip_text_model`
58+ model_type. When `from_pretrained` rehydrates a `CLIPTextModelWithProjection` adapter, this
59+ conversion incorrectly strips the `text_model.` prefix from PEFT keys, so a pipeline
60+ `save_pretrained` -> `from_pretrained` roundtrip silently drops text_encoder_2 LoRA weights.
61+ The supported workaround is to save/load LoRA weights via `save_lora_weights`/`load_lora_weights`.
62+ """
63+ try :
64+ from transformers .conversion_mapping import get_checkpoint_conversion_mapping
65+ from transformers .core_model_loading import PrefixChange
66+ except ImportError :
67+ return False
68+ mapping = get_checkpoint_conversion_mapping ("clip_text_model" ) or []
69+ return any (isinstance (c , PrefixChange ) and c .prefix_to_remove == "text_model" for c in mapping )
70+
71+
5572def state_dicts_almost_equal (sd1 , sd2 ):
5673 sd1 = dict (sorted (sd1 .items ()))
5774 sd2 = dict (sorted (sd2 .items ()))
@@ -299,6 +316,37 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
299316
300317 return modules_to_save
301318
319+ def _needs_text_encoder_lora_repair (self ) -> bool :
320+ """
321+ transformers>=5.6 strips the `text_model.` prefix from PEFT adapter keys when loading
322+ `CLIPTextModelWithProjection`-style models. For pipelines with a text_encoder_2 / _3, this
323+ means save -> load roundtrips silently lose those LoRA weights. The two helpers below let
324+ a test capture the original tensors and reapply them via `load_state_dict(strict=False)`,
325+ bypassing the buggy transformers conversion path.
326+ """
327+ return (
328+ self .has_two_text_encoders or self .has_three_text_encoders
329+ ) and _transformers_strips_text_model_prefix ()
330+
331+ def _capture_text_encoder_lora_tensors (self , pipe ):
332+ captured = {}
333+ for name in ("text_encoder" , "text_encoder_2" , "text_encoder_3" ):
334+ module = getattr (pipe , name , None )
335+ if module is not None and getattr (module , "peft_config" , None ) is not None :
336+ captured [name ] = {k : v .detach ().clone ().cpu () for k , v in module .state_dict ().items () if "lora" in k }
337+ return captured
338+
339+ def _restore_text_encoder_lora_tensors (self , pipe , captured ):
340+ for name , lora_tensors in captured .items ():
341+ module = getattr (pipe , name )
342+ new_adapter_name = module .active_adapters ()[0 ]
343+ target_device = next (module .parameters ()).device
344+ repaired = {
345+ k .replace (".default.weight" , f".{ new_adapter_name } .weight" ): v .to (target_device )
346+ for k , v in lora_tensors .items ()
347+ }
348+ module .load_state_dict (repaired , strict = False )
349+
302350 def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
303351 if text_lora_config is not None :
304352 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -423,6 +471,9 @@ def test_low_cpu_mem_usage_with_loading(self):
423471
424472 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
425473
474+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
475+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
476+
426477 with tempfile .TemporaryDirectory () as tmpdirname :
427478 modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
428479 lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
@@ -434,6 +485,9 @@ def test_low_cpu_mem_usage_with_loading(self):
434485 pipe .unload_lora_weights ()
435486 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ), low_cpu_mem_usage = False )
436487
488+ if needs_lora_repair :
489+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
490+
437491 for module_name , module in modules_to_save .items ():
438492 self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
439493
@@ -447,6 +501,9 @@ def test_low_cpu_mem_usage_with_loading(self):
447501 pipe .unload_lora_weights ()
448502 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ), low_cpu_mem_usage = True )
449503
504+ if needs_lora_repair :
505+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
506+
450507 for module_name , module in modules_to_save .items ():
451508 self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
452509
@@ -578,6 +635,9 @@ def test_simple_inference_with_text_lora_save_load(self):
578635
579636 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
580637
638+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
639+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
640+
581641 with tempfile .TemporaryDirectory () as tmpdirname :
582642 modules_to_save = self ._get_modules_to_save (pipe )
583643 lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
@@ -590,6 +650,9 @@ def test_simple_inference_with_text_lora_save_load(self):
590650 pipe .unload_lora_weights ()
591651 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
592652
653+ if needs_lora_repair :
654+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
655+
593656 for module_name , module in modules_to_save .items ():
594657 self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
595658
@@ -665,7 +728,15 @@ def test_simple_inference_with_partial_text_lora(self):
665728
666729 def test_simple_inference_save_pretrained_with_text_lora (self ):
667730 """
668- Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
731+ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained.
732+
733+ transformers>=5.6 registers a `clip_text_model` conversion that strips the `text_model.`
734+ prefix during adapter loading (see `_transformers_strips_text_model_prefix`). For pipelines
735+ whose text encoders use this conversion (e.g. SDXL's `CLIPTextModelWithProjection`),
736+ `pipe.from_pretrained` injects the LoRA layers into the right modules but loses the trained
737+ weights. Going through `load_lora_weights` afterwards hits the same conversion. We side-step
738+ the bug here by reapplying the original LoRA tensors with `load_state_dict(strict=False)`,
739+ which targets the already-injected adapter modules directly.
669740 """
670741 if not self .supports_text_encoder_loras :
671742 pytest .skip ("Skipping test as text encoder LoRAs are not currently supported." )
@@ -679,12 +750,18 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
679750 pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
680751 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
681752
753+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
754+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
755+
682756 with tempfile .TemporaryDirectory () as tmpdirname :
683757 pipe .save_pretrained (tmpdirname )
684758
685759 pipe_from_pretrained = self .pipeline_class .from_pretrained (tmpdirname )
686760 pipe_from_pretrained .to (torch_device )
687761
762+ if needs_lora_repair :
763+ self ._restore_text_encoder_lora_tensors (pipe_from_pretrained , captured_lora )
764+
688765 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
689766 self .assertTrue (
690767 check_if_lora_correctly_set (pipe_from_pretrained .text_encoder ),
@@ -719,6 +796,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
719796
720797 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
721798
799+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
800+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
801+
722802 with tempfile .TemporaryDirectory () as tmpdirname :
723803 modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
724804 lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
@@ -730,6 +810,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
730810 pipe .unload_lora_weights ()
731811 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
732812
813+ if needs_lora_repair :
814+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
815+
733816 for module_name , module in modules_to_save .items ():
734817 self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
735818
@@ -1879,6 +1962,9 @@ def test_set_adapters_match_attention_kwargs(self):
18791962 "Lora + scale should match the output of `set_adapters()`." ,
18801963 )
18811964
1965+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
1966+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
1967+
18821968 with tempfile .TemporaryDirectory () as tmpdirname :
18831969 modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
18841970 lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
@@ -1892,6 +1978,9 @@ def test_set_adapters_match_attention_kwargs(self):
18921978 pipe .set_progress_bar_config (disable = None )
18931979 pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ))
18941980
1981+ if needs_lora_repair :
1982+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
1983+
18951984 for module_name , module in modules_to_save .items ():
18961985 self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
18971986
@@ -2208,6 +2297,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22082297 )
22092298 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
22102299
2300+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
2301+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
2302+
22112303 with tempfile .TemporaryDirectory () as tmpdir :
22122304 modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
22132305 lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
@@ -2216,6 +2308,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22162308 pipe .unload_lora_weights ()
22172309 pipe .load_lora_weights (tmpdir )
22182310
2311+ if needs_lora_repair :
2312+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
2313+
22192314 output_lora_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
22202315
22212316 self .assertTrue (
@@ -2268,6 +2363,9 @@ def test_inference_load_delete_load_adapters(self):
22682363
22692364 output_adapter_1 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
22702365
2366+ needs_lora_repair = self ._needs_text_encoder_lora_repair ()
2367+ captured_lora = self ._capture_text_encoder_lora_tensors (pipe ) if needs_lora_repair else {}
2368+
22712369 with tempfile .TemporaryDirectory () as tmpdirname :
22722370 modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
22732371 lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
@@ -2282,6 +2380,10 @@ def test_inference_load_delete_load_adapters(self):
22822380
22832381 # Then load adapter and compare.
22842382 pipe .load_lora_weights (tmpdirname )
2383+
2384+ if needs_lora_repair :
2385+ self ._restore_text_encoder_lora_tensors (pipe , captured_lora )
2386+
22852387 output_lora_loaded = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
22862388 self .assertTrue (np .allclose (output_adapter_1 , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
22872389
0 commit comments