From 29dfb383688b0b45968ed2ed02b3713329c83e57 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Tue, 6 May 2025 17:22:46 +0800 Subject: [PATCH 1/2] Add basic `softmax` implementation that handles limited axis dim length --- src/ntops/softmax.py | 64 +++++++++++++++++++++++++++++++++++++++++++ tests/test_softmax.py | 22 +++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 src/ntops/softmax.py create mode 100644 tests/test_softmax.py diff --git a/src/ntops/softmax.py b/src/ntops/softmax.py new file mode 100644 index 0000000..44c592e --- /dev/null +++ b/src/ntops/softmax.py @@ -0,0 +1,64 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +import torch +from ninetoothed import Tensor + + +def arrangement(input, output, dim): + assert input.ndim == output.ndim + + def create_axis_tile_shape(dim, dim_block): + return tuple(1 for _ in range(dim)) + (dim_block,) + tuple(1 for _ in range(input.ndim - dim - 1)) + + inner_block_shape = create_axis_tile_shape(dim, input.shape[dim]) + outer_block_shape = create_axis_tile_shape(dim, -1) + + def arrange(input): + input_arranged = input.tile(inner_block_shape).tile(outer_block_shape) + + input_arranged.dtype = input_arranged.dtype.squeeze( + tuple(d for d in range(input.ndim) if d != dim) + ) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze( + tuple(d for d in range(input.ndim) if d != dim) + ) + return input_arranged + + input_arranged = arrange(input) + output_arranged = arrange(output) + + return input_arranged, output_arranged + + +def application(input, output): + for i in range(input.shape[0]): + input_i = input[i] + row_minus_max = input_i - ntl.max(input_i) + numerator = ntl.exp(ntl.cast(row_minus_max, ntl.float32)) + denominator = ntl.sum(numerator) + output[i] = numerator / denominator # noqa: F841 + + +def softmax(input, dim, output=None): + if output is None: + output = torch.empty_like(input) + + kernel = _make(input.ndim, dim) + + kernel(input, output) + + return output + + +@functools.cache +def _make(ndim, dim): + return ninetoothed.make( + functools.partial(arrangement, dim=dim), + application, + ( + Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}), + Tensor(ndim), + ), + ) diff --git a/tests/test_softmax.py b/tests/test_softmax.py new file mode 100644 index 0000000..97bc997 --- /dev/null +++ b/tests/test_softmax.py @@ -0,0 +1,22 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol): + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input.ndim - 1) + + ninetoothed_output = ntops.softmax(input, dim) + reference_output = torch.nn.functional.softmax(input, dim=dim) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol) From 682c5a079fa859ad0858d98b6e878f575a98159f Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Thu, 8 May 2025 15:02:50 +0800 Subject: [PATCH 2/2] Accommodate `softmax` to handle arbitrary axis dim length --- src/ntops/kernels/softmax.py | 69 ++++++++++++++++++++++++++++++++++++ src/ntops/softmax.py | 64 --------------------------------- src/ntops/torch.py | 13 +++++++ tests/test_softmax.py | 7 ++-- 4 files changed, 86 insertions(+), 67 deletions(-) create mode 100644 src/ntops/kernels/softmax.py delete mode 100644 src/ntops/softmax.py diff --git a/src/ntops/kernels/softmax.py b/src/ntops/kernels/softmax.py new file mode 100644 index 0000000..9abe814 --- /dev/null +++ b/src/ntops/kernels/softmax.py @@ -0,0 +1,69 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE = ninetoothed.block_size() + + +def arrangement(input, output, dim): + assert input.ndim == output.ndim + + def create_axis_tile_shape(dim, dim_block): + return ( + tuple(1 for _ in range(dim)) + + (dim_block,) + + tuple(1 for _ in range(input.ndim - dim - 1)) + ) + + def arrange(input): + input_arranged = input.tile(inner_block_shape).tile(outer_block_shape) + + input_arranged.dtype = input_arranged.dtype.squeeze( + tuple(d for d in range(input.ndim) if d != dim) + ) + input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze( + tuple(d for d in range(input.ndim) if d != dim) + ) + return input_arranged + + inner_block_shape = create_axis_tile_shape(dim, BLOCK_SIZE) + outer_block_shape = create_axis_tile_shape(dim, -1) + + return arrange(input), arrange(output) + + +def _exp(x, dtype): + exp_dtype = dtype if dtype != ntl.float16 else ntl.float32 + return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype) + + +def application(input, output): + dtype = output.dtype.dtype + prev_max = ntl.cast(float("-inf"), dtype) + denominator = ntl.cast(0, dtype) + + for i in range(input.shape[0]): + input_i = ntl.cast(input[i], dtype) + curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype) + input_max_diff_exp = _exp(input_i - curr_max, dtype) + prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype) + denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp) + prev_max = curr_max + + for i in range(input.shape[0]): + numerator = _exp(input[i] - prev_max, dtype) + output[i] = numerator / denominator + + +@functools.cache +def make(ndim, dim): + return ninetoothed.make( + functools.partial(arrangement, dim=dim), + application, + ( + Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}), + Tensor(ndim), + ), + ) diff --git a/src/ntops/softmax.py b/src/ntops/softmax.py deleted file mode 100644 index 44c592e..0000000 --- a/src/ntops/softmax.py +++ /dev/null @@ -1,64 +0,0 @@ -import functools - -import ninetoothed -import ninetoothed.language as ntl -import torch -from ninetoothed import Tensor - - -def arrangement(input, output, dim): - assert input.ndim == output.ndim - - def create_axis_tile_shape(dim, dim_block): - return tuple(1 for _ in range(dim)) + (dim_block,) + tuple(1 for _ in range(input.ndim - dim - 1)) - - inner_block_shape = create_axis_tile_shape(dim, input.shape[dim]) - outer_block_shape = create_axis_tile_shape(dim, -1) - - def arrange(input): - input_arranged = input.tile(inner_block_shape).tile(outer_block_shape) - - input_arranged.dtype = input_arranged.dtype.squeeze( - tuple(d for d in range(input.ndim) if d != dim) - ) - input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze( - tuple(d for d in range(input.ndim) if d != dim) - ) - return input_arranged - - input_arranged = arrange(input) - output_arranged = arrange(output) - - return input_arranged, output_arranged - - -def application(input, output): - for i in range(input.shape[0]): - input_i = input[i] - row_minus_max = input_i - ntl.max(input_i) - numerator = ntl.exp(ntl.cast(row_minus_max, ntl.float32)) - denominator = ntl.sum(numerator) - output[i] = numerator / denominator # noqa: F841 - - -def softmax(input, dim, output=None): - if output is None: - output = torch.empty_like(input) - - kernel = _make(input.ndim, dim) - - kernel(input, output) - - return output - - -@functools.cache -def _make(ndim, dim): - return ninetoothed.make( - functools.partial(arrangement, dim=dim), - application, - ( - Tensor(ndim, other=float("-inf"), shape_options={"constexpr": True}), - Tensor(ndim), - ), - ) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 6d4e487..d9cc71b 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -10,6 +10,7 @@ import ntops.kernels.mm import ntops.kernels.mul import ntops.kernels.rsqrt +import ntops.kernels.softmax def abs(input, *, out=None): @@ -128,3 +129,15 @@ def rsqrt(input, *, out=None): kernel(input, out) return out + + +def softmax(input, dim, dtype=None): + tensor_dtype = dtype if dtype is not None else input.dtype + + output = torch.empty_like(input, dtype=tensor_dtype) + + kernel = ntops.kernels.softmax.make(input.ndim, dim) + + kernel(input, output) + + return output diff --git a/tests/test_softmax.py b/tests/test_softmax.py index 97bc997..dfbd440 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -3,7 +3,7 @@ import pytest import torch -import ntops +import ntops.torch from tests.skippers import skip_if_cuda_not_available from tests.utils import generate_arguments @@ -15,8 +15,9 @@ def test_cuda(shape, dtype, atol, rtol): input = torch.randn(shape, dtype=dtype, device=device) dim = random.randint(0, input.ndim - 1) + dtype = random.choice([torch.float16, torch.float32, torch.float64]) - ninetoothed_output = ntops.softmax(input, dim) - reference_output = torch.nn.functional.softmax(input, dim=dim) + ninetoothed_output = ntops.torch.softmax(input, dim, dtype) + reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype) assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)