Skip to content

Commit 2098fe0

Browse files
authored
Merge pull request #6 from InfiniTensor/develop-rsqrt
Add `rsqrt` operator
2 parents 91947dc + 766dc46 commit 2098fe0

3 files changed

Lines changed: 55 additions & 1 deletion

File tree

src/ntops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from ntops.gelu import gelu
88
from ntops.mm import mm
99
from ntops.mul import mul
10+
from ntops.rsqrt import rsqrt
1011

11-
__all__ = ["abs", "add", "addmm", "bmm", "div", "exp", "gelu", "mm", "mul"]
12+
__all__ = ["abs", "add", "addmm", "bmm", "div", "exp", "gelu", "mm", "mul", "rsqrt"]

src/ntops/rsqrt.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
import torch
6+
from ninetoothed import Tensor
7+
8+
from ntops import element_wise
9+
10+
11+
def application(input, output):
12+
output = ntl.rsqrt(ntl.cast(input, ntl.float32)) # noqa: F841
13+
14+
15+
def rsqrt(input, output=None):
16+
if output is None:
17+
output = torch.empty_like(input)
18+
19+
kernel = _make(input.ndim)
20+
21+
kernel(input, output)
22+
23+
return output
24+
25+
26+
@functools.cache
27+
def _make(ndim):
28+
return ninetoothed.make(
29+
element_wise.arrangement,
30+
application,
31+
(Tensor(ndim), Tensor(ndim)),
32+
)

tests/test_rsqrt.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
import torch
3+
4+
import ntops
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.utils import generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
def test_cuda(shape, dtype, atol, rtol):
12+
device = "cuda"
13+
14+
input = torch.randn(shape, dtype=dtype, device=device)
15+
16+
ninetoothed_output = ntops.rsqrt(input)
17+
reference_output = torch.rsqrt(input)
18+
19+
assert torch.allclose(
20+
ninetoothed_output, reference_output, atol=atol, rtol=rtol, equal_nan=True
21+
)

0 commit comments

Comments
 (0)