Skip to content

Commit ac8da54

Browse files
committed
fix test_set_adapters_match_attention_kwargs
1 parent 9fdda4a commit ac8da54

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

src/diffusers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
115115
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
116116
_import_structure["transformers.transformer_joyimage"] = [
117-
"JoyImageEditTransformer3DModel",
117+
"JoyImageEditTransformer3DModel"
118118
]
119119
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
120120
_import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"]

tests/lora/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,9 @@ def test_set_adapters_match_attention_kwargs(self):
19621962
"Lora + scale should match the output of `set_adapters()`.",
19631963
)
19641964

1965+
needs_lora_repair = self._needs_text_encoder_lora_repair()
1966+
captured_lora = self._capture_text_encoder_lora_tensors(pipe) if needs_lora_repair else {}
1967+
19651968
with tempfile.TemporaryDirectory() as tmpdirname:
19661969
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
19671970
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
@@ -1975,6 +1978,9 @@ def test_set_adapters_match_attention_kwargs(self):
19751978
pipe.set_progress_bar_config(disable=None)
19761979
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
19771980

1981+
if needs_lora_repair:
1982+
self._restore_text_encoder_lora_tensors(pipe, captured_lora)
1983+
19781984
for module_name, module in modules_to_save.items():
19791985
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
19801986

0 commit comments

Comments
 (0)