1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
1816import torch
1917
2018from diffusers import AutoencoderKLCogVideoX
19+ from diffusers .utils .torch_utils import randn_tensor
2120
22- from ...testing_utils import (
23- enable_full_determinism ,
24- floats_tensor ,
25- torch_device ,
26- )
27- from ..test_modeling_common import ModelTesterMixin
28- from .testing_utils import AutoencoderTesterMixin
21+ from ...testing_utils import enable_full_determinism , torch_device
22+ from ..testing_utils import BaseModelTesterConfig , MemoryTesterMixin , ModelTesterMixin , TrainingTesterMixin
23+ from .testing_utils import NewAutoencoderTesterMixin
2924
3025
3126enable_full_determinism ()
3227
3328
34- class AutoencoderKLCogVideoXTests (ModelTesterMixin , AutoencoderTesterMixin , unittest .TestCase ):
35- model_class = AutoencoderKLCogVideoX
36- main_input_name = "sample"
37- base_precision = 1e-2
29+ class AutoencoderKLCogVideoXTesterConfig (BaseModelTesterConfig ):
30+ @property
31+ def model_class (self ):
32+ return AutoencoderKLCogVideoX
33+
34+ @property
35+ def main_input_name (self ) -> str :
36+ return "sample"
3837
39- def get_autoencoder_kl_cogvideox_config (self ):
38+ @property
39+ def output_shape (self ) -> tuple :
40+ return (3 , 8 , 16 , 16 )
41+
42+ @property
43+ def generator (self ):
44+ return torch .Generator ("cpu" ).manual_seed (0 )
45+
46+ def get_init_dict (self ) -> dict :
4047 return {
4148 "in_channels" : 3 ,
4249 "out_channels" : 3 ,
@@ -59,29 +66,23 @@ def get_autoencoder_kl_cogvideox_config(self):
5966 "temporal_compression_ratio" : 4 ,
6067 }
6168
62- @property
63- def dummy_input (self ):
69+ def get_dummy_inputs (self ) -> dict :
6470 batch_size = 4
6571 num_frames = 8
6672 num_channels = 3
6773 sizes = (16 , 16 )
68-
69- image = floats_tensor (( batch_size , num_channels , num_frames ) + sizes ). to ( torch_device )
70-
74+ image = randn_tensor (
75+ ( batch_size , num_channels , num_frames , * sizes ), generator = self . generator , device = torch_device
76+ )
7177 return {"sample" : image }
7278
73- @property
74- def input_shape (self ):
75- return (3 , 8 , 16 , 16 )
7679
77- @ property
78- def output_shape ( self ):
79- return ( 3 , 8 , 16 , 16 )
80+ class TestAutoencoderKLCogVideoX ( AutoencoderKLCogVideoXTesterConfig , ModelTesterMixin ):
81+ pass
82+
8083
81- def prepare_init_args_and_inputs_for_common (self ):
82- init_dict = self .get_autoencoder_kl_cogvideox_config ()
83- inputs_dict = self .dummy_input
84- return init_dict , inputs_dict
84+ class TestAutoencoderKLCogVideoXTraining (AutoencoderKLCogVideoXTesterConfig , TrainingTesterMixin ):
85+ """Training tests for AutoencoderKLCogVideoX."""
8586
8687 def test_gradient_checkpointing_is_applied (self ):
8788 expected_set = {
@@ -93,8 +94,18 @@ def test_gradient_checkpointing_is_applied(self):
9394 }
9495 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
9596
97+
98+ class TestAutoencoderKLCogVideoXMemory (AutoencoderKLCogVideoXTesterConfig , MemoryTesterMixin ):
99+ """Memory optimization tests for AutoencoderKLCogVideoX."""
100+
101+
102+ class TestAutoencoderKLCogVideoXSlicingTiling (AutoencoderKLCogVideoXTesterConfig , NewAutoencoderTesterMixin ):
103+ """Slicing and tiling tests for AutoencoderKLCogVideoX."""
104+
105+ # Overwritten because the base test's block_out_channels doesn't account for the length of down_block_types.
96106 def test_forward_with_norm_groups (self ):
97- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
107+ init_dict = self .get_init_dict ()
108+ inputs_dict = self .get_dummy_inputs ()
98109
99110 init_dict ["norm_num_groups" ] = 16
100111 init_dict ["block_out_channels" ] = (16 , 32 , 32 , 32 )
@@ -109,10 +120,6 @@ def test_forward_with_norm_groups(self):
109120 if isinstance (output , dict ):
110121 output = output .to_tuple ()[0 ]
111122
112- self . assertIsNotNone ( output )
123+ assert output is not None
113124 expected_shape = inputs_dict ["sample" ].shape
114- self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
115-
116- @unittest .skip ("Unsupported test." )
117- def test_outputs_equivalence (self ):
118- pass
125+ assert output .shape == expected_shape , "Input and output shapes do not match"
0 commit comments