Skip to content

Commit 334ef1a

Browse files
refactor autoencoder_kl_cogvideox tests (#13840)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 2875896 commit 334ef1a

1 file changed

Lines changed: 43 additions & 36 deletions

File tree

tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,37 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import AutoencoderKLCogVideoX
19+
from diffusers.utils.torch_utils import randn_tensor
2120

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
floats_tensor,
25-
torch_device,
26-
)
27-
from ..test_modeling_common import ModelTesterMixin
28-
from .testing_utils import AutoencoderTesterMixin
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
23+
from .testing_utils import NewAutoencoderTesterMixin
2924

3025

3126
enable_full_determinism()
3227

3328

34-
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
35-
model_class = AutoencoderKLCogVideoX
36-
main_input_name = "sample"
37-
base_precision = 1e-2
29+
class AutoencoderKLCogVideoXTesterConfig(BaseModelTesterConfig):
30+
@property
31+
def model_class(self):
32+
return AutoencoderKLCogVideoX
33+
34+
@property
35+
def main_input_name(self) -> str:
36+
return "sample"
3837

39-
def get_autoencoder_kl_cogvideox_config(self):
38+
@property
39+
def output_shape(self) -> tuple:
40+
return (3, 8, 16, 16)
41+
42+
@property
43+
def generator(self):
44+
return torch.Generator("cpu").manual_seed(0)
45+
46+
def get_init_dict(self) -> dict:
4047
return {
4148
"in_channels": 3,
4249
"out_channels": 3,
@@ -59,29 +66,23 @@ def get_autoencoder_kl_cogvideox_config(self):
5966
"temporal_compression_ratio": 4,
6067
}
6168

62-
@property
63-
def dummy_input(self):
69+
def get_dummy_inputs(self) -> dict:
6470
batch_size = 4
6571
num_frames = 8
6672
num_channels = 3
6773
sizes = (16, 16)
68-
69-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
70-
74+
image = randn_tensor(
75+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
76+
)
7177
return {"sample": image}
7278

73-
@property
74-
def input_shape(self):
75-
return (3, 8, 16, 16)
7679

77-
@property
78-
def output_shape(self):
79-
return (3, 8, 16, 16)
80+
class TestAutoencoderKLCogVideoX(AutoencoderKLCogVideoXTesterConfig, ModelTesterMixin):
81+
pass
82+
8083

81-
def prepare_init_args_and_inputs_for_common(self):
82-
init_dict = self.get_autoencoder_kl_cogvideox_config()
83-
inputs_dict = self.dummy_input
84-
return init_dict, inputs_dict
84+
class TestAutoencoderKLCogVideoXTraining(AutoencoderKLCogVideoXTesterConfig, TrainingTesterMixin):
85+
"""Training tests for AutoencoderKLCogVideoX."""
8586

8687
def test_gradient_checkpointing_is_applied(self):
8788
expected_set = {
@@ -93,8 +94,18 @@ def test_gradient_checkpointing_is_applied(self):
9394
}
9495
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9596

97+
98+
class TestAutoencoderKLCogVideoXMemory(AutoencoderKLCogVideoXTesterConfig, MemoryTesterMixin):
99+
"""Memory optimization tests for AutoencoderKLCogVideoX."""
100+
101+
102+
class TestAutoencoderKLCogVideoXSlicingTiling(AutoencoderKLCogVideoXTesterConfig, NewAutoencoderTesterMixin):
103+
"""Slicing and tiling tests for AutoencoderKLCogVideoX."""
104+
105+
# Overwritten because the base test's block_out_channels doesn't account for the length of down_block_types.
96106
def test_forward_with_norm_groups(self):
97-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
107+
init_dict = self.get_init_dict()
108+
inputs_dict = self.get_dummy_inputs()
98109

99110
init_dict["norm_num_groups"] = 16
100111
init_dict["block_out_channels"] = (16, 32, 32, 32)
@@ -109,10 +120,6 @@ def test_forward_with_norm_groups(self):
109120
if isinstance(output, dict):
110121
output = output.to_tuple()[0]
111122

112-
self.assertIsNotNone(output)
123+
assert output is not None
113124
expected_shape = inputs_dict["sample"].shape
114-
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
115-
116-
@unittest.skip("Unsupported test.")
117-
def test_outputs_equivalence(self):
118-
pass
125+
assert output.shape == expected_shape, "Input and output shapes do not match"

0 commit comments

Comments
 (0)