Skip to content

Commit ed31c58

Browse files
committed
Add relu operator
1 parent a7fa478 commit ed31c58

3 files changed

Lines changed: 47 additions & 0 deletions

File tree

src/ntops/kernels/relu.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import functools
2+
3+
import ninetoothed
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.element_wise import arrangement
7+
8+
9+
def application(input, output):
10+
output = max(0.0, input) # noqa: F841
11+
12+
13+
@functools.cache
14+
def make(ndim):
15+
return ninetoothed.make(arrangement, application,(Tensor(ndim), Tensor(ndim)))

src/ntops/torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import ntops.kernels.gelu
1010
import ntops.kernels.mm
1111
import ntops.kernels.mul
12+
import ntops.kernels.relu
1213
import ntops.kernels.rsqrt
1314

1415

@@ -119,6 +120,17 @@ def mul(input, other, *, out=None):
119120
return out
120121

121122

123+
def relu(input, *, out=None):
124+
if out is None:
125+
out = torch.empty_like(input)
126+
127+
kernel = ntops.kernels.relu.make(input.ndim)
128+
129+
kernel(input, out)
130+
131+
return out
132+
133+
122134
def rsqrt(input, *, out=None):
123135
if out is None:
124136
out = torch.empty_like(input)

tests/test_relu.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops.torch
6+
from tests.skippers import skip_if_cuda_not_available
7+
from tests.utils import generate_arguments
8+
9+
10+
@skip_if_cuda_not_available
11+
@pytest.mark.parametrize(*generate_arguments())
12+
def test_cuda(shape, dtype, atol, rtol):
13+
device = "cuda"
14+
15+
input = torch.randn(shape, dtype=dtype, device=device)
16+
17+
ninetoothed_output = ntops.torch.relu(input)
18+
reference_output = F.relu(input)
19+
20+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)