File tree Expand file tree Collapse file tree
src/infiniop/ops/flash_attention/ninetoothed Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -12,7 +12,7 @@ def build():
1212 with_attn_mask_values = (0 ,)
1313 causal_variant_values = (CausalVariant .UPPER_LEFT , CausalVariant .LOWER_RIGHT )
1414 dtype_values = (ninetoothed .float16 , ninetoothed .bfloat16 , ninetoothed .float32 )
15- block_size_m_values = (64 ,)
15+ block_size_m_values = (256 ,)
1616 block_size_n_values = (64 ,)
1717
1818 constexpr_param_grid = {
Original file line number Diff line number Diff line change @@ -67,10 +67,10 @@ class Descriptor final : public InfiniopDescriptor {
6767 const auto emb_dim_{_query_shape[3 ]};
6868 const auto is_causal_{_is_causal};
6969 const auto with_attn_mask_{0 };
70- const auto causal_variant_{1 };
70+ const auto causal_variant_{2 };
7171 const auto dtype_{_dtype};
7272
73- constexpr auto block_size_m_{64 };
73+ constexpr auto block_size_m_{256 };
7474 constexpr auto block_size_n_{64 };
7575
7676 if (launch_flash_attention (stream,
You can’t perform that action at this time.
0 commit comments