Skip to content

Commit 464f58a

Browse files
committed
Use reduction.arrangement in softmax.py
1 parent c8c93a8 commit 464f58a

1 file changed

Lines changed: 1 addition & 32 deletions

File tree

src/ntops/kernels/softmax.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,9 @@
11
import functools
22

3-
import ninetoothed
43
import ninetoothed.language as ntl
54
from ninetoothed import Tensor
65

7-
BLOCK_SIZE = ninetoothed.block_size()
8-
9-
10-
def arrangement(input, output, dim, block_size=None):
11-
assert input.ndim == output.ndim
12-
13-
def create_axis_tile_shape(dim, dim_block):
14-
return (
15-
tuple(1 for _ in range(dim))
16-
+ (dim_block,)
17-
+ tuple(1 for _ in range(input.ndim - dim - 1))
18-
)
19-
20-
def arrange(input):
21-
input_arranged = input.tile(inner_block_shape).tile(outer_block_shape)
22-
23-
input_arranged.dtype = input_arranged.dtype.squeeze(
24-
tuple(d for d in range(input.ndim) if d != dim)
25-
)
26-
input_arranged.dtype.dtype = input_arranged.dtype.dtype.squeeze(
27-
tuple(d for d in range(input.ndim) if d != dim)
28-
)
29-
return input_arranged
30-
31-
if block_size is None:
32-
block_size = BLOCK_SIZE
33-
34-
inner_block_shape = create_axis_tile_shape(dim, block_size)
35-
outer_block_shape = create_axis_tile_shape(dim, -1)
36-
37-
return arrange(input), arrange(output)
6+
from ntops.kernels.reduction import arrangement
387

398

409
def _exp(x, dtype):

0 commit comments

Comments
 (0)