Skip to content

Commit 0222cab

Browse files
authored
Merge pull request #9 from stever178/develop-relu
Add `relu` operator
2 parents a7fa478 + 54e67af commit 0222cab

3 files changed

Lines changed: 54 additions & 0 deletions

File tree

src/ntops/kernels/relu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import functools
2+
3+
import ninetoothed
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.element_wise import arrangement
7+
8+
9+
def application(input, output):
10+
output = max(0.0, input) # noqa: F841
11+
12+
13+
@functools.cache
14+
def make(ndim):
15+
tensors = (Tensor(ndim), Tensor(ndim))
16+
17+
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import ntops.kernels.gelu
1010
import ntops.kernels.mm
1111
import ntops.kernels.mul
12+
import ntops.kernels.relu
1213
import ntops.kernels.rsqrt
1314

1415

@@ -119,6 +120,19 @@ def mul(input, other, *, out=None):
119120
return out
120121

121122

123+
def relu(input, inplace=False):
124+
if inplace:
125+
output = input
126+
else:
127+
output = torch.empty_like(input)
128+
129+
kernel = ntops.kernels.relu.make(input.ndim)
130+
131+
kernel(input, output)
132+
133+
return output
134+
135+
122136
def rsqrt(input, *, out=None):
123137
if out is None:
124138
out = torch.empty_like(input)

tests/test_relu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops.torch
6+
from tests.skippers import skip_if_cuda_not_available
7+
from tests.utils import generate_arguments
8+
9+
10+
@skip_if_cuda_not_available
11+
@pytest.mark.parametrize(*generate_arguments())
12+
def test_cuda(shape, dtype, atol, rtol):
13+
device = "cuda"
14+
15+
input = torch.randn(shape, dtype=dtype, device=device)
16+
17+
for inplace in (False, True):
18+
ninetoothed_output = ntops.torch.relu(input, inplace)
19+
reference_output = F.relu(input, inplace)
20+
21+
assert torch.allclose(
22+
ninetoothed_output, reference_output, atol=atol, rtol=rtol
23+
)

0 commit comments

Comments
 (0)