Skip to content

Commit 0d56193

Browse files
refactor unet_1d tests (#13898)
* refactor unet_1d tests * use per-sample output_shape for unet_1d tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 97b5c7f commit 0d56193

1 file changed

Lines changed: 69 additions & 193 deletions

File tree

tests/models/unets/test_models_unet_1d.py

Lines changed: 69 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -13,77 +13,37 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import unittest
17-
18-
import pytest
1916
import torch
2017

2118
from diffusers import UNet1DModel
19+
from diffusers.utils.torch_utils import randn_tensor
2220

23-
from ...testing_utils import (
24-
backend_manual_seed,
25-
floats_tensor,
26-
slow,
27-
torch_device,
28-
)
29-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
30-
21+
from ...testing_utils import backend_manual_seed, enable_full_determinism, slow, torch_device
22+
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
3123

32-
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
33-
model_class = UNet1DModel
34-
main_input_name = "sample"
35-
36-
@property
37-
def dummy_input(self):
38-
batch_size = 4
39-
num_features = 14
40-
seq_len = 16
4124

42-
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
43-
time_step = torch.tensor([10] * batch_size).to(torch_device)
25+
enable_full_determinism()
4426

45-
return {"sample": noise, "timestep": time_step}
4627

28+
class UNet1DModelTesterConfig(BaseModelTesterConfig):
4729
@property
48-
def input_shape(self):
49-
return (4, 14, 16)
30+
def model_class(self):
31+
return UNet1DModel
5032

5133
@property
52-
def output_shape(self):
53-
return (4, 14, 16)
54-
55-
@unittest.skip("Test not supported.")
56-
def test_ema_training(self):
57-
pass
58-
59-
@unittest.skip("Test not supported.")
60-
def test_training(self):
61-
pass
62-
63-
@unittest.skip("Test not supported.")
64-
def test_layerwise_casting_training(self):
65-
pass
66-
67-
def test_determinism(self):
68-
super().test_determinism()
69-
70-
def test_outputs_equivalence(self):
71-
super().test_outputs_equivalence()
72-
73-
def test_from_save_pretrained(self):
74-
super().test_from_save_pretrained()
75-
76-
def test_from_save_pretrained_variant(self):
77-
super().test_from_save_pretrained_variant()
34+
def main_input_name(self) -> str:
35+
return "sample"
7836

79-
def test_model_from_pretrained(self):
80-
super().test_model_from_pretrained()
37+
@property
38+
def output_shape(self) -> tuple:
39+
return (14, 16)
8140

82-
def test_output(self):
83-
super().test_output()
41+
@property
42+
def generator(self):
43+
return torch.Generator("cpu").manual_seed(0)
8444

85-
def prepare_init_args_and_inputs_for_common(self):
86-
init_dict = {
45+
def get_init_dict(self) -> dict:
46+
return {
8747
"block_out_channels": (8, 8, 16, 16),
8848
"in_channels": 14,
8949
"out_channels": 14,
@@ -97,19 +57,26 @@ def prepare_init_args_and_inputs_for_common(self):
9757
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
9858
"act_fn": "swish",
9959
}
100-
inputs_dict = self.dummy_input
101-
return init_dict, inputs_dict
10260

61+
def get_dummy_inputs(self) -> dict:
62+
batch_size = 4
63+
num_features = 14
64+
seq_len = 16
65+
noise = randn_tensor((batch_size, num_features, seq_len), generator=self.generator, device=torch_device)
66+
timestep = torch.tensor([10] * batch_size, device=torch_device)
67+
return {"sample": noise, "timestep": timestep}
68+
69+
70+
class TestUNet1DModel(UNet1DModelTesterConfig, ModelTesterMixin):
10371
def test_from_pretrained_hub(self):
10472
model, loading_info = UNet1DModel.from_pretrained(
10573
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
10674
)
107-
self.assertIsNotNone(model)
108-
self.assertEqual(len(loading_info["missing_keys"]), 0)
75+
assert model is not None
76+
assert len(loading_info["missing_keys"]) == 0
10977

11078
model.to(torch_device)
111-
image = model(**self.dummy_input)
112-
79+
image = model(**self.get_dummy_inputs())
11380
assert image is not None, "Make sure output is not None"
11481

11582
def test_output_pretrained(self):
@@ -119,9 +86,7 @@ def test_output_pretrained(self):
11986

12087
num_features = model.config.in_channels
12188
seq_len = 16
122-
noise = torch.randn((1, seq_len, num_features)).permute(
123-
0, 2, 1
124-
) # match original, we can update values and remove
89+
noise = torch.randn((1, seq_len, num_features)).permute(0, 2, 1)
12590
time_step = torch.full((num_features,), 0)
12691

12792
with torch.no_grad():
@@ -131,12 +96,7 @@ def test_output_pretrained(self):
13196
# fmt: off
13297
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
13398
# fmt: on
134-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
135-
136-
@unittest.skip("Test not supported.")
137-
def test_forward_with_norm_groups(self):
138-
# Not implemented yet for this UNet
139-
pass
99+
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
140100

141101
@slow
142102
def test_unet_1d_maestro(self):
@@ -157,98 +117,26 @@ def test_unet_1d_maestro(self):
157117
assert (output_sum - 224.0896).abs() < 0.5
158118
assert (output_max - 0.0607).abs() < 4e-4
159119

160-
@pytest.mark.xfail(
161-
reason=(
162-
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
163-
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
164-
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
165-
"2. Unskip this test."
166-
),
167-
)
168-
def test_layerwise_casting_inference(self):
169-
super().test_layerwise_casting_inference()
170-
171-
@pytest.mark.xfail(
172-
reason=(
173-
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
174-
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
175-
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
176-
"2. Unskip this test."
177-
),
178-
)
179-
def test_layerwise_casting_memory(self):
180-
pass
181-
182-
183-
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
184-
model_class = UNet1DModel
185-
main_input_name = "sample"
186120

121+
class UNetRLModelTesterConfig(BaseModelTesterConfig):
187122
@property
188-
def dummy_input(self):
189-
batch_size = 4
190-
num_features = 14
191-
seq_len = 16
192-
193-
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
194-
time_step = torch.tensor([10] * batch_size).to(torch_device)
195-
196-
return {"sample": noise, "timestep": time_step}
123+
def model_class(self):
124+
return UNet1DModel
197125

198126
@property
199-
def input_shape(self):
200-
return (4, 14, 16)
127+
def main_input_name(self) -> str:
128+
return "sample"
201129

202130
@property
203-
def output_shape(self):
204-
return (4, 14, 1)
205-
206-
def test_determinism(self):
207-
super().test_determinism()
208-
209-
def test_outputs_equivalence(self):
210-
super().test_outputs_equivalence()
211-
212-
def test_from_save_pretrained(self):
213-
super().test_from_save_pretrained()
214-
215-
def test_from_save_pretrained_variant(self):
216-
super().test_from_save_pretrained_variant()
217-
218-
def test_model_from_pretrained(self):
219-
super().test_model_from_pretrained()
220-
221-
def test_output(self):
222-
# UNetRL is a value-function is different output shape
223-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
224-
model = self.model_class(**init_dict)
225-
model.to(torch_device)
226-
model.eval()
227-
228-
with torch.no_grad():
229-
output = model(**inputs_dict)
230-
231-
if isinstance(output, dict):
232-
output = output.sample
233-
234-
self.assertIsNotNone(output)
235-
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
236-
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
237-
238-
@unittest.skip("Test not supported.")
239-
def test_ema_training(self):
240-
pass
131+
def output_shape(self) -> tuple:
132+
return (1,)
241133

242-
@unittest.skip("Test not supported.")
243-
def test_training(self):
244-
pass
245-
246-
@unittest.skip("Test not supported.")
247-
def test_layerwise_casting_training(self):
248-
pass
134+
@property
135+
def generator(self):
136+
return torch.Generator("cpu").manual_seed(0)
249137

250-
def prepare_init_args_and_inputs_for_common(self):
251-
init_dict = {
138+
def get_init_dict(self) -> dict:
139+
return {
252140
"in_channels": 14,
253141
"out_channels": 14,
254142
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
@@ -264,19 +152,36 @@ def prepare_init_args_and_inputs_for_common(self):
264152
"time_embedding_type": "positional",
265153
"act_fn": "mish",
266154
}
267-
inputs_dict = self.dummy_input
268-
return init_dict, inputs_dict
155+
156+
def get_dummy_inputs(self) -> dict:
157+
batch_size = 4
158+
num_features = 14
159+
seq_len = 16
160+
noise = randn_tensor((batch_size, num_features, seq_len), generator=self.generator, device=torch_device)
161+
timestep = torch.tensor([10] * batch_size, device=torch_device)
162+
return {"sample": noise, "timestep": timestep}
163+
164+
165+
class TestUNetRLModel(UNetRLModelTesterConfig, ModelTesterMixin):
166+
# UNetRL is a value function, so it has a different output shape.
167+
def test_output(self):
168+
model = self.model_class(**self.get_init_dict()).to(torch_device).eval()
169+
170+
inputs = self.get_dummy_inputs()
171+
with torch.no_grad():
172+
output = model(**inputs).sample
173+
174+
assert output.shape == (inputs["sample"].shape[0], 1), "Input and output shapes do not match"
269175

270176
def test_from_pretrained_hub(self):
271177
value_function, vf_loading_info = UNet1DModel.from_pretrained(
272178
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
273179
)
274-
self.assertIsNotNone(value_function)
275-
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
180+
assert value_function is not None
181+
assert len(vf_loading_info["missing_keys"]) == 0
276182

277183
value_function.to(torch_device)
278-
image = value_function(**self.dummy_input)
279-
184+
image = value_function(**self.get_dummy_inputs())
280185
assert image is not None, "Make sure output is not None"
281186

282187
def test_output_pretrained(self):
@@ -288,9 +193,7 @@ def test_output_pretrained(self):
288193

289194
num_features = value_function.config.in_channels
290195
seq_len = 14
291-
noise = torch.randn((1, seq_len, num_features)).permute(
292-
0, 2, 1
293-
) # match original, we can update values and remove
196+
noise = torch.randn((1, seq_len, num_features)).permute(0, 2, 1)
294197
time_step = torch.full((num_features,), 0)
295198

296199
with torch.no_grad():
@@ -299,31 +202,4 @@ def test_output_pretrained(self):
299202
# fmt: off
300203
expected_output_slice = torch.tensor([165.25] * seq_len)
301204
# fmt: on
302-
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
303-
304-
@unittest.skip("Test not supported.")
305-
def test_forward_with_norm_groups(self):
306-
# Not implemented yet for this UNet
307-
pass
308-
309-
@pytest.mark.xfail(
310-
reason=(
311-
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
312-
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
313-
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
314-
"2. Unskip this test."
315-
),
316-
)
317-
def test_layerwise_casting_inference(self):
318-
pass
319-
320-
@pytest.mark.xfail(
321-
reason=(
322-
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
323-
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
324-
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
325-
"2. Unskip this test."
326-
),
327-
)
328-
def test_layerwise_casting_memory(self):
329-
pass
205+
assert torch.allclose(output, expected_output_slice, rtol=1e-3)

0 commit comments

Comments
 (0)