diff --git a/src/ntops/kernels/softmax.py b/src/ntops/kernels/softmax.py new file mode 100644 index 0000000..9abe814 --- /dev/null +++ b/src/ntops/kernels/softmax.py @@ -0,0 +1,69 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE = ninetoothed.block_size() + + +def arrangement(input, output, dim): + assert input.ndim == output.ndim + + def create_axis_tile_shape(dim, dim_block): + return ( + tuple(1 for _ in range(dim)) + + (dim_block,) + + tuple(1 for _ in range(input.ndim - dim - 1)) + ) + + def arrange(input): + input_arranged = input.tile(inner_block_shape).tile(outer_block_shape) + + input_arranged.dtype = input_arranged.dtype.squeeze( + tuple(d for d in range(input.ndim) if d != dim) + ) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze( + tuple(d for d in range(input.ndim) if d != dim) + ) + return input_arranged + + inner_block_shape = create_axis_tile_shape(dim, BLOCK_SIZE) + outer_block_shape = create_axis_tile_shape(dim, -1) + + return arrange(input), arrange(output) + + +def _exp(x, dtype): + exp_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype) + + +def application(input, output): + dtype = output.dtype.dtype + prev_max = ntl.cast(float("-inf"), dtype) + denominator = ntl.cast(0, dtype) + + for i in range(input.shape[0]): + input_i = ntl.cast(input[i], dtype) + curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype) + input_max_diff_exp = _exp(input_i - curr_max, dtype) + prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype) + denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp) + prev_max = curr_max + + for i in range(input.shape[0]): + numerator = _exp(input[i] - prev_max, dtype) + output[i] = numerator / denominator + + +@functools.cache +def make(ndim, dim): + return ninetoothed.make( + functools.partial(arrangement, dim=dim), + application, + ( + Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}), + Tensor(ndim), + ), + ) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index aba0574..a579cd1 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -22,6 +22,7 @@ import ntops.kernels.rsqrt import ntops.kernels.sigmoid import ntops.kernels.sin +import ntops.kernels.softmax import ntops.kernels.tanh @@ -275,6 +276,18 @@ def sin(input, *, out=None): return out +def softmax(input, dim, dtype=None): + tensor_dtype = dtype if dtype is not None else input.dtype + + output = torch.empty_like(input, dtype=tensor_dtype) + + kernel = ntops.kernels.softmax.make(input.ndim, dim) + + kernel(input, output) + + return output + + def tanh(input, *, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_softmax.py b/tests/test_softmax.py new file mode 100644 index 0000000..dfbd440 --- /dev/null +++ b/tests/test_softmax.py @@ -0,0 +1,23 @@ +import random + +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input.ndim - 1) + dtype = random.choice([torch.float16, torch.float32, torch.float64]) + + ninetoothed_output = ntops.torch.softmax(input, dim, dtype) + reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)