@@ -739,12 +739,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores,
739739 metrics = None ,
740740 )
741741
742- token_ids = tokens [i ][:, 0 ].tolist ()[: accept_num [i ]]
742+ tokens_i = tokens [i ].tolist ()
743+ scores_i = scores [i ].tolist ()
744+ ranks_i = ranks [i ].tolist ()
745+ token_ids = [row [0 ] for row in tokens_i [: accept_num [i ]]]
743746 for batch_token_index in range (len (token_ids )):
744- result .outputs .logprob = float ( scores [ i , batch_token_index , 0 ])
745- topk_token_ids = tokens [ i , batch_token_index , :]. tolist ()
746- topk_logprobs = scores [ i , batch_token_index , :]. tolist ()
747- sampled_rank = ranks [ i , batch_token_index ]. item ()
747+ result .outputs .logprob = scores_i [ batch_token_index ][ 0 ]
748+ topk_token_ids = tokens_i [ batch_token_index ]
749+ topk_logprobs = scores_i [ batch_token_index ]
750+ sampled_rank = ranks_i [ batch_token_index ]
748751
749752 if result .outputs .draft_top_logprobs is None :
750753 result .outputs .draft_top_logprobs = LogprobsLists (
@@ -771,16 +774,19 @@ def _process_batch_output(self):
771774 mtype = 3
772775 if self .cfg .speculative_config .method :
773776 if self .use_logprobs :
774- mtype = int (self .output_tokens [1 , 0 ].item ())
777+ # meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits).
778+ packed_meta1 = int (self .output_tokens [1 , 0 ].item ())
779+ mtype = packed_meta1 & 0xFF
780+ actual_topk = packed_meta1 >> 8
775781 batch = self .output_tokens [2 , 0 ]
776782 accept_num = [int (num [0 ]) for num in self .output_tokens [3 : batch + 3 ]]
777783 tokens = tokens [3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1 )].reshape (
778784 [batch , MAX_DRAFT_TOKENS , K + 1 ]
779- )
785+ )[:, :, : actual_topk ]
780786 scores = (
781787 self .output_scores [: batch * MAX_DRAFT_TOKENS * (K + 1 )]
782788 .numpy ()
783- .reshape ([batch , MAX_DRAFT_TOKENS , K + 1 ])
789+ .reshape ([batch , MAX_DRAFT_TOKENS , K + 1 ])[:, :, : actual_topk ]
784790 )
785791 ranks = self .output_ranks [: batch * MAX_DRAFT_TOKENS ].numpy ().reshape ([batch , MAX_DRAFT_TOKENS ])
786792
@@ -789,6 +795,10 @@ def _process_batch_output(self):
789795 batch_result = self ._process_batch_draft_tokens (mtype , batch , accept_num , tokens , scores , ranks )
790796 self .postprocess (batch_result , mtype )
791797 return
798+ # Pre-convert full arrays to Python lists once for MTP target token path.
799+ tokens_lists = tokens .tolist ()
800+ scores_lists = scores .tolist ()
801+ ranks_list = ranks .tolist ()
792802 else :
793803 batch = self .output_tokens [1 ]
794804 accept_num = tokens [2 : batch + 2 ]
@@ -856,7 +866,7 @@ def _process_batch_output(self):
856866 )
857867 token_ids = [RECOVERY_STOP_SIGNAL ]
858868 elif self .use_logprobs :
859- token_ids = tokens [ i ][:, 0 ]. tolist ()[ : accept_num [i ]]
869+ token_ids = [ row [ 0 ] for row in tokens_lists [ i ][ : accept_num [i ] ]]
860870 else :
861871 token_ids = tokens [
862872 2
@@ -988,10 +998,10 @@ def _process_batch_output(self):
988998 task .output_token_ids .append (token_id )
989999 if self .use_logprobs :
9901000 if self .cfg .speculative_config .method :
991- result .outputs .logprob = float ( scores [ i , batch_token_index , 0 ])
992- topk_token_ids = tokens [ i , batch_token_index , :]. tolist ()
993- topk_logprobs = scores [ i , batch_token_index , :]. tolist ()
994- sampled_rank = ranks [ i , batch_token_index ]. item ()
1001+ result .outputs .logprob = scores_lists [ i ][ batch_token_index ][ 0 ]
1002+ topk_token_ids = tokens_lists [ i ][ batch_token_index ]
1003+ topk_logprobs = scores_lists [ i ][ batch_token_index ]
1004+ sampled_rank = ranks_list [ i ][ batch_token_index ]
9951005 else :
9961006 # Use pre-converted lists (batch .tolist() done before the loop).
9971007 result .outputs .logprob = scores_lists [i ][0 ]
0 commit comments