Skip to content

Commit e3fc602

Browse files
refactor sana transformer tests (#13826)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent e87b2a7 commit e3fc602

1 file changed

Lines changed: 65 additions & 33 deletions

File tree

tests/models/transformers/test_models_transformer_sana.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,55 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import unittest
16-
1715
import torch
1816

1917
from diffusers import SanaTransformer2DModel
18+
from diffusers.utils.torch_utils import randn_tensor
2019

21-
from ...testing_utils import (
22-
enable_full_determinism,
23-
torch_device,
20+
from ...testing_utils import enable_full_determinism, torch_device
21+
from ..testing_utils import (
22+
AttentionTesterMixin,
23+
BaseModelTesterConfig,
24+
MemoryTesterMixin,
25+
ModelTesterMixin,
26+
TrainingTesterMixin,
2427
)
25-
from ..test_modeling_common import ModelTesterMixin
2628

2729

2830
enable_full_determinism()
2931

3032

31-
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
32-
model_class = SanaTransformer2DModel
33-
main_input_name = "hidden_states"
34-
uses_custom_attn_processor = True
35-
model_split_percents = [0.7, 0.7, 0.9]
36-
33+
class SanaTransformerTesterConfig(BaseModelTesterConfig):
3734
@property
38-
def dummy_input(self):
39-
batch_size = 2
40-
num_channels = 4
41-
height = 32
42-
width = 32
43-
embedding_dim = 8
44-
sequence_length = 8
35+
def model_class(self):
36+
return SanaTransformer2DModel
4537

46-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
47-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
48-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
38+
@property
39+
def main_input_name(self) -> str:
40+
return "hidden_states"
4941

50-
return {
51-
"hidden_states": hidden_states,
52-
"encoder_hidden_states": encoder_hidden_states,
53-
"timestep": timestep,
54-
}
42+
@property
43+
def uses_custom_attn_processor(self) -> bool:
44+
return True
5545

5646
@property
57-
def input_shape(self):
47+
def output_shape(self) -> tuple:
5848
return (4, 32, 32)
5949

6050
@property
61-
def output_shape(self):
51+
def input_shape(self) -> tuple:
6252
return (4, 32, 32)
6353

64-
def prepare_init_args_and_inputs_for_common(self):
65-
init_dict = {
54+
@property
55+
def model_split_percents(self) -> list:
56+
return [0.7, 0.7, 0.9]
57+
58+
@property
59+
def generator(self):
60+
return torch.Generator("cpu").manual_seed(0)
61+
62+
def get_init_dict(self) -> dict:
63+
return {
6664
"patch_size": 1,
6765
"in_channels": 4,
6866
"out_channels": 4,
@@ -75,9 +73,43 @@ def prepare_init_args_and_inputs_for_common(self):
7573
"caption_channels": 8,
7674
"sample_size": 32,
7775
}
78-
inputs_dict = self.dummy_input
79-
return init_dict, inputs_dict
8076

77+
def get_dummy_inputs(self) -> dict:
78+
batch_size = 2
79+
num_channels = 4
80+
height = 32
81+
width = 32
82+
embedding_dim = 8
83+
sequence_length = 8
84+
85+
hidden_states = randn_tensor(
86+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
87+
)
88+
encoder_hidden_states = randn_tensor(
89+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
90+
)
91+
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device)
92+
93+
return {
94+
"hidden_states": hidden_states,
95+
"encoder_hidden_states": encoder_hidden_states,
96+
"timestep": timestep,
97+
}
98+
99+
100+
class TestSanaTransformer(SanaTransformerTesterConfig, ModelTesterMixin):
101+
pass
102+
103+
104+
class TestSanaTransformerMemory(SanaTransformerTesterConfig, MemoryTesterMixin):
105+
pass
106+
107+
108+
class TestSanaTransformerTraining(SanaTransformerTesterConfig, TrainingTesterMixin):
81109
def test_gradient_checkpointing_is_applied(self):
82110
expected_set = {"SanaTransformer2DModel"}
83111
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
112+
113+
114+
class TestSanaTransformerAttention(SanaTransformerTesterConfig, AttentionTesterMixin):
115+
pass

0 commit comments

Comments
 (0)