Skip to content

Commit 97b5c7f

Browse files
authored
refactor unet tests (3d_condition, motion, controlnetxs) (#13897)
* refactor unet_3d_condition tests * refactor unet_motion tests * refactor unet_controlnetxs tests
1 parent 2c7efb9 commit 97b5c7f

3 files changed

Lines changed: 228 additions & 412 deletions

File tree

tests/models/unets/test_models_unet_3d_condition.py

Lines changed: 54 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -13,52 +13,43 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
18-
import numpy as np
1916
import torch
2017

21-
from diffusers.models import ModelMixin, UNet3DConditionModel
22-
from diffusers.utils import logging
23-
from diffusers.utils.import_utils import is_xformers_available
18+
from diffusers import UNet3DConditionModel
19+
from diffusers.utils.torch_utils import randn_tensor
2420

25-
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
26-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
21+
from ...testing_utils import enable_full_determinism, torch_device
22+
from ..testing_utils import (
23+
AttentionTesterMixin,
24+
BaseModelTesterConfig,
25+
MemoryTesterMixin,
26+
ModelTesterMixin,
27+
TrainingTesterMixin,
28+
)
2729

2830

2931
enable_full_determinism()
3032

31-
logger = logging.get_logger(__name__)
32-
33-
34-
@skip_mps
35-
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
36-
model_class = UNet3DConditionModel
37-
main_input_name = "sample"
3833

34+
class UNet3DConditionModelTesterConfig(BaseModelTesterConfig):
3935
@property
40-
def dummy_input(self):
41-
batch_size = 4
42-
num_channels = 4
43-
num_frames = 4
44-
sizes = (16, 16)
36+
def model_class(self):
37+
return UNet3DConditionModel
4538

46-
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
47-
time_step = torch.tensor([10]).to(torch_device)
48-
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
49-
50-
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
39+
@property
40+
def main_input_name(self) -> str:
41+
return "sample"
5142

5243
@property
53-
def input_shape(self):
44+
def output_shape(self) -> tuple:
5445
return (4, 4, 16, 16)
5546

5647
@property
57-
def output_shape(self):
58-
return (4, 4, 16, 16)
48+
def generator(self):
49+
return torch.Generator("cpu").manual_seed(0)
5950

60-
def prepare_init_args_and_inputs_for_common(self):
61-
init_dict = {
51+
def get_init_dict(self) -> dict:
52+
return {
6253
"block_out_channels": (4, 8),
6354
"norm_num_groups": 4,
6455
"down_block_types": (
@@ -73,111 +64,57 @@ def prepare_init_args_and_inputs_for_common(self):
7364
"layers_per_block": 1,
7465
"sample_size": 16,
7566
}
76-
inputs_dict = self.dummy_input
77-
return init_dict, inputs_dict
78-
79-
@unittest.skipIf(
80-
torch_device != "cuda" or not is_xformers_available(),
81-
reason="XFormers attention is only available with CUDA and `xformers` installed",
82-
)
83-
def test_xformers_enable_works(self):
84-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
85-
model = self.model_class(**init_dict)
8667

87-
model.enable_xformers_memory_efficient_attention()
68+
def get_dummy_inputs(self) -> dict:
69+
batch_size = 4
70+
num_channels = 4
71+
num_frames = 4
72+
sizes = (16, 16)
73+
noise = randn_tensor(
74+
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
75+
)
76+
timestep = torch.tensor([10], device=torch_device)
77+
encoder_hidden_states = randn_tensor((batch_size, 4, 8), generator=self.generator, device=torch_device)
78+
return {"sample": noise, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states}
8879

89-
assert (
90-
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
91-
== "XFormersAttnProcessor"
92-
), "xformers is not enabled"
9380

94-
# Overriding to set `norm_num_groups` needs to be different for this model.
81+
class TestUNet3DConditionModel(UNet3DConditionModelTesterConfig, ModelTesterMixin):
82+
# Overridden because UNet3DConditionModel needs a different `norm_num_groups`.
9583
def test_forward_with_norm_groups(self):
96-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
84+
init_dict = self.get_init_dict()
9785
init_dict["block_out_channels"] = (32, 64)
9886
init_dict["norm_num_groups"] = 32
99-
100-
model = self.model_class(**init_dict)
101-
model.to(torch_device)
102-
model.eval()
87+
model = self.model_class(**init_dict).to(torch_device).eval()
10388

10489
with torch.no_grad():
105-
output = model(**inputs_dict)
90+
output = model(**self.get_dummy_inputs()).sample
10691

107-
if isinstance(output, dict):
108-
output = output.sample
92+
assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match"
10993

110-
self.assertIsNotNone(output)
111-
expected_shape = inputs_dict["sample"].shape
112-
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
113-
114-
# Overriding since the UNet3D outputs a different structure.
115-
def test_determinism(self):
116-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
117-
model = self.model_class(**init_dict)
118-
model.to(torch_device)
119-
model.eval()
94+
def test_feed_forward_chunking(self):
95+
init_dict = self.get_init_dict()
96+
init_dict["block_out_channels"] = (32, 64)
97+
init_dict["norm_num_groups"] = 32
98+
model = self.model_class(**init_dict).to(torch_device).eval()
12099

121100
with torch.no_grad():
122-
# Warmup pass when using mps (see #372)
123-
if torch_device == "mps" and isinstance(model, ModelMixin):
124-
model(**self.dummy_input)
125-
126-
first = model(**inputs_dict)
127-
if isinstance(first, dict):
128-
first = first.sample
129-
130-
second = model(**inputs_dict)
131-
if isinstance(second, dict):
132-
second = second.sample
133-
134-
out_1 = first.cpu().numpy()
135-
out_2 = second.cpu().numpy()
136-
out_1 = out_1[~np.isnan(out_1)]
137-
out_2 = out_2[~np.isnan(out_2)]
138-
max_diff = np.amax(np.abs(out_1 - out_2))
139-
self.assertLessEqual(max_diff, 1e-5)
101+
output = model(**self.get_dummy_inputs())[0]
140102

141-
def test_model_attention_slicing(self):
142-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
143-
144-
init_dict["block_out_channels"] = (16, 32)
145-
init_dict["attention_head_dim"] = 8
146-
147-
model = self.model_class(**init_dict)
148-
model.to(torch_device)
149-
model.eval()
150-
151-
model.set_attention_slice("auto")
103+
model.enable_forward_chunking()
152104
with torch.no_grad():
153-
output = model(**inputs_dict)
154-
assert output is not None
105+
output_2 = model(**self.get_dummy_inputs())[0]
155106

156-
model.set_attention_slice("max")
157-
with torch.no_grad():
158-
output = model(**inputs_dict)
159-
assert output is not None
107+
assert output.shape == output_2.shape, "Shape doesn't match"
108+
assert (output - output_2).abs().max() < 1e-2
160109

161-
model.set_attention_slice(2)
162-
with torch.no_grad():
163-
output = model(**inputs_dict)
164-
assert output is not None
165110

166-
def test_feed_forward_chunking(self):
167-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
168-
init_dict["block_out_channels"] = (32, 64)
169-
init_dict["norm_num_groups"] = 32
111+
class TestUNet3DConditionModelTraining(UNet3DConditionModelTesterConfig, TrainingTesterMixin):
112+
"""Training tests for UNet3DConditionModel."""
170113

171-
model = self.model_class(**init_dict)
172-
model.to(torch_device)
173-
model.eval()
174114

175-
with torch.no_grad():
176-
output = model(**inputs_dict)[0]
115+
class TestUNet3DConditionModelMemory(UNet3DConditionModelTesterConfig, MemoryTesterMixin):
116+
"""Memory optimization tests for UNet3DConditionModel."""
177117

178-
model.enable_forward_chunking()
179-
with torch.no_grad():
180-
output_2 = model(**inputs_dict)[0]
181118

182-
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
183-
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
119+
class TestUNet3DConditionModelAttention(UNet3DConditionModelTesterConfig, AttentionTesterMixin):
120+
"""Attention processor tests for UNet3DConditionModel."""

0 commit comments

Comments
 (0)