1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
16+ import pytest
17+ import torch
1718
1819from diffusers import AutoencoderKLMagvit
20+ from diffusers .utils .torch_utils import randn_tensor
1921
20- from ...testing_utils import enable_full_determinism , floats_tensor , torch_device
21- from ..test_modeling_common import ModelTesterMixin
22- from .testing_utils import AutoencoderTesterMixin
22+ from ...testing_utils import enable_full_determinism , torch_device
23+ from ..testing_utils import BaseModelTesterConfig , MemoryTesterMixin , ModelTesterMixin , TrainingTesterMixin
24+ from .testing_utils import NewAutoencoderTesterMixin
2325
2426
2527enable_full_determinism ()
2628
2729
28- class AutoencoderKLMagvitTests (ModelTesterMixin , AutoencoderTesterMixin , unittest .TestCase ):
29- model_class = AutoencoderKLMagvit
30- main_input_name = "sample"
31- base_precision = 1e-2
30+ class AutoencoderKLMagvitTesterConfig (BaseModelTesterConfig ):
31+ @property
32+ def model_class (self ):
33+ return AutoencoderKLMagvit
34+
35+ @property
36+ def main_input_name (self ) -> str :
37+ return "sample"
38+
39+ @property
40+ def output_shape (self ) -> tuple :
41+ return (3 , 9 , 16 , 16 )
42+
43+ @property
44+ def generator (self ):
45+ return torch .Generator ("cpu" ).manual_seed (0 )
3246
33- def get_autoencoder_kl_magvit_config (self ):
47+ def get_init_dict (self ) -> dict :
3448 return {
3549 "in_channels" : 3 ,
3650 "latent_channels" : 4 ,
@@ -53,45 +67,48 @@ def get_autoencoder_kl_magvit_config(self):
5367 "spatial_group_norm" : True ,
5468 }
5569
56- @property
57- def dummy_input (self ):
70+ def get_dummy_inputs (self ) -> dict :
5871 batch_size = 2
5972 num_frames = 9
6073 num_channels = 3
6174 height = 16
6275 width = 16
63-
64- image = floats_tensor (( batch_size , num_channels , num_frames , height , width )). to ( torch_device )
65-
76+ image = randn_tensor (
77+ ( batch_size , num_channels , num_frames , height , width ), generator = self . generator , device = torch_device
78+ )
6679 return {"sample" : image }
6780
68- @property
69- def input_shape (self ):
70- return (3 , 9 , 16 , 16 )
7181
72- @property
73- def output_shape (self ):
74- return (3 , 9 , 16 , 16 )
82+ class TestAutoencoderKLMagvit (AutoencoderKLMagvitTesterConfig , ModelTesterMixin ):
83+ pass
7584
76- def prepare_init_args_and_inputs_for_common (self ):
77- init_dict = self .get_autoencoder_kl_magvit_config ()
78- inputs_dict = self .dummy_input
79- return init_dict , inputs_dict
85+
86+ class TestAutoencoderKLMagvitTraining (AutoencoderKLMagvitTesterConfig , TrainingTesterMixin ):
87+ """Training tests for AutoencoderKLMagvit."""
8088
8189 def test_gradient_checkpointing_is_applied (self ):
8290 expected_set = {"EasyAnimateEncoder" , "EasyAnimateDecoder" }
8391 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
8492
85- @unittest .skip ("Not quite sure why this test fails. Revisit later." )
86- def test_effective_gradient_checkpointing (self ):
87- pass
93+ @pytest .mark .skip ("Not quite sure why this test fails. Revisit later." )
94+ def test_gradient_checkpointing_equivalence (self ):
95+ super ().test_gradient_checkpointing_equivalence ()
96+
97+
98+ class TestAutoencoderKLMagvitMemory (AutoencoderKLMagvitTesterConfig , MemoryTesterMixin ):
99+ """Memory optimization tests for AutoencoderKLMagvit."""
100+
101+
102+ class TestAutoencoderKLMagvitSlicingTiling (AutoencoderKLMagvitTesterConfig , NewAutoencoderTesterMixin ):
103+ """Slicing and tiling tests for AutoencoderKLMagvit."""
88104
89- @unittest .skip ("Unsupported test." )
105+ @pytest . mark .skip ("Unsupported test." )
90106 def test_forward_with_norm_groups (self ):
91- pass
107+ super (). test_forward_with_norm_groups ()
92108
93- @unittest .skip (
94- "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
109+ @pytest .mark .skip (
110+ "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. "
111+ "Expected size 9 but got size 12 for tensor number 1 in the list."
95112 )
96113 def test_enable_disable_slicing (self ):
97- pass
114+ super (). test_enable_disable_slicing ()
0 commit comments