@@ -312,7 +312,9 @@ def test_op(op_meta, shape, dtype, device, rtol, atol):
312312 if aten_name in _RANDOM_OPS :
313313 pytest .skip (f"`{ aten_name } ` is non-deterministic (independent draws diverge)" )
314314 if device == "cuda" and aten_name in _DEVICE_ASSERTING_OPS :
315- pytest .skip (f"`{ aten_name } ` triggers a CUDA device-side assert on random inputs" )
315+ pytest .skip (
316+ f"`{ aten_name } ` triggers a CUDA device-side assert on random inputs"
317+ )
316318
317319 in_params = [p for p in op_meta ["params" ] if not p ["is_out" ]]
318320 out_params = [p for p in op_meta ["params" ] if p ["is_out" ]]
@@ -333,7 +335,13 @@ def test_op(op_meta, shape, dtype, device, rtol, atol):
333335 # not in the InfiniOps wrapper.
334336 try :
335337 ref = _torch_func (aten_name )(* inputs )
336- except (RuntimeError , TypeError , ValueError , IndexError , NotImplementedError ) as exc :
338+ except (
339+ RuntimeError ,
340+ TypeError ,
341+ ValueError ,
342+ IndexError ,
343+ NotImplementedError ,
344+ ) as exc :
337345 pytest .skip (f"`torch.{ aten_name } ` rejects these inputs: { exc } " )
338346
339347 ref_outs = ref if isinstance (ref , tuple ) else (ref ,)
@@ -357,13 +365,17 @@ def test_op(op_meta, shape, dtype, device, rtol, atol):
357365 (t .dtype for t in tensors if t .dtype not in _SUPPORTED_DTYPES ), None
358366 )
359367 if unsupported is not None :
360- pytest .skip (f"`{ op_name } ` uses dtype { unsupported } — not in InfiniOps `DataType`" )
368+ pytest .skip (
369+ f"`{ op_name } ` uses dtype { unsupported } — not in InfiniOps `DataType`"
370+ )
361371
362372 # On CUDA, `torch.empty_like` of a 0-element tensor gives a tensor
363373 # whose `data_ptr()` is unregistered with the device; passing it
364374 # through to the wrapper trips "pointer resides on host memory".
365375 if any (t .numel () == 0 for t in ref_outs ):
366- pytest .skip (f"`{ op_name } ` produced 0-element output (unregistered data_ptr on cuda)" )
376+ pytest .skip (
377+ f"`{ op_name } ` produced 0-element output (unregistered data_ptr on cuda)"
378+ )
367379
368380 outs = [torch .empty_like (t ) for t in ref_outs ]
369381 _call_infini (op_name , * inputs , * outs )
0 commit comments