1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
16+ import pytest
1817import torch
1918
2019from diffusers import AutoencoderKLLTXVideo
20+ from diffusers .utils .torch_utils import randn_tensor
2121
22- from ...testing_utils import (
23- enable_full_determinism ,
24- floats_tensor ,
25- torch_device ,
26- )
27- from ..test_modeling_common import ModelTesterMixin
28- from .testing_utils import AutoencoderTesterMixin
22+ from ...testing_utils import enable_full_determinism , torch_device
23+ from ..testing_utils import BaseModelTesterConfig , MemoryTesterMixin , ModelTesterMixin , TrainingTesterMixin
24+ from .testing_utils import NewAutoencoderTesterMixin
2925
3026
3127enable_full_determinism ()
3228
3329
34- class AutoencoderKLLTXVideo090Tests (ModelTesterMixin , AutoencoderTesterMixin , unittest .TestCase ):
35- model_class = AutoencoderKLLTXVideo
36- main_input_name = "sample"
37- base_precision = 1e-2
30+ _LTX_VIDEO_GRADIENT_CKPT_EXPECTED = {
31+ "LTXVideoEncoder3d" ,
32+ "LTXVideoDecoder3d" ,
33+ "LTXVideoDownBlock3D" ,
34+ "LTXVideoMidBlock3d" ,
35+ "LTXVideoUpBlock3d" ,
36+ }
37+
38+
39+ class AutoencoderKLLTXVideo090TesterConfig (BaseModelTesterConfig ):
40+ @property
41+ def model_class (self ):
42+ return AutoencoderKLLTXVideo
43+
44+ @property
45+ def main_input_name (self ) -> str :
46+ return "sample"
3847
39- def get_autoencoder_kl_ltx_video_config (self ):
48+ @property
49+ def output_shape (self ) -> tuple :
50+ return (3 , 9 , 16 , 16 )
51+
52+ @property
53+ def generator (self ):
54+ return torch .Generator ("cpu" ).manual_seed (0 )
55+
56+ def get_init_dict (self ) -> dict :
4057 return {
4158 "in_channels" : 3 ,
4259 "out_channels" : 3 ,
@@ -57,55 +74,62 @@ def get_autoencoder_kl_ltx_video_config(self):
5774 "decoder_causal" : False ,
5875 }
5976
60- @property
61- def dummy_input (self ):
77+ def get_dummy_inputs (self ) -> dict :
6278 batch_size = 2
6379 num_frames = 9
6480 num_channels = 3
6581 sizes = (16 , 16 )
82+ image = randn_tensor (
83+ (batch_size , num_channels , num_frames , * sizes ), generator = self .generator , device = torch_device
84+ )
85+ return {"sample" : image }
6686
67- image = floats_tensor ((batch_size , num_channels , num_frames ) + sizes ).to (torch_device )
6887
69- return {"sample" : image }
88+ class TestAutoencoderKLLTXVideo090 (AutoencoderKLLTXVideo090TesterConfig , ModelTesterMixin ):
89+ base_precision = 1e-2
7090
71- @property
72- def input_shape (self ):
73- return ( 3 , 9 , 16 , 16 )
91+ @pytest . mark . skip ( "Unsupported test." )
92+ def test_outputs_equivalence (self ):
93+ super (). test_outputs_equivalence ( )
7494
75- @property
76- def output_shape (self ):
77- return (3 , 9 , 16 , 16 )
7895
79- def prepare_init_args_and_inputs_for_common (self ):
80- init_dict = self .get_autoencoder_kl_ltx_video_config ()
81- inputs_dict = self .dummy_input
82- return init_dict , inputs_dict
96+ class TestAutoencoderKLLTXVideo090Training (AutoencoderKLLTXVideo090TesterConfig , TrainingTesterMixin ):
97+ """Training tests for AutoencoderKLLTXVideo (0.9.0 config)."""
8398
8499 def test_gradient_checkpointing_is_applied (self ):
85- expected_set = {
86- "LTXVideoEncoder3d" ,
87- "LTXVideoDecoder3d" ,
88- "LTXVideoDownBlock3D" ,
89- "LTXVideoMidBlock3d" ,
90- "LTXVideoUpBlock3d" ,
91- }
92- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
100+ super ().test_gradient_checkpointing_is_applied (expected_set = _LTX_VIDEO_GRADIENT_CKPT_EXPECTED )
101+
102+
103+ class TestAutoencoderKLLTXVideo090Memory (AutoencoderKLLTXVideo090TesterConfig , MemoryTesterMixin ):
104+ """Memory optimization tests for AutoencoderKLLTXVideo (0.9.0 config)."""
93105
94- @unittest .skip ("Unsupported test." )
95- def test_outputs_equivalence (self ):
96- pass
97106
98- @unittest .skip ("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm." )
107+ class TestAutoencoderKLLTXVideo090SlicingTiling (AutoencoderKLLTXVideo090TesterConfig , NewAutoencoderTesterMixin ):
108+ """Slicing and tiling tests for AutoencoderKLLTXVideo (0.9.0 config)."""
109+
110+ @pytest .mark .skip ("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm." )
99111 def test_forward_with_norm_groups (self ):
100- pass
112+ super (). test_forward_with_norm_groups ()
101113
102114
103- class AutoencoderKLLTXVideo091Tests (ModelTesterMixin , unittest .TestCase ):
104- model_class = AutoencoderKLLTXVideo
105- main_input_name = "sample"
106- base_precision = 1e-2
115+ class AutoencoderKLLTXVideo091TesterConfig (BaseModelTesterConfig ):
116+ @property
117+ def model_class (self ):
118+ return AutoencoderKLLTXVideo
119+
120+ @property
121+ def main_input_name (self ) -> str :
122+ return "sample"
123+
124+ @property
125+ def output_shape (self ) -> tuple :
126+ return (3 , 9 , 16 , 16 )
127+
128+ @property
129+ def generator (self ):
130+ return torch .Generator ("cpu" ).manual_seed (0 )
107131
108- def get_autoencoder_kl_ltx_video_config (self ):
132+ def get_init_dict (self ) -> dict :
109133 return {
110134 "in_channels" : 3 ,
111135 "out_channels" : 3 ,
@@ -126,45 +150,32 @@ def get_autoencoder_kl_ltx_video_config(self):
126150 "decoder_causal" : False ,
127151 }
128152
129- @property
130- def dummy_input (self ):
153+ def get_dummy_inputs (self ) -> dict :
131154 batch_size = 2
132155 num_frames = 9
133156 num_channels = 3
134157 sizes = (16 , 16 )
135-
136- image = floats_tensor ((batch_size , num_channels , num_frames ) + sizes ).to (torch_device )
158+ image = randn_tensor (
159+ (batch_size , num_channels , num_frames , * sizes ), generator = self .generator , device = torch_device
160+ )
137161 timestep = torch .tensor ([0.05 ] * batch_size , device = torch_device )
138-
139162 return {"sample" : image , "temb" : timestep }
140163
141- @property
142- def input_shape (self ):
143- return (3 , 9 , 16 , 16 )
144164
145- @property
146- def output_shape (self ):
147- return (3 , 9 , 16 , 16 )
165+ class TestAutoencoderKLLTXVideo091 (AutoencoderKLLTXVideo091TesterConfig , ModelTesterMixin ):
166+ base_precision = 1e-2
148167
149- def prepare_init_args_and_inputs_for_common (self ):
150- init_dict = self .get_autoencoder_kl_ltx_video_config ()
151- inputs_dict = self .dummy_input
152- return init_dict , inputs_dict
168+ @pytest .mark .skip ("Unsupported test." )
169+ def test_outputs_equivalence (self ):
170+ super ().test_outputs_equivalence ()
171+
172+
173+ class TestAutoencoderKLLTXVideo091Training (AutoencoderKLLTXVideo091TesterConfig , TrainingTesterMixin ):
174+ """Training tests for AutoencoderKLLTXVideo (0.9.1 config)."""
153175
154176 def test_gradient_checkpointing_is_applied (self ):
155- expected_set = {
156- "LTXVideoEncoder3d" ,
157- "LTXVideoDecoder3d" ,
158- "LTXVideoDownBlock3D" ,
159- "LTXVideoMidBlock3d" ,
160- "LTXVideoUpBlock3d" ,
161- }
162- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
177+ super ().test_gradient_checkpointing_is_applied (expected_set = _LTX_VIDEO_GRADIENT_CKPT_EXPECTED )
163178
164- @unittest .skip ("Unsupported test." )
165- def test_outputs_equivalence (self ):
166- pass
167179
168- @unittest .skip ("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm." )
169- def test_forward_with_norm_groups (self ):
170- pass
180+ class TestAutoencoderKLLTXVideo091Memory (AutoencoderKLLTXVideo091TesterConfig , MemoryTesterMixin ):
181+ """Memory optimization tests for AutoencoderKLLTXVideo (0.9.1 config)."""
0 commit comments