Skip to content

Commit d59ba94

Browse files
author
niushengxiao
committed
feat: page attention for fa4
1 parent 73627f7 commit d59ba94

4 files changed

Lines changed: 73 additions & 25 deletions

File tree

lightllm/common/basemodel/attention/fa4/fp.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import torch
33

44
from ..base_att import AttControl
5-
from ..fa3.fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState
5+
from ..paged_fa3.fp import PagedFa3AttBackend, PagedFa3PrefillAttState, PagedFa3DecodeAttState
66
from lightllm.utils.fa4_utils import (
77
ensure_fa4_available,
88
ensure_fa4_supported_gpu,
99
flash_attn_varlen_func,
10+
sm90_fa4_paged_kv_tile_n,
1011
unwrap_fa4_output,
1112
)
1213

1314

14-
class Fa4AttBackend(Fa3AttBackend):
15+
class Fa4AttBackend(PagedFa3AttBackend):
1516
def __init__(self, model):
1617
ensure_fa4_available()
1718
ensure_fa4_supported_gpu()
@@ -29,20 +30,7 @@ def _sm90_fa4_paged_kv_tile_n(
2930
head_dim_v: int,
3031
window_size: tuple[int, int],
3132
) -> int | None:
32-
major, _minor = torch.cuda.get_device_capability()
33-
if major != 9:
34-
return None
35-
36-
is_local = window_size != (-1, -1)
37-
if head_dim <= 64:
38-
return 128
39-
if head_dim <= 96:
40-
return 128 if is_local else 144
41-
if head_dim <= 128:
42-
return 128
43-
if head_dim <= 192:
44-
return 96 if is_local else (128 if head_dim_v <= 128 else 112)
45-
return 64 if is_local else 80
33+
return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size)
4634

4735

