Skip to content

Commit 047345a

Browse files
Ziminlivoltjia
authored andcommitted
Add sub operator
1 parent d21f180 commit 047345a

3 files changed

Lines changed: 50 additions & 0 deletions

File tree

src/ntops/kernels/sub.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import functools
2+
3+
import ninetoothed
4+
from ninetoothed import Tensor
5+
6+
from ntops.kernels.element_wise import arrangement
7+
8+
9+
def application(input, other, alpha, output):
10+
output = input - alpha * other # noqa: F841
11+
12+
13+
@functools.cache
14+
def make(ndim):
15+
tensors = (Tensor(ndim), Tensor(ndim), Tensor(0), Tensor(ndim))
16+
17+
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import ntops.kernels.sigmoid
3333
import ntops.kernels.sin
3434
import ntops.kernels.softmax
35+
import ntops.kernels.sub
3536
import ntops.kernels.tanh
3637

3738

@@ -384,6 +385,17 @@ def softmax(input, dim, dtype=None):
384385
return output
385386

386387

388+
def sub(input, other, *, alpha=1, out=None):
389+
if out is None:
390+
out = torch.empty_like(input)
391+
392+
kernel = ntops.kernels.sub.make(input.ndim)
393+
394+
kernel(input, other, alpha, out)
395+
396+
return out
397+
398+
387399
def tanh(input, *, out=None):
388400
if out is None:
389401
out = torch.empty_like(input)

tests/test_sub.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
import torch
3+
4+
import ntops.torch
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.utils import gauss, generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
def test_cuda(shape, dtype, atol, rtol):
12+
device = "cuda"
13+
14+
input = torch.randn(shape, dtype=dtype, device=device)
15+
other = torch.randn(shape, dtype=dtype, device=device)
16+
alpha = gauss()
17+
18+
ninetoothed_output = ntops.torch.sub(input, other, alpha=alpha)
19+
reference_output = torch.sub(input, other, alpha=alpha)
20+
21+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)