@@ -796,12 +796,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
796796 metrics = None ,
797797 )
798798
799- token_ids = tokens [i ][:, 0 ].tolist ()[: accept_num [i ]]
799+ tokens_i = tokens [i ].tolist ()
800+ scores_i = scores [i ].tolist ()
801+ ranks_i = ranks [i ].tolist ()
802+ token_ids = [row [0 ] for row in tokens_i [: accept_num [i ]]]
800803 for batch_token_index in range (len (token_ids )):
801- result .outputs .logprob = float ( scores [ i , batch_token_index , 0 ])
802- topk_token_ids = tokens [ i , batch_token_index , :]. tolist ()
803- topk_logprobs = scores [ i , batch_token_index , :]. tolist ()
804- sampled_rank = ranks [ i , batch_token_index ]. item ()
804+ result .outputs .logprob = scores_i [ batch_token_index ][ 0 ]
805+ topk_token_ids = tokens_i [ batch_token_index ]
806+ topk_logprobs = scores_i [ batch_token_index ]
807+ sampled_rank = ranks_i [ batch_token_index ]
805808
806809 if result .outputs .draft_top_logprobs is None :
807810 result .outputs .draft_top_logprobs = LogprobsLists (
@@ -828,16 +831,19 @@ def _process_batch_output(self):
828831 mtype = 3
829832 if self .cfg .speculative_config .method :
830833 if self .use_logprobs :
831- mtype = int (self .output_tokens [1 , 0 ].item ())
834+ # meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits).
835+ packed_meta1 = int (self .output_tokens [1 , 0 ].item ())
836+ mtype = packed_meta1 & 0xFF
837+ actual_topk = packed_meta1 >> 8
832838 batch = self .output_tokens [2 , 0 ]
833839 accept_num = [int (num [0 ]) for num in self .output_tokens [3 : batch + 3 ]]
834840 tokens = tokens [3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1 )].reshape (
835841 [batch , MAX_DRAFT_TOKENS , K + 1 ]
836- )
842+ )[:, :, : actual_topk ]
837843 scores = (
838844 self .output_scores [: batch * MAX_DRAFT_TOKENS * (K + 1 )]
839845 .numpy ()
840- .reshape ([batch , MAX_DRAFT_TOKENS , K + 1 ])
846+ .reshape ([batch , MAX_DRAFT_TOKENS , K + 1 ])[:, :, : actual_topk ]
841847 )
842848 ranks = self .output_ranks [: batch * MAX_DRAFT_TOKENS ].numpy ().reshape ([batch , MAX_DRAFT_TOKENS ])
843849
@@ -846,6 +852,10 @@ def _process_batch_output(self):
846852 batch_result = self ._process_batch_draft_tokens (mtype , batch , accept_num , tokens , scores , ranks )
847853 self .postprocess (batch_result , mtype )
848854 return
855+ # Pre-convert full arrays to Python lists once for MTP target token path.
856+ tokens_lists = tokens .tolist ()
857+ scores_lists = scores .tolist ()
858+ ranks_list = ranks .tolist ()
849859 else :
850860 batch = self .output_tokens [1 ]
851861 accept_num = tokens [2 : batch + 2 ]
@@ -914,7 +924,7 @@ def _process_batch_output(self):
914924 llm_logger .info (f"recovery stop signal found at task { task_id } " )
915925 token_ids = [RECOVERY_STOP_SIGNAL ]
916926 elif self .use_logprobs :
917- token_ids = tokens [ i ][:, 0 ]. tolist ()[ : accept_num [i ]]
927+ token_ids = [ row [ 0 ] for row in tokens_lists [ i ][ : accept_num [i ] ]]
918928 else :
919929 token_ids = tokens [
920930 2
@@ -1033,10 +1043,10 @@ def _process_batch_output(self):
10331043 task .output_token_ids .append (token_id )
10341044 if self .use_logprobs :
10351045 if self .cfg .speculative_config .method :
1036- result .outputs .logprob = float ( scores [ i , batch_token_index , 0 ])
1037- topk_token_ids = tokens [ i , batch_token_index , :]. tolist ()
1038- topk_logprobs = scores [ i , batch_token_index , :]. tolist ()
1039- sampled_rank = ranks [ i , batch_token_index ]. item ()
1046+ result .outputs .logprob = scores_lists [ i ][ batch_token_index ][ 0 ]
1047+ topk_token_ids = tokens_lists [ i ][ batch_token_index ]
1048+ topk_logprobs = scores_lists [ i ][ batch_token_index ]
1049+ sampled_rank = ranks_list [ i ][ batch_token_index ]
10401050 else :
10411051 # Use pre-converted lists (batch .tolist() done before the loop).
10421052 result .outputs .logprob = scores_lists [i ][0 ]
0 commit comments