Skip to content

Commit fa33ecc

Browse files
committed
T1-1-19 Add cosh operator
1 parent e49dd52 commit fa33ecc

File tree

5 files changed

+55
-0
lines changed

5 files changed

+55
-0
lines changed

src/ntops/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
clamp,
1111
conv2d,
1212
cos,
13+
cosh,
1314
div,
1415
dropout,
1516
eq,
@@ -54,6 +55,7 @@
5455
"clamp",
5556
"conv2d",
5657
"cos",
58+
"cosh",
5759
"div",
5860
"dropout",
5961
"eq",

src/ntops/kernels/cosh.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import functools
2+
3+
import ninetoothed.language as ntl
4+
from ninetoothed import Tensor
5+
from ninetoothed.language import libdevice
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, output):
11+
output = libdevice.cosh(ntl.cast(input, ntl.float32)) # noqa: F841
12+
13+
14+
def premake(ndim, dtype=None, block_size=None):
15+
arrangement_ = functools.partial(arrangement, block_size=block_size)
16+
17+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
18+
19+
return arrangement_, application, tensors

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ntops.torch.clamp import clamp
1010
from ntops.torch.conv2d import conv2d
1111
from ntops.torch.cos import cos
12+
from ntops.torch.cosh import cosh
1213
from ntops.torch.div import div
1314
from ntops.torch.dropout import dropout
1415
from ntops.torch.eq import eq
@@ -53,6 +54,7 @@
5354
"clamp",
5455
"conv2d",
5556
"cos",
57+
"cosh",
5658
"div",
5759
"dropout",
5860
"eq",

src/ntops/torch/cosh.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def cosh(input, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
kernel = _cached_make(ntops.kernels.cosh.premake, input.ndim)
12+
13+
kernel(input, out)
14+
15+
return out

tests/test_cosh.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
import torch
3+
4+
import ntops
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.utils import generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
def test_cosh(shape, dtype, device, rtol, atol):
12+
input = torch.randn(shape, dtype=dtype, device=device)
13+
14+
ninetoothed_output = ntops.torch.cosh(input)
15+
reference_output = torch.cosh(input)
16+
17+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)