Skip to content

Commit 6dbf6e0

Browse files
refactor autoencoder tests (asymmetric_kl, ltx_video) (#13845)
* refactor asymmetric_autoencoder_kl tests * refactor autoencoder_ltx_video tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 07de1f6 commit 6dbf6e0

2 files changed

Lines changed: 127 additions & 105 deletions

File tree

tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616
import gc
1717
import unittest
1818

19+
import pytest
1920
import torch
2021
from parameterized import parameterized
2122

2223
from diffusers import AsymmetricAutoencoderKL
2324
from diffusers.utils.import_utils import is_xformers_available
25+
from diffusers.utils.torch_utils import randn_tensor
2426

2527
from ...testing_utils import (
2628
Expectations,
2729
backend_empty_cache,
2830
enable_full_determinism,
29-
floats_tensor,
3031
load_hf_numpy,
3132
require_torch_accelerator,
3233
require_torch_gpu,
@@ -35,22 +36,33 @@
3536
torch_all_close,
3637
torch_device,
3738
)
38-
from ..test_modeling_common import ModelTesterMixin
39-
from .testing_utils import AutoencoderTesterMixin
39+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
40+
from .testing_utils import NewAutoencoderTesterMixin
4041

4142

4243
enable_full_determinism()
4344

4445

45-
class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
46-
model_class = AsymmetricAutoencoderKL
47-
main_input_name = "sample"
48-
base_precision = 1e-2
46+
class AsymmetricAutoencoderKLTesterConfig(BaseModelTesterConfig):
47+
@property
48+
def model_class(self):
49+
return AsymmetricAutoencoderKL
50+
51+
@property
52+
def main_input_name(self) -> str:
53+
return "sample"
54+
55+
@property
56+
def output_shape(self) -> tuple:
57+
return (3, 32, 32)
4958

50-
def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
51-
block_out_channels = block_out_channels or [2, 4]
52-
norm_num_groups = norm_num_groups or 2
53-
init_dict = {
59+
@property
60+
def generator(self):
61+
return torch.Generator("cpu").manual_seed(0)
62+
63+
def get_init_dict(self) -> dict:
64+
block_out_channels = [2, 4]
65+
return {
5466
"in_channels": 3,
5567
"out_channels": 3,
5668
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
@@ -61,39 +73,38 @@ def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_group
6173
"layers_per_up_block": 1,
6274
"act_fn": "silu",
6375
"latent_channels": 4,
64-
"norm_num_groups": norm_num_groups,
76+
"norm_num_groups": 2,
6577
"sample_size": 32,
6678
"scaling_factor": 0.18215,
6779
}
68-
return init_dict
6980

70-
@property
71-
def dummy_input(self):
81+
def get_dummy_inputs(self) -> dict:
7282
batch_size = 4
7383
num_channels = 3
7484
sizes = (32, 32)
85+
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
86+
mask = torch.ones((batch_size, 1, *sizes)).to(torch_device)
87+
return {"sample": image, "mask": mask}
7588

76-
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
77-
mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
7889

79-
return {"sample": image, "mask": mask}
90+
class TestAsymmetricAutoencoderKL(AsymmetricAutoencoderKLTesterConfig, ModelTesterMixin):
91+
base_precision = 1e-2
8092

81-
@property
82-
def input_shape(self):
83-
return (3, 32, 32)
8493

85-
@property
86-
def output_shape(self):
87-
return (3, 32, 32)
94+
class TestAsymmetricAutoencoderKLTraining(AsymmetricAutoencoderKLTesterConfig, TrainingTesterMixin):
95+
"""Training tests for AsymmetricAutoencoderKL."""
96+
97+
98+
class TestAsymmetricAutoencoderKLMemory(AsymmetricAutoencoderKLTesterConfig, MemoryTesterMixin):
99+
"""Memory optimization tests for AsymmetricAutoencoderKL."""
100+
88101

89-
def prepare_init_args_and_inputs_for_common(self):
90-
init_dict = self.get_asym_autoencoder_kl_config()
91-
inputs_dict = self.dummy_input
92-
return init_dict, inputs_dict
102+
class TestAsymmetricAutoencoderKLSlicingTiling(AsymmetricAutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
103+
"""Slicing and tiling tests for AsymmetricAutoencoderKL."""
93104

94-
@unittest.skip("Unsupported test.")
105+
@pytest.mark.skip("Unsupported test.")
95106
def test_forward_with_norm_groups(self):
96-
pass
107+
super().test_forward_with_norm_groups()
97108

98109

99110
@slow

tests/models/autoencoders/test_models_autoencoder_ltx_video.py

Lines changed: 86 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,47 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
16+
import pytest
1817
import torch
1918

2019
from diffusers import AutoencoderKLLTXVideo
20+
from diffusers.utils.torch_utils import randn_tensor
2121

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
floats_tensor,
25-
torch_device,
26-
)
27-
from ..test_modeling_common import ModelTesterMixin
28-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2925

3026

3127
enable_full_determinism()
3228

3329

34-
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
35-
model_class = AutoencoderKLLTXVideo
36-
main_input_name = "sample"
37-
base_precision = 1e-2
30+
_LTX_VIDEO_GRADIENT_CKPT_EXPECTED = {
31+
"LTXVideoEncoder3d",
32+
"LTXVideoDecoder3d",
33+
"LTXVideoDownBlock3D",
34+
"LTXVideoMidBlock3d",
35+
"LTXVideoUpBlock3d",
36+
}
37+
38+
39+
class AutoencoderKLLTXVideo090TesterConfig(BaseModelTesterConfig):
40+
@property
41+
def model_class(self):
42+
return AutoencoderKLLTXVideo
43+
44+
@property
45+
def main_input_name(self) -> str:
46+
return "sample"
3847

39-
def get_autoencoder_kl_ltx_video_config(self):
48+
@property
49+
def output_shape(self) -> tuple:
50+
return (3, 9, 16, 16)
51+
52+
@property
53+
def generator(self):
54+
return torch.Generator("cpu").manual_seed(0)
55+
56+
def get_init_dict(self) -> dict:
4057
return {
4158
"in_channels": 3,
4259
"out_channels": 3,
@@ -57,55 +74,62 @@ def get_autoencoder_kl_ltx_video_config(self):
5774
"decoder_causal": False,
5875
}
5976

60-
@property
61-
def dummy_input(self):
77+
def get_dummy_inputs(self) -> dict:
6278
batch_size = 2
6379
num_frames = 9
6480
num_channels = 3
6581
sizes = (16, 16)
82+
image = randn_tensor(
83+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
84+
)
85+
return {"sample": image}
6686

67-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
6887

69-
return {"sample": image}
88+
class TestAutoencoderKLLTXVideo090(AutoencoderKLLTXVideo090TesterConfig, ModelTesterMixin):
89+
base_precision = 1e-2
7090

71-
@property
72-
def input_shape(self):
73-
return (3, 9, 16, 16)
91+
@pytest.mark.skip("Unsupported test.")
92+
def test_outputs_equivalence(self):
93+
super().test_outputs_equivalence()
7494

75-
@property
76-
def output_shape(self):
77-
return (3, 9, 16, 16)
7895

79-
def prepare_init_args_and_inputs_for_common(self):
80-
init_dict = self.get_autoencoder_kl_ltx_video_config()
81-
inputs_dict = self.dummy_input
82-
return init_dict, inputs_dict
96+
class TestAutoencoderKLLTXVideo090Training(AutoencoderKLLTXVideo090TesterConfig, TrainingTesterMixin):
97+
"""Training tests for AutoencoderKLLTXVideo (0.9.0 config)."""
8398

8499
def test_gradient_checkpointing_is_applied(self):
85-
expected_set = {
86-
"LTXVideoEncoder3d",
87-
"LTXVideoDecoder3d",
88-
"LTXVideoDownBlock3D",
89-
"LTXVideoMidBlock3d",
90-
"LTXVideoUpBlock3d",
91-
}
92-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
100+
super().test_gradient_checkpointing_is_applied(expected_set=_LTX_VIDEO_GRADIENT_CKPT_EXPECTED)
101+
102+
103+
class TestAutoencoderKLLTXVideo090Memory(AutoencoderKLLTXVideo090TesterConfig, MemoryTesterMixin):
104+
"""Memory optimization tests for AutoencoderKLLTXVideo (0.9.0 config)."""
93105

94-
@unittest.skip("Unsupported test.")
95-
def test_outputs_equivalence(self):
96-
pass
97106

98-
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
107+
class TestAutoencoderKLLTXVideo090SlicingTiling(AutoencoderKLLTXVideo090TesterConfig, NewAutoencoderTesterMixin):
108+
"""Slicing and tiling tests for AutoencoderKLLTXVideo (0.9.0 config)."""
109+
110+
@pytest.mark.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
99111
def test_forward_with_norm_groups(self):
100-
pass
112+
super().test_forward_with_norm_groups()
101113

102114

103-
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
104-
model_class = AutoencoderKLLTXVideo
105-
main_input_name = "sample"
106-
base_precision = 1e-2
115+
class AutoencoderKLLTXVideo091TesterConfig(BaseModelTesterConfig):
116+
@property
117+
def model_class(self):
118+
return AutoencoderKLLTXVideo
119+
120+
@property
121+
def main_input_name(self) -> str:
122+
return "sample"
123+
124+
@property
125+
def output_shape(self) -> tuple:
126+
return (3, 9, 16, 16)
127+
128+
@property
129+
def generator(self):
130+
return torch.Generator("cpu").manual_seed(0)
107131

108-
def get_autoencoder_kl_ltx_video_config(self):
132+
def get_init_dict(self) -> dict:
109133
return {
110134
"in_channels": 3,
111135
"out_channels": 3,
@@ -126,45 +150,32 @@ def get_autoencoder_kl_ltx_video_config(self):
126150
"decoder_causal": False,
127151
}
128152

129-
@property
130-
def dummy_input(self):
153+
def get_dummy_inputs(self) -> dict:
131154
batch_size = 2
132155
num_frames = 9
133156
num_channels = 3
134157
sizes = (16, 16)
135-
136-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
158+
image = randn_tensor(
159+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
160+
)
137161
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
138-
139162
return {"sample": image, "temb": timestep}
140163

141-
@property
142-
def input_shape(self):
143-
return (3, 9, 16, 16)
144164

145-
@property
146-
def output_shape(self):
147-
return (3, 9, 16, 16)
165+
class TestAutoencoderKLLTXVideo091(AutoencoderKLLTXVideo091TesterConfig, ModelTesterMixin):
166+
base_precision = 1e-2
148167

149-
def prepare_init_args_and_inputs_for_common(self):
150-
init_dict = self.get_autoencoder_kl_ltx_video_config()
151-
inputs_dict = self.dummy_input
152-
return init_dict, inputs_dict
168+
@pytest.mark.skip("Unsupported test.")
169+
def test_outputs_equivalence(self):
170+
super().test_outputs_equivalence()
171+
172+
173+
class TestAutoencoderKLLTXVideo091Training(AutoencoderKLLTXVideo091TesterConfig, TrainingTesterMixin):
174+
"""Training tests for AutoencoderKLLTXVideo (0.9.1 config)."""
153175

154176
def test_gradient_checkpointing_is_applied(self):
155-
expected_set = {
156-
"LTXVideoEncoder3d",
157-
"LTXVideoDecoder3d",
158-
"LTXVideoDownBlock3D",
159-
"LTXVideoMidBlock3d",
160-
"LTXVideoUpBlock3d",
161-
}
162-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
177+
super().test_gradient_checkpointing_is_applied(expected_set=_LTX_VIDEO_GRADIENT_CKPT_EXPECTED)
163178

164-
@unittest.skip("Unsupported test.")
165-
def test_outputs_equivalence(self):
166-
pass
167179

168-
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
169-
def test_forward_with_norm_groups(self):
170-
pass
180+
class TestAutoencoderKLLTXVideo091Memory(AutoencoderKLLTXVideo091TesterConfig, MemoryTesterMixin):
181+
"""Memory optimization tests for AutoencoderKLLTXVideo (0.9.1 config)."""

0 commit comments

Comments
 (0)