diff --git a/src/ntops/kernels/sigmoid.py b/src/ntops/kernels/sigmoid.py new file mode 100644 index 0000000..092c6f9 --- /dev/null +++ b/src/ntops/kernels/sigmoid.py @@ -0,0 +1,16 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.sigmoid(input) # noqa: F841 + + +@functools.cache +def make(ndim): + return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim))) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index c00f65a..9a4c895 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -12,6 +12,7 @@ import ntops.kernels.mul import ntops.kernels.relu import ntops.kernels.rsqrt +import ntops.kernels.sigmoid import ntops.kernels.sin @@ -157,6 +158,17 @@ def rsqrt(input, *, out=None): return out +def sigmoid(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = ntops.kernels.sigmoid.make(input.ndim) + + kernel(input, out) + + return out + + def sin(input, *, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_sigmoid.py b/tests/test_sigmoid.py new file mode 100644 index 0000000..1906eaf --- /dev/null +++ b/tests/test_sigmoid.py @@ -0,0 +1,23 @@ +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.sigmoid(input) + reference_output = torch.sigmoid(input) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)