Skip to content

Commit 065e369

Browse files
DN6sayakpaul
andauthored
[CI] Refactor Cosmos Transformer Tests (#13335)
update Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent fbe8a75 commit 065e369

File tree

1 file changed

+101
-72
lines changed

1 file changed

+101
-72
lines changed

tests/models/transformers/test_models_transformer_cosmos.py

Lines changed: 101 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,46 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
1715
import torch
1816

1917
from diffusers import CosmosTransformer3DModel
18+
from diffusers.utils.torch_utils import randn_tensor
2019

2120
from ...testing_utils import enable_full_determinism, torch_device
22-
from ..test_modeling_common import ModelTesterMixin
21+
from ..testing_utils import (
22+
BaseModelTesterConfig,
23+
MemoryTesterMixin,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
2327

2428

2529
enable_full_determinism()
2630

2731

28-
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
29-
model_class = CosmosTransformer3DModel
30-
main_input_name = "hidden_states"
31-
uses_custom_attn_processor = True
32-
32+
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
3333
@property
34-
def dummy_input(self):
35-
batch_size = 1
36-
num_channels = 4
37-
num_frames = 1
38-
height = 16
39-
width = 16
40-
text_embed_dim = 16
41-
sequence_length = 12
42-
fps = 30
43-
44-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
45-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
46-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
47-
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
48-
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
49-
50-
return {
51-
"hidden_states": hidden_states,
52-
"timestep": timestep,
53-
"encoder_hidden_states": encoder_hidden_states,
54-
"attention_mask": attention_mask,
55-
"fps": fps,
56-
"padding_mask": padding_mask,
57-
}
34+
def model_class(self):
35+
return CosmosTransformer3DModel
5836

5937
@property
60-
def input_shape(self):
38+
def output_shape(self) -> tuple[int, ...]:
6139
return (4, 1, 16, 16)
6240

6341
@property
64-
def output_shape(self):
42+
def input_shape(self) -> tuple[int, ...]:
6543
return (4, 1, 16, 16)
6644

67-
def prepare_init_args_and_inputs_for_common(self):
68-
init_dict = {
45+
@property
46+
def main_input_name(self) -> str:
47+
return "hidden_states"
48+
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
54+
return {
6955
"in_channels": 4,
7056
"out_channels": 4,
7157
"num_attention_heads": 2,
@@ -80,57 +66,68 @@ def prepare_init_args_and_inputs_for_common(self):
8066
"concat_padding_mask": True,
8167
"extra_pos_embed_type": "learnable",
8268
}
83-
inputs_dict = self.dummy_input
84-
return init_dict, inputs_dict
85-
86-
def test_gradient_checkpointing_is_applied(self):
87-
expected_set = {"CosmosTransformer3DModel"}
88-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89-
90-
91-
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
92-
model_class = CosmosTransformer3DModel
93-
main_input_name = "hidden_states"
94-
uses_custom_attn_processor = True
9569

96-
@property
97-
def dummy_input(self):
98-
batch_size = 1
70+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
9971
num_channels = 4
10072
num_frames = 1
10173
height = 16
10274
width = 16
10375
text_embed_dim = 16
10476
sequence_length = 12
105-
fps = 30
106-
107-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
108-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
109-
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
110-
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
111-
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
112-
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
11377

11478
return {
115-
"hidden_states": hidden_states,
116-
"timestep": timestep,
117-
"encoder_hidden_states": encoder_hidden_states,
118-
"attention_mask": attention_mask,
119-
"fps": fps,
120-
"condition_mask": condition_mask,
121-
"padding_mask": padding_mask,
79+
"hidden_states": randn_tensor(
80+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
81+
),
82+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
83+
"encoder_hidden_states": randn_tensor(
84+
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
85+
),
86+
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
87+
"fps": 30,
88+
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
12289
}
12390

91+
92+
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
93+
"""Core model tests for Cosmos Transformer."""
94+
95+
96+
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
97+
"""Memory optimization tests for Cosmos Transformer."""
98+
99+
100+
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
101+
"""Training tests for Cosmos Transformer."""
102+
103+
def test_gradient_checkpointing_is_applied(self):
104+
expected_set = {"CosmosTransformer3DModel"}
105+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
106+
107+
108+
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
109+
@property
110+
def model_class(self):
111+
return CosmosTransformer3DModel
112+
124113
@property
125-
def input_shape(self):
114+
def output_shape(self) -> tuple[int, ...]:
126115
return (4, 1, 16, 16)
127116

128117
@property
129-
def output_shape(self):
118+
def input_shape(self) -> tuple[int, ...]:
130119
return (4, 1, 16, 16)
131120

132-
def prepare_init_args_and_inputs_for_common(self):
133-
init_dict = {
121+
@property
122+
def main_input_name(self) -> str:
123+
return "hidden_states"
124+
125+
@property
126+
def generator(self):
127+
return torch.Generator("cpu").manual_seed(0)
128+
129+
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
130+
return {
134131
"in_channels": 4 + 1,
135132
"out_channels": 4,
136133
"num_attention_heads": 2,
@@ -145,8 +142,40 @@ def prepare_init_args_and_inputs_for_common(self):
145142
"concat_padding_mask": True,
146143
"extra_pos_embed_type": "learnable",
147144
}
148-
inputs_dict = self.dummy_input
149-
return init_dict, inputs_dict
145+
146+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
147+
num_channels = 4
148+
num_frames = 1
149+
height = 16
150+
width = 16
151+
text_embed_dim = 16
152+
sequence_length = 12
153+
154+
return {
155+
"hidden_states": randn_tensor(
156+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
157+
),
158+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
159+
"encoder_hidden_states": randn_tensor(
160+
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
161+
),
162+
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
163+
"fps": 30,
164+
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
165+
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
166+
}
167+
168+
169+
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
170+
"""Core model tests for Cosmos Transformer (Video-to-World)."""
171+
172+
173+
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
174+
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
175+
176+
177+
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
178+
"""Training tests for Cosmos Transformer (Video-to-World)."""
150179

151180
def test_gradient_checkpointing_is_applied(self):
152181
expected_set = {"CosmosTransformer3DModel"}

0 commit comments

Comments
 (0)