Skip to content

Commit 7f2123a

Browse files
committed
issue/889 - optimize flash attention performance from default setup
1 parent 79d142f commit 7f2123a

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/infiniop/ops/flash_attention/ninetoothed/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def build():
1212
with_attn_mask_values = (0,)
1313
causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT)
1414
dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32)
15-
block_size_m_values = (64,)
15+
block_size_m_values = (256,)
1616
block_size_n_values = (64,)
1717

1818
constexpr_param_grid = {

src/infiniop/ops/flash_attention/ninetoothed/descriptor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ class Descriptor final : public InfiniopDescriptor {
6767
const auto emb_dim_{_query_shape[3]};
6868
const auto is_causal_{_is_causal};
6969
const auto with_attn_mask_{0};
70-
const auto causal_variant_{1};
70+
const auto causal_variant_{2};
7171
const auto dtype_{_dtype};
7272

73-
constexpr auto block_size_m_{64};
73+
constexpr auto block_size_m_{256};
7474
constexpr auto block_size_n_{64};
7575

7676
if (launch_flash_attention(stream,

0 commit comments

Comments
 (0)