Commit e173cbf
committed
[TRTLLM-12669][perf] Reuse draft probs to drop redundant softmax + cut rejection-path overhead
This commit refactors the rejection-sampling draft path to compute the
filtered + normalized prob distribution exactly once per draft step, and
folds three independent optimizations into one PR-coherent change:
1. Single-pass compute_probs + sample on draft side
_draft_sampler_advanced_for_rejection now calls a new
sampling_batch_spec_dec_one_model_for_rejection which returns both the
sampled token AND the probs in one go. The probs are scattered into the
slot-indexed draft_probs buffer immediately, so the previous separate
_compute_and_store_draft_probs path (which redundantly re-ran
temperature + top_k + top_p + softmax on the cloned logits) is gone.
2. Faster compute_probs_from_logits via flashinfer fast path
compute_probs_from_logits now composes flashinfer's radix-based O(N)
kernels (top_k_mask_logits → fused softmax+temp → top_p_renorm_probs)
when CUDA + flashinfer are available. The previous C++ op path triggered
torch.sort fallback (O(N log N) per row) due to a hard-coded kMax=0,
which severely under-utilized SMs at small batch sizes. C++ op and
PyTorch CPU paths are retained as fallbacks.
3. Pre-allocated full_draft_probs buffer
The (max_num_requests, max_draft_len, vocab_size) scratch used to pad
draft probs to target vocab is now zero-filled once at prepare() and
reused across iters, saving ~25 us/iter of 64 MB zero-fill. Only
allocated when use_rejection_sampling=True.
The eagle3 draft loop is simplified accordingly: it no longer accumulates
a draft_logits_list or invokes _compute_and_store_draft_probs after the
loop; per-step scatter happens inside _draft_sampler_advanced_for_rejection
keyed on the (already-required) draft_step index.
Net effect on llama70b bs=32 (T=0.7/top_k=50/top_p=0.9, MT-bench 2000):
ΔTPS recovered from -32% (post-refactor with sort fallback) and
-12% (pre-refactor with double softmax) to ~-5% (flashinfer fast path).
The remaining gap is fundamental: llama70b's Eagle3 draft already tracks
the target closely (AR uplift only +2%), so the inherent rejection
sampling overhead (chain_speculative_sampling kernel + target_probs +
d2t padding ≈ ~340 us/iter ≈ 1.5%) is not fully offset by the small AR
gain. qwen8b/qwen235b with ΔAR +9%~+14% remain solidly net positive.
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>1 parent 485792d commit e173cbf
3 files changed
Lines changed: 173 additions & 76 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
661 | 661 | | |
662 | 662 | | |
663 | 663 | | |
664 | | - | |
665 | 664 | | |
666 | 665 | | |
667 | 666 | | |
| |||
714 | 713 | | |
715 | 714 | | |
716 | 715 | | |
717 | | - | |
718 | | - | |
719 | | - | |
720 | | - | |
721 | | - | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
722 | 721 | | |
723 | 722 | | |
724 | 723 | | |
| |||
759 | 758 | | |
760 | 759 | | |
761 | 760 | | |
762 | | - | |
763 | | - | |
764 | | - | |
765 | | - | |
766 | | - | |
767 | | - | |
768 | | - | |
769 | | - | |
770 | | - | |
771 | | - | |
772 | | - | |
773 | | - | |
774 | | - | |
| 761 | + | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
775 | 773 | | |
776 | 774 | | |
777 | 775 | | |
| |||
802 | 800 | | |
803 | 801 | | |
804 | 802 | | |
| 803 | + | |
805 | 804 | | |
806 | 805 | | |
807 | 806 | | |
808 | 807 | | |
809 | 808 | | |
810 | 809 | | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
811 | 814 | | |
812 | 815 | | |
813 | 816 | | |
814 | 817 | | |
815 | 818 | | |
816 | 819 | | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
817 | 823 | | |
818 | 824 | | |
819 | 825 | | |
820 | 826 | | |
| 827 | + | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
821 | 831 | | |
822 | 832 | | |
823 | 833 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
30 | | - | |
| 30 | + | |
| 31 | + | |
31 | 32 | | |
32 | 33 | | |
33 | 34 | | |
| |||
479 | 480 | | |
480 | 481 | | |
481 | 482 | | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
482 | 490 | | |
483 | 491 | | |
484 | 492 | | |
| |||
501 | 509 | | |
502 | 510 | | |
503 | 511 | | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
| 521 | + | |
504 | 522 | | |
505 | 523 | | |
506 | 524 | | |
| |||
692 | 710 | | |
693 | 711 | | |
694 | 712 | | |
695 | | - | |
| 713 | + | |
696 | 714 | | |
697 | 715 | | |
698 | 716 | | |
| |||
1157 | 1175 | | |
1158 | 1176 | | |
1159 | 1177 | | |
1160 | | - | |
| 1178 | + | |
1161 | 1179 | | |
1162 | 1180 | | |
1163 | 1181 | | |
| |||
1171 | 1189 | | |
1172 | 1190 | | |
1173 | 1191 | | |
1174 | | - | |
1175 | | - | |
1176 | | - | |
1177 | | - | |
| 1192 | + | |
| 1193 | + | |
| 1194 | + | |
| 1195 | + | |
| 1196 | + | |
| 1197 | + | |
| 1198 | + | |
| 1199 | + | |
| 1200 | + | |
| 1201 | + | |
| 1202 | + | |
1178 | 1203 | | |
1179 | 1204 | | |
1180 | 1205 | | |
| |||
1295 | 1320 | | |
1296 | 1321 | | |
1297 | 1322 | | |
1298 | | - | |
| 1323 | + | |
1299 | 1324 | | |
1300 | | - | |
1301 | | - | |
| 1325 | + | |
| 1326 | + | |
1302 | 1327 | | |
| 1328 | + | |
| 1329 | + | |
1303 | 1330 | | |
1304 | 1331 | | |
1305 | | - | |
1306 | | - | |
1307 | | - | |
1308 | | - | |
| 1332 | + | |
| 1333 | + | |
| 1334 | + | |
| 1335 | + | |
| 1336 | + | |
| 1337 | + | |
| 1338 | + | |
| 1339 | + | |
| 1340 | + | |
| 1341 | + | |
| 1342 | + | |
| 1343 | + | |
| 1344 | + | |
1309 | 1345 | | |
1310 | | - | |
1311 | | - | |
1312 | | - | |
1313 | | - | |
1314 | | - | |
1315 | | - | |
1316 | | - | |
1317 | | - | |
1318 | | - | |
1319 | | - | |
1320 | | - | |
1321 | | - | |
1322 | | - | |
1323 | | - | |
1324 | | - | |
1325 | | - | |
1326 | | - | |
1327 | | - | |
1328 | | - | |
1329 | | - | |
1330 | | - | |
1331 | | - | |
1332 | | - | |
1333 | | - | |
1334 | | - | |
1335 | | - | |
1336 | | - | |
1337 | | - | |
1338 | | - | |
1339 | | - | |
1340 | | - | |
1341 | | - | |
1342 | | - | |
1343 | | - | |
1344 | | - | |
1345 | | - | |
| 1346 | + | |
| 1347 | + | |
| 1348 | + | |
| 1349 | + | |
| 1350 | + | |
| 1351 | + | |
| 1352 | + | |
| 1353 | + | |
| 1354 | + | |
| 1355 | + | |
| 1356 | + | |
| 1357 | + | |
| 1358 | + | |
| 1359 | + | |
| 1360 | + | |
| 1361 | + | |
| 1362 | + | |
| 1363 | + | |
| 1364 | + | |
| 1365 | + | |
| 1366 | + | |
| 1367 | + | |
| 1368 | + | |
| 1369 | + | |
| 1370 | + | |
| 1371 | + | |
| 1372 | + | |
| 1373 | + | |
| 1374 | + | |
| 1375 | + | |
1346 | 1376 | | |
1347 | 1377 | | |
1348 | 1378 | | |
1349 | 1379 | | |
1350 | | - | |
1351 | | - | |
1352 | | - | |
1353 | | - | |
| 1380 | + | |
| 1381 | + | |
| 1382 | + | |
| 1383 | + | |
| 1384 | + | |
| 1385 | + | |
| 1386 | + | |
| 1387 | + | |
1354 | 1388 | | |
1355 | 1389 | | |
1356 | 1390 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
8 | | - | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
9 | 16 | | |
10 | 17 | | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
11 | 22 | | |
12 | 23 | | |
13 | 24 | | |
| |||
114 | 125 | | |
115 | 126 | | |
116 | 127 | | |
117 | | - | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
118 | 136 | | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
119 | 148 | | |
| 149 | + | |
| 150 | + | |
120 | 151 | | |
121 | 152 | | |
122 | 153 | | |
| |||
125 | 156 | | |
126 | 157 | | |
127 | 158 | | |
128 | | - | |
129 | 159 | | |
130 | 160 | | |
131 | 161 | | |
| |||
135 | 165 | | |
136 | 166 | | |
137 | 167 | | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
138 | 191 | | |
139 | 192 | | |
140 | 193 | | |
| |||
0 commit comments