Skip to content

Commit 5adc544

Browse files
authored
[tests] refactor wan autoencoder tests (#13371)
* refactor wan autoencoder tests * up * address dhruv's feedback.
1 parent a05c8e9 commit 5adc544

File tree

2 files changed

+173
-42
lines changed

2 files changed

+173
-42
lines changed

tests/models/autoencoders/test_models_autoencoder_wan.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,34 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderKLWan
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
21-
from ..test_modeling_common import ModelTesterMixin
22-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2325

2426

2527
enable_full_determinism()
2628

2729

28-
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderKLWan
30-
main_input_name = "sample"
31-
base_precision = 1e-2
30+
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
31+
@property
32+
def model_class(self):
33+
return AutoencoderKLWan
3234

33-
def get_autoencoder_kl_wan_config(self):
35+
@property
36+
def output_shape(self):
37+
return (3, 9, 16, 16)
38+
39+
@property
40+
def generator(self):
41+
return torch.Generator("cpu").manual_seed(0)
42+
43+
def get_init_dict(self):
3444
return {
3545
"base_dim": 3,
3646
"z_dim": 16,
@@ -39,54 +49,40 @@ def get_autoencoder_kl_wan_config(self):
3949
"temperal_downsample": [False, True, True],
4050
}
4151

42-
@property
43-
def dummy_input(self):
52+
def get_dummy_inputs(self):
4453
batch_size = 2
4554
num_frames = 9
4655
num_channels = 3
4756
sizes = (16, 16)
48-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
57+
image = randn_tensor(
58+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
59+
)
4960
return {"sample": image}
5061

51-
@property
52-
def dummy_input_tiling(self):
53-
batch_size = 2
54-
num_frames = 9
55-
num_channels = 3
56-
sizes = (128, 128)
57-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
58-
return {"sample": image}
59-
60-
@property
61-
def input_shape(self):
62-
return (3, 9, 16, 16)
6362

64-
@property
65-
def output_shape(self):
66-
return (3, 9, 16, 16)
63+
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
64+
base_precision = 1e-2
6765

68-
def prepare_init_args_and_inputs_for_common(self):
69-
init_dict = self.get_autoencoder_kl_wan_config()
70-
inputs_dict = self.dummy_input
71-
return init_dict, inputs_dict
7266

73-
def prepare_init_args_and_inputs_for_tiling(self):
74-
init_dict = self.get_autoencoder_kl_wan_config()
75-
inputs_dict = self.dummy_input_tiling
76-
return init_dict, inputs_dict
67+
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
68+
"""Training tests for AutoencoderKLWan."""
7769

78-
@unittest.skip("Gradient checkpointing has not been implemented yet")
70+
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
7971
def test_gradient_checkpointing_is_applied(self):
8072
pass
8173

82-
@unittest.skip("Test not supported")
83-
def test_forward_with_norm_groups(self):
84-
pass
8574

86-
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
87-
def test_layerwise_casting_inference(self):
75+
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
76+
"""Memory optimization tests for AutoencoderKLWan."""
77+
78+
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
79+
def test_layerwise_casting_memory(self):
8880
pass
8981

90-
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
82+
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
9183
def test_layerwise_casting_training(self):
9284
pass
85+
86+
87+
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, NewAutoencoderTesterMixin):
88+
"""Slicing and tiling tests for AutoencoderKLWan."""

tests/models/autoencoders/testing_utils.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,138 @@ def test_enable_disable_slicing(self):
145145
output_without_slicing.detach().cpu().numpy().all(),
146146
output_without_slicing_2.detach().cpu().numpy().all(),
147147
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
148+
149+
150+
class NewAutoencoderTesterMixin:
151+
@staticmethod
152+
def _accepts_generator(model):
153+
model_sig = inspect.signature(model.forward)
154+
accepts_generator = "generator" in model_sig.parameters
155+
return accepts_generator
156+
157+
@staticmethod
158+
def _accepts_norm_num_groups(model_class):
159+
model_sig = inspect.signature(model_class.__init__)
160+
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
161+
return accepts_norm_groups
162+
163+
def test_forward_with_norm_groups(self):
164+
if not self._accepts_norm_num_groups(self.model_class):
165+
pytest.skip(f"Test not supported for {self.model_class.__name__}")
166+
init_dict = self.get_init_dict()
167+
inputs_dict = self.get_dummy_inputs()
168+
169+
init_dict["norm_num_groups"] = 16
170+
init_dict["block_out_channels"] = (16, 32)
171+
172+
model = self.model_class(**init_dict)
173+
model.to(torch_device)
174+
model.eval()
175+
176+
with torch.no_grad():
177+
output = model(**inputs_dict)
178+
179+
if isinstance(output, dict):
180+
output = output.to_tuple()[0]
181+
182+
assert output is not None
183+
expected_shape = inputs_dict["sample"].shape
184+
assert output.shape == expected_shape, "Input and output shapes do not match"
185+
186+
def test_enable_disable_tiling(self):
187+
if not hasattr(self.model_class, "enable_tiling"):
188+
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
189+
190+
init_dict = self.get_init_dict()
191+
inputs_dict = self.get_dummy_inputs()
192+
193+
torch.manual_seed(0)
194+
model = self.model_class(**init_dict).to(torch_device)
195+
196+
if not hasattr(model, "use_tiling"):
197+
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
198+
199+
inputs_dict.update({"return_dict": False})
200+
_ = inputs_dict.pop("generator", None)
201+
accepts_generator = self._accepts_generator(model)
202+
203+
with torch.no_grad():
204+
torch.manual_seed(0)
205+
if accepts_generator:
206+
inputs_dict["generator"] = torch.manual_seed(0)
207+
output_without_tiling = model(**inputs_dict)[0]
208+
if isinstance(output_without_tiling, DecoderOutput):
209+
output_without_tiling = output_without_tiling.sample
210+
211+
torch.manual_seed(0)
212+
model.enable_tiling()
213+
if accepts_generator:
214+
inputs_dict["generator"] = torch.manual_seed(0)
215+
output_with_tiling = model(**inputs_dict)[0]
216+
if isinstance(output_with_tiling, DecoderOutput):
217+
output_with_tiling = output_with_tiling.sample
218+
219+
assert (output_without_tiling.cpu() - output_with_tiling.cpu()).max() < 0.5, (
220+
"VAE tiling should not affect the inference results"
221+
)
222+
223+
torch.manual_seed(0)
224+
model.disable_tiling()
225+
if accepts_generator:
226+
inputs_dict["generator"] = torch.manual_seed(0)
227+
output_without_tiling_2 = model(**inputs_dict)[0]
228+
if isinstance(output_without_tiling_2, DecoderOutput):
229+
output_without_tiling_2 = output_without_tiling_2.sample
230+
231+
assert torch.allclose(output_without_tiling.cpu(), output_without_tiling_2.cpu()), (
232+
"Without tiling outputs should match with the outputs when tiling is manually disabled."
233+
)
234+
235+
def test_enable_disable_slicing(self):
236+
if not hasattr(self.model_class, "enable_slicing"):
237+
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
238+
239+
init_dict = self.get_init_dict()
240+
inputs_dict = self.get_dummy_inputs()
241+
242+
torch.manual_seed(0)
243+
model = self.model_class(**init_dict).to(torch_device)
244+
if not hasattr(model, "use_slicing"):
245+
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
246+
247+
inputs_dict.update({"return_dict": False})
248+
_ = inputs_dict.pop("generator", None)
249+
accepts_generator = self._accepts_generator(model)
250+
251+
with torch.no_grad():
252+
if accepts_generator:
253+
inputs_dict["generator"] = torch.manual_seed(0)
254+
255+
torch.manual_seed(0)
256+
output_without_slicing = model(**inputs_dict)[0]
257+
if isinstance(output_without_slicing, DecoderOutput):
258+
output_without_slicing = output_without_slicing.sample
259+
260+
torch.manual_seed(0)
261+
model.enable_slicing()
262+
if accepts_generator:
263+
inputs_dict["generator"] = torch.manual_seed(0)
264+
output_with_slicing = model(**inputs_dict)[0]
265+
if isinstance(output_with_slicing, DecoderOutput):
266+
output_with_slicing = output_with_slicing.sample
267+
268+
assert (output_without_slicing.cpu() - output_with_slicing.cpu()).max() < 0.5, (
269+
"VAE slicing should not affect the inference results"
270+
)
271+
272+
torch.manual_seed(0)
273+
model.disable_slicing()
274+
if accepts_generator:
275+
inputs_dict["generator"] = torch.manual_seed(0)
276+
output_without_slicing_2 = model(**inputs_dict)[0]
277+
if isinstance(output_without_slicing_2, DecoderOutput):
278+
output_without_slicing_2 = output_without_slicing_2.sample
279+
280+
assert torch.allclose(output_without_slicing.cpu(), output_without_slicing_2.cpu()), (
281+
"Without slicing outputs should match with the outputs when slicing is manually disabled."
282+
)

0 commit comments

Comments
 (0)