1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
1816import torch
1917
2018from diffusers import BriaTransformer2DModel
21- from diffusers .models .attention_processor import FluxIPAdapterJointAttnProcessor2_0
22- from diffusers .models .embeddings import ImageProjection
19+ from diffusers .utils .torch_utils import randn_tensor
2320
2421from ...testing_utils import enable_full_determinism , torch_device
25- from ..test_modeling_common import LoraHotSwappingForModelTesterMixin , ModelTesterMixin , TorchCompileTesterMixin
22+ from ..testing_utils import (
23+ BaseModelTesterConfig ,
24+ ModelTesterMixin ,
25+ TrainingTesterMixin ,
26+ )
2627
2728
2829enable_full_determinism ()
2930
3031
31- def create_bria_ip_adapter_state_dict (model ):
32- # "ip_adapter" (cross-attention weights)
33- ip_cross_attn_state_dict = {}
34- key_id = 0
35-
36- for name in model .attn_processors .keys ():
37- if name .startswith ("single_transformer_blocks" ):
38- continue
39-
40- joint_attention_dim = model .config ["joint_attention_dim" ]
41- hidden_size = model .config ["num_attention_heads" ] * model .config ["attention_head_dim" ]
42- sd = FluxIPAdapterJointAttnProcessor2_0 (
43- hidden_size = hidden_size , cross_attention_dim = joint_attention_dim , scale = 1.0
44- ).state_dict ()
45- ip_cross_attn_state_dict .update (
46- {
47- f"{ key_id } .to_k_ip.weight" : sd ["to_k_ip.0.weight" ],
48- f"{ key_id } .to_v_ip.weight" : sd ["to_v_ip.0.weight" ],
49- f"{ key_id } .to_k_ip.bias" : sd ["to_k_ip.0.bias" ],
50- f"{ key_id } .to_v_ip.bias" : sd ["to_v_ip.0.bias" ],
51- }
52- )
53-
54- key_id += 1
55-
56- # "image_proj" (ImageProjection layer weights)
57-
58- image_projection = ImageProjection (
59- cross_attention_dim = model .config ["joint_attention_dim" ],
60- image_embed_dim = model .config ["pooled_projection_dim" ],
61- num_image_text_embeds = 4 ,
62- )
63-
64- ip_image_projection_state_dict = {}
65- sd = image_projection .state_dict ()
66- ip_image_projection_state_dict .update (
67- {
68- "proj.weight" : sd ["image_embeds.weight" ],
69- "proj.bias" : sd ["image_embeds.bias" ],
70- "norm.weight" : sd ["norm.weight" ],
71- "norm.bias" : sd ["norm.bias" ],
72- }
73- )
74-
75- del sd
76- ip_state_dict = {}
77- ip_state_dict .update ({"image_proj" : ip_image_projection_state_dict , "ip_adapter" : ip_cross_attn_state_dict })
78- return ip_state_dict
79-
80-
81- class BriaTransformerTests (ModelTesterMixin , unittest .TestCase ):
82- model_class = BriaTransformer2DModel
83- main_input_name = "hidden_states"
84- # We override the items here because the transformer under consideration is small.
85- model_split_percents = [0.8 , 0.7 , 0.7 ]
86-
87- # Skip setting testing with default: AttnProcessor
88- uses_custom_attn_processor = True
89-
32+ class BriaTransformerTesterConfig (BaseModelTesterConfig ):
9033 @property
91- def dummy_input (self ):
92- batch_size = 1
93- num_latent_channels = 4
94- num_image_channels = 3
95- height = width = 4
96- sequence_length = 48
97- embedding_dim = 32
34+ def model_class (self ):
35+ return BriaTransformer2DModel
9836
99- hidden_states = torch .randn ((batch_size , height * width , num_latent_channels )).to (torch_device )
100- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
101- text_ids = torch .randn ((sequence_length , num_image_channels )).to (torch_device )
102- image_ids = torch .randn ((height * width , num_image_channels )).to (torch_device )
103- timestep = torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size )
37+ @property
38+ def main_input_name (self ) -> str :
39+ return "hidden_states"
10440
105- return {
106- "hidden_states" : hidden_states ,
107- "encoder_hidden_states" : encoder_hidden_states ,
108- "img_ids" : image_ids ,
109- "txt_ids" : text_ids ,
110- "timestep" : timestep ,
111- }
41+ @property
42+ def model_split_percents (self ) -> list :
43+ return [0.8 , 0.7 , 0.7 ]
11244
11345 @property
114- def input_shape (self ):
46+ def output_shape (self ) -> tuple :
11547 return (16 , 4 )
11648
11749 @property
118- def output_shape (self ):
119- return ( 16 , 4 )
50+ def generator (self ):
51+ return torch . Generator ( "cpu" ). manual_seed ( 0 )
12052
121- def prepare_init_args_and_inputs_for_common (self ):
122- init_dict = {
53+ def get_init_dict (self ) -> dict :
54+ return {
12355 "patch_size" : 1 ,
12456 "in_channels" : 4 ,
12557 "num_layers" : 1 ,
@@ -131,19 +63,42 @@ def prepare_init_args_and_inputs_for_common(self):
13163 "axes_dims_rope" : [0 , 4 , 4 ],
13264 }
13365
134- inputs_dict = self .dummy_input
135- return init_dict , inputs_dict
66+ def get_dummy_inputs (self , batch_size : int = 1 ) -> dict [str , torch .Tensor ]:
67+ num_latent_channels = 4
68+ num_image_channels = 3
69+ height = width = 4
70+ sequence_length = 48
71+ embedding_dim = 32
72+
73+ return {
74+ "hidden_states" : randn_tensor (
75+ (batch_size , height * width , num_latent_channels ), generator = self .generator , device = torch_device
76+ ),
77+ "encoder_hidden_states" : randn_tensor (
78+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
79+ ),
80+ "img_ids" : randn_tensor (
81+ (height * width , num_image_channels ), generator = self .generator , device = torch_device
82+ ),
83+ "txt_ids" : randn_tensor (
84+ (sequence_length , num_image_channels ), generator = self .generator , device = torch_device
85+ ),
86+ "timestep" : torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size ),
87+ }
88+
13689
90+ class TestBriaTransformer (BriaTransformerTesterConfig , ModelTesterMixin ):
13791 def test_deprecated_inputs_img_txt_ids_3d (self ):
138- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
92+ init_dict = self .get_init_dict ()
93+ inputs_dict = self .get_dummy_inputs ()
94+
13995 model = self .model_class (** init_dict )
14096 model .to (torch_device )
14197 model .eval ()
14298
14399 with torch .no_grad ():
144100 output_1 = model (** inputs_dict ).to_tuple ()[0 ]
145101
146- # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
147102 text_ids_3d = inputs_dict ["txt_ids" ].unsqueeze (0 )
148103 image_ids_3d = inputs_dict ["img_ids" ].unsqueeze (0 )
149104
@@ -156,26 +111,14 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
156111 with torch .no_grad ():
157112 output_2 = model (** inputs_dict ).to_tuple ()[0 ]
158113
159- self . assertEqual ( output_1 .shape , output_2 .shape )
160- self . assertTrue (
161- torch . allclose ( output_1 , output_2 , atol = 1e-5 ),
162- msg = "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
114+ assert output_1 .shape == output_2 .shape
115+ assert torch . allclose ( output_1 , output_2 , atol = 1e-5 ), (
116+ "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
117+ " are not equal as them as 2d inputs"
163118 )
164119
120+
121+ class TestBriaTransformerTraining (BriaTransformerTesterConfig , TrainingTesterMixin ):
165122 def test_gradient_checkpointing_is_applied (self ):
166123 expected_set = {"BriaTransformer2DModel" }
167124 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
168-
169-
170- class BriaTransformerCompileTests (TorchCompileTesterMixin , unittest .TestCase ):
171- model_class = BriaTransformer2DModel
172-
173- def prepare_init_args_and_inputs_for_common (self ):
174- return BriaTransformerTests ().prepare_init_args_and_inputs_for_common ()
175-
176-
177- class BriaTransformerLoRAHotSwapTests (LoraHotSwappingForModelTesterMixin , unittest .TestCase ):
178- model_class = BriaTransformer2DModel
179-
180- def prepare_init_args_and_inputs_for_common (self ):
181- return BriaTransformerTests ().prepare_init_args_and_inputs_for_common ()
0 commit comments