Skip to content

Commit 0c2e84e

Browse files
committed
update
1 parent 05a1d3a commit 0c2e84e

2 files changed

Lines changed: 43 additions & 18 deletions

File tree

tests/lora/test_lora_layers_wan.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import sys
1617
import tempfile
1718
import unittest
1819

19-
import numpy as np
20+
import safetensors.torch
2021
import torch
2122
from transformers import AutoTokenizer, T5EncoderModel
2223

@@ -26,12 +27,18 @@
2627
WanPipeline,
2728
WanTransformer3DModel,
2829
)
29-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
30+
from diffusers.utils.testing_utils import (
31+
floats_tensor,
32+
require_peft_backend,
33+
require_peft_version_greater,
34+
skip_mps,
35+
torch_device,
36+
)
3037

3138

3239
sys.path.append(".")
3340

34-
from utils import PeftLoraLoaderMixinTests # noqa: E402
41+
from utils import PeftLoraLoaderMixinTests, check_module_lora_metadata # noqa: E402
3542

3643

3744
@require_peft_backend
@@ -142,35 +149,40 @@ def test_simple_inference_with_text_lora_save_load(self):
142149

143150
# Refer to
144151
# https://github.com/huggingface/diffusers/pull/11806 for more details.
152+
@require_peft_version_greater("0.13.2")
145153
def test_lora_exclude_modules_for_wan(self):
146154
scheduler_cls = self.scheduler_classes[0]
147155
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
148156
pipe = self.pipeline_class(**components).to(torch_device)
149157
_, _, inputs = self.get_dummy_inputs(with_generator=False)
150158

151-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
152-
self.assertTrue(output_no_lora.shape == self.output_shape)
153-
154-
pipe, _ = self.check_if_adapters_added_correctly(
159+
# Only denoiser for now.
160+
denoiser_lora_config.target_modules = ["to_q", "to_k", "to_v", "out"]
161+
denoiser_lora_config.exclude_modules = ["proj_out"]
162+
pipe, _ = self.add_adapters_to_pipeline(
155163
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
156164
)
157-
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
165+
# Inference works.
166+
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
158167

159168
with tempfile.TemporaryDirectory() as tmpdir:
160169
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
161170
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
162171
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
163172
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
164173
pipe.unload_lora_weights()
165-
pipe.load_lora_weights(tmpdir)
166-
167-
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
168174

169-
self.assertTrue(
170-
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
171-
"LoRA should change outputs.",
172-
)
173-
self.assertTrue(
174-
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
175-
"Lora outputs should match.",
175+
# Check the state dict. It should not have any `proj_out` related modules.
176+
state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
177+
# There should not be any `proj_out` modules, but there should still be some modules for `out`.
178+
self.assertTrue(not any("proj_out" in k for k in state_dict))
179+
self.assertTrue("out" in k for k in state_dict)
180+
181+
# Check if the metadata matches.
182+
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
183+
_, parsed_metadata = out
184+
check_module_lora_metadata(
185+
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key="transformer"
176186
)
187+
188+
# Inference matching is already tested in `test_lora_exclude_modules`.

tests/lora/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
import inspect
1617
import os
1718
import re
@@ -2340,12 +2341,14 @@ def test_lora_unload_add_adapter(self):
23402341
)
23412342
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
23422343

2344+
@require_peft_version_greater("0.13.2")
23432345
def test_lora_exclude_modules(self):
23442346
"""
23452347
Test to check if `exclude_modules` works or not. It works in the following way:
23462348
we first create a pipeline and insert LoRA config into it. We then derive a `set`
23472349
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
23482350
state dict.
2351+
23492352
We then create a new LoRA config to include the `exclude_modules` and perform tests.
23502353
"""
23512354
scheduler_cls = self.scheduler_classes[0]
@@ -2356,6 +2359,16 @@ def test_lora_exclude_modules(self):
23562359
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
23572360
self.assertTrue(output_no_lora.shape == self.output_shape)
23582361

2362+
# only supported for `denoiser` now
2363+
pipe_cp = copy.deepcopy(pipe)
2364+
pipe_cp, _ = self.add_adapters_to_pipeline(
2365+
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2366+
)
2367+
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
2368+
pipe_cp.to("cpu")
2369+
del pipe_cp
2370+
2371+
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
23592372
pipe, _ = self.add_adapters_to_pipeline(
23602373
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
23612374
)

0 commit comments

Comments
 (0)