Skip to content

Commit 47538fc

Browse files
DN6sayakpaul
andauthored
[CI] Refactor Chroma , LongCat and HiDream Transformer Tests (#13345)
* update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent e377c0a commit 47538fc

3 files changed

Lines changed: 230 additions & 146 deletions

File tree

tests/models/transformers/test_models_transformer_chroma.py

Lines changed: 79 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -13,113 +13,51 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import ChromaTransformer2DModel
21-
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
22-
from diffusers.models.embeddings import ImageProjection
19+
from diffusers.utils.torch_utils import randn_tensor
2320

2421
from ...testing_utils import enable_full_determinism, torch_device
25-
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
LoraHotSwappingForModelTesterMixin,
25+
LoraTesterMixin,
26+
ModelTesterMixin,
27+
TrainingTesterMixin,
28+
)
2629

2730

2831
enable_full_determinism()
2932

3033

31-
def create_chroma_ip_adapter_state_dict(model):
32-
# "ip_adapter" (cross-attention weights)
33-
ip_cross_attn_state_dict = {}
34-
key_id = 0
35-
36-
for name in model.attn_processors.keys():
37-
if name.startswith("single_transformer_blocks"):
38-
continue
39-
40-
joint_attention_dim = model.config["joint_attention_dim"]
41-
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
42-
sd = FluxIPAdapterJointAttnProcessor2_0(
43-
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
44-
).state_dict()
45-
ip_cross_attn_state_dict.update(
46-
{
47-
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
48-
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
49-
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
50-
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
51-
}
52-
)
53-
54-
key_id += 1
55-
56-
# "image_proj" (ImageProjection layer weights)
57-
58-
image_projection = ImageProjection(
59-
cross_attention_dim=model.config["joint_attention_dim"],
60-
image_embed_dim=model.config["pooled_projection_dim"],
61-
num_image_text_embeds=4,
62-
)
63-
64-
ip_image_projection_state_dict = {}
65-
sd = image_projection.state_dict()
66-
ip_image_projection_state_dict.update(
67-
{
68-
"proj.weight": sd["image_embeds.weight"],
69-
"proj.bias": sd["image_embeds.bias"],
70-
"norm.weight": sd["norm.weight"],
71-
"norm.bias": sd["norm.bias"],
72-
}
73-
)
74-
75-
del sd
76-
ip_state_dict = {}
77-
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
78-
return ip_state_dict
79-
80-
81-
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
82-
model_class = ChromaTransformer2DModel
83-
main_input_name = "hidden_states"
84-
# We override the items here because the transformer under consideration is small.
85-
model_split_percents = [0.8, 0.7, 0.7]
86-
87-
# Skip setting testing with default: AttnProcessor
88-
uses_custom_attn_processor = True
89-
34+
class ChromaTransformerTesterConfig(BaseModelTesterConfig):
9035
@property
91-
def dummy_input(self):
92-
batch_size = 1
93-
num_latent_channels = 4
94-
num_image_channels = 3
95-
height = width = 4
96-
sequence_length = 48
97-
embedding_dim = 32
36+
def model_class(self):
37+
return ChromaTransformer2DModel
9838

99-
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
100-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
101-
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
102-
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
103-
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
39+
@property
40+
def main_input_name(self) -> str:
41+
return "hidden_states"
10442

105-
return {
106-
"hidden_states": hidden_states,
107-
"encoder_hidden_states": encoder_hidden_states,
108-
"img_ids": image_ids,
109-
"txt_ids": text_ids,
110-
"timestep": timestep,
111-
}
43+
@property
44+
def model_split_percents(self) -> list:
45+
return [0.8, 0.7, 0.7]
11246

11347
@property
114-
def input_shape(self):
48+
def output_shape(self) -> tuple:
11549
return (16, 4)
11650

11751
@property
118-
def output_shape(self):
52+
def input_shape(self) -> tuple:
11953
return (16, 4)
12054

