1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
1816import torch
1917
2018from diffusers import ChromaTransformer2DModel
21- from diffusers .models .attention_processor import FluxIPAdapterJointAttnProcessor2_0
22- from diffusers .models .embeddings import ImageProjection
19+ from diffusers .utils .torch_utils import randn_tensor
2320
2421from ...testing_utils import enable_full_determinism , torch_device
25- from ..test_modeling_common import LoraHotSwappingForModelTesterMixin , ModelTesterMixin , TorchCompileTesterMixin
22+ from ..testing_utils import (
23+ BaseModelTesterConfig ,
24+ LoraHotSwappingForModelTesterMixin ,
25+ LoraTesterMixin ,
26+ ModelTesterMixin ,
27+ TrainingTesterMixin ,
28+ )
2629
2730
2831enable_full_determinism ()
2932
3033
31- def create_chroma_ip_adapter_state_dict (model ):
32- # "ip_adapter" (cross-attention weights)
33- ip_cross_attn_state_dict = {}
34- key_id = 0
35-
36- for name in model .attn_processors .keys ():
37- if name .startswith ("single_transformer_blocks" ):
38- continue
39-
40- joint_attention_dim = model .config ["joint_attention_dim" ]
41- hidden_size = model .config ["num_attention_heads" ] * model .config ["attention_head_dim" ]
42- sd = FluxIPAdapterJointAttnProcessor2_0 (
43- hidden_size = hidden_size , cross_attention_dim = joint_attention_dim , scale = 1.0
44- ).state_dict ()
45- ip_cross_attn_state_dict .update (
46- {
47- f"{ key_id } .to_k_ip.weight" : sd ["to_k_ip.0.weight" ],
48- f"{ key_id } .to_v_ip.weight" : sd ["to_v_ip.0.weight" ],
49- f"{ key_id } .to_k_ip.bias" : sd ["to_k_ip.0.bias" ],
50- f"{ key_id } .to_v_ip.bias" : sd ["to_v_ip.0.bias" ],
51- }
52- )
53-
54- key_id += 1
55-
56- # "image_proj" (ImageProjection layer weights)
57-
58- image_projection = ImageProjection (
59- cross_attention_dim = model .config ["joint_attention_dim" ],
60- image_embed_dim = model .config ["pooled_projection_dim" ],
61- num_image_text_embeds = 4 ,
62- )
63-
64- ip_image_projection_state_dict = {}
65- sd = image_projection .state_dict ()
66- ip_image_projection_state_dict .update (
67- {
68- "proj.weight" : sd ["image_embeds.weight" ],
69- "proj.bias" : sd ["image_embeds.bias" ],
70- "norm.weight" : sd ["norm.weight" ],
71- "norm.bias" : sd ["norm.bias" ],
72- }
73- )
74-
75- del sd
76- ip_state_dict = {}
77- ip_state_dict .update ({"image_proj" : ip_image_projection_state_dict , "ip_adapter" : ip_cross_attn_state_dict })
78- return ip_state_dict
79-
80-
81- class ChromaTransformerTests (ModelTesterMixin , unittest .TestCase ):
82- model_class = ChromaTransformer2DModel
83- main_input_name = "hidden_states"
84- # We override the items here because the transformer under consideration is small.
85- model_split_percents = [0.8 , 0.7 , 0.7 ]
86-
87- # Skip setting testing with default: AttnProcessor
88- uses_custom_attn_processor = True
89-
34+ class ChromaTransformerTesterConfig (BaseModelTesterConfig ):
9035 @property
91- def dummy_input (self ):
92- batch_size = 1
93- num_latent_channels = 4
94- num_image_channels = 3
95- height = width = 4
96- sequence_length = 48
97- embedding_dim = 32
36+ def model_class (self ):
37+ return ChromaTransformer2DModel
9838
99- hidden_states = torch .randn ((batch_size , height * width , num_latent_channels )).to (torch_device )
100- encoder_hidden_states = torch .randn ((batch_size , sequence_length , embedding_dim )).to (torch_device )
101- text_ids = torch .randn ((sequence_length , num_image_channels )).to (torch_device )
102- image_ids = torch .randn ((height * width , num_image_channels )).to (torch_device )
103- timestep = torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size )
39+ @property
40+ def main_input_name (self ) -> str :
41+ return "hidden_states"
10442
105- return {
106- "hidden_states" : hidden_states ,
107- "encoder_hidden_states" : encoder_hidden_states ,
108- "img_ids" : image_ids ,
109- "txt_ids" : text_ids ,
110- "timestep" : timestep ,
111- }
43+ @property
44+ def model_split_percents (self ) -> list :
45+ return [0.8 , 0.7 , 0.7 ]
11246
11347 @property
114- def input_shape (self ):
48+ def output_shape (self ) -> tuple :
11549 return (16 , 4 )
11650
11751 @property
118- def output_shape (self ):
52+ def input_shape (self ) -> tuple :
11953 return (16 , 4 )
12054
121- def prepare_init_args_and_inputs_for_common (self ):
122- init_dict = {
55+ @property
56+ def generator (self ):
57+ return torch .Generator ("cpu" ).manual_seed (0 )
58+
59+ def get_init_dict (self ) -> dict :
60+ return {
12361 "patch_size" : 1 ,
12462 "in_channels" : 4 ,
12563 "num_layers" : 1 ,
@@ -133,51 +71,87 @@ def prepare_init_args_and_inputs_for_common(self):
13371 "approximator_layers" : 1 ,
13472 }
13573
136- inputs_dict = self .dummy_input
137- return init_dict , inputs_dict
74+ def get_dummy_inputs (self , batch_size : int = 1 ) -> dict [str , torch .Tensor ]:
75+ num_latent_channels = 4
76+ num_image_channels = 3
77+ height = width = 4
78+ sequence_length = 48
79+ embedding_dim = 32
80+
81+ return {
82+ "hidden_states" : randn_tensor (
83+ (batch_size , height * width , num_latent_channels ), generator = self .generator , device = torch_device
84+ ),
85+ "encoder_hidden_states" : randn_tensor (
86+ (batch_size , sequence_length , embedding_dim ), generator = self .generator , device = torch_device
87+ ),
88+ "img_ids" : randn_tensor (
89+ (height * width , num_image_channels ), generator = self .generator , device = torch_device
90+ ),
91+ "txt_ids" : randn_tensor (
92+ (sequence_length , num_image_channels ), generator = self .generator , device = torch_device
93+ ),
94+ "timestep" : torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size ),
95+ }
13896
97+
98+ class TestChromaTransformer (ChromaTransformerTesterConfig , ModelTesterMixin ):
13999 def test_deprecated_inputs_img_txt_ids_3d (self ):
140- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
100+ init_dict = self .get_init_dict ()
101+ inputs_dict = self .get_dummy_inputs ()
102+
141103 model = self .model_class (** init_dict )
142104 model .to (torch_device )
143105 model .eval ()
144106
145107 with torch .no_grad ():
146108 output_1 = model (** inputs_dict ).to_tuple ()[0 ]
147109
148- # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
149110 text_ids_3d = inputs_dict ["txt_ids" ].unsqueeze (0 )
150111 image_ids_3d = inputs_dict ["img_ids" ].unsqueeze (0 )
151112
152- assert text_ids_3d .ndim == 3 , "text_ids_3d should be a 3d tensor"
153- assert image_ids_3d .ndim == 3 , "img_ids_3d should be a 3d tensor"
113+ assert text_ids_3d .ndim == 3
114+ assert image_ids_3d .ndim == 3
154115
155116 inputs_dict ["txt_ids" ] = text_ids_3d
156117 inputs_dict ["img_ids" ] = image_ids_3d
157118
158119 with torch .no_grad ():
159120 output_2 = model (** inputs_dict ).to_tuple ()[0 ]
160121
161- self . assertEqual ( output_1 .shape , output_2 .shape )
162- self . assertTrue (
163- torch . allclose ( output_1 , output_2 , atol = 1e-5 ),
164- msg = "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
122+ assert output_1 .shape == output_2 .shape
123+ assert torch . allclose ( output_1 , output_2 , atol = 1e-5 ), (
124+ "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
125+ " are not equal as them as 2d inputs"
165126 )
166127
128+
129+ class TestChromaTransformerTraining (ChromaTransformerTesterConfig , TrainingTesterMixin ):
167130 def test_gradient_checkpointing_is_applied (self ):
168131 expected_set = {"ChromaTransformer2DModel" }
169132 super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
170133
171134
172- class ChromaTransformerCompileTests ( TorchCompileTesterMixin , unittest . TestCase ):
173- model_class = ChromaTransformer2DModel
135+ class TestChromaTransformerLoRA ( ChromaTransformerTesterConfig , LoraTesterMixin ):
136+ pass
174137
175- def prepare_init_args_and_inputs_for_common (self ):
176- return ChromaTransformerTests ().prepare_init_args_and_inputs_for_common ()
177138
139+ class TestChromaTransformerLoRAHotSwap (ChromaTransformerTesterConfig , LoraHotSwappingForModelTesterMixin ):
140+ @property
141+ def different_shapes_for_compilation (self ):
142+ return [(4 , 4 ), (4 , 8 ), (8 , 8 )]
178143
179- class ChromaTransformerLoRAHotSwapTests (LoraHotSwappingForModelTesterMixin , unittest .TestCase ):
180- model_class = ChromaTransformer2DModel
144+ def get_dummy_inputs (self , height : int = 4 , width : int = 4 ) -> dict [str , torch .Tensor ]:
145+ batch_size = 1
146+ num_latent_channels = 4
147+ num_image_channels = 3
148+ sequence_length = 24
149+ embedding_dim = 32
181150
182- def prepare_init_args_and_inputs_for_common (self ):
183- return ChromaTransformerTests ().prepare_init_args_and_inputs_for_common ()
151+ return {
152+ "hidden_states" : randn_tensor ((batch_size , height * width , num_latent_channels ), device = torch_device ),
153+ "encoder_hidden_states" : randn_tensor ((batch_size , sequence_length , embedding_dim ), device = torch_device ),
154+ "img_ids" : randn_tensor ((height * width , num_image_channels ), device = torch_device ),
155+ "txt_ids" : randn_tensor ((sequence_length , num_image_channels ), device = torch_device ),
156+ "timestep" : torch .tensor ([1.0 ]).to (torch_device ).expand (batch_size ),
157+ }
0 commit comments