11import torch
22import time
3+ import copy
34from typing import List , Optional , Callable , Dict , Any
45from queue import Queue
56from lightllm .server .router .model_infer .mode_backend .base_backend import ModeBackend
@@ -240,17 +241,23 @@ def decode_mtp(
240241 """
241242 model_input , run_reqs = prepare_decode_inputs (decode_reqs )
242243
244+ if self .mtp_step > 0 :
245+ accept_lens = [req .mtp_accept_len for req in decode_reqs ]
246+ model_input .b_num_accepted_tokens = g_pin_mem_manager .gen_from_list (
247+ key = "b_num_accepted_tokens" ,
248+ data = accept_lens ,
249+ dtype = torch .int32 ,
250+ )
251+
243252 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
244- b_mtp_index_cpu = model_input .b_mtp_index
245253 model_output = self .model .forward (model_input )
246254 next_token_ids , next_token_logprobs = sample (model_output .logits , run_reqs , self .eos_id )
247- # verify the next_token_ids
248- b_req_mtp_start_loc = [index for index , mtp_index in enumerate (b_mtp_index_cpu ) if mtp_index == 0 ]
249- b_req_mtp_start_loc = g_pin_mem_manager .gen_from_list (
250- key = "b_req_mtp_start_loc" ,
251- data = b_req_mtp_start_loc ,
252- dtype = torch .int32 ,
253- ).cuda (non_blocking = True )
255+ # verify the next_token_ids. The chunked decode batch is the contiguous
256+ # (mtp_step+1)-expanded layout, so request starts are structurally
257+ # arange(n_real)*(mtp_step+1). Compute on device instead of a per-step Python
258+ # list-comp + pinned pack + H2D (#22).
259+ n_real = model_input .batch_size // (self .mtp_step + 1 )
260+ b_req_mtp_start_loc = torch .arange (n_real , dtype = torch .int32 , device = "cuda" ) * (self .mtp_step + 1 )
254261
255262 mtp_accept_len , accepted_index = self ._verify_mtp_v2 (
256263 new_next_token_ids = next_token_ids ,
@@ -292,6 +299,8 @@ def decode_mtp(
292299 # 第二阶段
293300 event_pack .notify_post_handle_and_wait_pre_post_handle ()
294301 verify_event .synchronize ()
302+ for req , accept_len in zip (decode_reqs , mtp_accept_len_cpu ):
303+ req .mtp_accept_len = int (accept_len )
295304 verify_ok_reqs = [run_reqs [i ] for i in range (len (run_reqs )) if accepted_index_cpu [i ] == 1 ]
296305 update_packs = self ._pre_post_handle (verify_ok_reqs , is_chuncked_mode = False )
297306
@@ -344,15 +353,19 @@ def _draft_decode_vanilla(
344353 mtp_accept_len : torch .Tensor ,
345354 b_req_mtp_start_loc : torch .Tensor ,
346355 ):
347- # share some inference info with the main model
348- draft_model_input = main_model_input
356+ # share some inference info with the main model. copy.copy 后清空 b_num_accepted_tokens,
357+ # 使 draft (MTP) forward 走普通 decode 布局 (bs, False);否则会沿用主模型 decode_mtp 设置的
358+ # verify 布局,命中 MTP draft 模型从未捕获的 cudagraph key (bs, True) -> KeyError
359+ # (cudagraph 关闭时则会在扁平的 draft batch 上误用 S+1 分组的 verify attention)。
360+ # 镜像 eagle 路径 _build_eagle_accepted_draft_input 中清空 b_num_accepted_tokens 的处理。
361+ draft_model_input = copy .copy (main_model_input )
362+ draft_model_input .b_num_accepted_tokens = None
349363 draft_model_output = main_model_output
350364 draft_next_token_ids = next_token_ids
351365 all_next_token_ids = []
352366 all_next_token_ids .append (next_token_ids )
353367 # process the draft model output
354368 for draft_model_idx in range (self .mtp_step ):
355-
356369 draft_model_input .input_ids = draft_next_token_ids
357370 draft_model_input .mtp_draft_input_hiddens = draft_model_output .mtp_main_output_hiddens
358371 # spec decode: MTP
@@ -379,44 +392,47 @@ def _draft_decode_eagle(
379392 mtp_accept_len : torch .Tensor ,
380393 b_req_mtp_start_loc : torch .Tensor ,
381394 ):
382- batch_size = main_model_input .batch_size
383- num_reqs = batch_size // (self .mtp_step + 1 )
395+ num_reqs = b_req_mtp_start_loc .shape [0 ]
384396 if g_infer_context .radix_cache is not None :
385397 g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (num_reqs * self .mtp_step )
386398 eagle_mem_indexes_cpu = g_infer_context .req_manager .mem_manager .alloc (num_reqs * self .mtp_step )
387399 eagle_mem_indexes = eagle_mem_indexes_cpu .cuda (non_blocking = True )
388400
389- # share some inference info with the main model
390- draft_model_input = main_model_input
401+ (draft_model_input , draft_next_token_ids , accepted_req_idx ,) = self ._build_eagle_accepted_draft_input (
402+ main_model_input = main_model_input ,
403+ main_model_output = main_model_output ,
404+ next_token_ids = next_token_ids ,
405+ mtp_accept_len = mtp_accept_len ,
406+ b_req_mtp_start_loc = b_req_mtp_start_loc ,
407+ )
391408 draft_model_output = main_model_output
392- draft_next_token_ids = next_token_ids
393409 all_next_token_ids = []
394- all_next_token_ids .append (next_token_ids )
395- # process the draft model output
396- for _step in range (self .mtp_step ):
410+ all_next_token_ids .append (draft_next_token_ids )
411+
412+ mtp_size = self .mtp_step + 1
413+ main_mem_indexes = main_model_input .mem_indexes .view (num_reqs , mtp_size )
414+ eagle_mem_indexes_by_req = eagle_mem_indexes .view (self .mtp_step , num_reqs ).transpose (0 , 1 ).contiguous ()
415+ mem_index_plan = torch .cat ([main_mem_indexes , eagle_mem_indexes_by_req ], dim = 1 )
416+ accepted_offsets = mtp_accept_len .long () - 1
417+ req_offsets = torch .arange (num_reqs , dtype = torch .long , device = mtp_accept_len .device )
397418
419+ for _step in range (self .mtp_step ):
398420 draft_model_input .input_ids = draft_next_token_ids
399- draft_model_input .mtp_draft_input_hiddens = draft_model_output .mtp_main_output_hiddens
421+ if _step > 0 :
422+ draft_model_input .mtp_draft_input_hiddens = draft_model_output .mtp_main_output_hiddens
423+ draft_model_input .mem_indexes = mem_index_plan [req_offsets , accepted_offsets + _step ]
400424 # spec decode: MTP
401425 draft_model_idx = _step % self .num_mtp_models
402426 draft_model_output : ModelOutput = self .draft_models [draft_model_idx ].forward (draft_model_input )
403427 draft_next_token_ids = self ._gen_argmax_token_ids (draft_model_output )
404428 draft_model_input .b_seq_len += 1
405429 draft_model_input .max_kv_seq_len += 1
406- eagle_mem_indexes_i = eagle_mem_indexes [_step * num_reqs : (_step + 1 ) * num_reqs ]
407- draft_model_input .mem_indexes = torch .cat (
408- [draft_model_input .mem_indexes .view (- 1 , self .mtp_step + 1 )[:, 1 :], eagle_mem_indexes_i .view (- 1 , 1 )],
409- dim = 1 ,
410- ).view (- 1 )
411430 all_next_token_ids .append (draft_next_token_ids )
412431
413432 all_next_token_ids = torch .stack (all_next_token_ids , dim = 1 ) # [batch_size, mtp_step + 1]
414433
415- mtp_scatter_next_token_ids (
416- req_to_next_token_ids = self .model .req_manager .req_sampling_params_manager .req_to_next_token_ids ,
417- b_req_mtp_start_loc = b_req_mtp_start_loc ,
434+ self ._scatter_accepted_next_token_ids (
435+ accepted_req_idx = accepted_req_idx ,
418436 all_next_token_ids = all_next_token_ids ,
419- b_req_idx = main_model_input .b_req_idx ,
420- mtp_accept_len = mtp_accept_len ,
421437 )
422438 return eagle_mem_indexes_cpu
0 commit comments