121-
def prepare_init_args_and_inputs_for_common(self):
122-
init_dict = {
55+
@property
56+
def generator(self):
57+
return torch.Generator("cpu").manual_seed(0)
58+
59+
def get_init_dict(self) -> dict:
60+
return {
12361
"patch_size": 1,
12462
"in_channels": 4,
12563
"num_layers": 1,
@@ -133,51 +71,87 @@ def prepare_init_args_and_inputs_for_common(self):
13371
"approximator_layers": 1,
13472
}
13573

136-
inputs_dict = self.dummy_input
137-
return init_dict, inputs_dict
74+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
75+
num_latent_channels = 4
76+
num_image_channels = 3
77+
height = width = 4
78+
sequence_length = 48
79+
embedding_dim = 32
80+
81+
return {
82+
"hidden_states": randn_tensor(
83+
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
84+
),
85+
"encoder_hidden_states": randn_tensor(
86+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
87+
),
88+
"img_ids": randn_tensor(
89+
(height * width, num_image_channels), generator=self.generator, device=torch_device
90+
),
91+
"txt_ids": randn_tensor(
92+
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
93+
),
94+
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
95+
}
13896

97+
98+
class TestChromaTransformer(ChromaTransformerTesterConfig, ModelTesterMixin):
13999
def test_deprecated_inputs_img_txt_ids_3d(self):
140-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
100+
init_dict = self.get_init_dict()
101+
inputs_dict = self.get_dummy_inputs()
102+
141103
model = self.model_class(**init_dict)
142104
model.to(torch_device)
143105
model.eval()
144106

145107
with torch.no_grad():
146108
output_1 = model(**inputs_dict).to_tuple()[0]
147109

148-
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
149110
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
150111
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
151112

152-
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
153-
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
113+
assert text_ids_3d.ndim == 3
114+
assert image_ids_3d.ndim == 3
154115

155116
inputs_dict["txt_ids"] = text_ids_3d
156117
inputs_dict["img_ids"] = image_ids_3d
157118

158119
with torch.no_grad():
159120
output_2 = model(**inputs_dict).to_tuple()[0]
160121

161-
self.assertEqual(output_1.shape, output_2.shape)
162-
self.assertTrue(
163-
torch.allclose(output_1, output_2, atol=1e-5),
164-
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
122+
assert output_1.shape == output_2.shape
123+
assert torch.allclose(output_1, output_2, atol=1e-5), (
124+
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
125+
"are not equal as them as 2d inputs"
165126
)
166127

128+
129+
class TestChromaTransformerTraining(ChromaTransformerTesterConfig, TrainingTesterMixin):
167130
def test_gradient_checkpointing_is_applied(self):
168131
expected_set = {"ChromaTransformer2DModel"}
169132
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
170133

171134

172-
class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
173-
model_class = ChromaTransformer2DModel
135+
class TestChromaTransformerLoRA(ChromaTransformerTesterConfig, LoraTesterMixin):
136+
pass
174137

175-
def prepare_init_args_and_inputs_for_common(self):
176-
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
177138

139+
class TestChromaTransformerLoRAHotSwap(ChromaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
140+
@property
141+
def different_shapes_for_compilation(self):
142+
return [(4, 4), (4, 8), (8, 8)]
178143

179-
class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
180-
model_class = ChromaTransformer2DModel
144+
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
145+
batch_size = 1
146+
num_latent_channels = 4
147+
num_image_channels = 3
148+
sequence_length = 24
149+
embedding_dim = 32
181150

182-
def prepare_init_args_and_inputs_for_common(self):
183-
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
151+
return {
152+
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
153+
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
154+
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
155+
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
156+
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
157+
}

tests/models/transformers/test_models_transformer_hidream.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,61 +13,49 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import HiDreamImageTransformer2DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2120

22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
torch_device,
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
ModelTesterMixin,
25+
TrainingTesterMixin,
2526
)
26-
from ..test_modeling_common import ModelTesterMixin
2727

2828

2929
enable_full_determinism()
3030

3131

32-
class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
33-
model_class = HiDreamImageTransformer2DModel
34-
main_input_name = "hidden_states"
35-
model_split_percents = [0.8, 0.8, 0.9]
36-
32+
class HiDreamTransformerTesterConfig(BaseModelTesterConfig):
3733
@property
38-
def dummy_input(self):
39-
batch_size = 2
40-
num_channels = 4
41-
height = width = 32
42-
embedding_dim_t5, embedding_dim_llama, embedding_dim_pooled = 8, 4, 8
43-
sequence_length = 8
34+
def model_class(self):
35+
return HiDreamImageTransformer2DModel
4436

45-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
46-
encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length, embedding_dim_t5)).to(torch_device)
47-
encoder_hidden_states_llama3 = torch.randn((batch_size, batch_size, sequence_length, embedding_dim_llama)).to(
48-
torch_device
49-
)
50-
pooled_embeds = torch.randn((batch_size, embedding_dim_pooled)).to(torch_device)
51-
timesteps = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
37+
@property
38+
def main_input_name(self) -> str:
39+
return "hidden_states"
5240

