Skip to content

Commit 2c7efb9

Browse files
DN6sayakpaul
andauthored
[CI] Refactor SD3 Transformer Test (#13340)
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent c07f09c commit 2c7efb9

1 file changed

Lines changed: 128 additions & 102 deletions

File tree

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 128 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,63 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
1816
import torch
1917

2018
from diffusers import SD3Transformer2DModel
21-
from diffusers.utils.import_utils import is_xformers_available
22-
23-
from ...testing_utils import (
24-
enable_full_determinism,
25-
torch_device,
19+
from diffusers.utils.torch_utils import randn_tensor
20+
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
BaseModelTesterConfig,
24+
BitsAndBytesTesterMixin,
25+
ModelTesterMixin,
26+
TorchAoTesterMixin,
27+
TorchCompileTesterMixin,
28+
TrainingTesterMixin,
2629
)
27-
from ..test_modeling_common import ModelTesterMixin
2830

2931

3032
enable_full_determinism()
3133

3234

33-
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
34-
model_class = SD3Transformer2DModel
35-
main_input_name = "hidden_states"
36-
model_split_percents = [0.8, 0.8, 0.9]
35+
# ======================== SD3 Transformer ========================
3736

37+
38+
class SD3TransformerTesterConfig(BaseModelTesterConfig):
3839
@property
39-
def dummy_input(self):
40-
batch_size = 2
41-
num_channels = 4
42-
height = width = embedding_dim = 32
43-
pooled_embedding_dim = embedding_dim * 2
44-
sequence_length = 154
40+
def model_class(self):
41+
return SD3Transformer2DModel
42+
43+
@property
44+
def pretrained_model_name_or_path(self):
45+
return "hf-internal-testing/tiny-sd3-pipe"
4546

46-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
47-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
48-
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
49-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
47+
@property
48+
def pretrained_model_kwargs(self):
49+
return {"subfolder": "transformer"}
5050

51-
return {
52-
"hidden_states": hidden_states,
53-
"encoder_hidden_states": encoder_hidden_states,
54-
"pooled_projections": pooled_prompt_embeds,
55-
"timestep": timestep,
56-
}
51+
@property
52+
def main_input_name(self) -> str:
53+
return "hidden_states"
54+
55+
@property
56+
def model_split_percents(self) -> list:
57+
return [0.8, 0.8, 0.9]
5758

5859
@property
59-
def input_shape(self):
60+
def output_shape(self) -> tuple:
6061
return (4, 32, 32)
6162

6263
@property
63-
def output_shape(self):
64+
def input_shape(self) -> tuple:
6465
return (4, 32, 32)
6566

66-
def prepare_init_args_and_inputs_for_common(self):
67-
init_dict = {
67+
@property
68+
def generator(self):
69+
return torch.Generator("cpu").manual_seed(0)
70+
71+
def get_init_dict(self) -> dict:
72+
return {
6873
"sample_size": 32,
6974
"patch_size": 1,
7075
"in_channels": 4,
@@ -79,67 +84,79 @@ def prepare_init_args_and_inputs_for_common(self):
7984
"dual_attention_layers": (),
8085
"qk_norm": None,
8186
}
82-
inputs_dict = self.dummy_input
83-
return init_dict, inputs_dict
8487

85-
@unittest.skipIf(
86-
torch_device != "cuda" or not is_xformers_available(),
87-
reason="XFormers attention is only available with CUDA and `xformers` installed",
88-
)
89-
def test_xformers_enable_works(self):
90-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
91-
model = self.model_class(**init_dict)
88+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
89+
num_channels = 4
90+
height = width = embedding_dim = 32
91+
pooled_embedding_dim = embedding_dim * 2
92+
sequence_length = 154
93+
94+
return {
95+
"hidden_states": randn_tensor(
96+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
97+
),
98+
"encoder_hidden_states": randn_tensor(
99+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
100+
),
101+
"pooled_projections": randn_tensor(
102+
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
103+
),
104+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
105+
}
92106

93-
model.enable_xformers_memory_efficient_attention()
94107

95-
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
96-
"xformers is not enabled"
97-
)
108+
class TestSD3Transformer(SD3TransformerTesterConfig, ModelTesterMixin):
109+
pass
98110

99-
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
100-
def test_set_attn_processor_for_determinism(self):
101-
pass
102111

112+
class TestSD3TransformerTraining(SD3TransformerTesterConfig, TrainingTesterMixin):
103113
def test_gradient_checkpointing_is_applied(self):
104114
expected_set = {"SD3Transformer2DModel"}
105115
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
106116

107117

108-
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
109-
model_class = SD3Transformer2DModel
110-
main_input_name = "hidden_states"
111-
model_split_percents = [0.8, 0.8, 0.9]
118+
class TestSD3TransformerCompile(SD3TransformerTesterConfig, TorchCompileTesterMixin):
119+
pass
120+
121+
122+
# ======================== SD3.5 Transformer ========================
123+
112124

125+
class SD35TransformerTesterConfig(BaseModelTesterConfig):
113126
@property
114-
def dummy_input(self):
115-
batch_size = 2
116-
num_channels = 4
117-
height = width = embedding_dim = 32
118-
pooled_embedding_dim = embedding_dim * 2
119-
sequence_length = 154
127+
def model_class(self):
128+
return SD3Transformer2DModel
120129

121-
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
122-
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
123-
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
124-
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
130+
@property
131+
def pretrained_model_name_or_path(self):
132+
return "hf-internal-testing/tiny-sd35-pipe"
125133

126-
return {
127-
"hidden_states": hidden_states,
128-
"encoder_hidden_states": encoder_hidden_states,
129-
"pooled_projections": pooled_prompt_embeds,
130-
"timestep": timestep,
131-
}
134+
@property
135+
def pretrained_model_kwargs(self):
136+
return {"subfolder": "transformer"}
137+
138+
@property
139+
def main_input_name(self) -> str:
140+
return "hidden_states"
141+
142+
@property
143+
def model_split_percents(self) -> list:
144+
return [0.8, 0.8, 0.9]
132145

133146
@property
134-
def input_shape(self):
147+
def output_shape(self) -> tuple:
135148
return (4, 32, 32)
136149

137150
@property
138-
def output_shape(self):
151+
def input_shape(self) -> tuple:
139152
return (4, 32, 32)
140153

141-
def prepare_init_args_and_inputs_for_common(self):
142-
init_dict = {
154+
@property
155+
def generator(self):
156+
return torch.Generator("cpu").manual_seed(0)
157+
158+
def get_init_dict(self) -> dict:
159+
return {
143160
"sample_size": 32,
144161
"patch_size": 1,
145162
"in_channels": 4,
@@ -154,47 +171,56 @@ def prepare_init_args_and_inputs_for_common(self):
154171
"dual_attention_layers": (0,),
155172
"qk_norm": "rms_norm",
156173
}
157-
inputs_dict = self.dummy_input
158-
return init_dict, inputs_dict
159-
160-
@unittest.skipIf(
161-
torch_device != "cuda" or not is_xformers_available(),
162-
reason="XFormers attention is only available with CUDA and `xformers` installed",
163-
)
164-
def test_xformers_enable_works(self):
165-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
166-
model = self.model_class(**init_dict)
167174

168-
model.enable_xformers_memory_efficient_attention()
169-
170-
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
171-
"xformers is not enabled"
172-
)
175+
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
176+
num_channels = 4
177+
height = width = embedding_dim = 32
178+
pooled_embedding_dim = embedding_dim * 2
179+
sequence_length = 154
173180

174-
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
175-
def test_set_attn_processor_for_determinism(self):
176-
pass
181+
return {
182+
"hidden_states": randn_tensor(
183+
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
184+
),
185+
"encoder_hidden_states": randn_tensor(
186+
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
187+
),
188+
"pooled_projections": randn_tensor(
189+
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
190+
),
191+
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
192+
}
177193

178-
def test_gradient_checkpointing_is_applied(self):
179-
expected_set = {"SD3Transformer2DModel"}
180-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
181194

195+
class TestSD35Transformer(SD35TransformerTesterConfig, ModelTesterMixin):
182196
def test_skip_layers(self):
183-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
197+
init_dict = self.get_init_dict()
198+
inputs_dict = self.get_dummy_inputs()
184199
model = self.model_class(**init_dict).to(torch_device)
185200

186-
# Forward pass without skipping layers
187201
output_full = model(**inputs_dict).sample
188202

189-
# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
190203
inputs_dict_with_skip = inputs_dict.copy()
191204
inputs_dict_with_skip["skip_layers"] = [0]
192205
output_skip = model(**inputs_dict_with_skip).sample
193206

194-
# Check that the outputs are different
195-
self.assertFalse(
196-
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
197-
)
207+
assert not torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
208+
assert output_full.shape == output_skip.shape, "Outputs should have the same shape"
209+
210+
211+
class TestSD35TransformerTraining(SD35TransformerTesterConfig, TrainingTesterMixin):
212+
def test_gradient_checkpointing_is_applied(self):
213+
expected_set = {"SD3Transformer2DModel"}
214+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
215+
216+
217+
class TestSD35TransformerCompile(SD35TransformerTesterConfig, TorchCompileTesterMixin):
218+
pass
219+
220+
221+
class TestSD35TransformerBitsAndBytes(SD35TransformerTesterConfig, BitsAndBytesTesterMixin):
222+
"""BitsAndBytes quantization tests for SD3.5 Transformer."""
223+
198224

199-
# Check that the outputs have the same shape
200-
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
225+
class TestSD35TransformerTorchAo(SD35TransformerTesterConfig, TorchAoTesterMixin):
226+
"""TorchAO quantization tests for SD3.5 Transformer."""

0 commit comments

Comments
 (0)