diff --git a/src/ntops/kernels/isinf.py b/src/ntops/kernels/isinf.py new file mode 100644 index 0000000..9e5fdcb --- /dev/null +++ b/src/ntops/kernels/isinf.py @@ -0,0 +1,17 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + pos_result = input == float("+inf") + neg_result = input == float("-inf") + output = pos_result or neg_result # noqa: F841 + + +@functools.cache +def make(ndim): + return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim))) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 9a4c895..41ae3f2 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -8,6 +8,7 @@ import ntops.kernels.div import ntops.kernels.exp import ntops.kernels.gelu +import ntops.kernels.isinf import ntops.kernels.mm import ntops.kernels.mul import ntops.kernels.relu @@ -109,6 +110,16 @@ def gelu(input, approximate="none"): return output +def isinf(input): + output = torch.empty_like(input) + + kernel = ntops.kernels.isinf.make(input.ndim) + + kernel(input, output) + + return output + + def mm(input, mat2, *, out=None): m, _ = input.shape _, n = mat2.shape diff --git a/tests/test_isinf.py b/tests/test_isinf.py new file mode 100644 index 0000000..c7739db --- /dev/null +++ b/tests/test_isinf.py @@ -0,0 +1,32 @@ +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + device = "cuda" + + def generate_inf_tensor(shape, dtype, device): + x = torch.randn(shape, dtype=dtype, device=device) + + probs = (0.2, 0.6) + prob_tensor = torch.rand(shape, device=device) + + mask = (probs[0] < prob_tensor) & (prob_tensor < probs[1]) + x[mask] = float("inf") + mask = probs[1] < prob_tensor + x[mask] = float("-inf") + + return x + + input = generate_inf_tensor(shape, dtype, device) + + ninetoothed_output = ntops.torch.isinf(input) + reference_output = torch.isinf(input) + + assert torch.equal(ninetoothed_output, reference_output)