diff --git a/src/ntops/kernels/relu.py b/src/ntops/kernels/relu.py new file mode 100644 index 0000000..2e4cc3e --- /dev/null +++ b/src/ntops/kernels/relu.py @@ -0,0 +1,17 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = max(0.0, input) # noqa: F841 + + +@functools.cache +def make(ndim): + tensors = (Tensor(ndim), Tensor(ndim)) + + return ninetoothed.make(arrangement, application, tensors) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 6d4e487..28e314a 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -9,6 +9,7 @@ import ntops.kernels.gelu import ntops.kernels.mm import ntops.kernels.mul +import ntops.kernels.relu import ntops.kernels.rsqrt @@ -119,6 +120,19 @@ def mul(input, other, *, out=None): return out +def relu(input, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = ntops.kernels.relu.make(input.ndim) + + kernel(input, output) + + return output + + def rsqrt(input, *, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_relu.py b/tests/test_relu.py new file mode 100644 index 0000000..6f16cb6 --- /dev/null +++ b/tests/test_relu.py @@ -0,0 +1,23 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + + for inplace in (False, True): + ninetoothed_output = ntops.torch.relu(input, inplace) + reference_output = F.relu(input, inplace) + + assert torch.allclose( + ninetoothed_output, reference_output, atol=atol, rtol=rtol + )