Skip to content

Commit 192b53a

Browse files
authored
Merge branch 'main' into fix-flux2-vae-offload-device-mismatch
2 parents 68fa726 + 357b681 commit 192b53a

2 files changed

Lines changed: 134 additions & 99 deletions

File tree

tests/models/autoencoders/test_models_autoencoder_dc.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,34 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
16+
import pytest
17+
import torch
1718

1819
from diffusers import AutoencoderDC
20+
from diffusers.utils.torch_utils import randn_tensor
1921

20-
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
21-
from ..test_modeling_common import ModelTesterMixin
22-
from .testing_utils import AutoencoderTesterMixin
22+
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
23+
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
24+
from .testing_utils import NewAutoencoderTesterMixin
2325

2426

2527
enable_full_determinism()
2628

2729

28-
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
29-
model_class = AutoencoderDC
30-
main_input_name = "sample"
31-
base_precision = 1e-2
30+
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
31+
@property
32+
def model_class(self):
33+
return AutoencoderDC
34+
35+
@property
36+
def output_shape(self):
37+
return (3, 32, 32)
38+
39+
@property
40+
def generator(self):
41+
return torch.Generator("cpu").manual_seed(0)
3242

33-
def get_autoencoder_dc_config(self):
43+
def get_init_dict(self):
3444
return {
3545
"in_channels": 3,
3646
"latent_channels": 4,
@@ -56,33 +66,29 @@ def get_autoencoder_dc_config(self):
5666
"scaling_factor": 0.41407,
5767
}
5868

59-
@property
60-
def dummy_input(self):
69+
def get_dummy_inputs(self):
6170
batch_size = 4
6271
num_channels = 3
6372
sizes = (32, 32)
73+
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
74+
return {"sample": image}
6475

65-
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
6676

67-
return {"sample": image}
77+
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
78+
base_precision = 1e-2
6879

69-
@property
70-
def input_shape(self):
71-
return (3, 32, 32)
7280

73-
@property
74-
def output_shape(self):
75-
return (3, 32, 32)
81+
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
82+
"""Training tests for AutoencoderDC."""
7683

77-
def prepare_init_args_and_inputs_for_common(self):
78-
init_dict = self.get_autoencoder_dc_config()
79-
inputs_dict = self.dummy_input
80-
return init_dict, inputs_dict
8184

82-
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
83-
def test_layerwise_casting_inference(self):
84-
super().test_layerwise_casting_inference()
85+
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
86+
"""Memory optimization tests for AutoencoderDC."""
8587

86-
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
88+
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
8789
def test_layerwise_casting_memory(self):
8890
super().test_layerwise_casting_memory()
91+
92+
93+
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
94+
"""Slicing and tiling tests for AutoencoderDC."""

tests/models/transformers/test_models_transformer_cosmos.py

Lines changed: 101 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,46 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
1715
import torch
1816

1917
from diffusers import CosmosTransformer3DModel
18+
from diffusers.utils.torch_utils import randn_tensor
2019

