Skip to content

Commit d1f8e55

Browse files
authored
[tests] fix consistency decoder tests (#13905)
* fix consistency decoder tests * address feedback * feedback * up
1 parent 9b249df commit d1f8e55

3 files changed

Lines changed: 25 additions & 38 deletions

File tree

tests/models/autoencoders/test_models_consistency_decoder_vae.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import gc
17-
import unittest
1817

1918
import numpy as np
2019
import torch
@@ -103,14 +102,12 @@ class TestConsistencyDecoderVAESlicingTiling(ConsistencyDecoderVAETesterConfig,
103102

104103

105104
@slow
106-
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
107-
def setUp(self):
108-
super().setUp()
105+
class TestConsistencyDecoderVAEIntegration:
106+
def setup_method(self):
109107
gc.collect()
110108
backend_empty_cache(torch_device)
111109

112-
def tearDown(self):
113-
super().tearDown()
110+
def teardown_method(self):
114111
gc.collect()
115112
backend_empty_cache(torch_device)
116113

tests/models/testing_utils/common.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def get_dummy_inputs(self) -> Dict[str, Any]:
242242
"""
243243
Returns dict of inputs to pass to the model forward pass.
244244
245+
Implementations must be deterministic: every call must return identical inputs (seed any random
246+
tensors and generators), since tests call this once per forward pass to compare outputs.
247+
245248
Returns:
246249
Dict[str, Any]: Input tensors/values for model.forward().
247250
@@ -292,9 +295,8 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
292295
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
293296
)
294297

295-
inputs_dict = self.get_dummy_inputs()
296-
image = model(**inputs_dict, return_dict=False)[0]
297-
new_image = new_model(**inputs_dict, return_dict=False)[0]
298+
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
299+
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
298300

299301
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
300302

@@ -314,9 +316,8 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
314316

315317
new_model.to(torch_device)
316318

317-
inputs_dict = self.get_dummy_inputs()
318-
image = model(**inputs_dict, return_dict=False)[0]
319-
new_image = new_model(**inputs_dict, return_dict=False)[0]
319+
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
320+
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
320321

321322
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
322323

@@ -344,9 +345,8 @@ def test_determinism(self, atol=1e-5, rtol=0):
344345
model.to(torch_device)
345346
model.eval()
346347

347-
inputs_dict = self.get_dummy_inputs()
348-
first = model(**inputs_dict, return_dict=False)[0]
349-
second = model(**inputs_dict, return_dict=False)[0]
348+
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
349+
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
350350

351351
first_flat = first.flatten()
352352
second_flat = second.flatten()
@@ -403,9 +403,8 @@ def recursive_check(tuple_object, dict_object):
403403
model.to(torch_device)
404404
model.eval()
405405

406-
inputs_dict = self.get_dummy_inputs()
407-
outputs_dict = model(**inputs_dict)
408-
outputs_tuple = model(**inputs_dict, return_dict=False)
406+
outputs_dict = model(**self.get_dummy_inputs())
407+
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
409408

410409
recursive_check(outputs_tuple, outputs_dict)
411410

@@ -509,11 +508,10 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
509508
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
510509
torch.manual_seed(0)
511510
config = self.get_init_dict()
512-
inputs_dict = self.get_dummy_inputs()
513511
model = self.model_class(**config).eval()
514512
model = model.to(torch_device)
515513

516-
base_output = model(**inputs_dict, return_dict=False)[0]
514+
base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]
517515

518516
model_size = compute_module_persistent_sizes(model)[""]
519517
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -532,10 +530,7 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
532530
new_model = new_model.to(torch_device)
533531

534532
torch.manual_seed(0)
535-
# Re-create inputs only if they contain a generator (which needs to be reset)
536-
if "generator" in inputs_dict:
537-
inputs_dict = self.get_dummy_inputs()
538-
new_output = new_model(**inputs_dict, return_dict=False)[0]
533+
new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
539534

540535
assert_tensors_close(
541536
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
@@ -546,11 +541,10 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
546541
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
547542
torch.manual_seed(0)
548543
config = self.get_init_dict()
549-
inputs_dict = self.get_dummy_inputs()
550544
model = self.model_class(**config).eval()
551545
model = model.to(torch_device)
552546

553-
base_output = model(**inputs_dict, return_dict=False)[0]
547+
base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]
554548

555549
model_size = compute_module_persistent_sizes(model)[""]
556550
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -574,10 +568,7 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
574568
new_model = new_model.to(torch_device)
575569

576570
torch.manual_seed(0)
577-
# Re-create inputs only if they contain a generator (which needs to be reset)
578-
if "generator" in inputs_dict:
579-
inputs_dict = self.get_dummy_inputs()
580-
new_output = new_model(**inputs_dict, return_dict=False)[0]
571+
new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
581572

582573
assert_tensors_close(
583574
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
@@ -589,11 +580,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
589580

590581
torch.manual_seed(0)
591582
config = self.get_init_dict()
592-
inputs_dict = self.get_dummy_inputs()
593583
model = self.model_class(**config).eval()
594584
model = model.to(torch_device)
595585

596-
base_output = model(**inputs_dict, return_dict=False)[0]
586+
base_output = model(**self.get_dummy_inputs(), return_dict=False)[0]
597587

598588
model_size = compute_module_persistent_sizes(model)[""]
599589
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -627,10 +617,7 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
627617
model_parallel = model_parallel.to(torch_device)
628618

629619
torch.manual_seed(0)
630-
# Re-create inputs only if they contain a generator (which needs to be reset)
631-
if "generator" in inputs_dict:
632-
inputs_dict = self.get_dummy_inputs()
633-
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
620+
output_parallel = model_parallel(**self.get_dummy_inputs(), return_dict=False)[0]
634621

635622
assert_tensors_close(
636623
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
3636
from diffusers.utils import logging
3737
from diffusers.utils.import_utils import is_xformers_available
38+
from diffusers.utils.torch_utils import randn_tensor
3839

3940
from ...testing_utils import (
4041
backend_empty_cache,
@@ -391,11 +392,13 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
391392
batch_size = 4
392393
num_channels = 4
393394
sizes = (16, 16)
395+
# Seed locally so repeated calls (e.g. one per forward pass in the mixins) yield identical inputs.
396+
generator = torch.Generator("cpu").manual_seed(0)
394397

395398
return {
396-
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
399+
"sample": randn_tensor((batch_size, num_channels) + sizes, generator=generator, device=torch_device),
397400
"timestep": torch.tensor([10]).to(torch_device),
398-
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
401+
"encoder_hidden_states": randn_tensor((batch_size, 4, 8), generator=generator, device=torch_device),
399402
}
400403

401404

0 commit comments

Comments
 (0)