Skip to content

Commit 5e7540f

Browse files
authored
[tests] port final set of model tests and others (#13974)
* port final set of model tests and others * fix extracter.
1 parent a487e2f commit 5e7540f

16 files changed

Lines changed: 807 additions & 2643 deletions

tests/models/test_modeling_common.py

Lines changed: 30 additions & 2142 deletions
Large diffs are not rendered by default.

tests/models/transformers/test_models_dit_transformer2d.py

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,52 +13,48 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
16+
import pytest
1817
import torch
1918

2019
from diffusers import DiTTransformer2DModel, Transformer2DModel
21-
22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
floats_tensor,
25-
slow,
26-
torch_device,
20+
from diffusers.utils.torch_utils import randn_tensor
21+
22+
from ...testing_utils import enable_full_determinism, slow, torch_device
23+
from ..testing_utils import (
24+
AttentionTesterMixin,
25+
BaseModelTesterConfig,
26+
MemoryTesterMixin,
27+
ModelTesterMixin,
28+
TrainingTesterMixin,
2729
)
28-
from ..test_modeling_common import ModelTesterMixin
2930

3031

3132
enable_full_determinism()
3233

3334

34-
class DiTTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
35-
model_class = DiTTransformer2DModel
36-
main_input_name = "hidden_states"
37-
35+
class DiTTransformer2DTesterConfig(BaseModelTesterConfig):
3836
@property
39-
def dummy_input(self):
40-
batch_size = 4
41-
in_channels = 4
42-
sample_size = 8
43-
scheduler_num_train_steps = 1000
44-
num_class_labels = 4
37+
def model_class(self):
38+
return DiTTransformer2DModel
4539

46-
hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device)
47-
timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device)
48-
class_label_ids = torch.randint(0, num_class_labels, size=(batch_size,)).to(torch_device)
49-
50-
return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids}
40+
@property
41+
def main_input_name(self) -> str:
42+
return "hidden_states"
5143

5244
@property
53-
def input_shape(self):
45+
def input_shape(self) -> tuple:
5446
return (4, 8, 8)
5547

5648
@property
57-
def output_shape(self):
49+
def output_shape(self) -> tuple:
5850
return (8, 8, 8)
5951

60-
def prepare_init_args_and_inputs_for_common(self):
61-
init_dict = {
52+
@property
53+
def generator(self):
54+
return torch.Generator("cpu").manual_seed(0)
55+
56+
def get_init_dict(self) -> dict:
57+
return {
6258
"in_channels": 4,
6359
"out_channels": 8,
6460
"activation_fn": "gelu-approximate",
@@ -71,26 +67,38 @@ def prepare_init_args_and_inputs_for_common(self):
7167
"patch_size": 2,
7268
"sample_size": 8,
7369
}
74-
inputs_dict = self.dummy_input
75-
return init_dict, inputs_dict
7670

77-
def test_output(self):
78-
super().test_output(
79-
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
80-
)
71+
def get_dummy_inputs(self, batch_size: int = 4) -> dict[str, torch.Tensor]:
72+
in_channels = 4
73+
sample_size = 8
74+
scheduler_num_train_steps = 1000
75+
num_class_labels = 4
76+
77+
return {
78+
"hidden_states": randn_tensor(
79+
(batch_size, in_channels, sample_size, sample_size), generator=self.generator, device=torch_device
80+
),
81+
"timestep": torch.randint(0, scheduler_num_train_steps, size=(batch_size,), generator=self.generator).to(
82+
torch_device
83+
),
84+
"class_labels": torch.randint(0, num_class_labels, size=(batch_size,), generator=self.generator).to(
85+
torch_device
86+
),
87+
}
88+
89+
90+
class TestDiTTransformer2D(DiTTransformer2DTesterConfig, ModelTesterMixin):
91+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
92+
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
93+
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
94+
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
95+
pytest.skip("Tolerance requirements too high for meaningful test")
8196

8297
def test_correct_class_remapping_from_dict_config(self):
83-
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
98+
init_dict = self.get_init_dict()
8499
model = Transformer2DModel.from_config(init_dict)
85100
assert isinstance(model, DiTTransformer2DModel)
86101

87-
def test_gradient_checkpointing_is_applied(self):
88-
expected_set = {"DiTTransformer2DModel"}
89-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90-
91-
def test_effective_gradient_checkpointing(self):
92-
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)
93-
94102
def test_correct_class_remapping_from_pretrained_config(self):
95103
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
96104
model = Transformer2DModel.from_config(config)
@@ -100,3 +108,20 @@ def test_correct_class_remapping_from_pretrained_config(self):
100108
def test_correct_class_remapping(self):
101109
model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer")
102110
assert isinstance(model, DiTTransformer2DModel)
111+
112+
113+
class TestDiTTransformer2DMemory(DiTTransformer2DTesterConfig, MemoryTesterMixin):
114+
pass
115+
116+
117+
class TestDiTTransformer2DAttention(DiTTransformer2DTesterConfig, AttentionTesterMixin):
118+
pass
119+
120+
121+
class TestDiTTransformer2DTraining(DiTTransformer2DTesterConfig, TrainingTesterMixin):
122+
def test_gradient_checkpointing_is_applied(self):
123+
expected_set = {"DiTTransformer2DModel"}
124+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
125+
126+
def test_gradient_checkpointing_equivalence(self):
127+
super().test_gradient_checkpointing_equivalence(loss_tolerance=1e-4)

tests/models/transformers/test_models_pixart_transformer2d.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,60 +13,53 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
16+
import pytest
1817
import torch
1918

