1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
16+ import pytest
17+ import torch
1718
1819from diffusers import AutoencoderKLLTX2Audio
20+ from diffusers .utils .torch_utils import randn_tensor
1921
20- from ...testing_utils import (
21- floats_tensor ,
22- torch_device ,
23- )
24- from ..test_modeling_common import ModelTesterMixin
25- from .testing_utils import AutoencoderTesterMixin
22+ from ...testing_utils import is_flaky , torch_device
23+ from ..testing_utils import BaseModelTesterConfig , MemoryTesterMixin , ModelTesterMixin , TrainingTesterMixin
24+ from .testing_utils import NewAutoencoderTesterMixin
2625
2726
28- class AutoencoderKLLTX2AudioTests (ModelTesterMixin , AutoencoderTesterMixin , unittest .TestCase ):
29- model_class = AutoencoderKLLTX2Audio
30- main_input_name = "sample"
31- base_precision = 1e-2
27+ class AutoencoderKLLTX2AudioTesterConfig (BaseModelTesterConfig ):
28+ @property
29+ def main_input_name (self ):
30+ return "sample"
31+
32+ @property
33+ def model_class (self ):
34+ return AutoencoderKLLTX2Audio
3235
33- def get_autoencoder_kl_ltx_video_config (self ):
36+ @property
37+ def output_shape (self ):
38+ return (2 , 5 , 16 )
39+
40+ @property
41+ def generator (self ):
42+ return torch .Generator ("cpu" ).manual_seed (0 )
43+
44+ def get_init_dict (self ):
3445 return {
3546 "in_channels" : 2 , # stereo,
3647 "output_channels" : 2 ,
@@ -50,39 +61,39 @@ def get_autoencoder_kl_ltx_video_config(self):
5061 "double_z" : True ,
5162 }
5263
53- @property
54- def dummy_input (self ):
64+ def get_dummy_inputs (self ):
5565 batch_size = 2
5666 num_channels = 2
5767 num_frames = 8
5868 num_mel_bins = 16
69+ spectrogram = randn_tensor (
70+ (batch_size , num_channels , num_frames , num_mel_bins ),
71+ generator = self .generator ,
72+ device = torch_device ,
73+ )
74+ return {"sample" : spectrogram }
5975
60- spectrogram = floats_tensor ((batch_size , num_channels , num_frames , num_mel_bins )).to (torch_device )
6176
62- input_dict = { "sample" : spectrogram }
63- return input_dict
77+ class TestAutoencoderKLLTX2Audio ( AutoencoderKLLTX2AudioTesterConfig , ModelTesterMixin ):
78+ base_precision = 1e-2
6479
65- @property
66- def input_shape (self ):
67- return (2 , 5 , 16 )
80+ def test_outputs_equivalence (self ):
81+ pytest .skip ("Unsupported test." )
6882
69- @property
70- def output_shape (self ):
71- return (2 , 5 , 16 )
7283
73- def prepare_init_args_and_inputs_for_common (self ):
74- init_dict = self .get_autoencoder_kl_ltx_video_config ()
75- inputs_dict = self .dummy_input
76- return init_dict , inputs_dict
84+ class TestAutoencoderKLLTX2AudioTraining (AutoencoderKLLTX2AudioTesterConfig , TrainingTesterMixin ):
85+ """Training tests for AutoencoderKLLTX2Audio."""
7786
78- # Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
79- def test_output (self ):
80- super ().test_output (expected_output_shape = (2 , 2 , 5 , 16 ))
8187
82- @unittest .skip ("Unsupported test." )
83- def test_outputs_equivalence (self ):
84- pass
88+ class TestAutoencoderKLLTX2AudioMemory (AutoencoderKLLTX2AudioTesterConfig , MemoryTesterMixin ):
89+ """Memory optimization tests for AutoencoderKLLTX2Audio."""
90+
91+ @is_flaky ()
92+ @pytest .mark .parametrize ("record_stream" , [False , True ])
93+ @pytest .mark .parametrize ("offload_type" , ["block_level" , "leaf_level" ])
94+ def test_group_offloading_with_disk (self , tmp_path , record_stream , offload_type , atol = 1e-5 , rtol = 0 ):
95+ super ().test_group_offloading_with_disk (tmp_path , record_stream , offload_type , atol = atol , rtol = rtol )
96+
8597
86- @unittest .skip ("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm." )
87- def test_forward_with_norm_groups (self ):
88- pass
98+ class TestAutoencoderKLLTX2AudioSlicingTiling (AutoencoderKLLTX2AudioTesterConfig , NewAutoencoderTesterMixin ):
99+ """Slicing and tiling tests for AutoencoderKLLTX2Audio."""
0 commit comments