Skip to content

Commit 357b681

Browse files
authored
[tests] refactor autoencoderdc tests (#13369)
* refactor autoencoderdc tests * fix * propagate new changes.
1 parent 065e369 commit 357b681

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

tests/models/autoencoders/test_models_autoencoder_dc.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,34 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderDC
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
21-
from ..test_modeling_common import ModelTesterMixin
22-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2325

2426

2527
enable_full_determinism()
2628

2729

28-
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderDC
30-
main_input_name = "sample"
31-
base_precision = 1e-2
30+
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
31+
@property
32+
def model_class(self):
33+
return AutoencoderDC
34+
35+
@property
36+
def output_shape(self):
37+
return (3, 32, 32)
38+
39+
@property
40+
def generator(self):
41+
return torch.Generator("cpu").manual_seed(0)
3242

33-
def get_autoencoder_dc_config(self):
43+
def get_init_dict(self):
3444
return {
3545
"in_channels": 3,
3646
"latent_channels": 4,
@@ -56,33 +66,29 @@ def get_autoencoder_dc_config(self):
5666
"scaling_factor": 0.41407,
5767
}
5868

59-
@property
60-
def dummy_input(self):
69+
def get_dummy_inputs(self):
6170
batch_size = 4
6271
num_channels = 3
6372
sizes = (32, 32)
73+
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
74+
return {"sample": image}
6475

65-
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
6676

67-
return {"sample": image}
77+
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
78+
base_precision = 1e-2
6879

69-
@property
70-
def input_shape(self):
71-
return (3, 32, 32)
7280

73-
@property
74-
def output_shape(self):
75-
return (3, 32, 32)
81+
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
82+
"""Training tests for AutoencoderDC."""
7683

77-
def prepare_init_args_and_inputs_for_common(self):
78-
init_dict = self.get_autoencoder_dc_config()
79-
inputs_dict = self.dummy_input
80-
return init_dict, inputs_dict
8184

82-
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
83-
def test_layerwise_casting_inference(self):
84-
super().test_layerwise_casting_inference()
85+
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
86+
"""Memory optimization tests for AutoencoderDC."""
8587

86-
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
88+
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
8789
def test_layerwise_casting_memory(self):
8890
super().test_layerwise_casting_memory()
91+
92+
93+
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
94+
"""Slicing and tiling tests for AutoencoderDC."""

0 commit comments

Comments
 (0)