@@ -821,38 +821,63 @@ def _create_chat_logprobs(
821821 output_top_logprobs is None
822822 or len (output_top_logprobs ) < 3
823823 or any (not lst for lst in output_top_logprobs [:3 ])
824- ): # check top 3 because logits_stats maybe None
824+ ):
825825 return None
826826 logprobs_res : Optional [LogProbs ] = None
827827
828- # Extract logits stats from LogprobsLists if available
829- has_logits_stats = False if output_top_logprobs .logits_min is None else True
828+ # Check if output_top_logprobs is a LogprobsLists object(NamedTuple) or a list
829+ is_logprobslists = hasattr (output_top_logprobs , "logprob_token_ids" )
830+
831+ # Extract logits stats if available
832+ if is_logprobslists :
833+ # output_top_logprobs is LogprobsLists namedtuple
834+ has_logits_stats = output_top_logprobs .logits_min is not None
835+ else :
836+ # list from msgpack: [logprob_token_ids, logprobs, sampled_token_ranks, logits_min, logits_max, logits_mean, logits_std]
837+ has_logits_stats = len (output_top_logprobs ) >= 7 and output_top_logprobs [3 ] is not None
838+
839+ if is_logprobslists :
840+ num_tokens = len (output_top_logprobs .logprobs )
841+ _tk_ids = lambda idx : output_top_logprobs .logprob_token_ids [idx ]
842+ _lps = lambda idx : output_top_logprobs .logprobs [idx ]
843+ _ranks = lambda idx : output_top_logprobs .sampled_token_ranks [idx ]
844+ _lmin = lambda idx : output_top_logprobs .logits_min [idx ]
845+ _lmax = lambda idx : output_top_logprobs .logits_max [idx ]
846+ _lmean = lambda idx : output_top_logprobs .logits_mean [idx ]
847+ _lstd = lambda idx : output_top_logprobs .logits_std [idx ]
848+ else :
849+ num_tokens = len (output_top_logprobs [1 ])
850+ _tk_ids = lambda idx : output_top_logprobs [0 ][idx ]
851+ _lps = lambda idx : output_top_logprobs [1 ][idx ]
852+ _ranks = lambda idx : output_top_logprobs [2 ][idx ]
853+ _lmin = lambda idx : output_top_logprobs [3 ][idx ]
854+ _lmax = lambda idx : output_top_logprobs [4 ][idx ]
855+ _lmean = lambda idx : output_top_logprobs [5 ][idx ]
856+ _lstd = lambda idx : output_top_logprobs [6 ][idx ]
830857
831- # Iterate by index over mandatory fields; optionally include logits stats
832- num_tokens = len (output_top_logprobs .logprobs )
833858 for idx in range (num_tokens ):
834859 logits_stats = None
835860 if has_logits_stats :
836861 top_logprobs = LogprobsLists (
837- logprob_token_ids = [output_top_logprobs . logprob_token_ids [ idx ] ],
838- logprobs = [output_top_logprobs . logprobs [ idx ] ],
839- sampled_token_ranks = [output_top_logprobs . sampled_token_ranks [ idx ] ],
840- logits_min = [output_top_logprobs . logits_min [ idx ] ],
841- logits_max = [output_top_logprobs . logits_max [ idx ] ],
842- logits_mean = [output_top_logprobs . logits_mean [ idx ] ],
843- logits_std = [output_top_logprobs . logits_std [ idx ] ],
862+ logprob_token_ids = [_tk_ids ( idx ) ],
863+ logprobs = [_lps ( idx ) ],
864+ sampled_token_ranks = [_ranks ( idx ) ],
865+ logits_min = [_lmin ( idx ) ],
866+ logits_max = [_lmax ( idx ) ],
867+ logits_mean = [_lmean ( idx ) ],
868+ logits_std = [_lstd ( idx ) ],
844869 )
845870 logits_stats = {
846- "min" : float (output_top_logprobs . logits_min [ idx ] ),
847- "max" : float (output_top_logprobs . logits_max [ idx ] ),
848- "mean" : float (output_top_logprobs . logits_mean [ idx ] ),
849- "std" : float (output_top_logprobs . logits_std [ idx ] ),
871+ "min" : float (_lmin ( idx ) ),
872+ "max" : float (_lmax ( idx ) ),
873+ "mean" : float (_lmean ( idx ) ),
874+ "std" : float (_lstd ( idx ) ),
850875 }
851876 else :
852877 top_logprobs = LogprobsLists (
853- logprob_token_ids = [output_top_logprobs . logprob_token_ids [ idx ] ],
854- logprobs = [output_top_logprobs . logprobs [ idx ] ],
855- sampled_token_ranks = [output_top_logprobs . sampled_token_ranks [ idx ] ],
878+ logprob_token_ids = [_tk_ids ( idx ) ],
879+ logprobs = [_lps ( idx ) ],
880+ sampled_token_ranks = [_ranks ( idx ) ],
856881 )
857882 step_logprobs_res = self ._build_logprobs_response (
858883 request_logprobs = request_logprobs ,
@@ -943,7 +968,11 @@ def _build_prompt_logprobs(
943968 tensors.
944969 """
945970
946- token_ids , logprobs , ranks = prompt_logprobs_tensors
971+ token_ids , logprobs , ranks = (
972+ prompt_logprobs_tensors .logprob_token_ids ,
973+ prompt_logprobs_tensors .logprobs ,
974+ prompt_logprobs_tensors .selected_token_ranks ,
975+ )
947976
948977 # Normalize to plain Python lists (support both Tensor and list inputs)
949978 if hasattr (token_ids , "tolist" ):
0 commit comments