Skip to content

Commit 85e722a

Browse files
committed
add logits compute into FD
1 parent cd71865 commit 85e722a

3 files changed

Lines changed: 107 additions & 3 deletions

File tree

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,7 @@ def _start_worker_service(self):
20292029
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
20302030
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
20312031
"enable_logprob": self.cfg.model_config.enable_logprob,
2032+
"compute_logits_stats": self.cfg.model_config.compute_logits_stats,
20322033
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
20332034
"enable_entropy": self.cfg.model_config.enable_entropy,
20342035
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprob
364364

365365
self.guided_decoding = GuidedDecoding(fd_config)
366366
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
367+
self.compute_logits_stats = fd_config.model_config.compute_logits_stats if fd_config is not None else False
367368
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
368369
if (
369370
fd_config is not None
@@ -507,6 +508,19 @@ def forward_cuda(
507508
elif self.logprobs_mode == "raw_logits":
508509
raw_logprobs = logits.clone()
509510

511+
# Compute logits statistics (min/max/mean/std) per sequence before penalties
512+
logits_min = None
513+
logits_max = None
514+
logits_mean = None
515+
logits_std = None
516+
if num_logprobs is not None and self.compute_logits_stats:
517+
with paddle.no_grad():
518+
# logits shape: [batch_size, vocab_size], compute stats per sequence (reduce over vocab dimension)
519+
logits_min = paddle.min(logits, axis=1) # [batch_size]
520+
logits_max = paddle.max(logits, axis=1) # [batch_size]
521+
logits_mean = paddle.mean(logits, axis=1) # [batch_size]
522+
logits_std = paddle.std(logits, axis=1) # [batch_size]
523+
510524
for proc in sampling_metadata.logits_processors or []:
511525
logits = proc.apply(logits)
512526

@@ -546,6 +560,33 @@ def forward_cuda(
546560
logprobs_tensors = (
547561
None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
548562
)
563+
564+
# Pack logits stats into LogprobsTensors
565+
if logprobs_tensors is not None and logits_min is not None:
566+
if current_platform.is_cuda():
567+
logits_min_cpu = paddle.empty_like(logits_min, device="cpu").pin_memory()
568+
logits_max_cpu = paddle.empty_like(logits_max, device="cpu").pin_memory()
569+
logits_mean_cpu = paddle.empty_like(logits_mean, device="cpu").pin_memory()
570+
logits_std_cpu = paddle.empty_like(logits_std, device="cpu").pin_memory()
571+
logits_min_cpu.copy_(logits_min, False)
572+
logits_max_cpu.copy_(logits_max, False)
573+
logits_mean_cpu.copy_(logits_mean, False)
574+
logits_std_cpu.copy_(logits_std, False)
575+
else:
576+
logits_min_cpu = logits_min.cpu()
577+
logits_max_cpu = logits_max.cpu()
578+
logits_mean_cpu = logits_mean.cpu()
579+
logits_std_cpu = logits_std.cpu()
580+
logprobs_tensors = LogprobsTensors(
581+
logprob_token_ids=logprobs_tensors.logprob_token_ids,
582+
logprobs=logprobs_tensors.logprobs,
583+
selected_token_ranks=logprobs_tensors.selected_token_ranks,
584+
logits_min=logits_min_cpu,
585+
logits_max=logits_max_cpu,
586+
logits_mean=logits_mean_cpu,
587+
logits_std=logits_std_cpu,
588+
)
589+
549590
if sampling_metadata.enable_early_stop:
550591
# will set the stop batch in stop_flags
551592
assert sampling_metadata.stop_flags is not None, "need stop_flags for early stop"
@@ -621,6 +662,7 @@ def __init__(self, fd_config: FDConfig):
621662
else:
622663
raise NotImplementedError
623664
self.logprobs_mode = fd_config.model_config.logprobs_mode
665+
self.compute_logits_stats = fd_config.model_config.compute_logits_stats
624666
self.speculative_verify_window = fd_config.speculative_config.verify_window
625667
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
626668
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
@@ -872,12 +914,36 @@ def forward_cuda(
872914

873915
logprobs_tensors = None
874916
if num_logprobs is not None:
917+
# Compute logits statistics on target_logits for accepted tokens
918+
logits_min = None
919+
logits_max = None
920+
logits_mean = None
921+
logits_std = None
922+
if self.compute_logits_stats:
923+
with paddle.no_grad():
924+
logits_min = paddle.min(target_logits, axis=1)
925+
logits_max = paddle.max(target_logits, axis=1)
926+
logits_mean = paddle.mean(target_logits, axis=1)
927+
logits_std = paddle.std(target_logits, axis=1)
928+
875929
token_ids = share_inputs["accept_tokens"]
876930
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
877931
mask = idx < share_inputs["accept_num"].unsqueeze(1)
878932
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
879933
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
880934

935+
# Pack logits stats into LogprobsTensors
936+
if logits_min is not None:
937+
logprobs_tensors = LogprobsTensors(
938+
logprob_token_ids=logprobs_tensors.logprob_token_ids,
939+
logprobs=logprobs_tensors.logprobs,
940+
selected_token_ranks=logprobs_tensors.selected_token_ranks,
941+
logits_min=logits_min,
942+
logits_max=logits_max,
943+
logits_mean=logits_mean,
944+
logits_std=logits_std,
945+
)
946+
881947
sampler_output = SamplerOutput(
882948
sampled_token_ids=share_inputs["accept_tokens"],
883949
logprobs_tensors=logprobs_tensors,
@@ -987,6 +1053,7 @@ def __init__(self, fd_config: FDConfig):
9871053
else:
9881054
raise NotImplementedError
9891055
self.logprobs_mode = fd_config.model_config.logprobs_mode
1056+
self.compute_logits_stats = fd_config.model_config.compute_logits_stats
9901057
self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob
9911058

9921059
def pre_process(self, skip_idx_list: List[int] = []):
@@ -1167,6 +1234,24 @@ def forward_cuda(
11671234

11681235
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
11691236

1237+
# Compute logits statistics on draft_logits for MTP tokens
1238+
if self.compute_logits_stats:
1239+
draft_logits_for_stats = share_inputs["draft_logits"][:real_token_num, :]
1240+
with paddle.no_grad():
1241+
logits_min = paddle.min(draft_logits_for_stats, axis=1)
1242+
logits_max = paddle.max(draft_logits_for_stats, axis=1)
1243+
logits_mean = paddle.mean(draft_logits_for_stats, axis=1)
1244+
logits_std = paddle.std(draft_logits_for_stats, axis=1)
1245+
logprobs_tensors = LogprobsTensors(
1246+
logprob_token_ids=logprobs_tensors.logprob_token_ids,
1247+
logprobs=logprobs_tensors.logprobs,
1248+
selected_token_ranks=logprobs_tensors.selected_token_ranks,
1249+
logits_min=logits_min,
1250+
logits_max=logits_max,
1251+
logits_mean=logits_mean,
1252+
logits_std=logits_std,
1253+
)
1254+
11701255
sampler_output = SamplerOutput(
11711256
sampled_token_ids=token_ids,
11721257
logprobs_tensors=logprobs_tensors,

fastdeploy/worker/input_batch.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
2222
from fastdeploy.model_executor.logits_processor import build_logits_processors
2323
from fastdeploy.platforms import current_platform
24+
from fastdeploy.worker.output import LogprobsTensors
2425

2526

2627
class InputBatch:
@@ -1127,9 +1128,26 @@ def recover_batch_index_for_sampler_output(sampler_output, index_to_batch_id, en
11271128
real_logprob_token_ids = _recover_tensor(logprob_token_ids, src_order)
11281129
real_logprobs = _recover_tensor(logprobs, src_order)
11291130
real_selected_token_ranks = _recover_tensor(selected_token_ranks, src_order)
1130-
sampler_output.logprobs_tensors.logprob_token_ids = real_logprob_token_ids
1131-
sampler_output.logprobs_tensors.logprobs = real_logprobs
1132-
sampler_output.logprobs_tensors.sampled_token_ranks = real_selected_token_ranks
1131+
1132+
real_logits_min = None
1133+
real_logits_max = None
1134+
real_logits_mean = None
1135+
real_logits_std = None
1136+
if sampler_output.logprobs_tensors.logits_min is not None:
1137+
real_logits_min = _recover_tensor(sampler_output.logprobs_tensors.logits_min, src_order)
1138+
real_logits_max = _recover_tensor(sampler_output.logprobs_tensors.logits_max, src_order)
1139+
real_logits_mean = _recover_tensor(sampler_output.logprobs_tensors.logits_mean, src_order)
1140+
real_logits_std = _recover_tensor(sampler_output.logprobs_tensors.logits_std, src_order)
1141+
1142+
sampler_output.logprobs_tensors = LogprobsTensors(
1143+
logprob_token_ids=real_logprob_token_ids,
1144+
logprobs=real_logprobs,
1145+
selected_token_ranks=real_selected_token_ranks,
1146+
logits_min=real_logits_min,
1147+
logits_max=real_logits_max,
1148+
logits_mean=real_logits_mean,
1149+
logits_std=real_logits_std,
1150+
)
11331151

11341152
if sampler_output.token_num_per_batch is not None:
11351153
token_num_per_batch = sampler_output.token_num_per_batch

0 commit comments

Comments
 (0)