Skip to content

Commit 5169151

Browse files
committed
Add basic softmax implementation that handles limited axis dim length
1 parent c09f4f7 commit 5169151

3 files changed

Lines changed: 88 additions & 1 deletion

File tree

src/ntops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from ntops.gelu import gelu
88
from ntops.mm import mm
99
from ntops.mul import mul
10+
from ntops.softmax import softmax
1011

11-
__all__ = ["abs", "add", "addmm", "bmm", "div", "exp", "gelu", "mm", "mul"]
12+
__all__ = ["abs", "add", "addmm", "bmm", "div", "exp", "gelu", "mm", "mul", "softmax"]

src/ntops/softmax.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
import torch
6+
from ninetoothed import Tensor
7+
8+
9+
def arrangement(input, output, dim):
10+
assert input.ndim == output.ndim
11+
12+
def create_axis_tile_shape(dim, dim_block):
13+
return tuple(1 for _ in range(dim)) + (dim_block,) + tuple(1 for _ in range(input.ndim - dim - 1))
14+
15+
inner_block_shape = create_axis_tile_shape(dim, input.shape[dim])
16+
outer_block_shape = create_axis_tile_shape(dim, -1)
17+
18+
def arrange(input):
19+
input_arranged = input.tile(inner_block_shape).tile(outer_block_shape)
20+
21+
input_arranged.dtype = input_arranged.dtype.squeeze(
22+
tuple(d for d in range(input.ndim) if d != dim)
23+
)
24+
input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(
25+
tuple(d for d in range(input.ndim) if d != dim)
26+
)
27+
return input_arranged
28+
29+
input_arranged = arrange(input)
30+
output_arranged = arrange(output)
31+
32+
return input_arranged, output_arranged
33+
34+
35+
def application(input, output):
36+
for i in range(input.shape[0]):
37+
input_i = input[i]
38+
row_minus_max = input_i - ntl.max(input_i)
39+
numerator = ntl.exp(ntl.cast(row_minus_max, ntl.float32))
40+
denominator = ntl.sum(numerator)
41+
output[i] = numerator / denominator # noqa: F841
42+
43+
44+
def softmax(input, dim, output=None):
45+
if output is None:
46+
output = torch.empty_like(input)
47+
48+
kernel = _make(input.ndim, dim)
49+
50+
kernel(input, output)
51+
52+
return output
53+
54+
55+
@functools.cache
56+
def _make(ndim, dim):
57+
return ninetoothed.make(
58+
functools.partial(arrangement, dim=dim),
59+
application,
60+
(
61+
Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}),
62+
Tensor(ndim),
63+
),
64+
)

tests/test_softmax.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
6+
import ntops
7+
from tests.skippers import skip_if_cuda_not_available
8+
from tests.utils import generate_arguments
9+
10+
11+
@skip_if_cuda_not_available
12+
@pytest.mark.parametrize(*generate_arguments())
13+
def test_cuda(shape, dtype, atol, rtol):
14+
device = "cuda"
15+
16+
input = torch.randn(shape, dtype=dtype, device=device)
17+
dim = random.randint(0, input.ndim - 1)
18+
19+
ninetoothed_output = ntops.softmax(input, dim)
20+
reference_output = torch.nn.functional.softmax(input, dim=dim)
21+
22+
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)