Skip to content

Commit 060b2f4

Browse files
author
niushengxiao
committed
feat: page size > 1 support
1 parent 8ed5074 commit 060b2f4

File tree

24 files changed

+1587
-29
lines changed

24 files changed

+1587
-29
lines changed

lightllm/common/basemodel/attention/create_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import torch
5-
from lightllm.utils.envs_utils import get_env_start_args
5+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
66
from lightllm.utils.log_utils import init_logger
77
from lightllm.utils.backend_validator import validate
88
from .base_att import BaseAttBackend
@@ -13,18 +13,24 @@
1313
from .fa3.fp import Fa3AttBackend
1414
from .fa3.fp8 import Fp8Fa3AttBackend
1515
from .fa3.mla import MlaFa3AttBackend
16+
from .paged_fa3.fp import PagedFa3AttBackend
17+
from .paged_fa3.mla import PagedMlaFa3AttBackend
1618
from .flashinfer.fp8 import Fp8FlashInferAttBackend
1719
from .flashinfer.fp import FlashInferAttBackend
1820
from .flashinfer.mla import MlaFlashInferAttBackend
21+
from .paged_flashinfer.fp import PagedFlashInferAttBackend
22+
from .paged_flashinfer.mla import PagedMlaFlashInferAttBackend
1923

2024
logger = init_logger(__name__)
2125

