Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def randn_tensor(
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
if device.type != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
Expand Down
37 changes: 37 additions & 0 deletions tests/others/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,43 @@ def test_fourier_filter_preserves_dtype_and_shape(self):
assert out.shape == x.shape


class RandnTensorTester(unittest.TestCase):
"""Tests for :func:`diffusers.utils.torch_utils.randn_tensor`."""

def test_mps_suppresses_cpu_generator_info_log(self):
import torch

from diffusers.utils import logging as diffusers_logging
from diffusers.utils import torch_utils

from ..testing_utils import CaptureLogger

gen = torch.Generator(device="cpu")
diffusers_logging.set_verbosity_info()

def _capture(target_device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Refactored the test to use CaptureLogger — much cleaner. Also rebased on latest main to resolve the conflict.

with CaptureLogger(torch_utils.logger) as cl:
try:
torch_utils.randn_tensor((1, 2), generator=gen, device=target_device, dtype=torch.float32)
except (AssertionError, RuntimeError):
pass
return cl.out

mps_out = _capture("mps")
self.assertNotIn(
"moved to",
mps_out,
f"MPS target should not emit the CPU-fallback info log, got: {mps_out}",
)

cuda_out = _capture("cuda")
self.assertIn(
"moved to",
cuda_out,
f"Non-MPS target should still emit the CPU-fallback info log, got: {cuda_out}",
)


# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):
def test_expectations(self):
Expand Down
Loading