1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import unittest
17-
18- import pytest
1916import torch
2017
2118from diffusers import UNet1DModel
19+ from diffusers .utils .torch_utils import randn_tensor
2220
23- from ...testing_utils import (
24- backend_manual_seed ,
25- floats_tensor ,
26- slow ,
27- torch_device ,
28- )
29- from ..test_modeling_common import ModelTesterMixin , UNetTesterMixin
30-
21+ from ...testing_utils import backend_manual_seed , enable_full_determinism , slow , torch_device
22+ from ..testing_utils import BaseModelTesterConfig , ModelTesterMixin
3123
32- class UNet1DModelTests (ModelTesterMixin , UNetTesterMixin , unittest .TestCase ):
33- model_class = UNet1DModel
34- main_input_name = "sample"
35-
36- @property
37- def dummy_input (self ):
38- batch_size = 4
39- num_features = 14
40- seq_len = 16
4124
42- noise = floats_tensor ((batch_size , num_features , seq_len )).to (torch_device )
43- time_step = torch .tensor ([10 ] * batch_size ).to (torch_device )
25+ enable_full_determinism ()
4426
45- return {"sample" : noise , "timestep" : time_step }
4627
28+ class UNet1DModelTesterConfig (BaseModelTesterConfig ):
4729 @property
48- def input_shape (self ):
49- return ( 4 , 14 , 16 )
30+ def model_class (self ):
31+ return UNet1DModel
5032
5133 @property
52- def output_shape (self ):
53- return (4 , 14 , 16 )
54-
55- @unittest .skip ("Test not supported." )
56- def test_ema_training (self ):
57- pass
58-
59- @unittest .skip ("Test not supported." )
60- def test_training (self ):
61- pass
62-
63- @unittest .skip ("Test not supported." )
64- def test_layerwise_casting_training (self ):
65- pass
66-
67- def test_determinism (self ):
68- super ().test_determinism ()
69-
70- def test_outputs_equivalence (self ):
71- super ().test_outputs_equivalence ()
72-
73- def test_from_save_pretrained (self ):
74- super ().test_from_save_pretrained ()
75-
76- def test_from_save_pretrained_variant (self ):
77- super ().test_from_save_pretrained_variant ()
34+ def main_input_name (self ) -> str :
35+ return "sample"
7836
79- def test_model_from_pretrained (self ):
80- super ().test_model_from_pretrained ()
37+ @property
38+ def output_shape (self ) -> tuple :
39+ return (14 , 16 )
8140
82- def test_output (self ):
83- super ().test_output ()
41+ @property
42+ def generator (self ):
43+ return torch .Generator ("cpu" ).manual_seed (0 )
8444
85- def prepare_init_args_and_inputs_for_common (self ):
86- init_dict = {
45+ def get_init_dict (self ) -> dict :
46+ return {
8747 "block_out_channels" : (8 , 8 , 16 , 16 ),
8848 "in_channels" : 14 ,
8949 "out_channels" : 14 ,
@@ -97,19 +57,26 @@ def prepare_init_args_and_inputs_for_common(self):
9757 "up_block_types" : ("UpResnetBlock1D" , "UpResnetBlock1D" , "UpResnetBlock1D" ),
9858 "act_fn" : "swish" ,
9959 }
100- inputs_dict = self .dummy_input
101- return init_dict , inputs_dict
10260
61+ def get_dummy_inputs (self ) -> dict :
62+ batch_size = 4
63+ num_features = 14
64+ seq_len = 16
65+ noise = randn_tensor ((batch_size , num_features , seq_len ), generator = self .generator , device = torch_device )
66+ timestep = torch .tensor ([10 ] * batch_size , device = torch_device )
67+ return {"sample" : noise , "timestep" : timestep }
68+
69+
70+ class TestUNet1DModel (UNet1DModelTesterConfig , ModelTesterMixin ):
10371 def test_from_pretrained_hub (self ):
10472 model , loading_info = UNet1DModel .from_pretrained (
10573 "bglick13/hopper-medium-v2-value-function-hor32" , output_loading_info = True , subfolder = "unet"
10674 )
107- self . assertIsNotNone ( model )
108- self . assertEqual ( len (loading_info ["missing_keys" ]), 0 )
75+ assert model is not None
76+ assert len (loading_info ["missing_keys" ]) == 0
10977
11078 model .to (torch_device )
111- image = model (** self .dummy_input )
112-
79+ image = model (** self .get_dummy_inputs ())
11380 assert image is not None , "Make sure output is not None"
11481
11582 def test_output_pretrained (self ):
@@ -119,9 +86,7 @@ def test_output_pretrained(self):
11986
12087 num_features = model .config .in_channels
12188 seq_len = 16
122- noise = torch .randn ((1 , seq_len , num_features )).permute (
123- 0 , 2 , 1
124- ) # match original, we can update values and remove
89+ noise = torch .randn ((1 , seq_len , num_features )).permute (0 , 2 , 1 )
12590 time_step = torch .full ((num_features ,), 0 )
12691
12792 with torch .no_grad ():
@@ -131,12 +96,7 @@ def test_output_pretrained(self):
13196 # fmt: off
13297 expected_output_slice = torch .tensor ([- 2.137172 , 1.1426016 , 0.3688687 , - 0.766922 , 0.7303146 , 0.11038864 , - 0.4760633 , 0.13270172 , 0.02591348 ])
13398 # fmt: on
134- self .assertTrue (torch .allclose (output_slice , expected_output_slice , rtol = 1e-3 ))
135-
136- @unittest .skip ("Test not supported." )
137- def test_forward_with_norm_groups (self ):
138- # Not implemented yet for this UNet
139- pass
99+ assert torch .allclose (output_slice , expected_output_slice , rtol = 1e-3 )
140100
141101 @slow
142102 def test_unet_1d_maestro (self ):
@@ -157,98 +117,26 @@ def test_unet_1d_maestro(self):
157117 assert (output_sum - 224.0896 ).abs () < 0.5
158118 assert (output_max - 0.0607 ).abs () < 4e-4
159119
160- @pytest .mark .xfail (
161- reason = (
162- "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
163- "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n "
164- "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n "
165- "2. Unskip this test."
166- ),
167- )
168- def test_layerwise_casting_inference (self ):
169- super ().test_layerwise_casting_inference ()
170-
171- @pytest .mark .xfail (
172- reason = (
173- "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
174- "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n "
175- "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n "
176- "2. Unskip this test."
177- ),
178- )
179- def test_layerwise_casting_memory (self ):
180- pass
181-
182-
183- class UNetRLModelTests (ModelTesterMixin , UNetTesterMixin , unittest .TestCase ):
184- model_class = UNet1DModel
185- main_input_name = "sample"
186120
121+ class UNetRLModelTesterConfig (BaseModelTesterConfig ):
187122 @property
188- def dummy_input (self ):
189- batch_size = 4
190- num_features = 14
191- seq_len = 16
192-
193- noise = floats_tensor ((batch_size , num_features , seq_len )).to (torch_device )
194- time_step = torch .tensor ([10 ] * batch_size ).to (torch_device )
195-
196- return {"sample" : noise , "timestep" : time_step }
123+ def model_class (self ):
124+ return UNet1DModel
197125
198126 @property
199- def input_shape (self ):
200- return ( 4 , 14 , 16 )
127+ def main_input_name (self ) -> str :
128+ return "sample"
201129
202130 @property
203- def output_shape (self ):
204- return (4 , 14 , 1 )
205-
206- def test_determinism (self ):
207- super ().test_determinism ()
208-
209- def test_outputs_equivalence (self ):
210- super ().test_outputs_equivalence ()
211-
212- def test_from_save_pretrained (self ):
213- super ().test_from_save_pretrained ()
214-
215- def test_from_save_pretrained_variant (self ):
216- super ().test_from_save_pretrained_variant ()
217-
218- def test_model_from_pretrained (self ):
219- super ().test_model_from_pretrained ()
220-
221- def test_output (self ):
222- # UNetRL is a value-function is different output shape
223- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
224- model = self .model_class (** init_dict )
225- model .to (torch_device )
226- model .eval ()
227-
228- with torch .no_grad ():
229- output = model (** inputs_dict )
230-
231- if isinstance (output , dict ):
232- output = output .sample
233-
234- self .assertIsNotNone (output )
235- expected_shape = torch .Size ((inputs_dict ["sample" ].shape [0 ], 1 ))
236- self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
237-
238- @unittest .skip ("Test not supported." )
239- def test_ema_training (self ):
240- pass
131+ def output_shape (self ) -> tuple :
132+ return (1 ,)
241133
242- @unittest .skip ("Test not supported." )
243- def test_training (self ):
244- pass
245-
246- @unittest .skip ("Test not supported." )
247- def test_layerwise_casting_training (self ):
248- pass
134+ @property
135+ def generator (self ):
136+ return torch .Generator ("cpu" ).manual_seed (0 )
249137
250- def prepare_init_args_and_inputs_for_common (self ):
251- init_dict = {
138+ def get_init_dict (self ) -> dict :
139+ return {
252140 "in_channels" : 14 ,
253141 "out_channels" : 14 ,
254142 "down_block_types" : ["DownResnetBlock1D" , "DownResnetBlock1D" , "DownResnetBlock1D" , "DownResnetBlock1D" ],
@@ -264,19 +152,36 @@ def prepare_init_args_and_inputs_for_common(self):
264152 "time_embedding_type" : "positional" ,
265153 "act_fn" : "mish" ,
266154 }
267- inputs_dict = self .dummy_input
268- return init_dict , inputs_dict
155+
156+ def get_dummy_inputs (self ) -> dict :
157+ batch_size = 4
158+ num_features = 14
159+ seq_len = 16
160+ noise = randn_tensor ((batch_size , num_features , seq_len ), generator = self .generator , device = torch_device )
161+ timestep = torch .tensor ([10 ] * batch_size , device = torch_device )
162+ return {"sample" : noise , "timestep" : timestep }
163+
164+
165+ class TestUNetRLModel (UNetRLModelTesterConfig , ModelTesterMixin ):
166+ # UNetRL is a value function, so it has a different output shape.
167+ def test_output (self ):
168+ model = self .model_class (** self .get_init_dict ()).to (torch_device ).eval ()
169+
170+ inputs = self .get_dummy_inputs ()
171+ with torch .no_grad ():
172+ output = model (** inputs ).sample
173+
174+ assert output .shape == (inputs ["sample" ].shape [0 ], 1 ), "Input and output shapes do not match"
269175
270176 def test_from_pretrained_hub (self ):
271177 value_function , vf_loading_info = UNet1DModel .from_pretrained (
272178 "bglick13/hopper-medium-v2-value-function-hor32" , output_loading_info = True , subfolder = "value_function"
273179 )
274- self . assertIsNotNone ( value_function )
275- self . assertEqual ( len (vf_loading_info ["missing_keys" ]), 0 )
180+ assert value_function is not None
181+ assert len (vf_loading_info ["missing_keys" ]) == 0
276182
277183 value_function .to (torch_device )
278- image = value_function (** self .dummy_input )
279-
184+ image = value_function (** self .get_dummy_inputs ())
280185 assert image is not None , "Make sure output is not None"
281186
282187 def test_output_pretrained (self ):
@@ -288,9 +193,7 @@ def test_output_pretrained(self):
288193
289194 num_features = value_function .config .in_channels
290195 seq_len = 14
291- noise = torch .randn ((1 , seq_len , num_features )).permute (
292- 0 , 2 , 1
293- ) # match original, we can update values and remove
196+ noise = torch .randn ((1 , seq_len , num_features )).permute (0 , 2 , 1 )
294197 time_step = torch .full ((num_features ,), 0 )
295198
296199 with torch .no_grad ():
@@ -299,31 +202,4 @@ def test_output_pretrained(self):
299202 # fmt: off
300203 expected_output_slice = torch .tensor ([165.25 ] * seq_len )
301204 # fmt: on
302- self .assertTrue (torch .allclose (output , expected_output_slice , rtol = 1e-3 ))
303-
304- @unittest .skip ("Test not supported." )
305- def test_forward_with_norm_groups (self ):
306- # Not implemented yet for this UNet
307- pass
308-
309- @pytest .mark .xfail (
310- reason = (
311- "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
312- "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n "
313- "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n "
314- "2. Unskip this test."
315- ),
316- )
317- def test_layerwise_casting_inference (self ):
318- pass
319-
320- @pytest .mark .xfail (
321- reason = (
322- "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
323- "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n "
324- "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n "
325- "2. Unskip this test."
326- ),
327- )
328- def test_layerwise_casting_memory (self ):
329- pass
205+ assert torch .allclose (output , expected_output_slice , rtol = 1e-3 )
0 commit comments