diff --git a/src/ntops/kernels/eq.py b/src/ntops/kernels/eq.py new file mode 100644 index 0000000..2f0c0bb --- /dev/null +++ b/src/ntops/kernels/eq.py @@ -0,0 +1,17 @@ +import functools + +import ninetoothed +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + output = input == other # noqa: F841 + + +@functools.cache +def make(ndim): + tensors = (Tensor(ndim), Tensor(ndim), Tensor(ndim)) + + return ninetoothed.make(arrangement, application, tensors) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 9a4c895..4b39273 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -6,6 +6,7 @@ import ntops.kernels.bmm import ntops.kernels.cos import ntops.kernels.div +import ntops.kernels.eq import ntops.kernels.exp import ntops.kernels.gelu import ntops.kernels.mm @@ -99,6 +100,17 @@ def exp(input, *, out=None): return out +def eq(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = ntops.kernels.eq.make(input.ndim) + + kernel(input, other, out) + + return out + + def gelu(input, approximate="none"): output = torch.empty_like(input) diff --git a/tests/test_eq.py b/tests/test_eq.py new file mode 100644 index 0000000..6a58ffc --- /dev/null +++ b/tests/test_eq.py @@ -0,0 +1,20 @@ +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.eq(input, other) + reference_output = torch.eq(input, other) + + assert torch.equal(ninetoothed_output, reference_output)