Skip to content

[Cherry-Pick][Optimization] Reduce logprob processing overhead by using actual topk instead of fixed K+1 (#7860)#7861

Merged
Sunny-bot1 merged 3 commits into
PaddlePaddle:release/2.6from
Sunny-bot1:opt_logprob_process_26
May 20, 2026
Merged

[Cherry-Pick][Optimization] Reduce logprob processing overhead by using actual topk instead of fixed K+1 (#7860)#7861
Sunny-bot1 merged 3 commits into
PaddlePaddle:release/2.6from
Sunny-bot1:opt_logprob_process_26

Conversation

@Sunny-bot1

@Sunny-bot1 Sunny-bot1 commented May 19, 2026

Copy link
Copy Markdown
Collaborator

Motivation

开启 logprob(top_logprobs=0)时,性能下降明显,TPOT 比不开 logprob 高约 5ms。分析发现开销大的原因之一是 logprob 数据传输和处理均按固定的 K+1=21 列处理,而用户实际只需要 actual_topk=1 列,存在约 10 倍的冗余计算。

save_output_topk C++ op 中:

  • sender 循环固定写入 K+1=21 列,stride 固定为 K+1
  • mtext[1] 只存 bsz,actual_topk 信息丢失

get_output_topk C++ op 中:

  • receiver 循环固定读取 K+1=21 列,stride 固定为 K+1

token_processor.py 中:

  • reshape 固定按 K+1=21 列展开,导致 output_scores.numpy()
    拷贝 batch*21 个 float(10752),以及 per-request tolist()
    每行处理 21 个元素

Modifications

save_output_msg_with_topk.cc

  • mtext[1] 改为 bit-pack 存储:bsz(低16位)| actual_topk(高16位)
  • sender 循环改为 max_num_logprobs 次,stride 改为 max_num_logprobs

get_output_msg_with_topk.cc

  • 从 mtext[1] 解包 bsz 和 actual_topk
  • receiver 循环改为 actual_topk 次,stride 改为 actual_topk

token_processor.py

  • 从 packed mtext[1] 解包 batch 和 actual_topk
  • output_scores.numpy() 切片范围从 batch21 缩小到 batchactual_topk
  • reshape 列数从固定 K+1 改为动态 actual_topk
  • 提前to_list(),避免在逐token的循环里调用to_list(),增加开销

msgsnd/msgrcv 消息结构体大小不变,向后兼容。
actual_topk 通过 mtext[1] bit-pack 传递,不增加额外字段。

top_logprobs=0,concurrency=256,GLM-4.5-Air,TP8:

指标 优化前 优化后 提升
平均TPOT 33.80ms 31.41ms -2.39ms
平均解码速度 30.72 32.79 +6%
QPS 0.571 req/s 0.590 req/s +3%

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot

paddle-bot Bot commented May 19, 2026

Copy link
Copy Markdown

Thanks for your contribution!

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot

PaddlePaddle-bot commented May 19, 2026

Copy link
Copy Markdown

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-20 18:03:13

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

Required 任务已全部通过(10/10),当前失败均为 Optional 任务,不阻塞合并;建议从 required 维度通过,Optional 失败可按需处理或重跑。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
37(0) 37 33 3 0 1 0

2 任务状态汇总

日志列说明:失败任务直接使用 log_links_markdown 字段(已预生成),运行中任务手动拼接 [Job]({html_url})

2.1 Required任务 : 10/10 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
其余 10 个必选任务通过 - - - - -

2.2 可选任务 — 23/27 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Run iluvatar Tests / run_iluvatar_cases 1m39s Job -
Check PR Template 12s Job -
Trigger Jenkins for PR 46s Job -
⏸️ CI_HPU - - -
其余 23 个可选任务通过 - - -

3 失败详情(仅 required)

无 required 失败任务。

@codecov-commenter

codecov-commenter commented May 19, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (release/2.6@41d44d6). Learn more about missing BASE report.

Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7861   +/-   ##
==============================================
  Coverage               ?   72.39%           
==============================================
  Files                  ?      381           
  Lines                  ?    54222           
  Branches               ?     8473           
==============================================
  Hits                   ?    39252           
  Misses                 ?    12209           
  Partials               ?     2761           
Flag Coverage Δ
GPU 72.39% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-20 15:23:02

📋 Review 摘要

PR 概述:通过 bit-pack 方式将 actual_topk 信息随消息传递,使 sender/receiver/Python 三端均按实际所需列数处理 logprob 数据,消除固定 K+1=21 列的冗余计算,实测 TPOT 降低约 2.4ms。
变更范围custom_ops/gpu_ops/custom_ops/xpu_ops/fastdeploy/output/tests/output/
影响面 Tag[OP] [DataProcessor] [XPU]

问题

级别 文件 概述
❓ 疑问 get_output_msg_with_topk.cc 滚动升级时老版 sender 导致 actual_topk=0,数据静默丢失
📝 PR 规范 Usage or Command / Accuracy Tests 段为空;Checklist 全部未勾选

❓ 疑问:滚动升级兼容性风险

mtext[1] 的语义从 bsz 变更为 bsz | (actual_topk << 16),消息结构体大小不变,但字段语义发生了 breaking change:

  • 旧版 sender + 新版 receivermtext[1] = bsz(无高16位打包),receiver 解包 actual_topk = (mtext[1] >> 16) & 0xFFFF = 0,内层循环 for (j = 0; j < 0; j++) 不执行,out_datascores_data 全为零/未初始化,静默丢数据。

建议:

  1. 确认部署模型是否为原子升级(sender / receiver 总是同步更新),若是则风险可接受;
  2. 若存在滚动升级场景,可增加防御逻辑:
// 兼容旧版 sender:actual_topk == 0 时回退到 K+1
if (actual_topk == 0) actual_topk = K + 1;

📝 PR 规范检查

## Usage or Command## Accuracy Tests 两段内容仅为 HTML 注释,未填写实际内容或 N/A;Checklist 五条均未勾选,至少应勾选已满足的条目。

PR 描述建议(可直接复制,复刻 checklist §D2 完整结构):

## Motivation

开启 logprob(top_logprobs=0)时,性能下降明显,TPOT 比不开 logprob 高约 5ms。分析发现开销大的原因之一是 logprob 数据传输和处理均按固定的 K+1=21 列处理,而用户实际只需要 actual_topk=1 列,存在约 10 倍的冗余计算:

- save_output_topk C++ op:sender 循环固定写入 K+1=21 列,stride 固定为 K+1;mtext[1] 只存 bsz,actual_topk 信息丢失
- get_output_topk C++ op:receiver 循环固定读取 K+1=21 列,stride 固定为 K+1
- token_processor.py:reshape 固定按 K+1=21 列展开,导致 output_scores.numpy() 拷贝 batch*21 个 float,以及 per-request tolist() 每行处理 21 个元素

## Modifications

**save_output_msg_with_topk.cc**(GPU & XPU)
- mtext[1] 改为 bit-pack 存储:bsz(低16位)| actual_topk(高16位)
- sender 循环改为 max_num_logprobs 次,stride 改为 max_num_logprobs

**get_output_msg_with_topk.cc**(GPU & XPU)
- 从 mtext[1] 解包 bsz 和 actual_topk
- receiver 循环改为 actual_topk 次,stride 改为 actual_topk

**fastdeploy/output/token_processor.py**
- 从 packed mtext[1] 解包 batch 和 actual_topk
- output_scores.numpy() 切片范围从 batch\*21 缩小到 batch\*actual_topk
- reshape 列数从固定 K+1 改为动态 actual_topk
- 预先执行全量 .tolist() 转换,避免 per-request 重复调用

top_logprobs=0,concurrency=256,GLM-4.5-Air,TP8:

| 指标 | 优化前 | 优化后 | 提升 |
|------|--------|--------|------|
| 平均TPOT | 33.80ms | 31.41ms | -2.39ms |
| 平均解码速度 | 30.72 | 32.79 | +6% |
| QPS | 0.571 req/s | 0.590 req/s | +3% |

## Usage or Command

N/A

## Accuracy Tests

N/A(本次变更仅影响 logprob 数据的传输和处理范围,不改变模型输出)

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [x] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

优化思路清晰,实现方案正确,GPU 与 XPU 两端同步更新,Python 解包逻辑与 C++ 打包语义一致,测试也做了对应更新,实测性能提升显著。主要建议关注滚动升级场景下的向后兼容性,以及补全 PR 描述中的空白段落。

@Sunny-bot1 Sunny-bot1 merged commit 31b12ee into PaddlePaddle:release/2.6 May 20, 2026
34 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants