Skip to content

Commit 4201ea7

Browse files
committed
issue/889 - optimize flash attention performance from default setup
1 parent 79d142f commit 4201ea7

3 files changed

Lines changed: 5 additions & 5 deletions

File tree

src/infiniop/ops/flash_attention/ninetoothed/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def build():
1212
with_attn_mask_values = (0,)
1313
causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT)
1414
dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32)
15-
block_size_m_values = (64,)
15+
block_size_m_values = (256,)
1616
block_size_n_values = (64,)
1717

1818
constexpr_param_grid = {

src/infiniop/ops/flash_attention/ninetoothed/descriptor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ class Descriptor final : public InfiniopDescriptor {
6767
const auto emb_dim_{_query_shape[3]};
6868
const auto is_causal_{_is_causal};
6969
const auto with_attn_mask_{0};
70-
const auto causal_variant_{1};
70+
const auto causal_variant_{2};
7171
const auto dtype_{_dtype};
7272

73-
constexpr auto block_size_m_{64};
73+
constexpr auto block_size_m_{256};
7474
constexpr auto block_size_n_{64};
7575

7676
if (launch_flash_attention(stream,

src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def application_without_kv_cache(
183183
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
184184
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
185185

186-
for j in range(min(key.shape[0], actual_kv_len)):
186+
for j in range(key.shape[0]):
187187

188188
qk = ntl.dot(query_i, ntl.trans(key[j]))
189189

@@ -196,7 +196,7 @@ def application_without_kv_cache(
196196
if is_causal:
197197
query_pos = query[i].offsets(-2)
198198

199-
if causal_variant == 2:
199+
if causal_variant == 2: # CausalVariant.LOWER_RIGHT:
200200
mask = (
201201
query_pos[:, None] + actual_kv_len - query.source.shape[-2]
202202
>= key_pos[None, :]

0 commit comments

Comments
 (0)