Skip to content

fix(randn_tensor): compare device.type, not torch.device, when suppressing MPS info log#13508

Open
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Ricardo-M-L:fix/randn-tensor-mps-device-comparison
Open

fix(randn_tensor): compare device.type, not torch.device, when suppressing MPS info log#13508
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Ricardo-M-L:fix/randn-tensor-mps-device-comparison

Conversation

@Ricardo-M-L
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes a latent bug in randn_tensor where the MPS-specific info-log suppression has been broken since PR #1902 (Jan 2023).

The bug

# src/diffusers/utils/torch_utils.py
if isinstance(device, str):
    device = torch.device(device)             # ← device is now a torch.device object
...
if generator is not None:
    gen_device_type = ...
    if gen_device_type != device.type and gen_device_type == "cpu":
        rand_device = "cpu"
        if device != "mps":                   # ← always True
            logger.info("The passed generator was created on 'cpu' ...")

The if device != "mps": guard was meant to silence the info log on MPS, because MPS doesn't support device-side generators — suggesting the user "create a generator on the mps device" would be misleading.

But by this point device is a torch.device object (coerced a few lines earlier). torch.device("mps") == "mps" is Falsetorch.device.__eq__ returns NotImplemented when compared with a string, Python falls back to identity, and the two are different types. So the guard is effectively always true, and MPS users get the spurious log on every call where a CPU generator is passed — the opposite of the documented intent.

Sanity check:

>>> import torch
>>> torch.device("mps") == "mps"
False
>>> torch.device("mps") != "mps"
True
>>> torch.device("mps").type == "mps"
True

The fix

One-character change: compare device.type (a str) to "mps".

-            if device != "mps":
+            if device.type != "mps":

Test

Added RandnTensorTester.test_mps_suppresses_cpu_generator_info_log in tests/others/test_utils.py, which:

  • Asserts no CPU-fallback info log is emitted when a CPU generator targets MPS.
  • Asserts the log is emitted when a CPU generator targets a non-MPS accelerator.

Verified that the test fails on main and passes with this fix. All existing tests in tests/others/test_utils.py still pass (14/14).

Before submitting

  • This PR fixes a bug that affects MPS users.
  • Did you read the contributor guideline?
  • Did you write any new necessary tests? — Yes, regression test added.

Who can review?

@sayakpaul @DN6

@github-actions github-actions Bot added tests utils size/M PR with diff < 200 LOC labels Apr 20, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks
I left a comment

prev_level = torch_utils.logger.level
torch_utils.logger.setLevel(py_logging.INFO)

def _capture(target_device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Refactored the test to use CaptureLogger — much cleaner. Also rebased on latest main to resolve the conflict.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Ricardo-M-L Ricardo-M-L force-pushed the fix/randn-tensor-mps-device-comparison branch from 55f0fa3 to 53c4793 Compare April 24, 2026 06:17
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/M PR with diff < 200 LOC labels Apr 24, 2026
Ricardo-M-L and others added 2 commits April 27, 2026 22:51
… suppressing MPS info log

When a CPU generator is passed with an MPS target, randn_tensor intentionally skips
the 'generator was on cpu, tensor will be moved to <device>' info log — MPS doesn't
support device-side generators, so the suggestion to create one on MPS would be
misleading. The guard was written as `if device != "mps"`, but a few lines
earlier `device` is coerced to a `torch.device` object, and
`torch.device("mps") == "mps"` is False (torch.device's __eq__ with a string
returns NotImplemented, falling back to identity — they're different types).

Result: the guard is effectively always True, so MPS users get the spurious log
whenever they pass a CPU generator — the opposite of the documented intent.

Fix: compare `device.type` (a str) against "mps". Added a regression test in
tests/others/test_utils.py that exercises both the MPS and non-MPS paths via
`assertLogs` on the diffusers logger.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Ricardo-M-L Ricardo-M-L force-pushed the fix/randn-tensor-mps-device-comparison branch from 53c4793 to 9ec6385 Compare April 27, 2026 14:51
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants