|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import os |
15 | 16 | import sys |
16 | 17 | import tempfile |
17 | 18 | import unittest |
18 | 19 |
|
19 | | -import numpy as np |
| 20 | +import safetensors.torch |
20 | 21 | import torch |
21 | 22 | from transformers import AutoTokenizer, T5EncoderModel |
22 | 23 |
|
|
26 | 27 | WanPipeline, |
27 | 28 | WanTransformer3DModel, |
28 | 29 | ) |
29 | | -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device |
| 30 | +from diffusers.utils.testing_utils import ( |
| 31 | + floats_tensor, |
| 32 | + require_peft_backend, |
| 33 | + require_peft_version_greater, |
| 34 | + skip_mps, |
| 35 | + torch_device, |
| 36 | +) |
30 | 37 |
|
31 | 38 |
|
32 | 39 | sys.path.append(".") |
33 | 40 |
|
34 | | -from utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 41 | +from utils import PeftLoraLoaderMixinTests, check_module_lora_metadata # noqa: E402 |
35 | 42 |
|
36 | 43 |
|
37 | 44 | @require_peft_backend |
@@ -142,35 +149,40 @@ def test_simple_inference_with_text_lora_save_load(self): |
142 | 149 |
|
143 | 150 | # Refer to |
144 | 151 | # https://github.com/huggingface/diffusers/pull/11806 for more details. |
| 152 | + @require_peft_version_greater("0.13.2") |
145 | 153 | def test_lora_exclude_modules_for_wan(self): |
146 | 154 | scheduler_cls = self.scheduler_classes[0] |
147 | 155 | components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
148 | 156 | pipe = self.pipeline_class(**components).to(torch_device) |
149 | 157 | _, _, inputs = self.get_dummy_inputs(with_generator=False) |
150 | 158 |
|
151 | | - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
152 | | - self.assertTrue(output_no_lora.shape == self.output_shape) |
153 | | - |
154 | | - pipe, _ = self.check_if_adapters_added_correctly( |
| 159 | + # Only denoiser for now. |
| 160 | + denoiser_lora_config.target_modules = ["to_q", "to_k", "to_v", "out"] |
| 161 | + denoiser_lora_config.exclude_modules = ["proj_out"] |
| 162 | + pipe, _ = self.add_adapters_to_pipeline( |
155 | 163 | pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config |
156 | 164 | ) |
157 | | - output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 165 | + # Inference works. |
| 166 | + _ = pipe(**inputs, generator=torch.manual_seed(0))[0] |
158 | 167 |
|
159 | 168 | with tempfile.TemporaryDirectory() as tmpdir: |
160 | 169 | modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
161 | 170 | lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
162 | 171 | lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) |
163 | 172 | self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) |
164 | 173 | pipe.unload_lora_weights() |
165 | | - pipe.load_lora_weights(tmpdir) |
166 | | - |
167 | | - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] |
168 | 174 |
|
169 | | - self.assertTrue( |
170 | | - not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), |
171 | | - "LoRA should change outputs.", |
172 | | - ) |
173 | | - self.assertTrue( |
174 | | - np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), |
175 | | - "Lora outputs should match.", |
| 175 | + # Check the state dict. It should not have any `proj_out` related modules. |
| 176 | + state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) |
| 177 | + # There should not be any `proj_out` modules, but there should still be some modules for `out`. |
| 178 | + self.assertTrue(not any("proj_out" in k for k in state_dict)) |
| 179 | + self.assertTrue("out" in k for k in state_dict) |
| 180 | + |
| 181 | + # Check if the metadata matches. |
| 182 | + out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) |
| 183 | + _, parsed_metadata = out |
| 184 | + check_module_lora_metadata( |
| 185 | + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key="transformer" |
176 | 186 | ) |
| 187 | + |
| 188 | + # Inference matching is already tested in `test_lora_exclude_modules`. |
0 commit comments