Skip to content

Commit 231c8ea

Browse files
committed
Set block sizes dynamically if not provided
1 parent 8b6136a commit 231c8ea

6 files changed

Lines changed: 48 additions & 19 deletions

File tree

src/ntops/kernels/addmm.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,19 @@ def arrangement(
1414
beta,
1515
alpha,
1616
output,
17-
block_size_m=mm.BLOCK_SIZE_M,
18-
block_size_n=mm.BLOCK_SIZE_N,
19-
block_size_k=mm.BLOCK_SIZE_K,
17+
block_size_m=None,
18+
block_size_n=None,
19+
block_size_k=None,
2020
):
21+
if block_size_m is None:
22+
block_size_m = mm.BLOCK_SIZE_M
23+
24+
if block_size_n is None:
25+
block_size_n = mm.BLOCK_SIZE_N
26+
27+
if block_size_k is None:
28+
block_size_k = mm.BLOCK_SIZE_K
29+
2130
_, _, input_arranged = mm.arrangement(
2231
x,
2332
y,

src/ntops/kernels/bmm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77

88

99
def arrangement(
10-
input,
11-
other,
12-
output,
13-
block_size_m=BLOCK_SIZE_M,
14-
block_size_n=BLOCK_SIZE_N,
15-
block_size_k=BLOCK_SIZE_K,
10+
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
1611
):
12+
if block_size_m is None:
13+
block_size_m = BLOCK_SIZE_M
14+
15+
if block_size_n is None:
16+
block_size_n = BLOCK_SIZE_N
17+
18+
if block_size_k is None:
19+
block_size_k = BLOCK_SIZE_K
20+
1721
output_arranged = output.tile((1, block_size_m, block_size_n))
1822
output_arranged.dtype = output_arranged.dtype.squeeze(0)
1923

src/ntops/kernels/element_wise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import ninetoothed
22

33

4-
def arrangement(*tensors, block_size=ninetoothed.block_size()):
4+
def arrangement(*tensors, block_size=None):
5+
if block_size is None:
6+
block_size = ninetoothed.block_size()
7+
58
ndim = max(tensor.ndim for tensor in tensors)
69

710
assert all(tensor.ndim == ndim or tensor.ndim == 0 for tensor in tensors)

src/ntops/kernels/mm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010

1111

1212
def arrangement(
13-
input,
14-
other,
15-
output,
16-
block_size_m=BLOCK_SIZE_M,
17-
block_size_n=BLOCK_SIZE_N,
18-
block_size_k=BLOCK_SIZE_K,
13+
input, other, output, block_size_m=None, block_size_n=None, block_size_k=None
1914
):
15+
if block_size_m is None:
16+
block_size_m = BLOCK_SIZE_M
17+
18+
if block_size_n is None:
19+
block_size_n = BLOCK_SIZE_N
20+
21+
if block_size_k is None:
22+
block_size_k = BLOCK_SIZE_K
23+
2024
output_arranged = output.tile((block_size_m, block_size_n))
2125

2226
input_arranged = input.tile((block_size_m, block_size_k))

src/ntops/kernels/scaled_dot_product_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def arrangement(
2222
output,
2323
with_attn_mask,
2424
with_kv_cache,
25-
block_size_m=BLOCK_SIZE_M,
26-
block_size_n=BLOCK_SIZE_N,
25+
block_size_m=None,
26+
block_size_n=None,
2727
):
2828
def arrange_query_or_output(input):
2929
arranged = input.tile((1, 1, block_size_m, -1)).tile(
@@ -58,6 +58,12 @@ def arrange_attn_mask(input):
5858

5959
return arranged
6060

61+
if block_size_m is None:
62+
block_size_m = BLOCK_SIZE_M
63+
64+
if block_size_n is None:
65+
block_size_n = BLOCK_SIZE_N
66+
6167
query_arranged = arrange_query_or_output(query)
6268
key_arranged = arrange_key_or_value(key)
6369
value_arranged = arrange_key_or_value(value)

src/ntops/kernels/softmax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
BLOCK_SIZE = ninetoothed.block_size()
88

99

10-
def arrangement(input, output, dim, block_size=BLOCK_SIZE):
10+
def arrangement(input, output, dim, block_size=None):
1111
assert input.ndim == output.ndim
1212

1313
def create_axis_tile_shape(dim, dim_block):
@@ -28,6 +28,9 @@ def arrange(input):
2828
)
2929
return input_arranged
3030

31+
if block_size is None:
32+
block_size = BLOCK_SIZE
33+
3134
inner_block_shape = create_axis_tile_shape(dim, block_size)
3235
outer_block_shape = create_axis_tile_shape(dim, -1)
3336

0 commit comments

Comments
 (0)