4836
def _ensure_fa4_paged_kv_supported(
@@ -67,7 +55,7 @@ def _ensure_fa4_paged_kv_supported(
6755

6856

6957
@dataclasses.dataclass
70-
class Fa4PrefillAttState(Fa3PrefillAttState):
58+
class Fa4PrefillAttState(PagedFa3PrefillAttState):
7159
def _nomarl_prefill_att(
7260
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
7361
) -> torch.Tensor:
@@ -84,12 +72,12 @@ def _nomarl_prefill_att(
8472
head_dim = q.shape[-1]
8573
head_dim_v = v.shape[-1]
8674
softmax_scale = 1.0 / (head_dim ** 0.5)
87-
_ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=1)
75+
_ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size)
8876

8977
out = flash_attn_varlen_func(
9078
q=q,
91-
k=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
92-
v=v.view(v.shape[0], 1, v.shape[1], v.shape[2]),
79+
k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
80+
v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
9381
cu_seqlens_q=self.cu_seqlens_q,
9482
seqused_k=self.infer_state.b_seq_len.int(),
9583
max_seqlen_q=self.infer_state.max_q_seq_len,
@@ -106,7 +94,7 @@ def _nomarl_prefill_att(
10694

10795

10896
@dataclasses.dataclass
109-
class Fa4DecodeAttState(Fa3DecodeAttState):
97+
class Fa4DecodeAttState(PagedFa3DecodeAttState):
11098
def _normal_decode_att(
11199
self,
112100
q: torch.Tensor,
@@ -128,12 +116,12 @@ def _normal_decode_att(
128116
head_dim = q.shape[-1]
129117
head_dim_v = v.shape[-1]
130118
softmax_scale = 1.0 / (head_dim ** 0.5)
131-
_ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=1)
119+
_ensure_fa4_paged_kv_supported(head_dim, head_dim_v, window_size, page_size=self.backend.page_size)
132120

133121
out = flash_attn_varlen_func(
134122
q=q,
135-
k=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
136-
v=v.view(v.shape[0], 1, v.shape[1], v.shape[2]),
123+
k=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
124+
v=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
137125
cu_seqlens_q=self.cu_seqlens_q,
138126
seqused_k=self.b_att_seq_len.int(),
139127
max_seqlen_q=self.decode_max_q_seq_len,

lightllm/server/api_start.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .embed_cache.manager import start_cache_manager
1212
from lightllm.utils.log_utils import init_logger
1313
from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name
14-
from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive, get_page_size
14+
from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive, get_page_size, set_page_size
1515
from .detokenization.manager import start_detokenization_process
1616
from .router.manager import start_router_process
1717
from lightllm.utils.process_check import is_process_active
@@ -29,6 +29,20 @@
2929
logger = init_logger(__name__)
3030

3131

32+
def _auto_set_fa4_page_size(args, requested_backends):
33+
if "fa4" not in requested_backends or "PAGE_SIZE" in os.environ:
34+
return
35+
36+
from lightllm.utils.fa4_utils import infer_fa4_page_size
37+
38+
page_size = infer_fa4_page_size(args.model_dir)
39+
if page_size is None:
40+
return
41+
42+
set_page_size(page_size)
43+
logger.info(f"auto set PAGE_SIZE={page_size} for FA4 backend")
44+
45+
3246
def setup_signal_handlers(http_server_process, process_manager):
3347
def signal_handler(sig, frame):
3448
if sig == signal.SIGINT:
@@ -205,6 +219,11 @@ def normal_or_p_d_start(args):
205219
f"{sorted(allowed_ep_att_backends)}; flashinfer is not supported."
206220
)
207221

222+
llm_requested_backends = list(args.llm_prefill_att_backend) + list(args.llm_decode_att_backend)
223+
requested_backends = llm_requested_backends + list(args.vit_att_backend)
224+
if "fa4" in requested_backends:
225+
_auto_set_fa4_page_size(args, llm_requested_backends)
226+
208227
# mtp params check
209228
if args.mtp_mode is not None:
210229
assert args.mtp_draft_model_dir is not None

lightllm/utils/envs_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def get_page_size():
170170
return int(os.getenv("PAGE_SIZE", 1))
171171

172172

173+
def set_page_size(page_size: int):
174+
os.environ["PAGE_SIZE"] = str(page_size)
175+
get_page_size.cache_clear()
176+
177+
173178
g_model_init_done = False
174179

175180

lightllm/utils/fa4_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,41 @@ def ensure_fa4_supported_gpu() -> None:
4242
)
4343

4444

45+
def sm90_fa4_paged_kv_tile_n(head_dim: int, head_dim_v: int, window_size: tuple[int, int] = (-1, -1)) -> int | None:
46+
major, _minor = torch.cuda.get_device_capability()
47+
if major != 9:
48+
return None
49+
50+
is_local = window_size != (-1, -1)
51+
if head_dim <= 64:
52+
return 128
53+
if head_dim <= 96:
54+
return 128 if is_local else 144
55+
if head_dim <= 128:
56+
return 128
57+
if head_dim <= 192:
58+
return 96 if is_local else (128 if head_dim_v <= 128 else 112)
59+
return 64 if is_local else 80
60+
61+
62+
def infer_fa4_page_size(model_dir: str) -> int | None:
63+
from transformers.configuration_utils import PretrainedConfig
64+
65+
model_cfg, _ = PretrainedConfig.get_config_dict(model_dir)
66+
llm_config = model_cfg.get("text_config", model_cfg)
67+
68+
head_dim = llm_config.get("head_dim")
69+
if head_dim is None:
70+
head_dim = llm_config["hidden_size"] // llm_config["num_attention_heads"]
71+
head_dim_v = llm_config.get("v_head_dim", head_dim)
72+
73+
window_size = (-1, -1)
74+
sliding_window = llm_config.get("sliding_window", None)
75+
if sliding_window is not None and not llm_config.get("full_attention_interval", None):
76+
window_size = (sliding_window - 1, sliding_window - 1)
77+
78+
return sm90_fa4_paged_kv_tile_n(head_dim=head_dim, head_dim_v=head_dim_v, window_size=window_size)
79+
80+
4581
def unwrap_fa4_output(output):
4682
return output[0] if isinstance(output, tuple) else output

0 commit comments

Comments
 (0)