Skip to content

Commit 25b85c1

Browse files
DN6sayakpaul
andauthored
[CI] Refactor Chronoedit, PRX, EasyAnimate, Ovis transformer tests (#13347)
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent d6c360d commit 25b85c1

4 files changed

Lines changed: 287 additions & 69 deletions

File tree

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
18+
from diffusers import ChronoEditTransformer3DModel
19+
from diffusers.utils.torch_utils import randn_tensor
20+
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class ChronoEditTransformerTesterConfig(BaseModelTesterConfig):
33+
@property
34+
def model_class(self):
35+
return ChronoEditTransformer3DModel
36+
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
40+
41+
@property
42+
def output_shape(self) -> tuple:
43+
return (16, 8, 8)
44+
45+
@property
46+
def input_shape(self) -> tuple:
47+
return (16, 8, 8)
48+
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict:
54+
return {
55+
"patch_size": (1, 2, 2),
56+
"num_attention_heads": 2,
57+
"attention_head_dim": 8,
58+
"in_channels": 16,
59+
"out_channels": 16,
60+
"text_dim": 32,
61+
"freq_dim": 16,
62+
"ffn_dim": 32,
63+
"num_layers": 2,
64+
"cross_attn_norm": True,
65+
"qk_norm": "rms_norm_across_heads",
66+
"eps": 1e-06,
67+
"image_dim": None,
68+
"added_kv_proj_dim": None,
69+
"rope_max_seq_len": 64,
70+
"pos_embed_seq_len": None,
71+
"rope_temporal_skip_len": 8,
72+
}
73+
74+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
75+
num_channels = 16
76+
num_frames = 2
77+
height = 8
78+
width = 8
79+
embedding_dim = 32
80+
sequence_length = 12
81+
82+
return {
83+
"hidden_states": randn_tensor(
84+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
85+
),
86+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
87+
"encoder_hidden_states": randn_tensor(
88+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
89+
),
90+
"encoder_hidden_states_image": None,
91+
}
92+
93+
94+
class TestChronoEditTransformer(ChronoEditTransformerTesterConfig, ModelTesterMixin):
95+
pass
96+
97+
98+
class TestChronoEditTransformerTraining(ChronoEditTransformerTesterConfig, TrainingTesterMixin):
99+
def test_gradient_checkpointing_is_applied(self):
100+
expected_set = {"ChronoEditTransformer3DModel"}
101+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_easyanimate.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,45 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import EasyAnimateTransformer3DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

2221
from ...testing_utils import enable_full_determinism, torch_device
23-
from ..test_modeling_common import ModelTesterMixin
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
2427

2528

2629
enable_full_determinism()
2730

2831

29-
class EasyAnimateTransformerTests(ModelTesterMixin, unittest.TestCase):
30-
model_class = EasyAnimateTransformer3DModel
31-
main_input_name = "hidden_states"
32-
uses_custom_attn_processor = True
33-
32+
class EasyAnimateTransformerTesterConfig(BaseModelTesterConfig):
3433
@property
35-
def dummy_input(self):
36-
batch_size = 2
37-
num_channels = 4
38-
num_frames = 2
39-
height = 16
40-
width = 16
41-
embedding_dim = 16
42-
sequence_length = 16
43-
44-
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
45-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
46-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
34+
def model_class(self):
35+
return EasyAnimateTransformer3DModel
4736

48-
return {
49-
"hidden_states": hidden_states,
50-
"timestep": timestep,
51-
"timestep_cond": None,
52-
"encoder_hidden_states": encoder_hidden_states,
53-
"encoder_hidden_states_t5": None,
54-
"inpaint_latents": None,
55-
"control_latents": None,
56-
}
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
5740

5841
@property
59-
def input_shape(self):
42+
def output_shape(self) -> tuple:
6043
return (4, 2, 16, 16)
6144

6245
@property
63-
def output_shape(self):
46+
def input_shape(self) -> tuple:
6447
return (4, 2, 16, 16)
6548

66-
def prepare_init_args_and_inputs_for_common(self):
67-
init_dict = {
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict:
54+
return {
6855
"attention_head_dim": 16,
6956
"num_attention_heads": 2,
7057
"in_channels": 4,
@@ -79,9 +66,35 @@ def prepare_init_args_and_inputs_for_common(self):
7966
"time_position_encoding_type": "3d_rope",
8067
"timestep_activation_fn": "silu",
8168
}
82-
inputs_dict = self.dummy_input
83-
return init_dict, inputs_dict
8469

70+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
71+
num_channels = 4
72+
num_frames = 2
73+
height = 16
74+
width = 16
75+
embedding_dim = 16
76+
sequence_length = 16
77+
78+
return {
79+
"hidden_states": randn_tensor(
80+
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
81+
),
82+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
83+
"timestep_cond": None,
84+
"encoder_hidden_states": randn_tensor(
85+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
86+
),
87+
"encoder_hidden_states_t5": None,
88+
"inpaint_latents": None,
89+
"control_latents": None,
90+
}
91+
92+
93+
class TestEasyAnimateTransformer(EasyAnimateTransformerTesterConfig, ModelTesterMixin):
94+
pass
95+
96+
97+
class TestEasyAnimateTransformerTraining(EasyAnimateTransformerTesterConfig, TrainingTesterMixin):
8598
def test_gradient_checkpointing_is_applied(self):
8699
expected_set = {"EasyAnimateTransformer3DModel"}
87100
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
18+
from diffusers import OvisImageTransformer2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
20+
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class OvisImageTransformerTesterConfig(BaseModelTesterConfig):
33+
@property
34+
def model_class(self):
35+
return OvisImageTransformer2DModel
36+
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
40+
41+
@property
42+
def output_shape(self) -> tuple:
43+
return (16, 4)
44+
45+
@property
46+
def input_shape(self) -> tuple:
47+
return (16, 4)
48+
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict:
54+
return {
55+
"patch_size": 1,
56+
"in_channels": 4,
57+
"out_channels": 4,
58+
"num_layers": 1,
59+
"num_single_layers": 1,
60+
"attention_head_dim": 16,
61+
"num_attention_heads": 2,
62+
"joint_attention_dim": 32,
63+
"axes_dims_rope": (4, 4, 8),
64+
}
65+
66+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
67+
num_latent_channels = 4
68+
num_image_channels = 3
69+
height = width = 4
70+
sequence_length = 48
71+
embedding_dim = 32
72+
73+
return {
74+
"hidden_states": randn_tensor(
75+
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
76+
),
77+
"encoder_hidden_states": randn_tensor(
78+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
79+
),
80+
"img_ids": randn_tensor(
81+
(height * width, num_image_channels), generator=self.generator, device=torch_device
82+
),
83+
"txt_ids": randn_tensor(
84+
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
85+
),
86+
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
87+
}
88+
89+
90+
class TestOvisImageTransformer(OvisImageTransformerTesterConfig, ModelTesterMixin):
91+
pass
92+
93+
94+
class TestOvisImageTransformerTraining(OvisImageTransformerTesterConfig, TrainingTesterMixin):
95+
def test_gradient_checkpointing_is_applied(self):
96+
expected_set = {"OvisImageTransformer2DModel"}
97+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)