Skip to content

Commit 682c5a0

Browse files
committed
Accommodate softmax to handle arbitrary axis dim length
1 parent 29dfb38 commit 682c5a0

4 files changed

Lines changed: 86 additions & 67 deletions

File tree

src/ntops/kernels/softmax.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
BLOCK_SIZE = ninetoothed.block_size()
8+
9+
10+
def arrangement(input, output, dim):
11+
assert input.ndim == output.ndim
12+
13+
def create_axis_tile_shape(dim, dim_block):
14+
return (
15+
tuple(1 for _ in range(dim))
16+
+ (dim_block,)
17+
+ tuple(1 for _ in range(input.ndim - dim - 1))
18+
)
19+
20+
def arrange(input):
21+
input_arranged = input.tile(inner_block_shape).tile(outer_block_shape)
22+
23+
input_arranged.dtype = input_arranged.dtype.squeeze(
24+
tuple(d for d in range(input.ndim) if d != dim)
25+
)
26+
input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(
27+
tuple(d for d in range(input.ndim) if d != dim)
28+
)
29+
return input_arranged
30+
31+
inner_block_shape = create_axis_tile_shape(dim, BLOCK_SIZE)
32+
outer_block_shape = create_axis_tile_shape(dim, -1)
33+
34+
return arrange(input), arrange(output)
35+
36+
37+
def _exp(x, dtype):
38+
exp_dtype = dtype if dtype != ntl.float16 else ntl.float32
39+
return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype)
40+
41+
42+
def application(input, output):
43+
dtype = output.dtype.dtype
44+
prev_max = ntl.cast(float("-inf"), dtype)
45+
denominator = ntl.cast(0, dtype)
46+
47+
for i in range(input.shape[0]):
48+
input_i = ntl.cast(input[i], dtype)
49+
curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype)
50+
input_max_diff_exp = _exp(input_i - curr_max, dtype)
51+
prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype)
52+
denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp)
53+
prev_max = curr_max
54+
55+
for i in range(input.shape[0]):
56+
numerator = _exp(input[i] - prev_max, dtype)
57+
output[i] = numerator / denominator
58+
59+
60+
@functools.cache
61+
def make(ndim, dim):
62+
return ninetoothed.make(
63+
functools.partial(arrangement, dim=dim),
64+
application,
65+
(
66+
Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}),
67+
Tensor(ndim),
68+
),
69+
)

src/ntops/softmax.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

src/ntops/torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ntops.kernels.mm
1111
import ntops.kernels.mul
1212
import ntops.kernels.rsqrt
13+
import ntops.kernels.softmax
1314

1415

1516
def abs(input, *, out=None):
@@ -128,3 +129,15 @@ def rsqrt(input, *, out=None):
128129
kernel(input, out)
129130

130131
return out
132+
133+
134+
def softmax(input, dim, dtype=None):
135+
tensor_dtype = dtype if dtype is not None else input.dtype
136+
137+
output = torch.empty_like(input, dtype=tensor_dtype)
138+
139+
kernel = ntops.kernels.softmax.make(input.ndim, dim)
140+
141+
kernel(input, output)
142+
143+
return output

tests/test_softmax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
import ntops
6+
import ntops.torch
77
from tests.skippers import skip_if_cuda_not_available
88
from tests.utils import generate_arguments
99

@@ -15,8 +15,9 @@ def test_cuda(shape, dtype, atol, rtol):
1515

1616
input = torch.randn(shape, dtype=dtype, device=device)
1717
dim = random.randint(0, input.ndim - 1)
18+
dtype = random.choice([torch.float16, torch.float32, torch.float64])
1819

19-
ninetoothed_output = ntops.softmax(input, dim)
20-
reference_output = torch.nn.functional.softmax(input, dim=dim)
20+
ninetoothed_output = ntops.torch.softmax(input, dim, dtype)
21+
reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype)
2122

2223
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)