1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import unittest
16-
1715import torch
1816
1917from diffusers import SanaTransformer2DModel
18+ from diffusers .utils .torch_utils import randn_tensor
2019
21- from ...testing_utils import (
22- enable_full_determinism ,
23- torch_device ,
20+ from ...testing_utils import enable_full_determinism , torch_device
21+ from ..testing_utils import (
22+ AttentionTesterMixin ,
23+ BaseModelTesterConfig ,
24+ MemoryTesterMixin ,
25+ ModelTesterMixin ,
26+ TrainingTesterMixin ,
2427)
25- from ..test_modeling_common import ModelTesterMixin
2628
2729
2830enable_full_determinism ()
2931
3032
31- class SanaTransformerTests (ModelTesterMixin , unittest .TestCase ):
32- model_class = SanaTransformer2DModel
33- main_input_name = "hidden_states"
34- uses_custom_attn_processor = True
35- model_split_percents = [0.7 , 0.7 , 0.9 ]
36-
33+ class SanaTransformerTesterConfig (BaseModelTesterConfig ):
3734 @property
38- def dummy_input (self ):
39- batch_size = 2
40- num_channels = 4
41- height = 32
42- width = 32
43- embedding_dim = 8
44- sequence_length = 8
35+ def model_class (self ):
36+ return SanaTransformer2DModel
4537
46- hidden_states = torch . randn (( batch_size , num_channels , height , width )). to ( torch_device )
47- encoder_hidden_states = torch . randn (( batch_size , sequence_length , embedding_dim )). to ( torch_device )
48- timestep = torch . randint ( 0 , 1000 , size = ( batch_size ,)). to ( torch_device )
38+ @ property
39+ def main_input_name ( self ) -> str :
40+ return "hidden_states"
4941
50- return {
51- "hidden_states" : hidden_states ,
52- "encoder_hidden_states" : encoder_hidden_states ,
53- "timestep" : timestep ,
54- }
42+ @property
43+ def uses_custom_attn_processor (self ) -> bool :
44+ return True
5545
5646 @property
57- def input_shape (self ):
47+ def output_shape (self ) -> tuple :
5848 return (4 , 32 , 32 )
5949
6050 @property
61- def output_shape (self ):
51+ def input_shape (self ) -> tuple :
6252 return (4 , 32 , 32 )
6353
64- def prepare_init_args_and_inputs_for_common (self ):
65- init_dict = {
54+ @property
55+ def model_split_percents (self ) -> list :
56+ return [0.7 , 0.7 , 0.9 ]
57+
58+ @property
59+ def generator (self ):
60+ return torch .Generator ("cpu" ).manual_seed (0 )
61+
62+ def get_init_dict (self ) -> dict :
63+ return {
6664 "patch_size" : 1 ,
6765 "in_channels" : 4 ,
6866 "out_channels" : 4 ,
@@ -75,9 +73,43 @@ def prepare_init_args_and_inputs_for_common(self):
7573 "caption_channels" : 8 ,
7674 "sample_size" : 32 ,
7775 }
78- inputs_dict = self .dummy_input
79- return init_dict , inputs_dict
8076
77+ def get_dummy_inputs (self ) -> dict :
78+ batch_size = 2
79+ num_channels = 4
80+ height = 32
81+ width = 32
82+ embedding_dim = 8
83+ sequence_length = 8
84+
85+ hidden_states = randn_tensor (
86+ (batch_size , num_channels , height , width ), generator = self .generator , device = torch_device
87+ )
88+ encoder_hidden_states = randn_tensor (
89+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
90+ )
91+ timestep = torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device )
92+
93+ return {
94+ "hidden_states" : hidden_states ,
95+ "encoder_hidden_states" : encoder_hidden_states ,
96+ "timestep" : timestep ,
97+ }
98+
99+
100+ class TestSanaTransformer (SanaTransformerTesterConfig , ModelTesterMixin ):
101+ pass
102+
103+
104+ class TestSanaTransformerMemory (SanaTransformerTesterConfig , MemoryTesterMixin ):
105+ pass
106+
107+
108+ class TestSanaTransformerTraining (SanaTransformerTesterConfig , TrainingTesterMixin ):
81109 def test_gradient_checkpointing_is_applied (self ):
82110 expected_set = {"SanaTransformer2DModel" }
83111 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
112+
113+
114+ class TestSanaTransformerAttention (SanaTransformerTesterConfig , AttentionTesterMixin ):
115+ pass
0 commit comments