Skip to content

Commit 55f0fa3

Browse files
committed
fix(randn_tensor): compare device.type, not torch.device to str, when suppressing MPS info log
When a CPU generator is passed with an MPS target, randn_tensor intentionally skips the 'generator was on cpu, tensor will be moved to <device>' info log — MPS doesn't support device-side generators, so the suggestion to create one on MPS would be misleading. The guard was written as `if device != "mps"`, but a few lines earlier `device` is coerced to a `torch.device` object, and `torch.device("mps") == "mps"` is False (torch.device's __eq__ with a string returns NotImplemented, falling back to identity — they're different types). Result: the guard is effectively always True, so MPS users get the spurious log whenever they pass a CPU generator — the opposite of the documented intent. Fix: compare `device.type` (a str) against "mps". Added a regression test in tests/others/test_utils.py that exercises both the MPS and non-MPS paths via `assertLogs` on the diffusers logger.
1 parent c8c8401 commit 55f0fa3

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

src/diffusers/utils/torch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def randn_tensor(
173173
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
174174
if gen_device_type != device.type and gen_device_type == "cpu":
175175
rand_device = "cpu"
176-
if device != "mps":
176+
if device.type != "mps":
177177
logger.info(
178178
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
179179
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"

tests/others/test_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,56 @@ def test_deprecate_testing_utils_module(self):
204204
), f"Expected deprecation message substring not found, got: {messages}"
205205

206206

207+
class RandnTensorTester(unittest.TestCase):
208+
"""Tests for :func:`diffusers.utils.torch_utils.randn_tensor`."""
209+
210+
def test_mps_suppresses_cpu_generator_info_log(self):
211+
"""
212+
When a CPU generator targets MPS, the informational log about falling back to CPU
213+
should be suppressed — MPS does not support device-side generators, so the message
214+
would be misleading. Prior to fixing the ``device != "mps"`` string/``torch.device``
215+
comparison this log fired on every device, including MPS.
216+
"""
217+
import logging as py_logging
218+
219+
import torch
220+
221+
from diffusers.utils import logging as diffusers_logging
222+
from diffusers.utils import torch_utils
223+
224+
gen = torch.Generator(device="cpu")
225+
226+
diffusers_logging.set_verbosity_info()
227+
prev_level = torch_utils.logger.level
228+
torch_utils.logger.setLevel(py_logging.INFO)
229+
230+
def _capture(target_device):
231+
with self.assertLogs(torch_utils.logger, level="INFO") as cm:
232+
torch_utils.logger.info("sentinel")
233+
try:
234+
torch_utils.randn_tensor((1, 2), generator=gen, device=target_device, dtype=torch.float32)
235+
except (AssertionError, RuntimeError):
236+
# Accelerator may be unavailable in the test env; we only care about the log path
237+
# before tensor materialization.
238+
pass
239+
return [m for m in cm.output if "sentinel" not in m]
240+
241+
try:
242+
mps_logs = _capture("mps")
243+
self.assertFalse(
244+
any("moved to" in m and "generator" in m for m in mps_logs),
245+
f"MPS target should not emit the CPU-fallback info log, got: {mps_logs}",
246+
)
247+
248+
cuda_logs = _capture("cuda")
249+
self.assertTrue(
250+
any("moved to" in m and "generator" in m for m in cuda_logs),
251+
f"Non-MPS target should still emit the CPU-fallback info log, got: {cuda_logs}",
252+
)
253+
finally:
254+
torch_utils.logger.setLevel(prev_level)
255+
256+
207257
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
208258
class ExpectationsTester(unittest.TestCase):
209259
def test_expectations(self):

0 commit comments

Comments
 (0)