Skip to content

Commit 9ec6385

Browse files
Ricardo-M-Lclaude
andcommitted
refactor: use CaptureLogger instead of assertLogs in randn_tensor test
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 49370e3 commit 9ec6385

1 file changed

Lines changed: 16 additions & 27 deletions

File tree

tests/others/test_utils.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -251,48 +251,37 @@ class RandnTensorTester(unittest.TestCase):
251251
"""Tests for :func:`diffusers.utils.torch_utils.randn_tensor`."""
252252

253253
def test_mps_suppresses_cpu_generator_info_log(self):
254-
"""
255-
When a CPU generator targets MPS, the informational log about falling back to CPU
256-
should be suppressed — MPS does not support device-side generators, so the message
257-
would be misleading. Prior to fixing the ``device != "mps"`` string/``torch.device``
258-
comparison this log fired on every device, including MPS.
259-
"""
260-
import logging as py_logging
261-
262254
import torch
263255

264256
from diffusers.utils import logging as diffusers_logging
265257
from diffusers.utils import torch_utils
266258

267-
gen = torch.Generator(device="cpu")
259+
from ..testing_utils import CaptureLogger
268260

261+
gen = torch.Generator(device="cpu")
269262
diffusers_logging.set_verbosity_info()
270-
prev_level = torch_utils.logger.level
271-
torch_utils.logger.setLevel(py_logging.INFO)
272263

273264
def _capture(target_device):
274-
with self.assertLogs(torch_utils.logger, level="INFO") as cm:
275-
torch_utils.logger.info("sentinel")
265+
with CaptureLogger(torch_utils.logger) as cl:
276266
try:
277267
torch_utils.randn_tensor((1, 2), generator=gen, device=target_device, dtype=torch.float32)
278268
except (AssertionError, RuntimeError):
279269
pass
280-
return [m for m in cm.output if "sentinel" not in m]
270+
return cl.out
281271

282-
try:
283-
mps_logs = _capture("mps")
284-
self.assertFalse(
285-
any("moved to" in m and "generator" in m for m in mps_logs),
286-
f"MPS target should not emit the CPU-fallback info log, got: {mps_logs}",
287-
)
272+
mps_out = _capture("mps")
273+
self.assertNotIn(
274+
"moved to",
275+
mps_out,
276+
f"MPS target should not emit the CPU-fallback info log, got: {mps_out}",
277+
)
288278

289-
cuda_logs = _capture("cuda")
290-
self.assertTrue(
291-
any("moved to" in m and "generator" in m for m in cuda_logs),
292-
f"Non-MPS target should still emit the CPU-fallback info log, got: {cuda_logs}",
293-
)
294-
finally:
295-
torch_utils.logger.setLevel(prev_level)
279+
cuda_out = _capture("cuda")
280+
self.assertIn(
281+
"moved to",
282+
cuda_out,
283+
f"Non-MPS target should still emit the CPU-fallback info log, got: {cuda_out}",
284+
)
296285

297286

298287
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py

0 commit comments

Comments
 (0)