Skip to content

Commit 5385b2a

Browse files
refactor autoencoder_magvit tests (#13834)
* refactor autoencoder_magvit tests * remove unused base_precision --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 1d61993 commit 5385b2a

1 file changed

Lines changed: 49 additions & 32 deletions

File tree

tests/models/autoencoders/test_models_autoencoder_magvit.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,38 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderKLMagvit
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
21-
from ..test_modeling_common import ModelTesterMixin
22-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2325

2426

2527
enable_full_determinism()
2628

2729

28-
class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderKLMagvit
30-
main_input_name = "sample"
31-
base_precision = 1e-2
30+
class AutoencoderKLMagvitTesterConfig(BaseModelTesterConfig):
31+
@property
32+
def model_class(self):
33+
return AutoencoderKLMagvit
34+
35+
@property
36+
def main_input_name(self) -> str:
37+
return "sample"
38+
39+
@property
40+
def output_shape(self) -> tuple:
41+
return (3, 9, 16, 16)
42+
43+
@property
44+
def generator(self):
45+
return torch.Generator("cpu").manual_seed(0)
3246

33-
def get_autoencoder_kl_magvit_config(self):
47+
def get_init_dict(self) -> dict:
3448
return {
3549
"in_channels": 3,
3650
"latent_channels": 4,
@@ -53,45 +67,48 @@ def get_autoencoder_kl_magvit_config(self):
5367
"spatial_group_norm": True,
5468
}
5569

56-
@property
57-
def dummy_input(self):
70+
def get_dummy_inputs(self) -> dict:
5871
batch_size = 2
5972
num_frames = 9
6073
num_channels = 3
6174
height = 16
6275
width = 16
63-
64-
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
65-
76+
image = randn_tensor(
77+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
78+
)
6679
return {"sample": image}
6780

68-
@property
69-
def input_shape(self):
70-
return (3, 9, 16, 16)
7181

72-
@property
73-
def output_shape(self):
74-
return (3, 9, 16, 16)
82+
class TestAutoencoderKLMagvit(AutoencoderKLMagvitTesterConfig, ModelTesterMixin):
83+
pass
7584

76-
def prepare_init_args_and_inputs_for_common(self):
77-
init_dict = self.get_autoencoder_kl_magvit_config()
78-
inputs_dict = self.dummy_input
79-
return init_dict, inputs_dict
85+
86+
class TestAutoencoderKLMagvitTraining(AutoencoderKLMagvitTesterConfig, TrainingTesterMixin):
87+
"""Training tests for AutoencoderKLMagvit."""
8088

8189
def test_gradient_checkpointing_is_applied(self):
8290
expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"}
8391
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
8492

85-
@unittest.skip("Not quite sure why this test fails. Revisit later.")
86-
def test_effective_gradient_checkpointing(self):
87-
pass
93+
@pytest.mark.skip("Not quite sure why this test fails. Revisit later.")
94+
def test_gradient_checkpointing_equivalence(self):
95+
super().test_gradient_checkpointing_equivalence()
96+
97+
98+
class TestAutoencoderKLMagvitMemory(AutoencoderKLMagvitTesterConfig, MemoryTesterMixin):
99+
"""Memory optimization tests for AutoencoderKLMagvit."""
100+
101+
102+
class TestAutoencoderKLMagvitSlicingTiling(AutoencoderKLMagvitTesterConfig, NewAutoencoderTesterMixin):
103+
"""Slicing and tiling tests for AutoencoderKLMagvit."""
88104

89-
@unittest.skip("Unsupported test.")
105+
@pytest.mark.skip("Unsupported test.")
90106
def test_forward_with_norm_groups(self):
91-
pass
107+
super().test_forward_with_norm_groups()
92108

93-
@unittest.skip(
94-
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
109+
@pytest.mark.skip(
110+
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. "
111+
"Expected size 9 but got size 12 for tensor number 1 in the list."
95112
)
96113
def test_enable_disable_slicing(self):
97-
pass
114+
super().test_enable_disable_slicing()

0 commit comments

Comments
 (0)