Skip to content

Commit 685bdd6

Browse files
refactor autoencoder tests (temporal decoder, cosmos, kvae, mochi) (#13832)
* refactor autoencoder_kl_temporal_decoder tests * refactor autoencoder_kl_cosmos tests * refactor autoencoder_kl_kvae tests * fix return_dict propagation in AutoencoderKLMochi.forward * refactor autoencoder_kl_mochi tests * add docstrings * fix return type annotation * remove unused base_precision and test_outputs_equivalence skip --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 2a57cc4 commit 685bdd6

5 files changed

Lines changed: 167 additions & 137 deletions

File tree

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ def forward(
10921092
sample_posterior: bool = False,
10931093
return_dict: bool = True,
10941094
generator: torch.Generator | None = None,
1095-
) -> torch.Tensor | torch.Tensor:
1095+
) -> DecoderOutput | torch.Tensor:
10961096
r"""
10971097
Args:
10981098
sample (`torch.Tensor`): Input sample.
@@ -1115,7 +1115,5 @@ def forward(
11151115
z = posterior.sample(generator=generator)
11161116
else:
11171117
z = posterior.mode()
1118-
dec = self.decode(z)
1119-
if not return_dict:
1120-
return (dec,)
1118+
dec = self.decode(z, return_dict=return_dict)
11211119
return dec

tests/models/autoencoders/test_models_autoencoder_cosmos.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,38 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
15+
import pytest
16+
import torch
1617

1718
from diffusers import AutoencoderKLCosmos
19+
from diffusers.utils.torch_utils import randn_tensor
1820

19-
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
20-
from ..test_modeling_common import ModelTesterMixin
21-
from .testing_utils import AutoencoderTesterMixin
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
23+
from .testing_utils import NewAutoencoderTesterMixin
2224

2325

2426
enable_full_determinism()
2527

2628

27-
class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
28-
model_class = AutoencoderKLCosmos
29-
main_input_name = "sample"
30-
base_precision = 1e-2
29+
class AutoencoderKLCosmosTesterConfig(BaseModelTesterConfig):
30+
@property
31+
def model_class(self):
32+
return AutoencoderKLCosmos
33+
34+
@property
35+
def main_input_name(self) -> str:
36+
return "sample"
37+
38+
@property
39+
def output_shape(self) -> tuple:
40+
return (3, 9, 32, 32)
3141

32-
def get_autoencoder_kl_cosmos_config(self):
42+
@property
43+
def generator(self):
44+
return torch.Generator("cpu").manual_seed(0)
45+
46+
def get_init_dict(self) -> dict:
3347
return {
3448
"in_channels": 3,
3549
"out_channels": 3,
@@ -46,38 +60,37 @@ def get_autoencoder_kl_cosmos_config(self):
4660
"temporal_compression_ratio": 4,
4761
}
4862

49-
@property
50-
def dummy_input(self):
63+
def get_dummy_inputs(self) -> dict:
5164
batch_size = 2
5265
num_frames = 9
5366
num_channels = 3
5467
height = 32
5568
width = 32
56-
57-
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
58-
69+
image = randn_tensor(
70+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
71+
)
5972
return {"sample": image}
6073

61-
@property
62-
def input_shape(self):
63-
return (3, 9, 32, 32)
6474

65-
@property
66-
def output_shape(self):
67-
return (3, 9, 32, 32)
75+
class TestAutoencoderKLCosmos(AutoencoderKLCosmosTesterConfig, ModelTesterMixin):
76+
pass
77+
6878

69-
def prepare_init_args_and_inputs_for_common(self):
70-
init_dict = self.get_autoencoder_kl_cosmos_config()
71-
inputs_dict = self.dummy_input
72-
return init_dict, inputs_dict
79+
class TestAutoencoderKLCosmosTraining(AutoencoderKLCosmosTesterConfig, TrainingTesterMixin):
80+
"""Training tests for AutoencoderKLCosmos."""
7381

7482
def test_gradient_checkpointing_is_applied(self):
75-
expected_set = {
76-
"CosmosEncoder3d",
77-
"CosmosDecoder3d",
78-
}
83+
expected_set = {"CosmosEncoder3d", "CosmosDecoder3d"}
7984
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
8085

81-
@unittest.skip("Not sure why this test fails. Investigate later.")
82-
def test_effective_gradient_checkpointing(self):
83-
pass
86+
@pytest.mark.skip("Not sure why this test fails. Investigate later.")
87+
def test_gradient_checkpointing_equivalence(self):
88+
super().test_gradient_checkpointing_equivalence()
89+
90+
91+
class TestAutoencoderKLCosmosMemory(AutoencoderKLCosmosTesterConfig, MemoryTesterMixin):
92+
"""Memory optimization tests for AutoencoderKLCosmos."""
93+
94+
95+
class TestAutoencoderKLCosmosSlicingTiling(AutoencoderKLCosmosTesterConfig, NewAutoencoderTesterMixin):
96+
"""Slicing and tiling tests for AutoencoderKLCosmos."""

tests/models/autoencoders/test_models_autoencoder_kl_kvae.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,37 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import torch
1717

1818
from diffusers import AutoencoderKLKVAE
19+
from diffusers.utils.torch_utils import randn_tensor
1920

20-
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
21-
from ..test_modeling_common import ModelTesterMixin
22-
from .testing_utils import AutoencoderTesterMixin
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
23+
from .testing_utils import NewAutoencoderTesterMixin
2324

2425

2526
enable_full_determinism()
2627

2728

28-
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderKLKVAE
30-
main_input_name = "sample"
31-
base_precision = 1e-2
29+
class AutoencoderKLKVAETesterConfig(BaseModelTesterConfig):
30+
@property
31+
def model_class(self):
32+
return AutoencoderKLKVAE
33+
34+
@property
35+
def main_input_name(self) -> str:
36+
return "sample"
37+
38+
@property
39+
def output_shape(self) -> tuple:
40+
return (3, 32, 32)
3241

33-
def get_autoencoder_kl_kvae_config(self):
42+
@property
43+
def generator(self):
44+
return torch.Generator("cpu").manual_seed(0)
45+
46+
def get_init_dict(self) -> dict:
3447
return {
3548
"in_channels": 3,
3649
"channels": 32,
@@ -42,32 +55,29 @@ def get_autoencoder_kl_kvae_config(self):
4255
"sample_size": 32,
4356
}
4457

45-
@property
46-
def dummy_input(self):
58+
def get_dummy_inputs(self) -> dict:
4759
batch_size = 2
4860
num_channels = 3
4961
sizes = (32, 32)
50-
51-
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
52-
62+
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
5363
return {"sample": image}
5464

55-
@property
56-
def input_shape(self):
57-
return (3, 32, 32)
5865

59-
@property
60-
def output_shape(self):
61-
return (3, 32, 32)
66+
class TestAutoencoderKLKVAE(AutoencoderKLKVAETesterConfig, ModelTesterMixin):
67+
pass
68+
6269

63-
def prepare_init_args_and_inputs_for_common(self):
64-
init_dict = self.get_autoencoder_kl_kvae_config()
65-
inputs_dict = self.dummy_input
66-
return init_dict, inputs_dict
70+
class TestAutoencoderKLKVAETraining(AutoencoderKLKVAETesterConfig, TrainingTesterMixin):
71+
"""Training tests for AutoencoderKLKVAE."""
6772

6873
def test_gradient_checkpointing_is_applied(self):
69-
expected_set = {
70-
"KVAEEncoder2D",
71-
"KVAEDecoder2D",
72-
}
74+
expected_set = {"KVAEEncoder2D", "KVAEDecoder2D"}
7375
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
76+
77+
78+
class TestAutoencoderKLKVAEMemory(AutoencoderKLKVAETesterConfig, MemoryTesterMixin):
79+
"""Memory optimization tests for AutoencoderKLKVAE."""
80+
81+
82+
class TestAutoencoderKLKVAESlicingTiling(AutoencoderKLKVAETesterConfig, NewAutoencoderTesterMixin):
83+
"""Slicing and tiling tests for AutoencoderKLKVAE."""

tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,72 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import torch
1717

1818
from diffusers import AutoencoderKLTemporalDecoder
19+
from diffusers.utils.torch_utils import randn_tensor
1920

20-
from ...testing_utils import (
21-
enable_full_determinism,
22-
floats_tensor,
23-
torch_device,
24-
)
25-
from ..test_modeling_common import ModelTesterMixin
26-
from .testing_utils import AutoencoderTesterMixin
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
23+
from .testing_utils import NewAutoencoderTesterMixin
2724

2825

2926
enable_full_determinism()
3027

3128

32-
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
33-
model_class = AutoencoderKLTemporalDecoder
34-
main_input_name = "sample"
35-
base_precision = 1e-2
36-
29+
class AutoencoderKLTemporalDecoderTesterConfig(BaseModelTesterConfig):
3730
@property
38-
def dummy_input(self):
39-
batch_size = 3
40-
num_channels = 3
41-
sizes = (32, 32)
42-
43-
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
44-
num_frames = 3
31+
def model_class(self):
32+
return AutoencoderKLTemporalDecoder
4533

46-
return {"sample": image, "num_frames": num_frames}
34+
@property
35+
def main_input_name(self) -> str:
36+
return "sample"
4737

4838
@property
49-
def input_shape(self):
39+
def output_shape(self) -> tuple:
5040
return (3, 32, 32)
5141

5242
@property
53-
def output_shape(self):
54-
return (3, 32, 32)
43+
def generator(self):
44+
return torch.Generator("cpu").manual_seed(0)
5545

56-
def prepare_init_args_and_inputs_for_common(self):
57-
init_dict = {
46+
def get_init_dict(self) -> dict:
47+
return {
5848
"block_out_channels": [32, 64],
5949
"in_channels": 3,
6050
"out_channels": 3,
6151
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
6252
"latent_channels": 4,
6353
"layers_per_block": 2,
6454
}
65-
inputs_dict = self.dummy_input
66-
return init_dict, inputs_dict
55+
56+
def get_dummy_inputs(self) -> dict:
57+
batch_size = 3
58+
num_channels = 3
59+
sizes = (32, 32)
60+
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
61+
num_frames = 3
62+
return {"sample": image, "num_frames": num_frames}
63+
64+
65+
class TestAutoencoderKLTemporalDecoder(AutoencoderKLTemporalDecoderTesterConfig, ModelTesterMixin):
66+
pass
67+
68+
69+
class TestAutoencoderKLTemporalDecoderTraining(AutoencoderKLTemporalDecoderTesterConfig, TrainingTesterMixin):
70+
"""Training tests for AutoencoderKLTemporalDecoder."""
6771

6872
def test_gradient_checkpointing_is_applied(self):
6973
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
7074
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
75+
76+
77+
class TestAutoencoderKLTemporalDecoderMemory(AutoencoderKLTemporalDecoderTesterConfig, MemoryTesterMixin):
78+
"""Memory optimization tests for AutoencoderKLTemporalDecoder."""
79+
80+
81+
class TestAutoencoderKLTemporalDecoderSlicingTiling(
82+
AutoencoderKLTemporalDecoderTesterConfig, NewAutoencoderTesterMixin
83+
):
84+
"""Slicing and tiling tests for AutoencoderKLTemporalDecoder."""

0 commit comments

Comments
 (0)