Skip to content

Commit 1914971

Browse files
committed
style: satisfy generated torch op checks
1 parent 5ebacfb commit 1914971

2 files changed

Lines changed: 18 additions & 5 deletions

File tree

src/cuda/swiglu/kernel.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ __global__ void SwigluKernel(T* __restrict__ out, const T* __restrict__ a,
7777
out[out_idx] = Caster<kDev>::template Cast<T>(
7878
__fmul_rn(__fmul_rn(gatef, sigf), upf));
7979
} else if constexpr (std::is_same_v<T, float>) {
80-
out[out_idx] = __fmul_rn(__fmul_rn(gate, detail::Sigmoid<kDev>(gate)), up);
80+
out[out_idx] =
81+
__fmul_rn(__fmul_rn(gate, detail::Sigmoid<kDev>(gate)), up);
8182
} else {
8283
out[out_idx] = gate * detail::Sigmoid<kDev>(gate) * up;
8384
}

tests/test_torch_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ def test_op(op_meta, shape, dtype, device, rtol, atol):
312312
if aten_name in _RANDOM_OPS:
313313
pytest.skip(f"`{aten_name}` is non-deterministic (independent draws diverge)")
314314
if device == "cuda" and aten_name in _DEVICE_ASSERTING_OPS:
315-
pytest.skip(f"`{aten_name}` triggers a CUDA device-side assert on random inputs")
315+
pytest.skip(
316+
f"`{aten_name}` triggers a CUDA device-side assert on random inputs"
317+
)
316318

317319
in_params = [p for p in op_meta["params"] if not p["is_out"]]
318320
out_params = [p for p in op_meta["params"] if p["is_out"]]
@@ -333,7 +335,13 @@ def test_op(op_meta, shape, dtype, device, rtol, atol):
333335
# not in the InfiniOps wrapper.
334336
try:
335337
ref = _torch_func(aten_name)(*inputs)
336-
except (RuntimeError, TypeError, ValueError, IndexError, NotImplementedError) as exc:
338+
except (
339+
RuntimeError,
340+
TypeError,
341+
ValueError,
342+
IndexError,
343+
NotImplementedError,
344+
) as exc:
337345
pytest.skip(f"`torch.{aten_name}` rejects these inputs: {exc}")
338346

339347
ref_outs = ref if isinstance(ref, tuple) else (ref,)
@@ -357,13 +365,17 @@ def test_op(op_meta, shape, dtype, device, rtol, atol):
357365
(t.dtype for t in tensors if t.dtype not in _SUPPORTED_DTYPES), None
358366
)
359367
if unsupported is not None:
360-
pytest.skip(f"`{op_name}` uses dtype {unsupported} — not in InfiniOps `DataType`")
368+
pytest.skip(
369+
f"`{op_name}` uses dtype {unsupported} — not in InfiniOps `DataType`"
370+
)
361371

362372
# On CUDA, `torch.empty_like` of a 0-element tensor gives a tensor
363373
# whose `data_ptr()` is unregistered with the device; passing it
364374
# through to the wrapper trips "pointer resides on host memory".
365375
if any(t.numel() == 0 for t in ref_outs):
366-
pytest.skip(f"`{op_name}` produced 0-element output (unregistered data_ptr on cuda)")
376+
pytest.skip(
377+
f"`{op_name}` produced 0-element output (unregistered data_ptr on cuda)"
378+
)
367379

368380
outs = [torch.empty_like(t) for t in ref_outs]
369381
_call_infini(op_name, *inputs, *outs)

0 commit comments

Comments
 (0)