22import torch
33
44from ..base_att import AttControl
5- from ..fa3 .fp import Fa3AttBackend , Fa3PrefillAttState , Fa3DecodeAttState
5+ from ..paged_fa3 .fp import PagedFa3AttBackend , PagedFa3PrefillAttState , PagedFa3DecodeAttState
66from lightllm .utils .fa4_utils import (
77 ensure_fa4_available ,
88 ensure_fa4_supported_gpu ,
99 flash_attn_varlen_func ,
10+ sm90_fa4_paged_kv_tile_n ,
1011 unwrap_fa4_output ,
1112)
1213
1314
14- class Fa4AttBackend (Fa3AttBackend ):
15+ class Fa4AttBackend (PagedFa3AttBackend ):
1516 def __init__ (self , model ):
1617 ensure_fa4_available ()
1718 ensure_fa4_supported_gpu ()
@@ -29,20 +30,7 @@ def _sm90_fa4_paged_kv_tile_n(
2930 head_dim_v : int ,
3031 window_size : tuple [int , int ],
3132) -> int | None :
32- major , _minor = torch .cuda .get_device_capability ()
33- if major != 9 :
34- return None
35-
36- is_local = window_size != (- 1 , - 1 )
37- if head_dim <= 64 :
38- return 128
39- if head_dim <= 96 :
40- return 128 if is_local else 144
41- if head_dim <= 128 :
42- return 128
43- if head_dim <= 192 :
44- return 96 if is_local else (128 if head_dim_v <= 128 else 112 )
45- return 64 if is_local else 80
33+ return sm90_fa4_paged_kv_tile_n (head_dim = head_dim , head_dim_v = head_dim_v , window_size = window_size )
4634
4735
4836def _ensure_fa4_paged_kv_supported (
@@ -67,7 +55,7 @@ def _ensure_fa4_paged_kv_supported(
6755
6856
6957@dataclasses .dataclass
70- class Fa4PrefillAttState (Fa3PrefillAttState ):
58+ class Fa4PrefillAttState (PagedFa3PrefillAttState ):
7159 def _nomarl_prefill_att (
7260 self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , att_control : AttControl , alloc_func = torch .empty
7361 ) -> torch .Tensor :
@@ -84,12 +72,12 @@ def _nomarl_prefill_att(
8472 head_dim = q .shape [- 1 ]
8573 head_dim_v = v .shape [- 1 ]
8674 softmax_scale = 1.0 / (head_dim ** 0.5 )
87- _ensure_fa4_paged_kv_supported (head_dim , head_dim_v , window_size , page_size = 1 )
75+ _ensure_fa4_paged_kv_supported (head_dim , head_dim_v , window_size , page_size = self . backend . page_size )
8876
8977 out = flash_attn_varlen_func (
9078 q = q ,
91- k = k .view (k . shape [ 0 ], 1 , k .shape [1 ], k .shape [2 ]),
92- v = v .view (v . shape [ 0 ], 1 , v .shape [1 ], v .shape [2 ]),
79+ k = k .view (- 1 , self . backend . page_size , k .shape [1 ], k .shape [2 ]),
80+ v = v .view (- 1 , self . backend . page_size , v .shape [1 ], v .shape [2 ]),
9381 cu_seqlens_q = self .cu_seqlens_q ,
9482 seqused_k = self .infer_state .b_seq_len .int (),
9583 max_seqlen_q = self .infer_state .max_q_seq_len ,
@@ -106,7 +94,7 @@ def _nomarl_prefill_att(
10694
10795
10896@dataclasses .dataclass
109- class Fa4DecodeAttState (Fa3DecodeAttState ):
97+ class Fa4DecodeAttState (PagedFa3DecodeAttState ):
11098 def _normal_decode_att (
11199 self ,
112100 q : torch .Tensor ,
@@ -128,12 +116,12 @@ def _normal_decode_att(
128116 head_dim = q .shape [- 1 ]
129117 head_dim_v = v .shape [- 1 ]
130118 softmax_scale = 1.0 / (head_dim ** 0.5 )
131- _ensure_fa4_paged_kv_supported (head_dim , head_dim_v , window_size , page_size = 1 )
119+ _ensure_fa4_paged_kv_supported (head_dim , head_dim_v , window_size , page_size = self . backend . page_size )
132120
133121 out = flash_attn_varlen_func (
134122 q = q ,
135- k = k .view (k . shape [ 0 ], 1 , k .shape [1 ], k .shape [2 ]),
136- v = v .view (v . shape [ 0 ], 1 , v .shape [1 ], v .shape [2 ]),
123+ k = k .view (- 1 , self . backend . page_size , k .shape [1 ], k .shape [2 ]),
124+ v = v .view (- 1 , self . backend . page_size , v .shape [1 ], v .shape [2 ]),
137125 cu_seqlens_q = self .cu_seqlens_q ,
138126 seqused_k = self .b_att_seq_len .int (),
139127 max_seqlen_q = self .decode_max_q_seq_len ,
0 commit comments