@@ -375,6 +375,7 @@ def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprob
375375
376376 self .guided_decoding = GuidedDecoding (fd_config )
377377 self .logprobs_mode = fd_config .model_config .logprobs_mode if fd_config is not None else logprobs_mode
378+ self .compute_logits_stats = fd_config .model_config .compute_logits_stats if fd_config is not None else False
378379 # Can only be created when fd_config.early_stopper_config.enable_early_stop = True
379380 if (
380381 fd_config is not None
@@ -522,6 +523,19 @@ def forward_cuda(
522523 elif self .logprobs_mode == "raw_logits" :
523524 raw_logprobs = logits .clone ()
524525
526+ # Compute logits statistics (min/max/mean/std) per sequence before penalties
527+ logits_min = None
528+ logits_max = None
529+ logits_mean = None
530+ logits_std = None
531+ if num_logprobs is not None and self .compute_logits_stats :
532+ with paddle .no_grad ():
533+ # logits shape: [batch_size, vocab_size], compute stats per sequence (reduce over vocab dimension)
534+ logits_min = paddle .min (logits , axis = 1 ) # [batch_size]
535+ logits_max = paddle .max (logits , axis = 1 ) # [batch_size]
536+ logits_mean = paddle .mean (logits , axis = 1 ) # [batch_size]
537+ logits_std = paddle .std (logits , axis = 1 ) # [batch_size]
538+
525539 for proc in sampling_metadata .logits_processors or []:
526540 logits = proc .apply (logits )
527541
@@ -565,6 +579,33 @@ def forward_cuda(
565579 logprobs_tensors = (
566580 None if num_logprobs is None else self .gather_logprobs (raw_logprobs , num_logprobs , token_ids = next_tokens )
567581 )
582+
583+ # Pack logits stats into LogprobsTensors
584+ if logprobs_tensors is not None and logits_min is not None :
585+ if current_platform .is_cuda ():
586+ logits_min_cpu = paddle .empty_like (logits_min , device = "cpu" ).pin_memory ()
587+ logits_max_cpu = paddle .empty_like (logits_max , device = "cpu" ).pin_memory ()
588+ logits_mean_cpu = paddle .empty_like (logits_mean , device = "cpu" ).pin_memory ()
589+ logits_std_cpu = paddle .empty_like (logits_std , device = "cpu" ).pin_memory ()
590+ logits_min_cpu .copy_ (logits_min , False )
591+ logits_max_cpu .copy_ (logits_max , False )
592+ logits_mean_cpu .copy_ (logits_mean , False )
593+ logits_std_cpu .copy_ (logits_std , False )
594+ else :
595+ logits_min_cpu = logits_min .cpu ()
596+ logits_max_cpu = logits_max .cpu ()
597+ logits_mean_cpu = logits_mean .cpu ()
598+ logits_std_cpu = logits_std .cpu ()
599+ logprobs_tensors = LogprobsTensors (
600+ logprob_token_ids = logprobs_tensors .logprob_token_ids ,
601+ logprobs = logprobs_tensors .logprobs ,
602+ selected_token_ranks = logprobs_tensors .selected_token_ranks ,
603+ logits_min = logits_min_cpu ,
604+ logits_max = logits_max_cpu ,
605+ logits_mean = logits_mean_cpu ,
606+ logits_std = logits_std_cpu ,
607+ )
608+
568609 if sampling_metadata .enable_early_stop :
569610 # will set the stop batch in stop_flags
570611 assert sampling_metadata .stop_flags is not None , "need stop_flags for early stop"
@@ -640,6 +681,7 @@ def __init__(self, fd_config: FDConfig):
640681 else :
641682 raise NotImplementedError
642683 self .logprobs_mode = fd_config .model_config .logprobs_mode
684+ self .compute_logits_stats = fd_config .model_config .compute_logits_stats
643685 self .speculative_verify_window = fd_config .speculative_config .verify_window
644686 self .speculative_max_candidate_len = fd_config .speculative_config .max_candidate_len
645687 self .speculative_benchmark_mode = fd_config .speculative_config .benchmark_mode
@@ -1038,6 +1080,7 @@ def forward_cuda(
10381080 is_naive = is_naive ,
10391081 logprobs_mode = self .logprobs_mode ,
10401082 compute_logprobs_fn = self .compute_logprobs ,
1083+ compute_logits_stats = self .compute_logits_stats ,
10411084 )
10421085 sampler_output .logprobs_tensors = logprobs_tensors
10431086 if cu_batch_token_offset is not None :
@@ -1147,6 +1190,7 @@ def __init__(self, fd_config: FDConfig):
11471190 else :
11481191 raise NotImplementedError
11491192 self .logprobs_mode = fd_config .model_config .logprobs_mode
1193+ self .compute_logits_stats = fd_config .model_config .compute_logits_stats
11501194 self .enable_draft_logprob = fd_config .speculative_config .enable_draft_logprob
11511195
11521196 def pre_process (self , skip_idx_list : List [int ] = []):
@@ -1336,6 +1380,24 @@ def forward_cuda(
13361380
13371381 logprobs_tensors = self .gather_logprobs (raw_logprobs , num_logprobs , token_ids = token_ids )
13381382
1383+ # Compute logits statistics on draft_logits for MTP tokens
1384+ if self .compute_logits_stats :
1385+ draft_logits_for_stats = share_inputs ["draft_logits" ][:real_token_num , :]
1386+ with paddle .no_grad ():
1387+ logits_min = paddle .min (draft_logits_for_stats , axis = 1 )
1388+ logits_max = paddle .max (draft_logits_for_stats , axis = 1 )
1389+ logits_mean = paddle .mean (draft_logits_for_stats , axis = 1 )
1390+ logits_std = paddle .std (draft_logits_for_stats , axis = 1 )
1391+ logprobs_tensors = LogprobsTensors (
1392+ logprob_token_ids = logprobs_tensors .logprob_token_ids ,
1393+ logprobs = logprobs_tensors .logprobs ,
1394+ selected_token_ranks = logprobs_tensors .selected_token_ranks ,
1395+ logits_min = logits_min ,
1396+ logits_max = logits_max ,
1397+ logits_mean = logits_mean ,
1398+ logits_std = logits_std ,
1399+ )
1400+
13391401 sampler_output = SamplerOutput (
13401402 sampled_token_ids = token_ids ,
13411403 logprobs_tensors = logprobs_tensors ,
0 commit comments