Skip to content

Commit 16170f3

Browse files
committed
style: align formatting with upstream/main and inline mtp accept-len commit
- Revert local reformatting to match upstream/main exactly, minimizing PR diff - Inline _commit_mtp_accept_len into decode_mtp (phase-2 ordering preserved) - Drop redundant inline comments
1 parent 48c15de commit 16170f3

6 files changed

Lines changed: 37 additions & 131 deletions

File tree

lightllm/common/basemodel/attention/fa3/fp.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
import dataclasses
22
import torch
3-
from ..base_att import (
4-
BaseAttBackend,
5-
BasePrefillAttState,
6-
BaseDecodeAttState,
7-
AttControl,
8-
)
3+
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
94
from typing import Optional, TYPE_CHECKING
105
from lightllm.utils.dist_utils import get_current_device_id
116
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
127
from lightllm.utils.envs_utils import get_env_start_args
138
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
14-
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import (
15-
gen_cumsum_pad0_tensor,
16-
)
9+
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
1710

1811

1912
class Fa3AttBackend(BaseAttBackend):
@@ -28,14 +21,12 @@ def get_page_table_buffer(self):
2821
model = self.model
2922
if not hasattr(self, "_shared_page_table_buffer"):
3023
self._shared_page_table_buffer = [
31-
torch.empty(
32-
model.graph_max_batch_size * model.graph_max_len_in_batch,
33-
dtype=torch.int32,
34-
).to(get_current_device_id()),
35-
torch.empty(
36-
model.graph_max_batch_size * model.graph_max_len_in_batch,
37-
dtype=torch.int32,
38-
).to(get_current_device_id()),
24+
torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to(
25+
get_current_device_id()
26+
),
27+
torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to(
28+
get_current_device_id()
29+
),
3930
]
4031
return self._shared_page_table_buffer
4132

@@ -84,12 +75,7 @@ def prefill_att(
8475
)
8576

8677
def _nomarl_prefill_att(
87-
self,
88-
q: torch.Tensor,
89-
k: torch.Tensor,
90-
v: torch.Tensor,
91-
att_control: AttControl,
92-
alloc_func=torch.empty,
78+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
9379
) -> torch.Tensor:
9480
self.backend: Fa3AttBackend = self.backend # for typing
9581

lightllm/common/basemodel/attention/fa3/fp8.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,9 @@ def init_state(self):
4444
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
4545
)
4646
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
47-
self.k_descale = (
48-
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
49-
)
50-
self.v_descale = (
51-
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
52-
)
47+
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
48+
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
49+
5350

