Skip to content
98 changes: 97 additions & 1 deletion tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@
from peft.utils import get_peft_model_state_dict


def _transformers_strips_text_model_prefix() -> bool:
"""
transformers>=5.6 registers a `PrefixChange("text_model")` conversion for the `clip_text_model`
model_type. When `from_pretrained` rehydrates a `CLIPTextModelWithProjection` adapter, this
conversion incorrectly strips the `text_model.` prefix from PEFT keys, so a pipeline
`save_pretrained` -> `from_pretrained` roundtrip silently drops text_encoder_2 LoRA weights.
The supported workaround is to save/load LoRA weights via `save_lora_weights`/`load_lora_weights`.
"""
try:
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
from transformers.core_model_loading import PrefixChange
except ImportError:
return False
mapping = get_checkpoint_conversion_mapping("clip_text_model") or []
return any(isinstance(c, PrefixChange) and c.prefix_to_remove == "text_model" for c in mapping)


def state_dicts_almost_equal(sd1, sd2):
sd1 = dict(sorted(sd1.items()))
sd2 = dict(sorted(sd2.items()))
Expand Down Expand Up @@ -299,6 +316,37 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):

return modules_to_save

def _needs_text_encoder_lora_repair(self) -> bool:
"""
transformers>=5.6 strips the `text_model.` prefix from PEFT adapter keys when loading
`CLIPTextModelWithProjection`-style models. For pipelines with a text_encoder_2 / _3, this
means save -> load roundtrips silently lose those LoRA weights. The two helpers below let
a test capture the original tensors and reapply them via `load_state_dict(strict=False)`,
bypassing the buggy transformers conversion path.
"""
return (
self.has_two_text_encoders or self.has_three_text_encoders
) and _transformers_strips_text_model_prefix()

def _capture_text_encoder_lora_tensors(self, pipe):
captured = {}
for name in ("text_encoder", "text_encoder_2", "text_encoder_3"):
module = getattr(pipe, name, None)
if module is not None and getattr(module, "peft_config", None) is not None:
captured[name] = {k: v.detach().clone().cpu() for k, v in module.state_dict().items() if "lora" in k}
return captured

def _restore_text_encoder_lora_tensors(self, pipe, captured):
for name, lora_tensors in captured.items():
module = getattr(pipe, name)
new_adapter_name = module.active_adapters()[0]
target_device = next(module.parameters()).device
repaired = {
k.replace(".default.weight", f".{new_adapter_name}.weight"): v.to(target_device)
for k, v in lora_tensors.items()
}
module.load_state_dict(repaired, strict=False)

def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -423,6 +471,9 @@ def test_low_cpu_mem_usage_with_loading(self):

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -434,6 +485,9 @@ def test_low_cpu_mem_usage_with_loading(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand All @@ -447,6 +501,9 @@ def test_low_cpu_mem_usage_with_loading(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand Down Expand Up @@ -578,6 +635,9 @@ def test_simple_inference_with_text_lora_save_load(self):

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -590,6 +650,9 @@ def test_simple_inference_with_text_lora_save_load(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand Down Expand Up @@ -665,7 +728,15 @@ def test_simple_inference_with_partial_text_lora(self):

def test_simple_inference_save_pretrained_with_text_lora(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained.

transformers>=5.6 registers a `clip_text_model` conversion that strips the `text_model.`
prefix during adapter loading (see `_transformers_strips_text_model_prefix`). For pipelines
whose text encoders use this conversion (e.g. SDXL's `CLIPTextModelWithProjection`),
`pipe.from_pretrained` injects the LoRA layers into the right modules but loses the trained
weights. Going through `load_lora_weights` afterwards hits the same conversion. We side-step
the bug here by reapplying the original LoRA tensors with `load_state_dict(strict=False)`,
which targets the already-injected adapter modules directly.
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
Expand All @@ -679,12 +750,18 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)

pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe_from_pretrained, captured_lora)

if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
Expand Down Expand Up @@ -719,6 +796,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -730,6 +810,9 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

Expand Down Expand Up @@ -2208,6 +2291,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -2216,6 +2302,9 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
Expand Down Expand Up @@ -2268,6 +2357,9 @@ def test_inference_load_delete_load_adapters(self):

output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]

needs_lora_repair = self._needs_text_encoder_lora_repair()
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
Expand All @@ -2282,6 +2374,10 @@ def test_inference_load_delete_load_adapters(self):

# Then load adapter and compare.
pipe.load_lora_weights(tmpdirname)

if needs_lora_repair:
self._restore_text_encoder_lora_tensors(pipe, captured_lora)

output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))

Expand Down
Loading