Skip to content

Commit 468a0cc

Browse files
committed
add logits compute into FD
1 parent 67e0aa1 commit 468a0cc

3 files changed

Lines changed: 102 additions & 3 deletions

File tree

fastdeploy/model_executor/layers/sample/logprobs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def build_output_logprobs(
133133
is_naive: bool = False,
134134
logprobs_mode: str = "default",
135135
compute_logprobs_fn: Optional[Callable] = None,
136+
compute_logits_stats: bool = False,
136137
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]:
137138
"""
138139
Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes.
@@ -151,6 +152,7 @@ def build_output_logprobs(
151152
logprobs_mode: One of "raw_logprobs", "raw_logits", or "default".
152153
compute_logprobs_fn: Callable for computing logprobs with temperature
153154
scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs".
155+
compute_logits_stats: Whether to compute per-token logits statistics (min/max/mean/std).
154156
155157
Returns:
156158
tuple: (logprobs_tensors, cu_batch_token_offset)
@@ -218,4 +220,21 @@ def build_output_logprobs(
218220

219221
logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
220222

223+
# Compute logits statistics (min/max/mean/std) per token and pack into LogprobsTensors
224+
if compute_logits_stats:
225+
with paddle.no_grad():
226+
logits_min = paddle.min(output_logits, axis=1)
227+
logits_max = paddle.max(output_logits, axis=1)
228+
logits_mean = paddle.mean(output_logits, axis=1)
229+
logits_std = paddle.std(output_logits, axis=1)
230+
logprobs_tensors = LogprobsTensors(
231+
logprob_token_ids=logprobs_tensors.logprob_token_ids,
232+
logprobs=logprobs_tensors.logprobs,
233+
selected_token_ranks=logprobs_tensors.selected_token_ranks,
234+
logits_min=logits_min,
235+
logits_max=logits_max,
236+
logits_mean=logits_mean,
237+
logits_std=logits_std,
238+
)
239+
221240
return logprobs_tensors, cu_batch_token_offset

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprob
375375

376376
self.guided_decoding = GuidedDecoding(fd_config)
377377
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
378+
self.compute_logits_stats = fd_config.model_config.compute_logits_stats if fd_config is not None else False
378379
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
379380
if (
380381
fd_config is not None
@@ -522,6 +523,19 @@ def forward_cuda(
522523
elif self.logprobs_mode == "raw_logits":
523524
raw_logprobs = logits.clone()
524525

526+
# Compute logits statistics (min/max/mean/std) per sequence before penalties
527+
logits_min = None
528+
logits_max = None
529+
logits_mean = None
530+
logits_std = None
531+
if num_logprobs is not None and self.compute_logits_stats:
532+
with paddle.no_grad():
533+
# logits shape: [batch_size, vocab_size], compute stats per sequence (reduce over vocab dimension)
534+
logits_min = paddle.min(logits, axis=1) # [batch_size]
535+
logits_max = paddle.max(logits, axis=1) # [batch_size]
536+
logits_mean = paddle.mean(logits, axis=1) # [batch_size]
537+
logits_std = paddle.std(logits, axis=1) # [batch_size]
538+
525539
for proc in sampling_metadata.logits_processors or []:
526540
logits = proc.apply(logits)
527541

@@ -565,6 +579,33 @@ def forward_cuda(
565579
logprobs_tensors = (
566580
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
567581
)
582+
583+
# Pack logits stats into LogprobsTensors
584+
if logprobs_tensors is not None and logits_min is not None:
585+
if current_platform.is_cuda():
586+
logits_min_cpu = paddle.empty_like(logits_min, device="cpu").pin_memory()
587+
logits_max_cpu = paddle.empty_like(logits_max, device="cpu").pin_memory()
588+
logits_mean_cpu = paddle.empty_like(logits_mean, device="cpu").pin_memory()
589+
logits_std_cpu = paddle.empty_like(logits_std, device="cpu").pin_memory()
590+
logits_min_cpu.copy_(logits_min, False)
591+
logits_max_cpu.copy_(logits_max, False)
592+
logits_mean_cpu.copy_(logits_mean, False)
593+
logits_std_cpu.copy_(logits_std, False)
594+
else:
595+
logits_min_cpu = logits_min.cpu()
596+
logits_max_cpu = logits_max.cpu()
597+
logits_mean_cpu = logits_mean.cpu()
598+
logits_std_cpu = logits_std.cpu()
599+
logprobs_tensors = LogprobsTensors(
600+
logprob_token_ids=logprobs_tensors.logprob_token_ids,
601+
logprobs=logprobs_tensors.logprobs,
602+
selected_token_ranks=logprobs_tensors.selected_token_ranks,
603+
logits_min=logits_min_cpu,
604+
logits_max=logits_max_cpu,
605+
logits_mean=logits_mean_cpu,
606+
logits_std=logits_std_cpu,
607+
)
608+
568609
if sampling_metadata.enable_early_stop:
569610
# will set the stop batch in stop_flags
570611
assert sampling_metadata.stop_flags is not None, "need stop_flags for early stop"
@@ -640,6 +681,7 @@ def __init__(self, fd_config: FDConfig):
640681
else:
641682
raise NotImplementedError
642683
self.logprobs_mode = fd_config.model_config.logprobs_mode
684+
self.compute_logits_stats = fd_config.model_config.compute_logits_stats
643685
self.speculative_verify_window = fd_config.speculative_config.verify_window
644686
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
645687
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
@@ -1038,6 +1080,7 @@ def forward_cuda(
10381080
is_naive=is_naive,
10391081
logprobs_mode=self.logprobs_mode,
10401082
compute_logprobs_fn=self.compute_logprobs,
1083+
compute_logits_stats=self.compute_logits_stats,
10411084
)
10421085
sampler_output.logprobs_tensors = logprobs_tensors
10431086
if cu_batch_token_offset is not None:
@@ -1147,6 +1190,7 @@ def __init__(self, fd_config: FDConfig):
11471190
else:
11481191
raise NotImplementedError
11491192
self.logprobs_mode = fd_config.model_config.logprobs_mode
1193+
self.compute_logits_stats = fd_config.model_config.compute_logits_stats
11501194
self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob
11511195

11521196
def pre_process(self, skip_idx_list: List[int] = []):
@@ -1336,6 +1380,24 @@ def forward_cuda(
13361380

13371381
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
13381382

1383+
# Compute logits statistics on draft_logits for MTP tokens
1384+
if self.compute_logits_stats:
1385+
draft_logits_for_stats = share_inputs["draft_logits"][:real_token_num, :]
1386+
with paddle.no_grad():
1387+
logits_min = paddle.min(draft_logits_for_stats, axis=1)
1388+
logits_max = paddle.max(draft_logits_for_stats, axis=1)
1389+
logits_mean = paddle.mean(draft_logits_for_stats, axis=1)
1390+
logits_std = paddle.std(draft_logits_for_stats, axis=1)
1391+
logprobs_tensors = LogprobsTensors(
1392+
logprob_token_ids=logprobs_tensors.logprob_token_ids,
1393+
logprobs=logprobs_tensors.logprobs,
1394+
selected_token_ranks=logprobs_tensors.selected_token_ranks,
1395+
logits_min=logits_min,
1396+
logits_max=logits_max,
1397+
logits_mean=logits_mean,
1398+
logits_std=logits_std,
1399+
)
1400+
13391401
sampler_output = SamplerOutput(
13401402
sampled_token_ids=token_ids,
13411403
logprobs_tensors=logprobs_tensors,

fastdeploy/worker/input_batch.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
2222
from fastdeploy.model_executor.logits_processor import build_logits_processors
2323
from fastdeploy.platforms import current_platform
24+
from fastdeploy.worker.output import LogprobsTensors
2425

2526

2627
class InputBatch:
@@ -1150,9 +1151,26 @@ def recover_batch_index_for_sampler_output(sampler_output, index_to_batch_id, en
11501151
real_logprob_token_ids = _recover_tensor(logprob_token_ids, src_order)
11511152
real_logprobs = _recover_tensor(logprobs, src_order)
11521153
real_selected_token_ranks = _recover_tensor(selected_token_ranks, src_order)
1153-
sampler_output.logprobs_tensors.logprob_token_ids = real_logprob_token_ids
1154-
sampler_output.logprobs_tensors.logprobs = real_logprobs
1155-
sampler_output.logprobs_tensors.sampled_token_ranks = real_selected_token_ranks
1154+
1155+
real_logits_min = None
1156+
real_logits_max = None
1157+
real_logits_mean = None
1158+
real_logits_std = None
1159+
if sampler_output.logprobs_tensors.logits_min is not None:
1160+
real_logits_min = _recover_tensor(sampler_output.logprobs_tensors.logits_min, src_order)
1161+
real_logits_max = _recover_tensor(sampler_output.logprobs_tensors.logits_max, src_order)
1162+
real_logits_mean = _recover_tensor(sampler_output.logprobs_tensors.logits_mean, src_order)
1163+
real_logits_std = _recover_tensor(sampler_output.logprobs_tensors.logits_std, src_order)
1164+
1165+
sampler_output.logprobs_tensors = LogprobsTensors(
1166+
logprob_token_ids=real_logprob_token_ids,
1167+
logprobs=real_logprobs,
1168+
selected_token_ranks=real_selected_token_ranks,
1169+
logits_min=real_logits_min,
1170+
logits_max=real_logits_max,
1171+
logits_mean=real_logits_mean,
1172+
logits_std=real_logits_std,
1173+
)
11561174

11571175
if sampler_output.token_num_per_batch is not None:
11581176
token_num_per_batch = sampler_output.token_num_per_batch

0 commit comments

Comments
 (0)