1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
18- import numpy as np
1916import torch
2017
21- from diffusers .models import ModelMixin , UNet3DConditionModel
22- from diffusers .utils import logging
23- from diffusers .utils .import_utils import is_xformers_available
18+ from diffusers import UNet3DConditionModel
19+ from diffusers .utils .torch_utils import randn_tensor
2420
25- from ...testing_utils import enable_full_determinism , floats_tensor , skip_mps , torch_device
26- from ..test_modeling_common import ModelTesterMixin , UNetTesterMixin
21+ from ...testing_utils import enable_full_determinism , torch_device
22+ from ..testing_utils import (
23+ AttentionTesterMixin ,
24+ BaseModelTesterConfig ,
25+ MemoryTesterMixin ,
26+ ModelTesterMixin ,
27+ TrainingTesterMixin ,
28+ )
2729
2830
2931enable_full_determinism ()
3032
31- logger = logging .get_logger (__name__ )
32-
33-
34- @skip_mps
35- class UNet3DConditionModelTests (ModelTesterMixin , UNetTesterMixin , unittest .TestCase ):
36- model_class = UNet3DConditionModel
37- main_input_name = "sample"
3833
34+ class UNet3DConditionModelTesterConfig (BaseModelTesterConfig ):
3935 @property
40- def dummy_input (self ):
41- batch_size = 4
42- num_channels = 4
43- num_frames = 4
44- sizes = (16 , 16 )
36+ def model_class (self ):
37+ return UNet3DConditionModel
4538
46- noise = floats_tensor ((batch_size , num_channels , num_frames ) + sizes ).to (torch_device )
47- time_step = torch .tensor ([10 ]).to (torch_device )
48- encoder_hidden_states = floats_tensor ((batch_size , 4 , 8 )).to (torch_device )
49-
50- return {"sample" : noise , "timestep" : time_step , "encoder_hidden_states" : encoder_hidden_states }
39+ @property
40+ def main_input_name (self ) -> str :
41+ return "sample"
5142
5243 @property
53- def input_shape (self ):
44+ def output_shape (self ) -> tuple :
5445 return (4 , 4 , 16 , 16 )
5546
5647 @property
57- def output_shape (self ):
58- return ( 4 , 4 , 16 , 16 )
48+ def generator (self ):
49+ return torch . Generator ( "cpu" ). manual_seed ( 0 )
5950
60- def prepare_init_args_and_inputs_for_common (self ):
61- init_dict = {
51+ def get_init_dict (self ) -> dict :
52+ return {
6253 "block_out_channels" : (4 , 8 ),
6354 "norm_num_groups" : 4 ,
6455 "down_block_types" : (
@@ -73,111 +64,57 @@ def prepare_init_args_and_inputs_for_common(self):
7364 "layers_per_block" : 1 ,
7465 "sample_size" : 16 ,
7566 }
76- inputs_dict = self .dummy_input
77- return init_dict , inputs_dict
78-
79- @unittest .skipIf (
80- torch_device != "cuda" or not is_xformers_available (),
81- reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
82- )
83- def test_xformers_enable_works (self ):
84- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
85- model = self .model_class (** init_dict )
8667
87- model .enable_xformers_memory_efficient_attention ()
68+ def get_dummy_inputs (self ) -> dict :
69+ batch_size = 4
70+ num_channels = 4
71+ num_frames = 4
72+ sizes = (16 , 16 )
73+ noise = randn_tensor (
74+ (batch_size , num_channels , num_frames , * sizes ), generator = self .generator , device = torch_device
75+ )
76+ timestep = torch .tensor ([10 ], device = torch_device )
77+ encoder_hidden_states = randn_tensor ((batch_size , 4 , 8 ), generator = self .generator , device = torch_device )
78+ return {"sample" : noise , "timestep" : timestep , "encoder_hidden_states" : encoder_hidden_states }
8879
89- assert (
90- model .mid_block .attentions [0 ].transformer_blocks [0 ].attn1 .processor .__class__ .__name__
91- == "XFormersAttnProcessor"
92- ), "xformers is not enabled"
9380
94- # Overriding to set `norm_num_groups` needs to be different for this model.
81+ class TestUNet3DConditionModel (UNet3DConditionModelTesterConfig , ModelTesterMixin ):
82+ # Overridden because UNet3DConditionModel needs a different `norm_num_groups`.
9583 def test_forward_with_norm_groups (self ):
96- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
84+ init_dict = self .get_init_dict ()
9785 init_dict ["block_out_channels" ] = (32 , 64 )
9886 init_dict ["norm_num_groups" ] = 32
99-
100- model = self .model_class (** init_dict )
101- model .to (torch_device )
102- model .eval ()
87+ model = self .model_class (** init_dict ).to (torch_device ).eval ()
10388
10489 with torch .no_grad ():
105- output = model (** inputs_dict )
90+ output = model (** self . get_dummy_inputs ()). sample
10691
107- if isinstance (output , dict ):
108- output = output .sample
92+ assert output .shape == self .get_dummy_inputs ()["sample" ].shape , "Input and output shapes do not match"
10993
110- self .assertIsNotNone (output )
111- expected_shape = inputs_dict ["sample" ].shape
112- self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
113-
114- # Overriding since the UNet3D outputs a different structure.
115- def test_determinism (self ):
116- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
117- model = self .model_class (** init_dict )
118- model .to (torch_device )
119- model .eval ()
94+ def test_feed_forward_chunking (self ):
95+ init_dict = self .get_init_dict ()
96+ init_dict ["block_out_channels" ] = (32 , 64 )
97+ init_dict ["norm_num_groups" ] = 32
98+ model = self .model_class (** init_dict ).to (torch_device ).eval ()
12099
121100 with torch .no_grad ():
122- # Warmup pass when using mps (see #372)
123- if torch_device == "mps" and isinstance (model , ModelMixin ):
124- model (** self .dummy_input )
125-
126- first = model (** inputs_dict )
127- if isinstance (first , dict ):
128- first = first .sample
129-
130- second = model (** inputs_dict )
131- if isinstance (second , dict ):
132- second = second .sample
133-
134- out_1 = first .cpu ().numpy ()
135- out_2 = second .cpu ().numpy ()
136- out_1 = out_1 [~ np .isnan (out_1 )]
137- out_2 = out_2 [~ np .isnan (out_2 )]
138- max_diff = np .amax (np .abs (out_1 - out_2 ))
139- self .assertLessEqual (max_diff , 1e-5 )
101+ output = model (** self .get_dummy_inputs ())[0 ]
140102
141- def test_model_attention_slicing (self ):
142- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
143-
144- init_dict ["block_out_channels" ] = (16 , 32 )
145- init_dict ["attention_head_dim" ] = 8
146-
147- model = self .model_class (** init_dict )
148- model .to (torch_device )
149- model .eval ()
150-
151- model .set_attention_slice ("auto" )
103+ model .enable_forward_chunking ()
152104 with torch .no_grad ():
153- output = model (** inputs_dict )
154- assert output is not None
105+ output_2 = model (** self .get_dummy_inputs ())[0 ]
155106
156- model .set_attention_slice ("max" )
157- with torch .no_grad ():
158- output = model (** inputs_dict )
159- assert output is not None
107+ assert output .shape == output_2 .shape , "Shape doesn't match"
108+ assert (output - output_2 ).abs ().max () < 1e-2
160109
161- model .set_attention_slice (2 )
162- with torch .no_grad ():
163- output = model (** inputs_dict )
164- assert output is not None
165110
166- def test_feed_forward_chunking (self ):
167- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
168- init_dict ["block_out_channels" ] = (32 , 64 )
169- init_dict ["norm_num_groups" ] = 32
111+ class TestUNet3DConditionModelTraining (UNet3DConditionModelTesterConfig , TrainingTesterMixin ):
112+ """Training tests for UNet3DConditionModel."""
170113
171- model = self .model_class (** init_dict )
172- model .to (torch_device )
173- model .eval ()
174114
175- with torch . no_grad ( ):
176- output = model ( ** inputs_dict )[ 0 ]
115+ class TestUNet3DConditionModelMemory ( UNet3DConditionModelTesterConfig , MemoryTesterMixin ):
116+ """Memory optimization tests for UNet3DConditionModel."""
177117
178- model .enable_forward_chunking ()
179- with torch .no_grad ():
180- output_2 = model (** inputs_dict )[0 ]
181118
182- self . assertEqual ( output . shape , output_2 . shape , "Shape doesn't match" )
183- assert np . abs ( output . cpu () - output_2 . cpu ()). max () < 1e-2
119+ class TestUNet3DConditionModelAttention ( UNet3DConditionModelTesterConfig , AttentionTesterMixin ):
120+ """Attention processor tests for UNet3DConditionModel."""
0 commit comments