2019
from diffusers import PixArtTransformer2DModel, Transformer2DModel
21-
22-
from ...testing_utils import (
23-
enable_full_determinism,
24-
floats_tensor,
25-
slow,
26-
torch_device,
20+
from diffusers.utils.torch_utils import randn_tensor
21+
22+
from ...testing_utils import enable_full_determinism, slow, torch_device
23+
from ..testing_utils import (
24+
AttentionTesterMixin,
25+
BaseModelTesterConfig,
26+
MemoryTesterMixin,
27+
ModelTesterMixin,
28+
TrainingTesterMixin,
2729
)
28-
from ..test_modeling_common import ModelTesterMixin
2930

3031

3132
enable_full_determinism()
3233

3334

34-
class PixArtTransformer2DModelTests(ModelTesterMixin, unittest.TestCase):
35-
model_class = PixArtTransformer2DModel
36-
main_input_name = "hidden_states"
37-
# We override the items here because the transformer under consideration is small.
38-
model_split_percents = [0.7, 0.6, 0.6]
39-
35+
class PixArtTransformer2DTesterConfig(BaseModelTesterConfig):
4036
@property
41-
def dummy_input(self):
42-
batch_size = 4
43-
in_channels = 4
44-
sample_size = 8
45-
scheduler_num_train_steps = 1000
46-
cross_attention_dim = 8
47-
seq_len = 8
37+
def model_class(self):
38+
return PixArtTransformer2DModel
4839

49-
hidden_states = floats_tensor((batch_size, in_channels, sample_size, sample_size)).to(torch_device)
50-
timesteps = torch.randint(0, scheduler_num_train_steps, size=(batch_size,)).to(torch_device)
51-
encoder_hidden_states = floats_tensor((batch_size, seq_len, cross_attention_dim)).to(torch_device)
52-
53-
return {
54-
"hidden_states": hidden_states,
55-
"timestep": timesteps,
56-
"encoder_hidden_states": encoder_hidden_states,
57-
"added_cond_kwargs": {"aspect_ratio": None, "resolution": None},
58-
}
40+
@property
41+
def main_input_name(self) -> str:
42+
return "hidden_states"
5943

6044
@property
61-
def input_shape(self):
45+
def input_shape(self) -> tuple:
6246
return (4, 8, 8)
6347

6448
@property
65-
def output_shape(self):
49+
def output_shape(self) -> tuple:
6650
return (8, 8, 8)
6751

68-
def prepare_init_args_and_inputs_for_common(self):
69-
init_dict = {
52+
@property
53+
def model_split_percents(self) -> list:
54+
# We override the items here because the transformer under consideration is small.
55+
return [0.7, 0.6, 0.6]
56+
57+
@property
58+
def generator(self):
59+
return torch.Generator("cpu").manual_seed(0)
60+
61+
def get_init_dict(self) -> dict:
62+
return {
7063
"sample_size": 8,
7164
"num_layers": 1,
7265
"patch_size": 2,
@@ -84,20 +77,37 @@ def prepare_init_args_and_inputs_for_common(self):
8477
"use_additional_conditions": False,
8578
"caption_channels": None,
8679
}
87-
inputs_dict = self.dummy_input
88-
return init_dict, inputs_dict
8980

90-
def test_output(self):
91-
super().test_output(
92-
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
93-
)
81+
def get_dummy_inputs(self, batch_size: int = 4) -> dict[str, torch.Tensor]:
82+
in_channels = 4
83+
sample_size = 8
84+
scheduler_num_train_steps = 1000
85+
cross_attention_dim = 8
86+
seq_len = 8
87+
88+
return {
89+
"hidden_states": randn_tensor(
90+
(batch_size, in_channels, sample_size, sample_size), generator=self.generator, device=torch_device
91+
),
92+
"timestep": torch.randint(0, scheduler_num_train_steps, size=(batch_size,), generator=self.generator).to(
93+
torch_device
94+
),
95+
"encoder_hidden_states": randn_tensor(
96+
(batch_size, seq_len, cross_attention_dim), generator=self.generator, device=torch_device
97+
),
98+
"added_cond_kwargs": {"aspect_ratio": None, "resolution": None},
99+
}
100+
94101

95-
def test_gradient_checkpointing_is_applied(self):
96-
expected_set = {"PixArtTransformer2DModel"}
97-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
102+
class TestPixArtTransformer2D(PixArtTransformer2DTesterConfig, ModelTesterMixin):
103+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
104+
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
105+
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
106+
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
107+
pytest.skip("Tolerance requirements too high for meaningful test")
98108

99109
def test_correct_class_remapping_from_dict_config(self):
100-
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
110+
init_dict = self.get_init_dict()
101111
model = Transformer2DModel.from_config(init_dict)
102112
assert isinstance(model, PixArtTransformer2DModel)
103113

@@ -110,3 +120,17 @@ def test_correct_class_remapping_from_pretrained_config(self):
110120
def test_correct_class_remapping(self):
111121
model = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
112122
assert isinstance(model, PixArtTransformer2DModel)
123+
124+
125+
class TestPixArtTransformer2DMemory(PixArtTransformer2DTesterConfig, MemoryTesterMixin):
126+
pass
127+
128+
129+
class TestPixArtTransformer2DAttention(PixArtTransformer2DTesterConfig, AttentionTesterMixin):
130+
pass
131+
132+
133+
class TestPixArtTransformer2DTraining(PixArtTransformer2DTesterConfig, TrainingTesterMixin):
134+
def test_gradient_checkpointing_is_applied(self):
135+
expected_set = {"PixArtTransformer2DModel"}
136+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)