1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
16+ import pytest
1817import torch
1918
2019from diffusers import PixArtTransformer2DModel , Transformer2DModel
21-
22- from ...testing_utils import (
23- enable_full_determinism ,
24- floats_tensor ,
25- slow ,
26- torch_device ,
20+ from diffusers .utils .torch_utils import randn_tensor
21+
22+ from ...testing_utils import enable_full_determinism , slow , torch_device
23+ from ..testing_utils import (
24+ AttentionTesterMixin ,
25+ BaseModelTesterConfig ,
26+ MemoryTesterMixin ,
27+ ModelTesterMixin ,
28+ TrainingTesterMixin ,
2729)
28- from ..test_modeling_common import ModelTesterMixin
2930
3031
3132enable_full_determinism ()
3233
3334
34- class PixArtTransformer2DModelTests (ModelTesterMixin , unittest .TestCase ):
35- model_class = PixArtTransformer2DModel
36- main_input_name = "hidden_states"
37- # We override the items here because the transformer under consideration is small.
38- model_split_percents = [0.7 , 0.6 , 0.6 ]
39-
35+ class PixArtTransformer2DTesterConfig (BaseModelTesterConfig ):
4036 @property
41- def dummy_input (self ):
42- batch_size = 4
43- in_channels = 4
44- sample_size = 8
45- scheduler_num_train_steps = 1000
46- cross_attention_dim = 8
47- seq_len = 8
37+ def model_class (self ):
38+ return PixArtTransformer2DModel
4839
49- hidden_states = floats_tensor ((batch_size , in_channels , sample_size , sample_size )).to (torch_device )
50- timesteps = torch .randint (0 , scheduler_num_train_steps , size = (batch_size ,)).to (torch_device )
51- encoder_hidden_states = floats_tensor ((batch_size , seq_len , cross_attention_dim )).to (torch_device )
52-
53- return {
54- "hidden_states" : hidden_states ,
55- "timestep" : timesteps ,
56- "encoder_hidden_states" : encoder_hidden_states ,
57- "added_cond_kwargs" : {"aspect_ratio" : None , "resolution" : None },
58- }
40+ @property
41+ def main_input_name (self ) -> str :
42+ return "hidden_states"
5943
6044 @property
61- def input_shape (self ):
45+ def input_shape (self ) -> tuple :
6246 return (4 , 8 , 8 )
6347
6448 @property
65- def output_shape (self ):
49+ def output_shape (self ) -> tuple :
6650 return (8 , 8 , 8 )
6751
68- def prepare_init_args_and_inputs_for_common (self ):
69- init_dict = {
52+ @property
53+ def model_split_percents (self ) -> list :
54+ # We override the items here because the transformer under consideration is small.
55+ return [0.7 , 0.6 , 0.6 ]
56+
57+ @property
58+ def generator (self ):
59+ return torch .Generator ("cpu" ).manual_seed (0 )
60+
61+ def get_init_dict (self ) -> dict :
62+ return {
7063 "sample_size" : 8 ,
7164 "num_layers" : 1 ,
7265 "patch_size" : 2 ,
@@ -84,20 +77,37 @@ def prepare_init_args_and_inputs_for_common(self):
8477 "use_additional_conditions" : False ,
8578 "caption_channels" : None ,
8679 }
87- inputs_dict = self .dummy_input
88- return init_dict , inputs_dict
8980
90- def test_output (self ):
91- super ().test_output (
92- expected_output_shape = (self .dummy_input [self .main_input_name ].shape [0 ],) + self .output_shape
93- )
81+ def get_dummy_inputs (self , batch_size : int = 4 ) -> dict [str , torch .Tensor ]:
82+ in_channels = 4
83+ sample_size = 8
84+ scheduler_num_train_steps = 1000
85+ cross_attention_dim = 8
86+ seq_len = 8
87+
88+ return {
89+ "hidden_states" : randn_tensor (
90+ (batch_size , in_channels , sample_size , sample_size ), generator = self .generator , device = torch_device
91+ ),
92+ "timestep" : torch .randint (0 , scheduler_num_train_steps , size = (batch_size ,), generator = self .generator ).to (
93+ torch_device
94+ ),
95+ "encoder_hidden_states" : randn_tensor (
96+ (batch_size , seq_len , cross_attention_dim ), generator = self .generator , device = torch_device
97+ ),
98+ "added_cond_kwargs" : {"aspect_ratio" : None , "resolution" : None },
99+ }
100+
94101
95- def test_gradient_checkpointing_is_applied (self ):
96- expected_set = {"PixArtTransformer2DModel" }
97- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
102+ class TestPixArtTransformer2D (PixArtTransformer2DTesterConfig , ModelTesterMixin ):
103+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ], ids = ["fp16" , "bf16" ])
104+ def test_from_save_pretrained_dtype_inference (self , tmp_path , dtype ):
105+ # Skip: fp16/bf16 require very high atol to pass, providing little signal.
106+ # Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
107+ pytest .skip ("Tolerance requirements too high for meaningful test" )
98108
99109 def test_correct_class_remapping_from_dict_config (self ):
100- init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
110+ init_dict = self .get_init_dict ()
101111 model = Transformer2DModel .from_config (init_dict )
102112 assert isinstance (model , PixArtTransformer2DModel )
103113
@@ -110,3 +120,17 @@ def test_correct_class_remapping_from_pretrained_config(self):
110120 def test_correct_class_remapping (self ):
111121 model = Transformer2DModel .from_pretrained ("PixArt-alpha/PixArt-XL-2-1024-MS" , subfolder = "transformer" )
112122 assert isinstance (model , PixArtTransformer2DModel )
123+
124+
125+ class TestPixArtTransformer2DMemory (PixArtTransformer2DTesterConfig , MemoryTesterMixin ):
126+ pass
127+
128+
129+ class TestPixArtTransformer2DAttention (PixArtTransformer2DTesterConfig , AttentionTesterMixin ):
130+ pass
131+
132+
133+ class TestPixArtTransformer2DTraining (PixArtTransformer2DTesterConfig , TrainingTesterMixin ):
134+ def test_gradient_checkpointing_is_applied (self ):
135+ expected_set = {"PixArtTransformer2DModel" }
136+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
0 commit comments