Skip to content

Commit e49dd52

Browse files
committed
Add round operator implementation
1 parent 6bc90d5 commit e49dd52

File tree

5 files changed

+76
-0
lines changed

5 files changed

+76
-0
lines changed

src/ntops/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
relu,
3232
rms_norm,
3333
rotary_position_embedding,
34+
round,
3435
rsqrt,
3536
scaled_dot_product_attention,
3637
sigmoid,
@@ -74,6 +75,7 @@
7475
"relu",
7576
"rms_norm",
7677
"rotary_position_embedding",
78+
"round",
7779
"rsqrt",
7880
"scaled_dot_product_attention",
7981
"sigmoid",

src/ntops/kernels/round.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
from ninetoothed.language import libdevice
7+
8+
from ntops.kernels.element_wise import arrangement
9+
10+
11+
def application(input, output):
12+
output = libdevice.nearbyint(ntl.cast(input, ntl.float32)) # noqa: F841
13+
14+
15+
def application_with_decimals(input, factor, inv_factor, output):
16+
scaled = input * ntl.cast(
17+
factor, input.dtype
18+
) # 在 input 的原始精度下乘,匹配 torch 行为
19+
output = libdevice.nearbyint(ntl.cast(scaled, ntl.float32)) * inv_factor # noqa: F841
20+
21+
22+
def premake(ndim, decimals=0, dtype=None, block_size=None):
23+
arrangement_ = functools.partial(arrangement, block_size=block_size)
24+
25+
if decimals == 0:
26+
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
27+
return arrangement_, application, tensors
28+
else:
29+
tensors = (
30+
Tensor(ndim, dtype=dtype),
31+
Tensor(0, dtype=ninetoothed.float64),
32+
Tensor(0, dtype=ninetoothed.float64),
33+
Tensor(ndim, dtype=dtype),
34+
)
35+
return arrangement_, application_with_decimals, tensors

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ntops.torch.relu import relu
3232
from ntops.torch.rms_norm import rms_norm
3333
from ntops.torch.rotary_position_embedding import rotary_position_embedding
34+
from ntops.torch.round import round
3435
from ntops.torch.rsqrt import rsqrt
3536
from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention
3637
from ntops.torch.sigmoid import sigmoid
@@ -74,6 +75,7 @@
7475
"relu",
7576
"rms_norm",
7677
"rotary_position_embedding",
78+
"round",
7779
"rsqrt",
7880
"scaled_dot_product_attention",
7981
"sigmoid",

src/ntops/torch/round.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
3+
import ntops
4+
from ntops.torch.utils import _cached_make
5+
6+
7+
def round(input, decimals=0, *, out=None):
8+
if out is None:
9+
out = torch.empty_like(input)
10+
11+
if decimals == 0:
12+
kernel = _cached_make(ntops.kernels.round.premake, input.ndim)
13+
kernel(input, out)
14+
else:
15+
factor = 10.0**decimals
16+
inv_factor = 1.0 / factor
17+
kernel = _cached_make(ntops.kernels.round.premake, input.ndim, decimals=True)
18+
kernel(input, factor, inv_factor, out)
19+
20+
return out

tests/test_round.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
import torch
3+
4+
import ntops
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.utils import generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
def test_round(shape, dtype, device, rtol, atol):
12+
input = torch.randn(shape, dtype=dtype, device=device)
13+
14+
ninetoothed_output = ntops.torch.round(input)
15+
reference_output = torch.round(input)
16+
17+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)