44from typing import List , Optional , Callable , Dict , Any
55from queue import Queue
66from lightllm .server .router .model_infer .mode_backend .base_backend import ModeBackend
7- from lightllm .server .router .model_infer .mode_backend .overlap_events import (
8- OverlapEventPack ,
9- )
7+ from lightllm .server .router .model_infer .mode_backend .overlap_events import OverlapEventPack
108from lightllm .server .router .model_infer .infer_batch import InferReq
119from lightllm .server .router .model_infer .mode_backend .pre import (
1210 prepare_prefill_inputs ,
@@ -43,10 +41,7 @@ def __init__(self) -> None:
4341 if get_env_start_args ().mtp_mode :
4442 self .prefill = self .prefill_mtp
4543 self .decode = self .decode_mtp
46- self .is_mtp_eagle = get_env_start_args ().mtp_mode in [
47- "eagle_with_att" ,
48- "eagle_no_att" ,
49- ]
44+ self .is_mtp_eagle = get_env_start_args ().mtp_mode in ["eagle_with_att" , "eagle_no_att" ]
5045 self .num_mtp_models = 1 if self .is_mtp_eagle else get_env_start_args ().mtp_step
5146 self ._draft_decode_func = self ._draft_decode_eagle if self .is_mtp_eagle else self ._draft_decode_vanilla
5247 else :
@@ -115,7 +110,7 @@ def prefill_normal(
115110 model_input , run_reqs = prepare_prefill_inputs (prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill )
116111 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
117112 model_output = self .model .forward (model_input )
118- ( _ , next_token_ids_cpu , next_token_logprobs_cpu ,) = self ._sample_and_scatter_token (
113+ _ , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
119114 logits = model_output .logits ,
120115 b_req_idx = model_input .b_req_idx ,
121116 b_mtp_index = model_input .b_mtp_index ,
@@ -158,7 +153,7 @@ def decode_normal(
158153 model_input , run_reqs = prepare_decode_inputs (decode_reqs )
159154 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
160155 model_output = self .model .forward (model_input )
161- ( _ , next_token_ids_cpu , next_token_logprobs_cpu ,) = self ._sample_and_scatter_token (
156+ _ , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
162157 logits = model_output .logits ,
163158 b_req_idx = model_input .b_req_idx ,
164159 b_mtp_index = model_input .b_mtp_index ,
@@ -196,7 +191,7 @@ def prefill_mtp(
196191 model_input , run_reqs = prepare_prefill_inputs (prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill )
197192 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
198193 model_output = self .model .forward (model_input )
199- ( next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu ,) = self ._sample_and_scatter_token (
194+ next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
200195 logits = model_output .logits ,
201196 b_req_idx = model_input .b_req_idx ,
202197 b_mtp_index = model_input .b_mtp_index ,
@@ -207,9 +202,7 @@ def prefill_mtp(
207202 )
208203 # mtp kv fill
209204 self ._draft_prefill_forward (
210- model_input = model_input ,
211- model_output = model_output ,
212- next_token_ids = next_token_ids ,
205+ model_input = model_input , model_output = model_output , next_token_ids = next_token_ids
213206 )
214207 g_infer_context .copy_linear_att_state_to_cache_buffer (
215208 b_req_idx = model_input .b_req_idx ,
@@ -249,11 +242,6 @@ def decode_mtp(
249242 """
250243 model_input , run_reqs = prepare_decode_inputs (decode_reqs )
251244
252- # Build the per-real-request accept tensor (carried InferReq.mtp_accept_len
253- # from the previous step). decode_reqs is one entry per real request,
254- # aligning 1:1 with the b_gdn_verify_cu_seqlens grouping (the same zip used
255- # by _update_mtp_accept_ratio). Threaded onto the infer_state via ModelInput
256- # (mirrors b_mtp_index); to_cuda() moves it inside forward. §3.1
257245 if self .mtp_step > 0 :
258246 accept_lens = [req .mtp_accept_len for req in decode_reqs ]
259247 model_input .b_num_accepted_tokens = g_pin_mem_manager .gen_from_list (
@@ -290,10 +278,9 @@ def decode_mtp(
290278 verify_event = torch .cuda .Event ()
291279 verify_event .record ()
292280
293- (
294- next_token_ids_cpu ,
295- next_token_logprobs_cpu ,
296- ) = self ._async_copy_next_token_infos_to_pin_mem (next_token_ids , next_token_logprobs )
281+ next_token_ids_cpu , next_token_logprobs_cpu = self ._async_copy_next_token_infos_to_pin_mem (
282+ next_token_ids , next_token_logprobs
283+ )
297284
298285 # 调用具体的draft decode函数
299286 additional_mem_indexes_cpu = self ._draft_decode_func (
@@ -315,12 +302,8 @@ def decode_mtp(
315302 # 第二阶段
316303 event_pack .notify_post_handle_and_wait_pre_post_handle ()
317304 verify_event .synchronize ()
318- # Commit the carried accept count HERE (phase 2 / pre_post_handle), not in
319- # phase 3: the next overlapped step reads req.mtp_accept_len as soon as this
320- # step calls notify_forward_and_wait_post_handle() below, so the update must
321- # land before that release to avoid feeding the kernels a stale (one-step-old)
322- # accept count. See _commit_mtp_accept_len for the full rationale.
323- self ._commit_mtp_accept_len (decode_reqs = decode_reqs , mtp_accept_len_cpu = mtp_accept_len_cpu )
305+ for req , accept_len in zip (decode_reqs , mtp_accept_len_cpu ):
306+ req .mtp_accept_len = int (accept_len )
324307 verify_ok_reqs = [run_reqs [i ] for i in range (len (run_reqs )) if accepted_index_cpu [i ] == 1 ]
325308 update_packs = self ._pre_post_handle (verify_ok_reqs , is_chuncked_mode = False )
326309
@@ -352,12 +335,7 @@ def decode_mtp(
352335 event_pack .notify_pre_post_handle ()
353336 return
354337
355- def _draft_prefill_forward (
356- self ,
357- model_input : ModelInput ,
358- model_output : ModelOutput ,
359- next_token_ids : torch .Tensor ,
360- ):
338+ def _draft_prefill_forward (self , model_input : ModelInput , model_output : ModelOutput , next_token_ids : torch .Tensor ):
361339 # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
362340 draft_model_input = model_input
363341 draft_model_output = model_output
0 commit comments