|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import sys |
| 16 | +import tempfile |
16 | 17 | import unittest |
17 | 18 |
|
| 19 | +import numpy as np |
18 | 20 | import torch |
19 | 21 | from transformers import AutoTokenizer, T5EncoderModel |
20 | 22 |
|
|
24 | 26 | WanPipeline, |
25 | 27 | WanTransformer3DModel, |
26 | 28 | ) |
27 | | -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps |
| 29 | +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device |
28 | 30 |
|
29 | 31 |
|
30 | 32 | sys.path.append(".") |
@@ -137,3 +139,38 @@ def test_simple_inference_with_text_lora_fused(self): |
137 | 139 | @unittest.skip("Text encoder LoRA is not supported in Wan.") |
138 | 140 | def test_simple_inference_with_text_lora_save_load(self): |
139 | 141 | pass |
| 142 | + |
| 143 | + # Refer to |
| 144 | + # https://github.com/huggingface/diffusers/pull/11806 for more details. |
| 145 | + def test_lora_exclude_modules_for_wan(self): |
| 146 | + scheduler_cls = self.scheduler_classes[0] |
| 147 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
| 148 | + pipe = self.pipeline_class(**components).to(torch_device) |
| 149 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 150 | + |
| 151 | + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 152 | + self.assertTrue(output_no_lora.shape == self.output_shape) |
| 153 | + |
| 154 | + pipe, _ = self.check_if_adapters_added_correctly( |
| 155 | + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config |
| 156 | + ) |
| 157 | + output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 158 | + |
| 159 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 160 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 161 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 162 | + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) |
| 163 | + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) |
| 164 | + pipe.unload_lora_weights() |
| 165 | + pipe.load_lora_weights(tmpdir) |
| 166 | + |
| 167 | + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 168 | + |
| 169 | + self.assertTrue( |
| 170 | + not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), |
| 171 | + "LoRA should change outputs.", |
| 172 | + ) |
| 173 | + self.assertTrue( |
| 174 | + np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), |
| 175 | + "Lora outputs should match.", |
| 176 | + ) |
0 commit comments