Skip to content

Commit d6c360d

Browse files
DN6sayakpaul
andauthored
[CI] Refactor Bria Transformer Tests (#13341)
* update * update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 3b80bcf commit d6c360d

2 files changed

Lines changed: 104 additions & 150 deletions

File tree

tests/models/transformers/test_models_transformer_bria.py

Lines changed: 53 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -13,113 +13,45 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import BriaTransformer2DModel
21-
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
22-
from diffusers.models.embeddings import ImageProjection
19+
from diffusers.utils.torch_utils import randn_tensor
2320

2421
from ...testing_utils import enable_full_determinism, torch_device
25-
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
2627

2728

2829
enable_full_determinism()
2930

3031

31-
def create_bria_ip_adapter_state_dict(model):
32-
# "ip_adapter" (cross-attention weights)
33-
ip_cross_attn_state_dict = {}
34-
key_id = 0
35-
36-
for name in model.attn_processors.keys():
37-
if name.startswith("single_transformer_blocks"):
38-
continue
39-
40-
joint_attention_dim = model.config["joint_attention_dim"]
41-
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
42-
sd = FluxIPAdapterJointAttnProcessor2_0(
43-
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
44-
).state_dict()
45-
ip_cross_attn_state_dict.update(
46-
{
47-
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
48-
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
49-
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
50-
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
51-
}
52-
)
53-
54-
key_id += 1
55-
56-
# "image_proj" (ImageProjection layer weights)
57-
58-
image_projection = ImageProjection(
59-
cross_attention_dim=model.config["joint_attention_dim"],
60-
image_embed_dim=model.config["pooled_projection_dim"],
61-
num_image_text_embeds=4,
62-
)
63-
64-
ip_image_projection_state_dict = {}
65-
sd = image_projection.state_dict()
66-
ip_image_projection_state_dict.update(
67-
{
68-
"proj.weight": sd["image_embeds.weight"],
69-
"proj.bias": sd["image_embeds.bias"],
70-
"norm.weight": sd["norm.weight"],
71-
"norm.bias": sd["norm.bias"],
72-
}
73-
)
74-
75-
del sd
76-
ip_state_dict = {}
77-
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
78-
return ip_state_dict
79-
80-
81-
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
82-
model_class = BriaTransformer2DModel
83-
main_input_name = "hidden_states"
84-
# We override the items here because the transformer under consideration is small.
85-
model_split_percents = [0.8, 0.7, 0.7]
86-
87-
# Skip setting testing with default: AttnProcessor
88-
uses_custom_attn_processor = True
89-
32+
class BriaTransformerTesterConfig(BaseModelTesterConfig):
9033
@property
91-
def dummy_input(self):
92-
batch_size = 1
93-
num_latent_channels = 4
94-
num_image_channels = 3
95-
height = width = 4
96-
sequence_length = 48
97-
embedding_dim = 32
34+
def model_class(self):
35+
return BriaTransformer2DModel
9836

99-
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
100-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
101-
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
102-
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
103-
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
10440

105-
return {
106-
"hidden_states": hidden_states,
107-
"encoder_hidden_states": encoder_hidden_states,
108-
"img_ids": image_ids,
109-
"txt_ids": text_ids,
110-
"timestep": timestep,
111-
}
41+
@property
42+
def model_split_percents(self) -> list:
43+
return [0.8, 0.7, 0.7]
11244

11345
@property
114-
def input_shape(self):
46+
def output_shape(self) -> tuple:
11547
return (16, 4)
11648

11749
@property
118-
def output_shape(self):
119-
return (16, 4)
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
12052

121-
def prepare_init_args_and_inputs_for_common(self):
122-
init_dict = {
53+
def get_init_dict(self) -> dict:
54+
return {
12355
"patch_size": 1,
12456
"in_channels": 4,
12557
"num_layers": 1,
@@ -131,19 +63,42 @@ def prepare_init_args_and_inputs_for_common(self):
13163
"axes_dims_rope": [0, 4, 4],
13264
}
13365

134-
inputs_dict = self.dummy_input
135-
return init_dict, inputs_dict
66+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
67+
num_latent_channels = 4
68+
num_image_channels = 3
69+
height = width = 4
70+
sequence_length = 48
71+
embedding_dim = 32
72+
73+
return {
74+
"hidden_states": randn_tensor(
75+
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
76+
),
77+
"encoder_hidden_states": randn_tensor(
78+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
79+
),
80+
"img_ids": randn_tensor(
81+
(height * width, num_image_channels), generator=self.generator, device=torch_device
82+
),
83+
"txt_ids": randn_tensor(
84+
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
85+
),
86+
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
87+
}
88+
13689

90+
class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
13791
def test_deprecated_inputs_img_txt_ids_3d(self):
138-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
92+
init_dict = self.get_init_dict()
93+
inputs_dict = self.get_dummy_inputs()
94+
13995
model = self.model_class(**init_dict)
14096
model.to(torch_device)
14197
model.eval()
14298

14399
with torch.no_grad():
144100
output_1 = model(**inputs_dict).to_tuple()[0]
145101

146-
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
147102
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
148103
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
149104

@@ -156,26 +111,14 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
156111
with torch.no_grad():
157112
output_2 = model(**inputs_dict).to_tuple()[0]
158113

159-
self.assertEqual(output_1.shape, output_2.shape)
160-
self.assertTrue(
161-
torch.allclose(output_1, output_2, atol=1e-5),
162-
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
114+
assert output_1.shape == output_2.shape
115+
assert torch.allclose(output_1, output_2, atol=1e-5), (
116+
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
117+
"are not equal as them as 2d inputs"
163118
)
164119

120+
121+
class TestBriaTransformerTraining(BriaTransformerTesterConfig, TrainingTesterMixin):
165122
def test_gradient_checkpointing_is_applied(self):
166123
expected_set = {"BriaTransformer2DModel"}
167124
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
168-
169-
170-
class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
171-
model_class = BriaTransformer2DModel
172-
173-
def prepare_init_args_and_inputs_for_common(self):
174-
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
175-
176-
177-
class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
178-
model_class = BriaTransformer2DModel
179-
180-
def prepare_init_args_and_inputs_for_common(self):
181-
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()

tests/models/transformers/test_models_transformer_bria_fibo.py

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,62 +13,45 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import BriaFiboTransformer2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

2221
from ...testing_utils import enable_full_determinism, torch_device
23-
from ..test_modeling_common import ModelTesterMixin
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
26+
)
2427

2528

2629
enable_full_determinism()
2730

2831

29-
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
30-
model_class = BriaFiboTransformer2DModel
31-
main_input_name = "hidden_states"
32-
# We override the items here because the transformer under consideration is small.
33-
model_split_percents = [0.8, 0.7, 0.7]
34-
35-
# Skip setting testing with default: AttnProcessor
36-
uses_custom_attn_processor = True
37-
32+
class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
3833
@property
39-
def dummy_input(self):
40-
batch_size = 1
41-
num_latent_channels = 48
42-
num_image_channels = 3
43-
height = width = 16
44-
sequence_length = 32
45-
embedding_dim = 64
34+
def model_class(self):
35+
return BriaFiboTransformer2DModel
4636

47-
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
48-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
49-
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
50-
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
51-
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
52-
53-
return {
54-
"hidden_states": hidden_states,
55-
"encoder_hidden_states": encoder_hidden_states,
56-
"img_ids": image_ids,
57-
"txt_ids": text_ids,
58-
"timestep": timestep,
59-
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
60-
}
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
6140

6241
@property
63-
def input_shape(self):
64-
return (16, 16)
42+
def model_split_percents(self) -> list:
43+
return [0.8, 0.7, 0.7]
6544

6645
@property
67-
def output_shape(self):
46+
def output_shape(self) -> tuple:
6847
return (256, 48)
6948

70-
def prepare_init_args_and_inputs_for_common(self):
71-
init_dict = {
49+
@property
50+
def generator(self):
51+
return torch.Generator("cpu").manual_seed(0)
52+
53+
def get_init_dict(self) -> dict:
54+
return {
7255
"patch_size": 1,
7356
"in_channels": 48,
7457
"num_layers": 1,
@@ -81,9 +64,37 @@ def prepare_init_args_and_inputs_for_common(self):
8164
"axes_dims_rope": [0, 4, 4],
8265
}
8366

84-
inputs_dict = self.dummy_input
85-
return init_dict, inputs_dict
67+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
68+
num_latent_channels = 48
69+
num_image_channels = 3
70+
height = width = 16
71+
sequence_length = 32
72+
embedding_dim = 64
73+
74+
encoder_hidden_states = randn_tensor(
75+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
76+
)
77+
return {
78+
"hidden_states": randn_tensor(
79+
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
80+
),
81+
"encoder_hidden_states": encoder_hidden_states,
82+
"img_ids": randn_tensor(
83+
(height * width, num_image_channels), generator=self.generator, device=torch_device
84+
),
85+
"txt_ids": randn_tensor(
86+
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
87+
),
88+
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
89+
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
90+
}
91+
92+
93+
class TestBriaFiboTransformer(BriaFiboTransformerTesterConfig, ModelTesterMixin):
94+
pass
95+
8696

97+
class TestBriaFiboTransformerTraining(BriaFiboTransformerTesterConfig, TrainingTesterMixin):
8798
def test_gradient_checkpointing_is_applied(self):
8899
expected_set = {"BriaFiboTransformer2DModel"}
89100
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)