diff --git a/src/ntops/kernels/cos.py b/src/ntops/kernels/cos.py new file mode 100644 index 0000000..72fcbcf --- /dev/null +++ b/src/ntops/kernels/cos.py @@ -0,0 +1,16 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.cos(input) # noqa: F841 + + +@functools.cache +def make(ndim): + return ninetoothed.make(arrangement, application, (Tensor(ndim), Tensor(ndim))) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 6d4e487..c27c2af 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -4,6 +4,7 @@ import ntops.kernels.add import ntops.kernels.addmm import ntops.kernels.bmm +import ntops.kernels.cos import ntops.kernels.div import ntops.kernels.exp import ntops.kernels.gelu @@ -62,6 +63,17 @@ def bmm(input, mat2, *, out=None): return out +def cos(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = ntops.kernels.cos.make(input.ndim) + + kernel(input, out) + + return out + + def div(input, other, *, rounding_mode=None, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_cos.py b/tests/test_cos.py new file mode 100644 index 0000000..6246ac1 --- /dev/null +++ b/tests/test_cos.py @@ -0,0 +1,23 @@ +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.cos(input) + reference_output = torch.cos(input) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)