@@ -251,48 +251,37 @@ class RandnTensorTester(unittest.TestCase):
251251 """Tests for :func:`diffusers.utils.torch_utils.randn_tensor`."""
252252
253253 def test_mps_suppresses_cpu_generator_info_log (self ):
254- """
255- When a CPU generator targets MPS, the informational log about falling back to CPU
256- should be suppressed — MPS does not support device-side generators, so the message
257- would be misleading. Prior to fixing the ``device != "mps"`` string/``torch.device``
258- comparison this log fired on every device, including MPS.
259- """
260- import logging as py_logging
261-
262254 import torch
263255
264256 from diffusers .utils import logging as diffusers_logging
265257 from diffusers .utils import torch_utils
266258
267- gen = torch . Generator ( device = "cpu" )
259+ from .. testing_utils import CaptureLogger
268260
261+ gen = torch .Generator (device = "cpu" )
269262 diffusers_logging .set_verbosity_info ()
270- prev_level = torch_utils .logger .level
271- torch_utils .logger .setLevel (py_logging .INFO )
272263
273264 def _capture (target_device ):
274- with self .assertLogs (torch_utils .logger , level = "INFO" ) as cm :
275- torch_utils .logger .info ("sentinel" )
265+ with CaptureLogger (torch_utils .logger ) as cl :
276266 try :
277267 torch_utils .randn_tensor ((1 , 2 ), generator = gen , device = target_device , dtype = torch .float32 )
278268 except (AssertionError , RuntimeError ):
279269 pass
280- return [ m for m in cm . output if "sentinel" not in m ]
270+ return cl . out
281271
282- try :
283- mps_logs = _capture ( "mps" )
284- self . assertFalse (
285- any ( "moved to" in m and "generator" in m for m in mps_logs ) ,
286- f"MPS target should not emit the CPU-fallback info log, got: { mps_logs } " ,
287- )
272+ mps_out = _capture ( "mps" )
273+ self . assertNotIn (
274+ "moved to" ,
275+ mps_out ,
276+ f"MPS target should not emit the CPU-fallback info log, got: { mps_out } " ,
277+ )
288278
289- cuda_logs = _capture ("cuda" )
290- self .assertTrue (
291- any ("moved to" in m and "generator" in m for m in cuda_logs ),
292- f"Non-MPS target should still emit the CPU-fallback info log, got: { cuda_logs } " ,
293- )
294- finally :
295- torch_utils .logger .setLevel (prev_level )
279+ cuda_out = _capture ("cuda" )
280+ self .assertIn (
281+ "moved to" ,
282+ cuda_out ,
283+ f"Non-MPS target should still emit the CPU-fallback info log, got: { cuda_out } " ,
284+ )
296285
297286
298287# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
0 commit comments