1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
1816import torch
1917
2018from diffusers import SD3Transformer2DModel
21- from diffusers .utils .import_utils import is_xformers_available
22-
23- from ...testing_utils import (
24- enable_full_determinism ,
25- torch_device ,
19+ from diffusers .utils .torch_utils import randn_tensor
20+
21+ from ...testing_utils import enable_full_determinism , torch_device
22+ from ..testing_utils import (
23+ BaseModelTesterConfig ,
24+ BitsAndBytesTesterMixin ,
25+ ModelTesterMixin ,
26+ TorchAoTesterMixin ,
27+ TorchCompileTesterMixin ,
28+ TrainingTesterMixin ,
2629)
27- from ..test_modeling_common import ModelTesterMixin
2830
2931
3032enable_full_determinism ()
3133
3234
33- class SD3TransformerTests (ModelTesterMixin , unittest .TestCase ):
34- model_class = SD3Transformer2DModel
35- main_input_name = "hidden_states"
36- model_split_percents = [0.8 , 0.8 , 0.9 ]
35+ # ======================== SD3 Transformer ========================
3736
37+
38+ class SD3TransformerTesterConfig (BaseModelTesterConfig ):
3839 @property
39- def dummy_input (self ):
40- batch_size = 2
41- num_channels = 4
42- height = width = embedding_dim = 32
43- pooled_embedding_dim = embedding_dim * 2
44- sequence_length = 154
40+ def model_class (self ):
41+ return SD3Transformer2DModel
42+
43+ @ property
44+ def pretrained_model_name_or_path ( self ):
45+ return "hf-internal-testing/tiny-sd3-pipe"
4546
46- hidden_states = torch .randn ((batch_size , num_channels , height , width )).to (torch_device )
47- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
48- pooled_prompt_embeds = torch .randn ((batch_size , pooled_embedding_dim )).to (torch_device )
49- timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
47+ @property
48+ def pretrained_model_kwargs (self ):
49+ return {"subfolder" : "transformer" }
5050
51- return {
52- "hidden_states" : hidden_states ,
53- "encoder_hidden_states" : encoder_hidden_states ,
54- "pooled_projections" : pooled_prompt_embeds ,
55- "timestep" : timestep ,
56- }
51+ @property
52+ def main_input_name (self ) -> str :
53+ return "hidden_states"
54+
55+ @property
56+ def model_split_percents (self ) -> list :
57+ return [0.8 , 0.8 , 0.9 ]
5758
5859 @property
59- def input_shape (self ):
60+ def output_shape (self ) -> tuple :
6061 return (4 , 32 , 32 )
6162
6263 @property
63- def output_shape (self ):
64+ def input_shape (self ) -> tuple :
6465 return (4 , 32 , 32 )
6566
66- def prepare_init_args_and_inputs_for_common (self ):
67- init_dict = {
67+ @property
68+ def generator (self ):
69+ return torch .Generator ("cpu" ).manual_seed (0 )
70+
71+ def get_init_dict (self ) -> dict :
72+ return {
6873 "sample_size" : 32 ,
6974 "patch_size" : 1 ,
7075 "in_channels" : 4 ,
@@ -79,67 +84,79 @@ def prepare_init_args_and_inputs_for_common(self):
7984 "dual_attention_layers" : (),
8085 "qk_norm" : None ,
8186 }
82- inputs_dict = self .dummy_input
83- return init_dict , inputs_dict
8487
85- @unittest .skipIf (
86- torch_device != "cuda" or not is_xformers_available (),
87- reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
88- )
89- def test_xformers_enable_works (self ):
90- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
91- model = self .model_class (** init_dict )
88+ def get_dummy_inputs (self , batch_size : int = 2 ) -> dict [str , torch .Tensor ]:
89+ num_channels = 4
90+ height = width = embedding_dim = 32
91+ pooled_embedding_dim = embedding_dim * 2
92+ sequence_length = 154
93+
94+ return {
95+ "hidden_states" : randn_tensor (
96+ (batch_size , num_channels , height , width ), generator = self .generator , device = torch_device
97+ ),
98+ "encoder_hidden_states" : randn_tensor (
99+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
100+ ),
101+ "pooled_projections" : randn_tensor (
102+ (batch_size , pooled_embedding_dim ), generator = self .generator , device = torch_device
103+ ),
104+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
105+ }
92106
93- model .enable_xformers_memory_efficient_attention ()
94107
95- assert model .transformer_blocks [0 ].attn .processor .__class__ .__name__ == "XFormersJointAttnProcessor" , (
96- "xformers is not enabled"
97- )
108+ class TestSD3Transformer (SD3TransformerTesterConfig , ModelTesterMixin ):
109+ pass
98110
99- @unittest .skip ("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply" )
100- def test_set_attn_processor_for_determinism (self ):
101- pass
102111
112+ class TestSD3TransformerTraining (SD3TransformerTesterConfig , TrainingTesterMixin ):
103113 def test_gradient_checkpointing_is_applied (self ):
104114 expected_set = {"SD3Transformer2DModel" }
105115 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
106116
107117
108- class SD35TransformerTests (ModelTesterMixin , unittest .TestCase ):
109- model_class = SD3Transformer2DModel
110- main_input_name = "hidden_states"
111- model_split_percents = [0.8 , 0.8 , 0.9 ]
118+ class TestSD3TransformerCompile (SD3TransformerTesterConfig , TorchCompileTesterMixin ):
119+ pass
120+
121+
122+ # ======================== SD3.5 Transformer ========================
123+
112124
125+ class SD35TransformerTesterConfig (BaseModelTesterConfig ):
113126 @property
114- def dummy_input (self ):
115- batch_size = 2
116- num_channels = 4
117- height = width = embedding_dim = 32
118- pooled_embedding_dim = embedding_dim * 2
119- sequence_length = 154
127+ def model_class (self ):
128+ return SD3Transformer2DModel
120129
121- hidden_states = torch .randn ((batch_size , num_channels , height , width )).to (torch_device )
122- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
123- pooled_prompt_embeds = torch .randn ((batch_size , pooled_embedding_dim )).to (torch_device )
124- timestep = torch .randint (0 , 1000 , size = (batch_size ,)).to (torch_device )
130+ @property
131+ def pretrained_model_name_or_path (self ):
132+ return "hf-internal-testing/tiny-sd35-pipe"
125133
126- return {
127- "hidden_states" : hidden_states ,
128- "encoder_hidden_states" : encoder_hidden_states ,
129- "pooled_projections" : pooled_prompt_embeds ,
130- "timestep" : timestep ,
131- }
134+ @property
135+ def pretrained_model_kwargs (self ):
136+ return {"subfolder" : "transformer" }
137+
138+ @property
139+ def main_input_name (self ) -> str :
140+ return "hidden_states"
141+
142+ @property
143+ def model_split_percents (self ) -> list :
144+ return [0.8 , 0.8 , 0.9 ]
132145
133146 @property
134- def input_shape (self ):
147+ def output_shape (self ) -> tuple :
135148 return (4 , 32 , 32 )
136149
137150 @property
138- def output_shape (self ):
151+ def input_shape (self ) -> tuple :
139152 return (4 , 32 , 32 )
140153
141- def prepare_init_args_and_inputs_for_common (self ):
142- init_dict = {
154+ @property
155+ def generator (self ):
156+ return torch .Generator ("cpu" ).manual_seed (0 )
157+
158+ def get_init_dict (self ) -> dict :
159+ return {
143160 "sample_size" : 32 ,
144161 "patch_size" : 1 ,
145162 "in_channels" : 4 ,
@@ -154,47 +171,56 @@ def prepare_init_args_and_inputs_for_common(self):
154171 "dual_attention_layers" : (0 ,),
155172 "qk_norm" : "rms_norm" ,
156173 }
157- inputs_dict = self .dummy_input
158- return init_dict , inputs_dict
159-
160- @unittest .skipIf (
161- torch_device != "cuda" or not is_xformers_available (),
162- reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
163- )
164- def test_xformers_enable_works (self ):
165- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
166- model = self .model_class (** init_dict )
167174
168- model . enable_xformers_memory_efficient_attention ()
169-
170- assert model . transformer_blocks [ 0 ]. attn . processor . __class__ . __name__ == "XFormersJointAttnProcessor" , (
171- "xformers is not enabled"
172- )
175+ def get_dummy_inputs ( self , batch_size : int = 2 ) -> dict [ str , torch . Tensor ]:
176+ num_channels = 4
177+ height = width = embedding_dim = 32
178+ pooled_embedding_dim = embedding_dim * 2
179+ sequence_length = 154
173180
174- @unittest .skip ("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply" )
175- def test_set_attn_processor_for_determinism (self ):
176- pass
181+ return {
182+ "hidden_states" : randn_tensor (
183+ (batch_size , num_channels , height , width ), generator = self .generator , device = torch_device
184+ ),
185+ "encoder_hidden_states" : randn_tensor (
186+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
187+ ),
188+ "pooled_projections" : randn_tensor (
189+ (batch_size , pooled_embedding_dim ), generator = self .generator , device = torch_device
190+ ),
191+ "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
192+ }
177193
178- def test_gradient_checkpointing_is_applied (self ):
179- expected_set = {"SD3Transformer2DModel" }
180- super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
181194
195+ class TestSD35Transformer (SD35TransformerTesterConfig , ModelTesterMixin ):
182196 def test_skip_layers (self ):
183- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
197+ init_dict = self .get_init_dict ()
198+ inputs_dict = self .get_dummy_inputs ()
184199 model = self .model_class (** init_dict ).to (torch_device )
185200
186- # Forward pass without skipping layers
187201 output_full = model (** inputs_dict ).sample
188202
189- # Forward pass with skipping layers 0 (since there's only one layer in this test setup)
190203 inputs_dict_with_skip = inputs_dict .copy ()
191204 inputs_dict_with_skip ["skip_layers" ] = [0 ]
192205 output_skip = model (** inputs_dict_with_skip ).sample
193206
194- # Check that the outputs are different
195- self .assertFalse (
196- torch .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped"
197- )
207+ assert not torch .allclose (output_full , output_skip , atol = 1e-5 ), "Outputs should differ when layers are skipped"
208+ assert output_full .shape == output_skip .shape , "Outputs should have the same shape"
209+
210+
211+ class TestSD35TransformerTraining (SD35TransformerTesterConfig , TrainingTesterMixin ):
212+ def test_gradient_checkpointing_is_applied (self ):
213+ expected_set = {"SD3Transformer2DModel" }
214+ super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
215+
216+
217+ class TestSD35TransformerCompile (SD35TransformerTesterConfig , TorchCompileTesterMixin ):
218+ pass
219+
220+
221+ class TestSD35TransformerBitsAndBytes (SD35TransformerTesterConfig , BitsAndBytesTesterMixin ):
222+ """BitsAndBytes quantization tests for SD3.5 Transformer."""
223+
198224
199- # Check that the outputs have the same shape
200- self . assertEqual ( output_full . shape , output_skip . shape , "Outputs should have the same shape" )
225+ class TestSD35TransformerTorchAo ( SD35TransformerTesterConfig , TorchAoTesterMixin ):
226+ """TorchAO quantization tests for SD3.5 Transformer."""
0 commit comments