53-
return {
54-
"hidden_states": hidden_states,
55-
"encoder_hidden_states_t5": encoder_hidden_states_t5,
56-
"encoder_hidden_states_llama3": encoder_hidden_states_llama3,
57-
"pooled_embeds": pooled_embeds,
58-
"timesteps": timesteps,
59-
}
41+
@property
42+
def model_split_percents(self) -> list:
43+
return [0.8, 0.8, 0.9]
6044

6145
@property
62-
def input_shape(self):
46+
def output_shape(self) -> tuple:
6347
return (4, 32, 32)
6448

6549
@property
66-
def output_shape(self):
50+
def input_shape(self) -> tuple:
6751
return (4, 32, 32)
6852

69-
def prepare_init_args_and_inputs_for_common(self):
70-
init_dict = {
53+
@property
54+
def generator(self):
55+
return torch.Generator("cpu").manual_seed(0)
56+
57+
def get_init_dict(self) -> dict:
58+
return {
7159
"patch_size": 2,
7260
"in_channels": 4,
7361
"out_channels": 4,
@@ -82,15 +70,39 @@ def prepare_init_args_and_inputs_for_common(self):
8270
"axes_dims_rope": (4, 2, 2),
8371
"max_resolution": (32, 32),
8472
"llama_layers": (0, 1),
85-
"force_inference_output": True, # TODO: as we don't implement MoE loss in training tests.
73+
"force_inference_output": True,
8674
}
87-
inputs_dict = self.dummy_input
88-
return init_dict, inputs_dict
8975

90-
@unittest.skip("HiDreamImageTransformer2DModel uses a dedicated attention processor. This test doesn't apply")
91-
def test_set_attn_processor_for_determinism(self):
92-
pass
76+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
77+
num_channels = 4
78+
height = width = 32
79+
embedding_dim_t5, embedding_dim_llama, embedding_dim_pooled = 8, 4, 8
80+
sequence_length = 8
81+
82+
return {
83+
"hidden_states": randn_tensor(
84+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
85+
),
86+
"encoder_hidden_states_t5": randn_tensor(
87+
(batch_size, sequence_length, embedding_dim_t5), generator=self.generator, device=torch_device
88+
),
89+
"encoder_hidden_states_llama3": randn_tensor(
90+
(batch_size, batch_size, sequence_length, embedding_dim_llama),
91+
generator=self.generator,
92+
device=torch_device,
93+
),
94+
"pooled_embeds": randn_tensor(
95+
(batch_size, embedding_dim_pooled), generator=self.generator, device=torch_device
96+
),
97+
"timesteps": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
98+
}
99+
100+
101+
class TestHiDreamTransformer(HiDreamTransformerTesterConfig, ModelTesterMixin):
102+
pass
103+
93104

105+
class TestHiDreamTransformerTraining(HiDreamTransformerTesterConfig, TrainingTesterMixin):
94106
def test_gradient_checkpointing_is_applied(self):
95107
expected_set = {"HiDreamImageTransformer2DModel"}
96108
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)