From be8133c0b79f40d08fce3b9910e7c04a2c413295 Mon Sep 17 00:00:00 2001 From: stever178 <2874146120@qq.com> Date: Sat, 10 May 2025 00:09:17 +0800 Subject: [PATCH] Add `sigmoid` operator --- src/ntops/kernels/sigmoid.py | 16 ++++++++++++++++ src/ntops/torch.py | 12 ++++++++++++ tests/test_sigmoid.py | 23 +++++++++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 src/ntops/kernels/sigmoid.py create mode 100644 tests/test_sigmoid.py diff --git a/src/ntops/kernels/sigmoid.py b/src/ntops/kernels/sigmoid.py new file mode 100644 index 0000000..092c6f9 --- /dev/null +++ b/src/ntops/kernels/sigmoid.py @@ -0,0 +1,16 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.sigmoid(input) # noqa: F841 + + +@functools.cache +def make(ndim): + return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim))) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 6d4e487..28fd974 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -10,6 +10,7 @@ import ntops.kernels.mm import ntops.kernels.mul import ntops.kernels.rsqrt +import ntops.kernels.sigmoid def abs(input, *, out=None): @@ -128,3 +129,14 @@ def rsqrt(input, *, out=None): kernel(input, out) return out + + +def sigmoid(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = ntops.kernels.sigmoid.make(input.ndim) + + kernel(input, out) + + return out diff --git a/tests/test_sigmoid.py b/tests/test_sigmoid.py new file mode 100644 index 0000000..1906eaf --- /dev/null +++ b/tests/test_sigmoid.py @@ -0,0 +1,23 @@ +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.sigmoid(input) + reference_output = torch.sigmoid(input) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)