5451
def prefill_att(
5552
self,
@@ -125,12 +122,8 @@ def init_state(self):
125122
head_num = mem_manager.head_num
126123

127124
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
128-
self.k_descale = (
129-
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
130-
)
131-
self.v_descale = (
132-
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
133-
)
125+
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
126+
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
134127

135128
return
136129

lightllm/common/basemodel/attention/fa3/mla.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
import dataclasses
22
import torch
3-
from ..base_att import (
4-
BaseAttBackend,
5-
BasePrefillAttState,
6-
BaseDecodeAttState,
7-
AttControl,
8-
)
3+
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
94
from typing import Optional, TYPE_CHECKING, Tuple
105
from lightllm.utils.dist_utils import get_current_device_id
116
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
127
from lightllm.utils.envs_utils import get_env_start_args
138
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
14-
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import (
15-
gen_cumsum_pad0_tensor,
16-
)
9+
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
1710
from lightllm.utils.sgl_utils import flash_attn_varlen_func
1811

1912

@@ -29,14 +22,12 @@ def get_page_table_buffer(self):
2922
model = self.model
3023
if not hasattr(self, "_shared_page_table_buffer"):
3124
self._shared_page_table_buffer = [
32-
torch.empty(
33-
model.graph_max_batch_size * model.graph_max_len_in_batch,
34-
dtype=torch.int32,
35-
).to(get_current_device_id()),
36-
torch.empty(
37-
model.graph_max_batch_size * model.graph_max_len_in_batch,
38-
dtype=torch.int32,
39-
).to(get_current_device_id()),
25+
torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to(
26+
get_current_device_id()
27+
),
28+
torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to(
29+
get_current_device_id()
30+
),
4031
]
4132
return self._shared_page_table_buffer
4233

@@ -78,12 +69,7 @@ def prefill_att(
7869
)
7970

8071
def _mla_prefill_att(
81-
self,
82-
q: torch.Tensor,
83-
k: torch.Tensor,
84-
v: torch.Tensor,
85-
att_control: AttControl,
86-
alloc_func=torch.empty,
72+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
8773
) -> torch.Tensor:
8874
self.backend: MlaFa3AttBackend = self.backend # for typing
8975
k_nope, k_rope = k

lightllm/common/basemodel/triton_kernel/linear_att_copy.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ def _copy_linear_att_state_to_kv_buffer(
4646
accept_len = tl.load(num_accepted_tokens_ptr + cur_batch).to(tl.int64)
4747
canonical_off = accept_len - 1
4848

49-
# --- conv snapshot ---
50-
# conv is a single WIDENED slot keyed by req_idx (asymmetric layout, §3.4).
51-
# The committed NARROW window of byte length conv_narrow_row_bytes sits at
52-
# byte offset canonical_off * itemsize inside each widened row. The flattened
53-
# uint8 tail lays out element [d, w] at d * gpu_conv_row_bytes + w (bytes),
54-
# so the narrow window is strided per row: copy row-by-row.
5549
conv_src_slot = cur_req_idx
56-
# gpu_conv_stride_d carries the per-element byte size (itemsize); the narrow
57-
# window starts canonical_off elements into the widened row.
5850
conv_off_bytes = canonical_off * gpu_conv_stride_d
5951
gpu_conv_base = gpu_conv_ptr + cur_layer * gpu_conv_stride_l + conv_src_slot * gpu_conv_stride_s + conv_off_bytes
6052
cpu_conv_base = cpu_kv_conv_ptr + big_page_buffer_idx * cpu_kv_conv_stride_s + cur_layer * cpu_kv_conv_stride_l
@@ -65,9 +57,6 @@ def _copy_linear_att_state_to_kv_buffer(
6557
conv_data = tl.load(gpu_conv_base + d * gpu_conv_row_bytes + off, mask=mask)
6658
tl.store(cpu_conv_base + d * cpu_kv_conv_stride_d + off, conv_data, mask=mask)
6759

68-
# --- ssm snapshot ---
69-
# ssm is an (S+1) BLOCK per request; the committed block slot is
70-
# req_idx * (mtp_step + 1) + canonical_off.
7160
ssm_src_slot = (cur_req_idx * (mtp_step + 1) + canonical_off).to(tl.int64)
7261
for i in range(tl.cdiv(gpu_ssm_tail_dim, BLOCK)):
7362
gpu_start_off = i * BLOCK + tl.arange(0, BLOCK)
@@ -98,10 +87,6 @@ def copy_linear_att_state_to_kv_buffer(
9887
assert len(b_req_idx) == b_num_accepted_tokens.shape[0]
9988
BLOCK = 4096
10089

101-
# Conv: keep the (conv_dim, width) tail un-flattened so the committed narrow
102-
# window can be read per row at the canonical offset (the window is strided
103-
# in the flattened widened layout). Capture itemsize BEFORE the uint8 view to
104-
# convert the element-unit canonical offset into a byte offset.
10590
assert gpu_conv_state.dim() >= 4, "gpu_conv_state must be [layer, s, conv_dim, widened_width]"
10691
assert cpu_kv_conv_state.dim() >= 4, "cpu_kv_conv_state must be [size, layer, conv_dim, width_narrow]"
10792
conv_itemsize = gpu_conv_state.element_size()

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -794,35 +794,13 @@ def _verify_mtp_v2(
794794
)
795795
return mtp_accept_len, accepted_index
796796

797-
def _commit_mtp_accept_len(
798-
self,
799-
decode_reqs: List[InferReq],
800-
mtp_accept_len_cpu: torch.Tensor,
801-
):
802-
# Carry the per-req accept count into the NEXT step as the canonical
803-
# pointer (design §3.1). This must run on every rank (not only master):
804-
# the kernels on this rank read req.mtp_accept_len.
805-
#
806-
# CRITICAL ordering (overlap scheduler): the next step's decode_mtp reads
807-
# req.mtp_accept_len (to build b_num_accepted_tokens) the moment its
808-
# wait_to_forward() is released, which happens at THIS step's
809-
# notify_forward_and_wait_post_handle() (start of phase 3). So this carry
810-
# MUST be committed in phase 2 (pre_post_handle), before that release —
811-
# otherwise the next step reads a one-step-stale accept count. The error
812-
# is invisible while accept_len is constant (==1) and corrupts the GDN
813-
# conv/ssm committed-state read-offset the instant a multi-token accept
814-
# (accept_len>=2) occurs.
815-
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
816-
req.mtp_accept_len = int(accept_len)
817-
return
818-
819797
def _update_mtp_accept_ratio(
820798
self,
821799
decode_reqs: List[InferReq],
822800
mtp_accept_len_cpu: torch.Tensor,
823801
):
824-
# Master-only accept-ratio statistics. Unlike _commit_mtp_accept_len this
825-
# only feeds metrics, so it may stay in the phase-3 post_handle region.
802+
# Master-only accept-ratio statistics. Unlike the phase-2 mtp_accept_len commit
803+
# (inlined in decode_mtp) this only feeds metrics, so it may stay in phase 3.
826804
if self.is_master_in_dp:
827805
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
828806
req.update_mtp_accepted_token_num(accept_token_num=accept_len - 1)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
from typing import List, Optional, Callable, Dict, Any
55
from queue import Queue
66
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
7-
from lightllm.server.router.model_infer.mode_backend.overlap_events import (
8-
OverlapEventPack,
9-
)
7+
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack
108
from lightllm.server.router.model_infer.infer_batch import InferReq
119
from lightllm.server.router.model_infer.mode_backend.pre import (
1210
prepare_prefill_inputs,
@@ -43,10 +41,7 @@ def __init__(self) -> None:
4341
if get_env_start_args().mtp_mode:
4442
self.prefill = self.prefill_mtp
4543
self.decode = self.decode_mtp
46-
self.is_mtp_eagle = get_env_start_args().mtp_mode in [
47-
"eagle_with_att",
48-
"eagle_no_att",
49-
]
44+
self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"]
5045
self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step
5146
self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla
5247
else:
@@ -115,7 +110,7 @@ def prefill_normal(
115110
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
116111
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
117112
model_output = self.model.forward(model_input)
118-
(_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token(
113+
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
119114
logits=model_output.logits,
120115
b_req_idx=model_input.b_req_idx,
121116
b_mtp_index=model_input.b_mtp_index,
@@ -158,7 +153,7 @@ def decode_normal(
158153
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
159154
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
160155
model_output = self.model.forward(model_input)
161-
(_, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token(
156+
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
162157
logits=model_output.logits,
163158
b_req_idx=model_input.b_req_idx,
164159
b_mtp_index=model_input.b_mtp_index,
@@ -196,7 +191,7 @@ def prefill_mtp(
196191
model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
197192
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
198193
model_output = self.model.forward(model_input)
199-
(next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu,) = self._sample_and_scatter_token(
194+
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
200195
logits=model_output.logits,
201196
b_req_idx=model_input.b_req_idx,
202197
b_mtp_index=model_input.b_mtp_index,
@@ -207,9 +202,7 @@ def prefill_mtp(
207202
)
208203
# mtp kv fill
209204
self._draft_prefill_forward(
210-
model_input=model_input,
211-
model_output=model_output,
212-
next_token_ids=next_token_ids,
205+
model_input=model_input, model_output=model_output, next_token_ids=next_token_ids
213206
)
214207
g_infer_context.copy_linear_att_state_to_cache_buffer(
215208
b_req_idx=model_input.b_req_idx,
@@ -249,11 +242,6 @@ def decode_mtp(
249242
"""
250243
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
251244

252-
# Build the per-real-request accept tensor (carried InferReq.mtp_accept_len
253-
# from the previous step). decode_reqs is one entry per real request,
254-
# aligning 1:1 with the b_gdn_verify_cu_seqlens grouping (the same zip used
255-
# by _update_mtp_accept_ratio). Threaded onto the infer_state via ModelInput
256-
# (mirrors b_mtp_index); to_cuda() moves it inside forward. §3.1
257245
if self.mtp_step > 0:
258246
accept_lens = [req.mtp_accept_len for req in decode_reqs]
259247
model_input.b_num_accepted_tokens = g_pin_mem_manager.gen_from_list(
@@ -290,10 +278,9 @@ def decode_mtp(
290278
verify_event = torch.cuda.Event()
291279
verify_event.record()
292280

293-
(
294-
next_token_ids_cpu,
295-
next_token_logprobs_cpu,
296-
) = self._async_copy_next_token_infos_to_pin_mem(next_token_ids, next_token_logprobs)
281+
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
282+
next_token_ids, next_token_logprobs
283+
)
297284

298285
# 调用具体的draft decode函数
299286
additional_mem_indexes_cpu = self._draft_decode_func(
@@ -315,12 +302,8 @@ def decode_mtp(
315302
# 第二阶段
316303
event_pack.notify_post_handle_and_wait_pre_post_handle()
317304
verify_event.synchronize()
318-
# Commit the carried accept count HERE (phase 2 / pre_post_handle), not in
319-
# phase 3: the next overlapped step reads req.mtp_accept_len as soon as this
320-
# step calls notify_forward_and_wait_post_handle() below, so the update must
321-
# land before that release to avoid feeding the kernels a stale (one-step-old)
322-
# accept count. See _commit_mtp_accept_len for the full rationale.
323-
self._commit_mtp_accept_len(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu)
305+
for req, accept_len in zip(decode_reqs, mtp_accept_len_cpu):
306+
req.mtp_accept_len = int(accept_len)
324307
verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1]
325308
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)
326309

@@ -352,12 +335,7 @@ def decode_mtp(
352335
event_pack.notify_pre_post_handle()
353336
return
354337

355-
def _draft_prefill_forward(
356-
self,
357-
model_input: ModelInput,
358-
model_output: ModelOutput,
359-
next_token_ids: torch.Tensor,
360-
):
338+
def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor):
361339
# spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
362340
draft_model_input = model_input
363341
draft_model_output = model_output

0 commit comments

Comments
 (0)