Skip to content

Commit c07f09c

Browse files
DN6sayakpaul
andauthored
[CI] Refactor Skyreels, Lumina, Ominigen, Mochi transformer tests (#13348)
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 47538fc commit c07f09c

5 files changed

Lines changed: 235 additions & 202 deletions

tests/models/transformers/test_models_transformer_lumina.py

Lines changed: 45 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,85 +13,45 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import LuminaNextDiT2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
torch_device,
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
2526
)
26-
from ..test_modeling_common import ModelTesterMixin
2727

2828

2929
enable_full_determinism()
3030

3131

32-
class LuminaNextDiT2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
33-
model_class = LuminaNextDiT2DModel
34-
main_input_name = "hidden_states"
35-
uses_custom_attn_processor = True
36-
32+
class LuminaNextDiTTesterConfig(BaseModelTesterConfig):
3733
@property
38-
def dummy_input(self):
39-
"""
40-
Args:
41-
None
42-
Returns:
43-
Dict: Dictionary of dummy input tensors
44-
"""
45-
batch_size = 2 # N
46-
num_channels = 4 # C
47-
height = width = 16 # H, W
48-
embedding_dim = 32 # D
49-
sequence_length = 16 # L
50-
51-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
52-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
53-
timestep = torch.rand(size=(batch_size,)).to(torch_device)
54-
encoder_mask = torch.randn(size=(batch_size, sequence_length)).to(torch_device)
55-
image_rotary_emb = torch.randn((384, 384, 4)).to(torch_device)
34+
def model_class(self):
35+
return LuminaNextDiT2DModel
5636

57-
return {
58-
"hidden_states": hidden_states,
59-
"encoder_hidden_states": encoder_hidden_states,
60-
"timestep": timestep,
61-
"encoder_mask": encoder_mask,
62-
"image_rotary_emb": image_rotary_emb,
63-
"cross_attention_kwargs": {},
64-
}
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
6540

6641
@property
67-
def input_shape(self):
68-
"""
69-
Args:
70-
None
71-
Returns:
72-
Tuple: (int, int, int)
73-
"""
42+
def output_shape(self) -> tuple:
7443
return (4, 16, 16)
7544

7645
@property
77-
def output_shape(self):
78-
"""
79-
Args:
80-
None
81-
Returns:
82-
Tuple: (int, int, int)
83-
"""
46+
def input_shape(self) -> tuple:
8447
return (4, 16, 16)
8548

86-
def prepare_init_args_and_inputs_for_common(self):
87-
"""
88-
Args:
89-
None
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
9052

91-
Returns:
92-
Tuple: (Dict, Dict)
93-
"""
94-
init_dict = {
53+
def get_init_dict(self) -> dict:
54+
return {
9555
"sample_size": 16,
9656
"patch_size": 2,
9757
"in_channels": 4,
@@ -108,5 +68,29 @@ def prepare_init_args_and_inputs_for_common(self):
10868
"scaling_factor": 1.0,
10969
}
11070

111-
inputs_dict = self.dummy_input
112-
return init_dict, inputs_dict
71+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
72+
num_channels = 4
73+
height = width = 16
74+
embedding_dim = 32
75+
sequence_length = 16
76+
77+
return {
78+
"hidden_states": randn_tensor(
79+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
80+
),
81+
"encoder_hidden_states": randn_tensor(
82+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
83+
),
84+
"timestep": torch.rand(size=(batch_size,), generator=self.generator).to(torch_device),
85+
"encoder_mask": randn_tensor((batch_size, sequence_length), generator=self.generator, device=torch_device),
86+
"image_rotary_emb": randn_tensor((384, 384, 4), generator=self.generator, device=torch_device),
87+
"cross_attention_kwargs": {},
88+
}
89+
90+
91+
class TestLuminaNextDiT(LuminaNextDiTTesterConfig, ModelTesterMixin):
92+
pass
93+
94+
95+
class TestLuminaNextDiTTraining(LuminaNextDiTTesterConfig, TrainingTesterMixin):
96+
pass

tests/models/transformers/test_models_transformer_lumina2.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,45 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import Lumina2Transformer2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
torch_device,
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
2526
)
26-
from ..test_modeling_common import ModelTesterMixin
2727

2828

2929
enable_full_determinism()
3030

3131

32-
class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
33-
model_class = Lumina2Transformer2DModel
34-
main_input_name = "hidden_states"
35-
uses_custom_attn_processor = True
36-
32+
class Lumina2TransformerTesterConfig(BaseModelTesterConfig):
3733
@property
38-
def dummy_input(self):
39-
batch_size = 2 # N
40-
num_channels = 4 # C
41-
height = width = 16 # H, W
42-
embedding_dim = 32 # D
43-
sequence_length = 16 # L
44-
45-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
46-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
47-
timestep = torch.rand(size=(batch_size,)).to(torch_device)
48-
attention_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device)
34+
def model_class(self):
35+
return Lumina2Transformer2DModel
4936

