Skip to content

Commit ac194af

Browse files
committed
update
1 parent 7f582ca commit ac194af

1 file changed

Lines changed: 38 additions & 1 deletion

File tree

tests/lora/test_lora_layers_wan.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import sys
16+
import tempfile
1617
import unittest
1718

19+
import numpy as np
1820
import torch
1921
from transformers import AutoTokenizer, T5EncoderModel
2022

@@ -24,7 +26,7 @@
2426
WanPipeline,
2527
WanTransformer3DModel,
2628
)
27-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
29+
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
2830

2931

3032
sys.path.append(".")
@@ -137,3 +139,38 @@ def test_simple_inference_with_text_lora_fused(self):
137139
@unittest.skip("Text encoder LoRA is not supported in Wan.")
138140
def test_simple_inference_with_text_lora_save_load(self):
139141
pass
142+
143+
# Refer to
144+
# https://github.com/huggingface/diffusers/pull/11806 for more details.
145+
def test_lora_exclude_modules_for_wan(self):
146+
scheduler_cls = self.scheduler_classes[0]
147+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
148+
pipe = self.pipeline_class(**components).to(torch_device)
149+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
150+
151+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
152+
self.assertTrue(output_no_lora.shape == self.output_shape)
153+
154+
pipe, _ = self.check_if_adapters_added_correctly(
155+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
156+
)
157+
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
158+
159+
with tempfile.TemporaryDirectory() as tmpdir:
160+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
161+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
162+
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
163+
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
164+
pipe.unload_lora_weights()
165+
pipe.load_lora_weights(tmpdir)
166+
167+
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
168+
169+
self.assertTrue(
170+
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
171+
"LoRA should change outputs.",
172+
)
173+
self.assertTrue(
174+
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
175+
"Lora outputs should match.",
176+
)

0 commit comments

Comments
 (0)