1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import unittest
16-
1715import torch
1816
1917from diffusers import CosmosTransformer3DModel
18+ from diffusers .utils .torch_utils import randn_tensor
2019
2120from ...testing_utils import enable_full_determinism , torch_device
22- from ..test_modeling_common import ModelTesterMixin
21+ from ..testing_utils import (
22+ BaseModelTesterConfig ,
23+ MemoryTesterMixin ,
24+ ModelTesterMixin ,
25+ TrainingTesterMixin ,
26+ )
2327
2428
2529enable_full_determinism ()
2630
2731
28- class CosmosTransformer3DModelTests (ModelTesterMixin , unittest .TestCase ):
29- model_class = CosmosTransformer3DModel
30- main_input_name = "hidden_states"
31- uses_custom_attn_processor = True
32-
32+ class CosmosTransformerTesterConfig (BaseModelTesterConfig ):
3333 @property
34- def dummy_input (self ):
35- batch_size = 1
36- num_channels = 4
37- num_frames = 1
38- height = 16
39- width = 16
40- text_embed_dim = 16
41- sequence_length = 12
42- fps = 30
43-
44- hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
45- timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
46- encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_embed_dim )).to (torch_device )
47- attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
48- padding_mask = torch .zeros (batch_size , 1 , height , width ).to (torch_device )
49-
50- return {
51- "hidden_states" : hidden_states ,
52- "timestep" : timestep ,
53- "encoder_hidden_states" : encoder_hidden_states ,
54- "attention_mask" : attention_mask ,
55- "fps" : fps ,
56- "padding_mask" : padding_mask ,
57- }
34+ def model_class (self ):
35+ return CosmosTransformer3DModel
5836
5937 @property
60- def input_shape (self ):
38+ def output_shape (self ) -> tuple [ int , ...] :
6139 return (4 , 1 , 16 , 16 )
6240
6341 @property
64- def output_shape (self ):
42+ def input_shape (self ) -> tuple [ int , ...] :
6543 return (4 , 1 , 16 , 16 )
6644
67- def prepare_init_args_and_inputs_for_common (self ):
68- init_dict = {
45+ @property
46+ def main_input_name (self ) -> str :
47+ return "hidden_states"
48+
49+ @property
50+ def generator (self ):
51+ return torch .Generator ("cpu" ).manual_seed (0 )
52+
53+ def get_init_dict (self ) -> dict [str , int | list | tuple | float | bool | str ]:
54+ return {
6955 "in_channels" : 4 ,
7056 "out_channels" : 4 ,
7157 "num_attention_heads" : 2 ,
@@ -80,57 +66,68 @@ def prepare_init_args_and_inputs_for_common(self):
8066 "concat_padding_mask" : True ,
8167 "extra_pos_embed_type" : "learnable" ,
8268 }
83- inputs_dict = self .dummy_input
84- return init_dict , inputs_dict
85-
86- def test_gradient_checkpointing_is_applied (self ):
87- expected_set = {"CosmosTransformer3DModel" }
88- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
89-
90-
91- class CosmosTransformer3DModelVideoToWorldTests (ModelTesterMixin , unittest .TestCase ):
92- model_class = CosmosTransformer3DModel
93- main_input_name = "hidden_states"
94- uses_custom_attn_processor = True
9569
96- @property
97- def dummy_input (self ):
98- batch_size = 1
70+ def get_dummy_inputs (self , batch_size : int = 1 ) -> dict [str , torch .Tensor ]:
9971 num_channels = 4
10072 num_frames = 1
10173 height = 16
10274 width = 16
10375 text_embed_dim = 16
10476 sequence_length = 12
105- fps = 30
106-
107- hidden_states = torch .randn ((batch_size , num_channels , num_frames , height , width )).to (torch_device )
108- timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
109- encoder_hidden_states = torch .randn ((batch_size , sequence_length , text_embed_dim )).to (torch_device )
110- attention_mask = torch .ones ((batch_size , sequence_length )).to (torch_device )
111- condition_mask = torch .ones (batch_size , 1 , num_frames , height , width ).to (torch_device )
112- padding_mask = torch .zeros (batch_size , 1 , height , width ).to (torch_device )
11377
11478 return {
115- "hidden_states" : hidden_states ,
116- "timestep" : timestep ,
117- "encoder_hidden_states" : encoder_hidden_states ,
118- "attention_mask" : attention_mask ,
119- "fps" : fps ,
120- "condition_mask" : condition_mask ,
121- "padding_mask" : padding_mask ,
79+ "hidden_states" : randn_tensor (
80+ (batch_size , num_channels , num_frames , height , width ), generator = self .generator , device = torch_device
81+ ),
82+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
83+ "encoder_hidden_states" : randn_tensor (
84+ (batch_size , sequence_length , text_embed_dim ), generator = self .generator , device = torch_device
85+ ),
86+ "attention_mask" : torch .ones ((batch_size , sequence_length )).to (torch_device ),
87+ "fps" : 30 ,
88+ "padding_mask" : torch .zeros (batch_size , 1 , height , width ).to (torch_device ),
12289 }
12390
91+
92+ class TestCosmosTransformer (CosmosTransformerTesterConfig , ModelTesterMixin ):
93+ """Core model tests for Cosmos Transformer."""
94+
95+
96+ class TestCosmosTransformerMemory (CosmosTransformerTesterConfig , MemoryTesterMixin ):
97+ """Memory optimization tests for Cosmos Transformer."""
98+
99+
100+ class TestCosmosTransformerTraining (CosmosTransformerTesterConfig , TrainingTesterMixin ):
101+ """Training tests for Cosmos Transformer."""
102+
103+ def test_gradient_checkpointing_is_applied (self ):
104+ expected_set = {"CosmosTransformer3DModel" }
105+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
106+
107+
108+ class CosmosTransformerVideoToWorldTesterConfig (BaseModelTesterConfig ):
109+ @property
110+ def model_class (self ):
111+ return CosmosTransformer3DModel
112+
124113 @property
125- def input_shape (self ):
114+ def output_shape (self ) -> tuple [ int , ...] :
126115 return (4 , 1 , 16 , 16 )
127116
128117 @property
129- def output_shape (self ):
118+ def input_shape (self ) -> tuple [ int , ...] :
130119 return (4 , 1 , 16 , 16 )
131120
132- def prepare_init_args_and_inputs_for_common (self ):
133- init_dict = {
121+ @property
122+ def main_input_name (self ) -> str :
123+ return "hidden_states"
124+
125+ @property
126+ def generator (self ):
127+ return torch .Generator ("cpu" ).manual_seed (0 )
128+
129+ def get_init_dict (self ) -> dict [str , int | list | tuple | float | bool | str ]:
130+ return {
134131 "in_channels" : 4 + 1 ,
135132 "out_channels" : 4 ,
136133 "num_attention_heads" : 2 ,
@@ -145,8 +142,40 @@ def prepare_init_args_and_inputs_for_common(self):
145142 "concat_padding_mask" : True ,
146143 "extra_pos_embed_type" : "learnable" ,
147144 }
148- inputs_dict = self .dummy_input
149- return init_dict , inputs_dict
145+
146+ def get_dummy_inputs (self , batch_size : int = 1 ) -> dict [str , torch .Tensor ]:
147+ num_channels = 4
148+ num_frames = 1
149+ height = 16
150+ width = 16
151+ text_embed_dim = 16
152+ sequence_length = 12
153+
154+ return {
155+ "hidden_states" : randn_tensor (
156+ (batch_size , num_channels , num_frames , height , width ), generator = self .generator , device = torch_device
157+ ),
158+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
159+ "encoder_hidden_states" : randn_tensor (
160+ (batch_size , sequence_length , text_embed_dim ), generator = self .generator , device = torch_device
161+ ),
162+ "attention_mask" : torch .ones ((batch_size , sequence_length )).to (torch_device ),
163+ "fps" : 30 ,
164+ "condition_mask" : torch .ones (batch_size , 1 , num_frames , height , width ).to (torch_device ),
165+ "padding_mask" : torch .zeros (batch_size , 1 , height , width ).to (torch_device ),
166+ }
167+
168+
169+ class TestCosmosTransformerVideoToWorld (CosmosTransformerVideoToWorldTesterConfig , ModelTesterMixin ):
170+ """Core model tests for Cosmos Transformer (Video-to-World)."""
171+
172+
173+ class TestCosmosTransformerVideoToWorldMemory (CosmosTransformerVideoToWorldTesterConfig , MemoryTesterMixin ):
174+ """Memory optimization tests for Cosmos Transformer (Video-to-World)."""
175+
176+
177+ class TestCosmosTransformerVideoToWorldTraining (CosmosTransformerVideoToWorldTesterConfig , TrainingTesterMixin ):
178+ """Training tests for Cosmos Transformer (Video-to-World)."""
150179
151180 def test_gradient_checkpointing_is_applied (self ):
152181 expected_set = {"CosmosTransformer3DModel" }
0 commit comments