From 90fa74de976cb41b34c1cfc94991355c6e3528d5 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Sun, 27 Apr 2025 21:09:08 +0800 Subject: [PATCH] Add `silu` operator --- src/ntops/kernels/silu.py | 18 ++++++++++++++++++ src/ntops/torch.py | 14 ++++++++++++++ tests/test_silu.py | 21 +++++++++++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 src/ntops/kernels/silu.py create mode 100644 tests/test_silu.py diff --git a/src/ntops/kernels/silu.py b/src/ntops/kernels/silu.py new file mode 100644 index 0000000..37a946c --- /dev/null +++ b/src/ntops/kernels/silu.py @@ -0,0 +1,18 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = input / (1 + ntl.exp(-ntl.cast(input, ntl.float32))) # noqa: F841 + + +@functools.cache +def make(ndim): + tensors = (Tensor(ndim), Tensor(ndim)) + + return ninetoothed.make(arrangement, application, tensors) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 8e3aa50..5509ade 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -30,6 +30,7 @@ import ntops.kernels.relu import ntops.kernels.rsqrt import ntops.kernels.sigmoid +import ntops.kernels.silu import ntops.kernels.sin import ntops.kernels.softmax import ntops.kernels.sub @@ -362,6 +363,19 @@ def sigmoid(input, *, out=None): return out +def silu(input, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = ntops.kernels.silu.make(input.ndim) + + kernel(input, output) + + return output + + def sin(input, *, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_silu.py b/tests/test_silu.py new file mode 100644 index 0000000..037f9b9 --- /dev/null +++ b/tests/test_silu.py @@ -0,0 +1,21 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + + # TODO: Add `inplace` tests later. + ninetoothed_output = ntops.torch.silu(input) + reference_output = F.silu(input) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)