26+
_PAGE_ENABLED = get_page_size() > 1
27+
2228
# Backend class mappings by data type
2329
data_type_to_backend = {
2430
"None": {
2531
"triton": TritonAttBackend,
26-
"fa3": Fa3AttBackend,
27-
"flashinfer": FlashInferAttBackend,
32+
"fa3": PagedFa3AttBackend if _PAGE_ENABLED else Fa3AttBackend,
33+
"flashinfer": PagedFlashInferAttBackend if _PAGE_ENABLED else FlashInferAttBackend,
2834
},
2935
"int4kv": {
3036
"triton": Int4kvTritonAttBackend,
@@ -41,8 +47,8 @@
4147
mla_data_type_to_backend = {
4248
"None": {
4349
"triton": MlaTritonAttBackend,
44-
"fa3": MlaFa3AttBackend,
45-
"flashinfer": MlaFlashInferAttBackend,
50+
"fa3": PagedMlaFa3AttBackend if _PAGE_ENABLED else MlaFa3AttBackend,
51+
"flashinfer": PagedMlaFlashInferAttBackend if _PAGE_ENABLED else MlaFlashInferAttBackend,
4652
},
4753
}
4854

lightllm/common/basemodel/attention/fa3/fp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def prefill_att(
6666
alloc_func=torch.empty,
6767
) -> torch.Tensor:
6868
assert att_control.use_alibi is False
69-
return self._nomarl_prefill_att(
69+
return self._normal_prefill_att(
7070
q=q,
7171
k=k,
7272
v=v,
7373
att_control=att_control,
7474
alloc_func=alloc_func,
7575
)
7676

77-
def _nomarl_prefill_att(
77+
def _normal_prefill_att(
7878
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
7979
) -> torch.Tensor:
8080
self.backend: Fa3AttBackend = self.backend # for typing

lightllm/common/basemodel/attention/flashinfer/fp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def prefill_att(
9999
and att_control.use_sliding_window is False
100100
and att_control.use_att_sink is False
101101
)
102-
return self._nomarl_prefill_att(
102+
return self._normal_prefill_att(
103103
q=q,
104104
k=k,
105105
v=v,
106106
alloc_func=alloc_func,
107107
)
108108

109-
def _nomarl_prefill_att(
109+
def _normal_prefill_att(
110110
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty
111111
) -> torch.Tensor:
112112
self.backend: FlashInferAttBackend = self.backend # for typing

lightllm/common/basemodel/attention/paged_fa3/__init__.py

Whitespace-only changes.
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import dataclasses
2+
import torch
3+
import triton
4+
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
5+
from lightllm.utils.dist_utils import get_current_device_id
6+
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
7+
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
8+
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
9+
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
10+
11+
12+
class PagedFa3AttBackend(BaseAttBackend):
13+
def __init__(self, model, page_size=None):
14+
super().__init__(model=model)
15+
self.page_size = page_size or get_page_size()
16+
self.get_page_table_buffer()
17+
18+
def get_page_table_buffer(self):
19+
model = self.model
20+
if not hasattr(self, "_shared_page_table_buffer"):
21+
shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size)
22+
self._shared_page_table_buffer = [
23+
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
24+
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
25+
]
26+
return self._shared_page_table_buffer
27+
28+
def create_att_prefill_state(self, infer_state):
29+
return PagedFa3PrefillAttState(backend=self, infer_state=infer_state)
30+
31+
def create_att_decode_state(self, infer_state):
32+
return PagedFa3DecodeAttState(backend=self, infer_state=infer_state)
33+
34+
35+
@dataclasses.dataclass
36+
class PagedFa3PrefillAttState(BasePrefillAttState):
37+
cu_seqlens_q: torch.Tensor = None
38+
cu_seqlens_k: torch.Tensor = None
39+
page_table: torch.Tensor = None
40+
41+
def init_state(self):
42+
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
43+
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()
44+
table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size)
45+
self.page_table = torch.empty(
46+
(self.infer_state.batch_size, table_len),
47+
dtype=torch.int32,
48+
device=self.infer_state.input_ids.device,
49+
)
50+
page_table_copy(
51+
page_table=self.page_table,
52+
req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs,
53+
b_req_idx=self.infer_state.b_req_idx,
54+
)
55+
56+
def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty):
57+
assert att_control.use_alibi is False
58+
return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
59+
60+
def _normal_prefill_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty):
61+
if att_control.use_sliding_window:
62+
window_size = att_control.sliding_window
63+
else:
64+
window_size = (-1, -1)
65+
66+
if att_control.use_att_sink:
67+
sink_weight = att_control.sink_weight
68+
else:
69+
sink_weight = None
70+
71+
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
72+
return flash_attn_with_kvcache(
73+
q=q,
74+
k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
75+
v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
76+
page_table=self.page_table,
77+
cache_seqlens=self.infer_state.b_seq_len,
78+
cu_seqlens_q=self.cu_seqlens_q,
79+
cu_seqlens_k_new=self.cu_seqlens_k,
80+
max_seqlen_q=self.infer_state.max_q_seq_len,
81+
softmax_scale=sm_scale,
82+
causal=True,
83+
window_size=window_size,
84+
softcap=0.0,
85+
k_descale=None,
86+
v_descale=None,
87+
return_softmax_lse=False,
88+
sinks=sink_weight,
89+
)
90+
91+
92+
@dataclasses.dataclass
93+
class PagedFa3DecodeAttState(BaseDecodeAttState):
94+
cu_seqlens_q: torch.Tensor = None
95+
cu_seqlens_k: torch.Tensor = None
96+
page_table: torch.Tensor = None
97+
b_att_seq_len: torch.Tensor = None
98+
decode_max_q_seq_len: int = None
99+
100+
def init_state(self):
101+
args_mtp_step = get_env_start_args().mtp_step
102+
if args_mtp_step > 0:
103+
mtp_size = args_mtp_step + 1
104+
b_q_seq_len = torch.full(
105+
(self.infer_state.b_seq_len.shape[0] // mtp_size,),
106+
fill_value=mtp_size,
107+
dtype=torch.int32,
108+
device=self.infer_state.b_seq_len.device,
109+
)
110+
b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size]
111+
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
112+
self.cu_seqlens_q = b1_cu_q_seq_len.int()
113+
self.cu_seqlens_k = b1_cu_kv_seq_len.int()
114+
else:
115+
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
116+
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()
117+
118+
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
119+
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
120+
model = self.backend.model
121+
table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size)
122+
if (
123+
self.infer_state.batch_size <= model.graph_max_batch_size
124+
and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch
125+
):
126+
page_buffer = self.backend.get_page_table_buffer()
127+
shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size)
128+
self.page_table = page_buffer[self.infer_state.microbatch_index][
129+
: att_batch_size * shared_table_len
130+
].reshape(att_batch_size, shared_table_len)
131+
else:
132+
self.page_table = torch.empty(
133+
(att_batch_size, table_len),
134+
dtype=torch.int32,
135+
device=self.infer_state.input_ids.device,
136+
)
137+
138+
if args_mtp_step > 0:
139+
page_table_copy(
140+
page_table=self.page_table[:, :table_len],
141+
req_to_token_indexs=model.req_manager.req_to_token_indexs,
142+
b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)],
143+
)
144+
self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()
145+
self.decode_max_q_seq_len = args_mtp_step + 1
146+
else:
147+
page_table_copy(
148+
page_table=self.page_table[:, :table_len],
149+
req_to_token_indexs=model.req_manager.req_to_token_indexs,
150+
b_req_idx=self.infer_state.b_req_idx,
151+
)
152+
self.b_att_seq_len = self.infer_state.b_seq_len
153+
self.decode_max_q_seq_len = 1
154+
155+
def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty):
156+
assert att_control.use_alibi is False
157+
return self._normal_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
158+
159+
def _normal_decode_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty):
160+
if att_control.use_sliding_window:
161+
window_size = att_control.sliding_window
162+
else:
163+
window_size = (-1, -1)
164+
165+
if att_control.use_att_sink:
166+
sink_weight = att_control.sink_weight
167+
else:
168+
sink_weight = None
169+
170+
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
171+
return flash_attn_with_kvcache(
172+
q=q,
173+
k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
174+
v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
175+
page_table=self.page_table,
176+
cache_seqlens=self.b_att_seq_len,
177+
cu_seqlens_q=self.cu_seqlens_q,
178+
cu_seqlens_k_new=self.cu_seqlens_k,
179+
max_seqlen_q=self.decode_max_q_seq_len,
180+
softmax_scale=sm_scale,
181+
causal=True,
182+
window_size=window_size,
183+
softcap=0.0,
184+
k_descale=None,
185+
v_descale=None,
186+
return_softmax_lse=False,
187+
sinks=sink_weight,
188+
)

0 commit comments

Comments
 (0)