Skip to content

Commit 7495ffe

Browse files
authored
Merge pull request #15 from stever178/develop-tanh
Add `tanh` operator
2 parents b7d325f + 55ad183 commit 7495ffe

3 files changed

Lines changed: 53 additions & 0 deletions

File tree

src/ntops/kernels/tanh.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, output):
11+
exp_input = ntl.exp(input)
12+
exp_neg_input = ntl.exp(-input)
13+
output = (exp_input - exp_neg_input) / (exp_input + exp_neg_input) # noqa: F841
14+
15+
16+
@functools.cache
17+
def make(ndim):
18+
return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim)))

src/ntops/torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import ntops.kernels.rsqrt
1515
import ntops.kernels.sigmoid
1616
import ntops.kernels.sin
17+
import ntops.kernels.tanh
1718

1819

1920
def abs(input, *, out=None):
@@ -178,3 +179,14 @@ def sin(input, *, out=None):
178179
kernel(input, out)
179180

180181
return out
182+
183+
184+
def tanh(input, *, out=None):
185+
if out is None:
186+
out = torch.empty_like(input)
187+
188+
kernel = ntops.kernels.tanh.make(input.ndim)
189+
190+
kernel(input, out)
191+
192+
return out

tests/test_tanh.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
import ntops.torch
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.utils import generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
def test_cuda(shape, dtype, atol, rtol):
12+
# TODO: Test for `float16` later.
13+
if dtype is torch.float16:
14+
return
15+
16+
device = "cuda"
17+
18+
input = torch.randn(shape, dtype=dtype, device=device)
19+
20+
ninetoothed_output = ntops.torch.tanh(input)
21+
reference_output = torch.tanh(input)
22+
23+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)