fix(randn_tensor): compare device.type, not torch.device, when suppressing MPS info log#13508
Open
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Open
fix(randn_tensor): compare device.type, not torch.device, when suppressing MPS info log#13508Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Conversation
yiyixuxu
approved these changes
Apr 21, 2026
Collaborator
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks
I left a comment
| prev_level = torch_utils.logger.level | ||
| torch_utils.logger.setLevel(py_logging.INFO) | ||
|
|
||
| def _capture(target_device): |
Collaborator
There was a problem hiding this comment.
Contributor
Author
There was a problem hiding this comment.
Good call! Refactored the test to use CaptureLogger — much cleaner. Also rebased on latest main to resolve the conflict.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
55f0fa3 to
53c4793
Compare
… suppressing MPS info log
When a CPU generator is passed with an MPS target, randn_tensor intentionally skips
the 'generator was on cpu, tensor will be moved to <device>' info log — MPS doesn't
support device-side generators, so the suggestion to create one on MPS would be
misleading. The guard was written as `if device != "mps"`, but a few lines
earlier `device` is coerced to a `torch.device` object, and
`torch.device("mps") == "mps"` is False (torch.device's __eq__ with a string
returns NotImplemented, falling back to identity — they're different types).
Result: the guard is effectively always True, so MPS users get the spurious log
whenever they pass a CPU generator — the opposite of the documented intent.
Fix: compare `device.type` (a str) against "mps". Added a regression test in
tests/others/test_utils.py that exercises both the MPS and non-MPS paths via
`assertLogs` on the diffusers logger.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
53c4793 to
9ec6385
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a latent bug in
randn_tensorwhere the MPS-specific info-log suppression has been broken since PR #1902 (Jan 2023).The bug
The
if device != "mps":guard was meant to silence the info log on MPS, because MPS doesn't support device-side generators — suggesting the user "create a generator on the mps device" would be misleading.But by this point
deviceis atorch.deviceobject (coerced a few lines earlier).torch.device("mps") == "mps"isFalse—torch.device.__eq__returnsNotImplementedwhen compared with a string, Python falls back to identity, and the two are different types. So the guard is effectively always true, and MPS users get the spurious log on every call where a CPU generator is passed — the opposite of the documented intent.Sanity check:
The fix
One-character change: compare
device.type(astr) to"mps".Test
Added
RandnTensorTester.test_mps_suppresses_cpu_generator_info_logintests/others/test_utils.py, which:Verified that the test fails on
mainand passes with this fix. All existing tests intests/others/test_utils.pystill pass (14/14).Before submitting
Who can review?
@sayakpaul @DN6