Skip to content

Commit 8f14cde

Browse files
authored
[tests] refactor ltx2 autoencoder tests to use latest mixins (#13739)
* refactor ltx2 autoencoder tests to use latest mixins * fix more. * fix tests * is_flaky
1 parent 40a43dd commit 8f14cde

2 files changed

Lines changed: 87 additions & 70 deletions

File tree

tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,35 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderKLLTX2Audio
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import (
21-
floats_tensor,
22-
torch_device,
23-
)
24-
from ..test_modeling_common import ModelTesterMixin
25-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import is_flaky, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2625

2726

28-
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderKLLTX2Audio
30-
main_input_name = "sample"
31-
base_precision = 1e-2
27+
class AutoencoderKLLTX2AudioTesterConfig(BaseModelTesterConfig):
28+
@property
29+
def main_input_name(self):
30+
return "sample"
31+
32+
@property
33+
def model_class(self):
34+
return AutoencoderKLLTX2Audio
3235

33-
def get_autoencoder_kl_ltx_video_config(self):
36+
@property
37+
def output_shape(self):
38+
return (2, 5, 16)
39+
40+
@property
41+
def generator(self):
42+
return torch.Generator("cpu").manual_seed(0)
43+
44+
def get_init_dict(self):
3445
return {
3546
"in_channels": 2, # stereo,
3647
"output_channels": 2,
@@ -50,39 +61,39 @@ def get_autoencoder_kl_ltx_video_config(self):
5061
"double_z": True,
5162
}
5263

53-
@property
54-
def dummy_input(self):
64+
def get_dummy_inputs(self):
5565
batch_size = 2
5666
num_channels = 2
5767
num_frames = 8
5868
num_mel_bins = 16
69+
spectrogram = randn_tensor(
70+
(batch_size, num_channels, num_frames, num_mel_bins),
71+
generator=self.generator,
72+
device=torch_device,
73+
)
74+
return {"sample": spectrogram}
5975

60-
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
6176

62-
input_dict = {"sample": spectrogram}
63-
return input_dict
77+
class TestAutoencoderKLLTX2Audio(AutoencoderKLLTX2AudioTesterConfig, ModelTesterMixin):
78+
base_precision = 1e-2
6479

65-
@property
66-
def input_shape(self):
67-
return (2, 5, 16)
80+
def test_outputs_equivalence(self):
81+
pytest.skip("Unsupported test.")
6882

69-
@property
70-
def output_shape(self):
71-
return (2, 5, 16)
7283

73-
def prepare_init_args_and_inputs_for_common(self):
74-
init_dict = self.get_autoencoder_kl_ltx_video_config()
75-
inputs_dict = self.dummy_input
76-
return init_dict, inputs_dict
84+
class TestAutoencoderKLLTX2AudioTraining(AutoencoderKLLTX2AudioTesterConfig, TrainingTesterMixin):
85+
"""Training tests for AutoencoderKLLTX2Audio."""
7786

78-
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
79-
def test_output(self):
80-
super().test_output(expected_output_shape=(2, 2, 5, 16))
8187

82-
@unittest.skip("Unsupported test.")
83-
def test_outputs_equivalence(self):
84-
pass
88+
class TestAutoencoderKLLTX2AudioMemory(AutoencoderKLLTX2AudioTesterConfig, MemoryTesterMixin):
89+
"""Memory optimization tests for AutoencoderKLLTX2Audio."""
90+
91+
@is_flaky()
92+
@pytest.mark.parametrize("record_stream", [False, True])
93+
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
94+
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
95+
super().test_group_offloading_with_disk(tmp_path, record_stream, offload_type, atol=atol, rtol=rtol)
96+
8597

86-
@unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
87-
def test_forward_with_norm_groups(self):
88-
pass
98+
class TestAutoencoderKLLTX2AudioSlicingTiling(AutoencoderKLLTX2AudioTesterConfig, NewAutoencoderTesterMixin):
99+
"""Slicing and tiling tests for AutoencoderKLLTX2Audio."""

tests/models/autoencoders/test_models_autoencoder_ltx2_video.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,38 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderKLLTX2Video
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import (
21-
enable_full_determinism,
22-
floats_tensor,
23-
torch_device,
24-
)
25-
from ..test_modeling_common import ModelTesterMixin
26-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2725

2826

2927
enable_full_determinism()
3028

3129

32-
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
33-
model_class = AutoencoderKLLTX2Video
34-
main_input_name = "sample"
35-
base_precision = 1e-2
30+
class AutoencoderKLLTX2VideoTesterConfig(BaseModelTesterConfig):
31+
@property
32+
def main_input_name(self):
33+
return "sample"
34+
35+
@property
36+
def model_class(self):
37+
return AutoencoderKLLTX2Video
3638

37-
def get_autoencoder_kl_ltx_video_config(self):
39+
@property
40+
def output_shape(self):
41+
return (3, 9, 16, 16)
42+
43+
@property
44+
def generator(self):
45+
return torch.Generator("cpu").manual_seed(0)
46+
47+
def get_init_dict(self):
3848
return {
3949
"in_channels": 3,
4050
"out_channels": 3,
@@ -59,30 +69,26 @@ def get_autoencoder_kl_ltx_video_config(self):
5969
"decoder_spatial_padding_mode": "zeros",
6070
}
6171

62-
@property
63-
def dummy_input(self):
72+
def get_dummy_inputs(self):
6473
batch_size = 2
6574
num_frames = 9
6675
num_channels = 3
6776
sizes = (16, 16)
77+
image = randn_tensor(
78+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
79+
)
80+
return {"sample": image}
6881

69-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
7082

71-
input_dict = {"sample": image}
72-
return input_dict
83+
class TestAutoencoderKLLTX2Video(AutoencoderKLLTX2VideoTesterConfig, ModelTesterMixin):
84+
base_precision = 1e-2
7385

74-
@property
75-
def input_shape(self):
76-
return (3, 9, 16, 16)
86+
def test_outputs_equivalence(self):
87+
pytest.skip("Unsupported test.")
7788

78-
@property
79-
def output_shape(self):
80-
return (3, 9, 16, 16)
8189

82-
def prepare_init_args_and_inputs_for_common(self):
83-
init_dict = self.get_autoencoder_kl_ltx_video_config()
84-
inputs_dict = self.dummy_input
85-
return init_dict, inputs_dict
90+
class TestAutoencoderKLLTX2VideoTraining(AutoencoderKLLTX2VideoTesterConfig, TrainingTesterMixin):
91+
"""Training tests for AutoencoderKLLTX2Video."""
8692

8793
def test_gradient_checkpointing_is_applied(self):
8894
expected_set = {
@@ -94,10 +100,10 @@ def test_gradient_checkpointing_is_applied(self):
94100
}
95101
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
96102

97-
@unittest.skip("Unsupported test.")
98-
def test_outputs_equivalence(self):
99-
pass
100103

101-
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
102-
def test_forward_with_norm_groups(self):
103-
pass
104+
class TestAutoencoderKLLTX2VideoMemory(AutoencoderKLLTX2VideoTesterConfig, MemoryTesterMixin):
105+
"""Memory optimization tests for AutoencoderKLLTX2Video."""
106+
107+
108+
class TestAutoencoderKLLTX2VideoSlicingTiling(AutoencoderKLLTX2VideoTesterConfig, NewAutoencoderTesterMixin):
109+
"""Slicing and tiling tests for AutoencoderKLLTX2Video."""

0 commit comments

Comments
 (0)