Skip to content

Commit 045d2cf

Browse files
authored
Merge pull request #11 from InfiniTensor/develop-softmax
Add `softmax` opeartor
2 parents c244ca1 + b947ad5 commit 045d2cf

3 files changed

Lines changed: 105 additions & 0 deletions

File tree

src/ntops/kernels/softmax.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
BLOCK_SIZE = ninetoothed.block_size()
8+
9+
10+
def arrangement(input, output, dim):
11+
assert input.ndim == output.ndim
12+
13+
def create_axis_tile_shape(dim, dim_block):
14+
return (
15+
tuple(1 for _ in range(dim))
16+
+ (dim_block,)
17+
+ tuple(1 for _ in range(input.ndim - dim - 1))
18+
)
19+
20+
def arrange(input):
21+
input_arranged = input.tile(inner_block_shape).tile(outer_block_shape)
22+
23+
input_arranged.dtype = input_arranged.dtype.squeeze(
24+
tuple(d for d in range(input.ndim) if d != dim)
25+
)
26+
input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(
27+
tuple(d for d in range(input.ndim) if d != dim)
28+
)
29+
return input_arranged
30+
31+
inner_block_shape = create_axis_tile_shape(dim, BLOCK_SIZE)
32+
outer_block_shape = create_axis_tile_shape(dim, -1)
33+
34+
return arrange(input), arrange(output)
35+
36+
37+
def _exp(x, dtype):
38+
exp_dtype = dtype if dtype != ntl.float16 else ntl.float32
39+
return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype)
40+
41+
42+
def application(input, output):
43+
dtype = output.dtype.dtype
44+
prev_max = ntl.cast(float("-inf"), dtype)
45+
denominator = ntl.cast(0, dtype)
46+
47+
for i in range(input.shape[0]):
48+
input_i = ntl.cast(input[i], dtype)
49+
curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype)
50+
input_max_diff_exp = _exp(input_i - curr_max, dtype)
51+
prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype)
52+
denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp)
53+
prev_max = curr_max
54+
55+
for i in range(input.shape[0]):
56+
numerator = _exp(input[i] - prev_max, dtype)
57+
output[i] = numerator / denominator
58+
59+
60+
@functools.cache
61+
def make(ndim, dim):
62+
return ninetoothed.make(
63+
functools.partial(arrangement, dim=dim),
64+
application,
65+
(
66+
Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}),
67+
Tensor(ndim),
68+
),
69+
)

src/ntops/torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ntops.kernels.rsqrt
2424
import ntops.kernels.sigmoid
2525
import ntops.kernels.sin
26+
import ntops.kernels.softmax
2627
import ntops.kernels.tanh
2728

2829

@@ -287,6 +288,18 @@ def sin(input, *, out=None):
287288
return out
288289

289290

291+
def softmax(input, dim, dtype=None):
292+
tensor_dtype = dtype if dtype is not None else input.dtype
293+
294+
output = torch.empty_like(input, dtype=tensor_dtype)
295+
296+
kernel = ntops.kernels.softmax.make(input.ndim, dim)
297+
298+
kernel(input, output)
299+
300+
return output
301+
302+
290303
def tanh(input, *, out=None):
291304
if out is None:
292305
out = torch.empty_like(input)

tests/test_softmax.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
6+
import ntops.torch
7+
from tests.skippers import skip_if_cuda_not_available
8+
from tests.utils import generate_arguments
9+
10+
11+
@skip_if_cuda_not_available
12+
@pytest.mark.parametrize(*generate_arguments())
13+
def test_cuda(shape, dtype, atol, rtol):
14+
device = "cuda"
15+
16+
input = torch.randn(shape, dtype=dtype, device=device)
17+
dim = random.randint(0, input.ndim - 1)
18+
dtype = random.choice([torch.float16, torch.float32, torch.float64])
19+
20+
ninetoothed_output = ntops.torch.softmax(input, dim, dtype)
21+
reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype)
22+
23+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)