@@ -364,6 +364,7 @@ def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprob
364364
365365 self .guided_decoding = GuidedDecoding (fd_config )
366366 self .logprobs_mode = fd_config .model_config .logprobs_mode if fd_config is not None else logprobs_mode
367+ self .compute_logits_stats = fd_config .model_config .compute_logits_stats if fd_config is not None else False
367368 # Can only be created when fd_config.early_stopper_config.enable_early_stop = True
368369 if (
369370 fd_config is not None
@@ -507,6 +508,19 @@ def forward_cuda(
507508 elif self .logprobs_mode == "raw_logits" :
508509 raw_logprobs = logits .clone ()
509510
511+ # Compute logits statistics (min/max/mean/std) per sequence before penalties
512+ logits_min = None
513+ logits_max = None
514+ logits_mean = None
515+ logits_std = None
516+ if num_logprobs is not None and self .compute_logits_stats :
517+ with paddle .no_grad ():
518+ # logits shape: [batch_size, vocab_size], compute stats per sequence (reduce over vocab dimension)
519+ logits_min = paddle .min (logits , axis = 1 ) # [batch_size]
520+ logits_max = paddle .max (logits , axis = 1 ) # [batch_size]
521+ logits_mean = paddle .mean (logits , axis = 1 ) # [batch_size]
522+ logits_std = paddle .std (logits , axis = 1 ) # [batch_size]
523+
510524 for proc in sampling_metadata .logits_processors or []:
511525 logits = proc .apply (logits )
512526
@@ -546,6 +560,33 @@ def forward_cuda(
546560 logprobs_tensors = (
547561 None if num_logprobs is None else self .gather_logprobs (raw_logprobs , num_logprobs , token_ids = next_tokens )
548562 )
563+
564+ # Pack logits stats into LogprobsTensors
565+ if logprobs_tensors is not None and logits_min is not None :
566+ if current_platform .is_cuda ():
567+ logits_min_cpu = paddle .empty_like (logits_min , device = "cpu" ).pin_memory ()
568+ logits_max_cpu = paddle .empty_like (logits_max , device = "cpu" ).pin_memory ()
569+ logits_mean_cpu = paddle .empty_like (logits_mean , device = "cpu" ).pin_memory ()
570+ logits_std_cpu = paddle .empty_like (logits_std , device = "cpu" ).pin_memory ()
571+ logits_min_cpu .copy_ (logits_min , False )
572+ logits_max_cpu .copy_ (logits_max , False )
573+ logits_mean_cpu .copy_ (logits_mean , False )
574+ logits_std_cpu .copy_ (logits_std , False )
575+ else :
576+ logits_min_cpu = logits_min .cpu ()
577+ logits_max_cpu = logits_max .cpu ()
578+ logits_mean_cpu = logits_mean .cpu ()
579+ logits_std_cpu = logits_std .cpu ()
580+ logprobs_tensors = LogprobsTensors (
581+ logprob_token_ids = logprobs_tensors .logprob_token_ids ,
582+ logprobs = logprobs_tensors .logprobs ,
583+ selected_token_ranks = logprobs_tensors .selected_token_ranks ,
584+ logits_min = logits_min_cpu ,
585+ logits_max = logits_max_cpu ,
586+ logits_mean = logits_mean_cpu ,
587+ logits_std = logits_std_cpu ,
588+ )
589+
549590 if sampling_metadata .enable_early_stop :
550591 # will set the stop batch in stop_flags
551592 assert sampling_metadata .stop_flags is not None , "need stop_flags for early stop"
@@ -621,6 +662,7 @@ def __init__(self, fd_config: FDConfig):
621662 else :
622663 raise NotImplementedError
623664 self .logprobs_mode = fd_config .model_config .logprobs_mode
665+ self .compute_logits_stats = fd_config .model_config .compute_logits_stats
624666 self .speculative_verify_window = fd_config .speculative_config .verify_window
625667 self .speculative_max_candidate_len = fd_config .speculative_config .max_candidate_len
626668 self .speculative_benchmark_mode = fd_config .speculative_config .benchmark_mode
@@ -872,12 +914,36 @@ def forward_cuda(
872914
873915 logprobs_tensors = None
874916 if num_logprobs is not None :
917+ # Compute logits statistics on target_logits for accepted tokens
918+ logits_min = None
919+ logits_max = None
920+ logits_mean = None
921+ logits_std = None
922+ if self .compute_logits_stats :
923+ with paddle .no_grad ():
924+ logits_min = paddle .min (target_logits , axis = 1 )
925+ logits_max = paddle .max (target_logits , axis = 1 )
926+ logits_mean = paddle .mean (target_logits , axis = 1 )
927+ logits_std = paddle .std (target_logits , axis = 1 )
928+
875929 token_ids = share_inputs ["accept_tokens" ]
876930 idx = paddle .arange (share_inputs ["accept_tokens" ].shape [1 ], dtype = "int32" )
877931 mask = idx < share_inputs ["accept_num" ].unsqueeze (1 )
878932 token_ids = paddle .masked_select (share_inputs ["accept_tokens" ], mask )
879933 logprobs_tensors = self .gather_logprobs (raw_logprobs , num_logprobs , token_ids = token_ids )
880934
935+ # Pack logits stats into LogprobsTensors
936+ if logits_min is not None :
937+ logprobs_tensors = LogprobsTensors (
938+ logprob_token_ids = logprobs_tensors .logprob_token_ids ,
939+ logprobs = logprobs_tensors .logprobs ,
940+ selected_token_ranks = logprobs_tensors .selected_token_ranks ,
941+ logits_min = logits_min ,
942+ logits_max = logits_max ,
943+ logits_mean = logits_mean ,
944+ logits_std = logits_std ,
945+ )
946+
881947 sampler_output = SamplerOutput (
882948 sampled_token_ids = share_inputs ["accept_tokens" ],
883949 logprobs_tensors = logprobs_tensors ,
@@ -987,6 +1053,7 @@ def __init__(self, fd_config: FDConfig):
9871053 else :
9881054 raise NotImplementedError
9891055 self .logprobs_mode = fd_config .model_config .logprobs_mode
1056+ self .compute_logits_stats = fd_config .model_config .compute_logits_stats
9901057 self .enable_draft_logprob = fd_config .speculative_config .enable_draft_logprob
9911058
9921059 def pre_process (self , skip_idx_list : List [int ] = []):
@@ -1167,6 +1234,24 @@ def forward_cuda(
11671234
11681235 logprobs_tensors = self .gather_logprobs (raw_logprobs , num_logprobs , token_ids = token_ids )
11691236
1237+ # Compute logits statistics on draft_logits for MTP tokens
1238+ if self .compute_logits_stats :
1239+ draft_logits_for_stats = share_inputs ["draft_logits" ][:real_token_num , :]
1240+ with paddle .no_grad ():
1241+ logits_min = paddle .min (draft_logits_for_stats , axis = 1 )
1242+ logits_max = paddle .max (draft_logits_for_stats , axis = 1 )
1243+ logits_mean = paddle .mean (draft_logits_for_stats , axis = 1 )
1244+ logits_std = paddle .std (draft_logits_for_stats , axis = 1 )
1245+ logprobs_tensors = LogprobsTensors (
1246+ logprob_token_ids = logprobs_tensors .logprob_token_ids ,
1247+ logprobs = logprobs_tensors .logprobs ,
1248+ selected_token_ranks = logprobs_tensors .selected_token_ranks ,
1249+ logits_min = logits_min ,
1250+ logits_max = logits_max ,
1251+ logits_mean = logits_mean ,
1252+ logits_std = logits_std ,
1253+ )
1254+
11701255 sampler_output = SamplerOutput (
11711256 sampled_token_ids = token_ids ,
11721257 logprobs_tensors = logprobs_tensors ,
0 commit comments