2120
from ...testing_utils import enable_full_determinism, torch_device
22-
from ..test_modeling_common import ModelTesterMixin
21+
from ..testing_utils import (
22+
BaseModelTesterConfig,
23+
MemoryTesterMixin,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
2327

2428

2529
enable_full_determinism()
2630

2731

28-
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
29-
model_class = CosmosTransformer3DModel
30-
main_input_name = "hidden_states"
31-
uses_custom_attn_processor = True
32-
32+
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
3333
@property
34-
def dummy_input(self):
35-
batch_size = 1
36-
num_channels = 4
37-
num_frames = 1
38-
height = 16
39-
width = 16
40-
text_embed_dim = 16
41-
sequence_length = 12
42-
fps = 30
43-
44-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
45-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
46-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
47-
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
48-
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
49-
50-
return {
51-
"hidden_states": hidden_states,
52-
"timestep": timestep,
53-
"encoder_hidden_states": encoder_hidden_states,
54-
"attention_mask": attention_mask,
55-
"fps": fps,
56-
"padding_mask": padding_mask,
57-
}
34+
def model_class(self):
35+
return CosmosTransformer3DModel
5836

5937
@property
60-
def input_shape(self):
38+
def output_shape(self) -> tuple[int, ...]:
6139
return (4, 1, 16, 16)
6240

6341
@property
64-
def output_shape(self):
42+
def input_shape(self) -> tuple[int, ...]:
6543
return (4, 1, 16, 16)
6644

67-
def prepare_init_args_and_inputs_for_common(self):
68-
init_dict = {
45+
@property
46+
def main_input_name(self) -> str:
47+
return "hidden_states"
48+
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
54+
return {
6955
"in_channels": 4,
7056
"out_channels": 4,
7157
"num_attention_heads": 2,
@@ -80,57 +66,68 @@ def prepare_init_args_and_inputs_for_common(self):
8066
"concat_padding_mask": True,
8167
"extra_pos_embed_type": "learnable",
8268
}
83-
inputs_dict = self.dummy_input
84-
return init_dict, inputs_dict
85-
86-
def test_gradient_checkpointing_is_applied(self):
87-
expected_set = {"CosmosTransformer3DModel"}
88-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89-
90-
91-
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
92-
model_class = CosmosTransformer3DModel
93-
main_input_name = "hidden_states"
94-
uses_custom_attn_processor = True
9569

96-
@property
97-
def dummy_input(self):
98-
batch_size = 1
70+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
9971
num_channels = 4
10072
num_frames = 1
10173
height = 16
10274
width = 16
10375
text_embed_dim = 16
10476
sequence_length = 12
105-
fps = 30
106-
107-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
108-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
109-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
110-
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
111-
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
112-
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
11377

11478
return {
115-
"hidden_states": hidden_states,
116-
"timestep": timestep,
117-
"encoder_hidden_states": encoder_hidden_states,
118-
"attention_mask": attention_mask,
119-
"fps": fps,
120-
"condition_mask": condition_mask,
121-
"padding_mask": padding_mask,
79+
"hidden_states": randn_tensor(
80+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
81+
),
82+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
83+
"encoder_hidden_states": randn_tensor(
84+
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
85+
),
86+
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
87+
"fps": 30,
88+
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
12289
}
12390

91+
92+
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
93+
"""Core model tests for Cosmos Transformer."""
94+
95+
96+
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
97+
"""Memory optimization tests for Cosmos Transformer."""
98+
99+
100+
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
101+
"""Training tests for Cosmos Transformer."""
102+
103+
def test_gradient_checkpointing_is_applied(self):
104+
expected_set = {"CosmosTransformer3DModel"}
105+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
106+
107+
108+
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
109+
@property
110+
def model_class(self):
111+
return CosmosTransformer3DModel
112+
124113
@property
125-
def input_shape(self):
114+
def output_shape(self) -> tuple[int, ...]:
126115
return (4, 1, 16, 16)
127116

128117
@property
129-
def output_shape(self):
118+
def input_shape(self) -> tuple[int, ...]:
130119
return (4, 1, 16, 16)
131120

132-
def prepare_init_args_and_inputs_for_common(self):
133-
init_dict = {
121+
@property
122+
def main_input_name(self) -> str:
123+
return "hidden_states"
124+
125+
@property
126+
def generator(self):
127+
return torch.Generator("cpu").manual_seed(0)
128+
129+
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
130+
return {
134131
"in_channels": 4 + 1,
135132
"out_channels": 4,
136133
"num_attention_heads": 2,
@@ -145,8 +142,40 @@ def prepare_init_args_and_inputs_for_common(self):
145142
"concat_padding_mask": True,
146143
"extra_pos_embed_type": "learnable",
147144
}
148-
inputs_dict = self.dummy_input
149-
return init_dict, inputs_dict
145+
146+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
147+
num_channels = 4
148+
num_frames = 1
149+
height = 16
150+
width = 16
151+
text_embed_dim = 16
152+
sequence_length = 12
153+
154+
return {
155+
"hidden_states": randn_tensor(
156+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
157+
),
158+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
159+
"encoder_hidden_states": randn_tensor(
160+
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
161+
),
162+
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
163+
"fps": 30,
164+
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
165+
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
166+
}
167+
168+
169+
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
170+
"""Core model tests for Cosmos Transformer (Video-to-World)."""
171+
172+
173+
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
174+
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
175+
176+
177+
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
178+
"""Training tests for Cosmos Transformer (Video-to-World)."""
150179

151180
def test_gradient_checkpointing_is_applied(self):
152181
expected_set = {"CosmosTransformer3DModel"}

0 commit comments

Comments
 (0)