50-
return {
51-
"hidden_states": hidden_states,
52-
"encoder_hidden_states": encoder_hidden_states,
53-
"timestep": timestep,
54-
"encoder_attention_mask": attention_mask,
55-
}
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
5640

5741
@property
58-
def input_shape(self):
42+
def output_shape(self) -> tuple:
5943
return (4, 16, 16)
6044

6145
@property
62-
def output_shape(self):
46+
def input_shape(self) -> tuple:
6347
return (4, 16, 16)
6448

65-
def prepare_init_args_and_inputs_for_common(self):
66-
init_dict = {
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict:
54+
return {
6755
"sample_size": 16,
6856
"patch_size": 2,
6957
"in_channels": 4,
@@ -81,9 +69,29 @@ def prepare_init_args_and_inputs_for_common(self):
8169
"cap_feat_dim": 32,
8270
}
8371

84-
inputs_dict = self.dummy_input
85-
return init_dict, inputs_dict
72+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
73+
num_channels = 4
74+
height = width = 16
75+
embedding_dim = 32
76+
sequence_length = 16
77+
78+
return {
79+
"hidden_states": randn_tensor(
80+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
81+
),
82+
"encoder_hidden_states": randn_tensor(
83+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
84+
),
85+
"timestep": torch.rand(size=(batch_size,), generator=self.generator).to(torch_device),
86+
"encoder_attention_mask": torch.ones((batch_size, sequence_length), dtype=torch.bool, device=torch_device),
87+
}
88+
89+
90+
class TestLumina2Transformer(Lumina2TransformerTesterConfig, ModelTesterMixin):
91+
pass
92+
8693

94+
class TestLumina2TransformerTraining(Lumina2TransformerTesterConfig, TrainingTesterMixin):
8795
def test_gradient_checkpointing_is_applied(self):
8896
expected_set = {"Lumina2Transformer2DModel"}
8997
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_mochi.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,49 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import MochiTransformer3DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

2221
from ...testing_utils import enable_full_determinism, torch_device
23-
from ..test_modeling_common import ModelTesterMixin
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
2427

2528

2629
enable_full_determinism()
2730

2831

29-
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
30-
model_class = MochiTransformer3DModel
31-
main_input_name = "hidden_states"
32-
uses_custom_attn_processor = True
33-
# Overriding it because of the transformer size.
34-
model_split_percents = [0.7, 0.6, 0.6]
35-
32+
class MochiTransformerTesterConfig(BaseModelTesterConfig):
3633
@property
37-
def dummy_input(self):
38-
batch_size = 2
39-
num_channels = 4
40-
num_frames = 2
41-
height = 16
42-
width = 16
43-
embedding_dim = 16
44-
sequence_length = 16
34+
def model_class(self):
35+
return MochiTransformer3DModel
4536

46-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
47-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
48-
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
49-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
5040

51-
return {
52-
"hidden_states": hidden_states,
53-
"encoder_hidden_states": encoder_hidden_states,
54-
"timestep": timestep,
55-
"encoder_attention_mask": encoder_attention_mask,
56-
}
41+
@property
42+
def model_split_percents(self) -> list:
43+
return [0.7, 0.6, 0.6]
5744

5845
@property
59-
def input_shape(self):
46+
def output_shape(self) -> tuple:
6047
return (4, 2, 16, 16)
6148

6249
@property
63-
def output_shape(self):
50+
def input_shape(self) -> tuple:
6451
return (4, 2, 16, 16)
6552

66-
def prepare_init_args_and_inputs_for_common(self):
67-
init_dict = {
53+
@property
54+
def generator(self):
55+
return torch.Generator("cpu").manual_seed(0)
56+
57+
def get_init_dict(self) -> dict:
58+
return {
6859
"patch_size": 2,
6960
"num_attention_heads": 2,
7061
"attention_head_dim": 8,
@@ -78,9 +69,32 @@ def prepare_init_args_and_inputs_for_common(self):
7869
"activation_fn": "swiglu",
7970
"max_sequence_length": 16,
8071
}
81-
inputs_dict = self.dummy_input
82-
return init_dict, inputs_dict
8372

73+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
74+
num_channels = 4
75+
num_frames = 2
76+
height = 16
77+
width = 16
78+
embedding_dim = 16
79+
sequence_length = 16
80+
81+
return {
82+
"hidden_states": randn_tensor(
83+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
84+
),
85+
"encoder_hidden_states": randn_tensor(
86+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
87+
),
88+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
89+
"encoder_attention_mask": torch.ones((batch_size, sequence_length), dtype=torch.bool).to(torch_device),
90+
}
91+
92+
93+
class TestMochiTransformer(MochiTransformerTesterConfig, ModelTesterMixin):
94+
pass
95+
96+
97+
class TestMochiTransformerTraining(MochiTransformerTesterConfig, TrainingTesterMixin):
8498
def test_gradient_checkpointing_is_applied(self):
8599
expected_set = {"MochiTransformer3DModel"}
86100
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)