Skip to content

Commit e173cbf

Browse files
[TRTLLM-12669][perf] Reuse draft probs to drop redundant softmax + cut rejection-path overhead
This commit refactors the rejection-sampling draft path to compute the filtered + normalized prob distribution exactly once per draft step, and folds three independent optimizations into one PR-coherent change: 1. Single-pass compute_probs + sample on draft side _draft_sampler_advanced_for_rejection now calls a new sampling_batch_spec_dec_one_model_for_rejection which returns both the sampled token AND the probs in one go. The probs are scattered into the slot-indexed draft_probs buffer immediately, so the previous separate _compute_and_store_draft_probs path (which redundantly re-ran temperature + top_k + top_p + softmax on the cloned logits) is gone. 2. Faster compute_probs_from_logits via flashinfer fast path compute_probs_from_logits now composes flashinfer's radix-based O(N) kernels (top_k_mask_logits → fused softmax+temp → top_p_renorm_probs) when CUDA + flashinfer are available. The previous C++ op path triggered torch.sort fallback (O(N log N) per row) due to a hard-coded kMax=0, which severely under-utilized SMs at small batch sizes. C++ op and PyTorch CPU paths are retained as fallbacks. 3. Pre-allocated full_draft_probs buffer The (max_num_requests, max_draft_len, vocab_size) scratch used to pad draft probs to target vocab is now zero-filled once at prepare() and reused across iters, saving ~25 us/iter of 64 MB zero-fill. Only allocated when use_rejection_sampling=True. The eagle3 draft loop is simplified accordingly: it no longer accumulates a draft_logits_list or invokes _compute_and_store_draft_probs after the loop; per-step scatter happens inside _draft_sampler_advanced_for_rejection keyed on the (already-required) draft_step index. Net effect on llama70b bs=32 (T=0.7/top_k=50/top_p=0.9, MT-bench 2000): ΔTPS recovered from -32% (post-refactor with sort fallback) and -12% (pre-refactor with double softmax) to ~-5% (flashinfer fast path). The remaining gap is fundamental: llama70b's Eagle3 draft already tracks the target closely (AR uplift only +2%), so the inherent rejection sampling overhead (chain_speculative_sampling kernel + target_probs + d2t padding ≈ ~340 us/iter ≈ 1.5%) is not fully offset by the small AR gain. qwen8b/qwen235b with ΔAR +9%~+14% remain solidly net positive. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 485792d commit e173cbf

3 files changed

Lines changed: 173 additions & 76 deletions

File tree

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,6 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata,
661661
"""Original linear draft loop (1 token per layer)."""
662662
runtime_draft_len = spec_metadata.runtime_draft_len
663663
next_draft_tokens = []
664-
draft_logits_list = []
665664
position_ids = inputs["position_ids"]
666665

667666
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
@@ -714,11 +713,11 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata,
714713
d2t,
715714
draft_step=i)
716715

717-
if spec_metadata.use_rejection_sampling:
718-
draft_logits_list.append(logits.clone())
719-
720-
new_draft_token = self.draft_decoder(logits, draft_model,
721-
spec_metadata, batch_size)
716+
new_draft_token = self.draft_decoder(logits,
717+
draft_model,
718+
spec_metadata,
719+
batch_size,
720+
draft_step=i)
722721
next_draft_tokens.append(new_draft_token)
723722
# update inputs
724723
hidden_states = hidden_states_to_save[gather_ids]
@@ -759,19 +758,18 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata,
759758
gen_draft_tokens)
760759
next_draft_tokens[num_contexts:] = gen_draft_tokens
761760

762-
# Skip when the whole batch is greedy: _can_use_rejection_sampling will
763-
# bypass the rejection path anyway, so computing draft probs is wasted.
764-
if (spec_metadata.use_rejection_sampling and draft_logits_list
765-
and not spec_metadata.is_all_greedy_sample):
766-
d2t_param = getattr(draft_model.model, "d2t", None)
767-
spec_metadata.d2t = d2t_param.data if d2t_param is not None else None
768-
self._compute_and_store_draft_probs(draft_logits_list,
769-
spec_metadata, batch_size)
770-
elif spec_metadata.use_rejection_sampling:
771-
# No draft probs were written this iter (all-greedy or empty draft
772-
# loop). Invalidate the buffer so the next iter does not read stale
773-
# data if it transitions back to a non-greedy mix.
774-
spec_metadata.draft_probs_valid = False
761+
# Probs were already scattered into the slot-indexed buffer by
762+
# _draft_sampler_advanced_for_rejection on each draft step (non-greedy
763+
# batches only). All-greedy batches skip storage — rejection sampling
764+
# will be bypassed by _can_use_rejection_sampling. Finalize the validity
765+
# flag and d2t for next-iter target-side verification.
766+
if spec_metadata.use_rejection_sampling:
767+
if not spec_metadata.is_all_greedy_sample:
768+
d2t_param = getattr(draft_model.model, "d2t", None)
769+
spec_metadata.d2t = d2t_param.data if d2t_param is not None else None
770+
spec_metadata.draft_probs_valid = True
771+
else:
772+
spec_metadata.draft_probs_valid = False
775773

776774
return next_draft_tokens
777775

@@ -802,22 +800,34 @@ def draft_decoder(
802800
draft_model: nn.Module,
803801
spec_metadata: Optional[Eagle3OneModelSpecMetadata] = None,
804802
batch_size: Optional[int] = None,
803+
draft_step: Optional[int] = None,
805804
):
806805
'''
807806
Sample draft tokens. When spec_metadata + batch_size are provided, use
808807
the target's per-request sampling params (temperature/top_k/top_p);
809808
otherwise fall back to argmax.
810809
810+
When rejection sampling is enabled and draft_step is provided, take the
811+
single-pass path that also scatters the draft prob distribution into the
812+
slot-indexed buffer (avoids a redundant softmax later).
813+
811814
Args:
812815
logits: [batch_size, vocab_size] - Draft model logits.
813816
draft_model: The draft model.
814817
spec_metadata: Carries per-request sampling param tensors. When
815818
None, sampling is forced greedy.
816819
batch_size: Active requests, used to slice per-request tensors.
820+
draft_step: Current draft step index (0..max_draft_len-1). Required
821+
for the rejection-sampling code path so probs are written to
822+
the correct slice of spec_metadata.draft_probs.
817823
'''
818824

819825
d2t = getattr(draft_model.model, "d2t", None)
820826
if spec_metadata is not None and batch_size is not None:
827+
if (spec_metadata.use_rejection_sampling and draft_step is not None
828+
and not spec_metadata.is_all_greedy_sample):
829+
return self._draft_sampler_advanced_for_rejection(
830+
logits, spec_metadata, batch_size, d2t, draft_step)
821831
return self._draft_sampler_advanced(logits, spec_metadata,
822832
batch_size, d2t)
823833
return self._draft_sampler_greedy(logits, d2t)

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 88 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
from .one_model_sampler import (compute_probs_from_logits,
2929
rejection_sampling_one_model,
30-
sampling_batch_spec_dec_one_model)
30+
sampling_batch_spec_dec_one_model,
31+
sampling_batch_spec_dec_one_model_for_rejection)
3132

3233
# Environment variable name for forcing the number of accepted tokens in speculative decoding
3334
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
@@ -479,6 +480,13 @@ class SpecMetadata:
479480
batch_slot_ids: Optional[torch.Tensor] = None
480481
# Draft-to-target vocab offset tensor.
481482
d2t: Optional[torch.Tensor] = None
483+
# Pre-allocated scratch for draft probs expanded to the target vocab size.
484+
# Filled with zeros once at prepare(); each rejection iter only overwrites
485+
# the positions selected by d2t (or [:draft_vocab] when there is no d2t),
486+
# so the zeros outside those positions persist across iterations and we
487+
# avoid a per-iter 64 MB zero-fill on the (max_num_requests, max_draft_len,
488+
# vocab_size) tensor. Shape: [max_num_requests, max_draft_len, vocab_size].
489+
full_draft_probs: Optional[torch.Tensor] = None
482490

483491
def __post_init__(self):
484492
pass
@@ -501,6 +509,16 @@ def prepare(self):
501509
self.batch_slot_ids = torch.empty((self.max_num_requests, ),
502510
dtype=torch.long,
503511
device='cuda')
512+
if (self.use_rejection_sampling and self.full_draft_probs is None
513+
and self.vocab_size > 0):
514+
# Zero-fill once. Subsequent iters only overwrite the d2t-mapped
515+
# positions (constant across iters since d2t is model-static), so
516+
# untouched positions stay 0 forever — saves the per-iter 64 MB
517+
# zero-fill in _sample_and_accept_draft_tokens_rejection.
518+
self.full_draft_probs = torch.zeros(
519+
(self.max_num_requests, self.max_draft_len, self.vocab_size),
520+
dtype=torch.float32,
521+
device='cuda')
504522

505523
def create_cuda_graph_metadata(self, max_batch_size: int):
506524
"""
@@ -692,7 +710,7 @@ def _normalize_request_sampling_params(
692710

693711
# Always-populate the per-request slot id table when rejection sampling
694712
# is configured: it's tiny (max_num_requests longs) and needed at
695-
# _compute_and_store_draft_probs time to scatter draft probs by slot.
713+
# draft-sampler time to scatter draft probs by slot.
696714
if self.use_rejection_sampling and self.batch_slot_ids is not None:
697715
self.batch_slot_ids[:len(per_request_slot_ids)].copy_(
698716
torch.tensor(per_request_slot_ids,
@@ -1157,7 +1175,7 @@ def _sample_and_accept_draft_tokens_rejection(
11571175
spec_metadata.top_ps[gen_start:gen_end])
11581176

11591177
target_probs_flat = compute_probs_from_logits(
1160-
gen_logits.clone(), temperatures, top_ks, top_ps)
1178+
gen_logits, temperatures, top_ks, top_ps)
11611179
target_probs = target_probs_flat.reshape(num_gens,
11621180
runtime_draft_len + 1,
11631181
vocab_size)
@@ -1171,10 +1189,17 @@ def _sample_and_accept_draft_tokens_rejection(
11711189
f"{runtime_draft_len}")
11721190
d2t = getattr(spec_metadata, "d2t", None)
11731191
if draft_vocab_size != vocab_size:
1174-
full_draft_probs = torch.zeros(
1175-
(num_gens, runtime_draft_len, vocab_size),
1176-
dtype=torch.float32,
1177-
device=device)
1192+
# Use the pre-allocated buffer from spec_metadata.prepare()
1193+
# (zero-filled once at init; untouched positions stay 0). Falls
1194+
# back to per-iter allocation if the buffer is not configured,
1195+
# e.g. when use_rejection_sampling was off at prepare() time.
1196+
if spec_metadata.full_draft_probs is not None:
1197+
full_draft_probs = spec_metadata.full_draft_probs[:num_gens]
1198+
else:
1199+
full_draft_probs = torch.zeros(
1200+
(num_gens, runtime_draft_len, vocab_size),
1201+
dtype=torch.float32,
1202+
device=device)
11781203
if d2t is not None:
11791204
assert d2t.numel() == draft_vocab_size, (
11801205
f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}"
@@ -1295,62 +1320,71 @@ def _draft_sampler_advanced(
12951320

12961321
return draft_tokens.type(torch.int32)
12971322

1298-
def _compute_and_store_draft_probs(
1323+
def _draft_sampler_advanced_for_rejection(
12991324
self,
1300-
draft_logits_list: List[torch.Tensor],
1301-
spec_metadata: SpecMetadata,
1325+
logits: torch.Tensor,
1326+
spec_metadata: "SpecMetadata",
13021327
batch_size: int,
1328+
d2t: Optional[torch.Tensor] = None,
1329+
draft_step: int = 0,
13031330
):
13041331
"""
1305-
Compute draft probabilities and store them for next-step rejection
1306-
sampling. The storage is keyed by py_seq_slot, so the data is robust
1307-
to batch composition shifts across iterations (chunking ctxs, gen
1308-
completion, new ctxs joining).
1332+
Rejection-sampling-aware variant of ``_draft_sampler_advanced``.
1333+
1334+
Single-pass compute + sample + scatter: computes the per-request prob
1335+
distribution once via TRT-LLM's fused ``compute_probs_from_logits``
1336+
(temp + top_k + top_p + softmax + greedy override in one CUDA kernel),
1337+
samples the draft token from that distribution, and scatters the same
1338+
probs into the slot-indexed ``spec_metadata.draft_probs`` buffer for
1339+
next-iter rejection verification. Replaces the previous two-stage path
1340+
(flashinfer fused sampling kernel + a redundant softmax pass to store
1341+
probs).
1342+
1343+
All-greedy batches take the cheaper argmax path —
1344+
``_can_use_rejection_sampling`` will bypass rejection for those anyway.
13091345
"""
1310-
draft_tokens_per_request = len(draft_logits_list)
1311-
vocab_size = draft_logits_list[0].shape[-1]
1312-
device = draft_logits_list[0].device
1313-
1314-
draft_logits = torch.stack(draft_logits_list, dim=0)
1315-
draft_logits_flat = draft_logits.transpose(0, 1).reshape(-1, vocab_size)
1316-
1317-
num_draft_tokens = batch_size * draft_tokens_per_request
1318-
if spec_metadata.request_temperatures is not None:
1319-
draft_temps = spec_metadata.request_temperatures[:batch_size].repeat_interleave(
1320-
draft_tokens_per_request)
1321-
draft_top_ks = (
1322-
spec_metadata.request_top_ks[:batch_size].repeat_interleave(
1323-
draft_tokens_per_request) if not spec_metadata.skip_top_k
1324-
and spec_metadata.request_top_ks is not None else None)
1325-
draft_top_ps = (
1326-
spec_metadata.request_top_ps[:batch_size].repeat_interleave(
1327-
draft_tokens_per_request) if not spec_metadata.skip_top_p
1328-
and spec_metadata.request_top_ps is not None else None)
1329-
else:
1330-
draft_temps = torch.ones(num_draft_tokens, device=device)
1331-
draft_top_ks = None
1332-
draft_top_ps = None
1333-
1334-
draft_probs_flat = compute_probs_from_logits(draft_logits_flat,
1335-
draft_temps, draft_top_ks,
1336-
draft_top_ps)
1337-
# [batch_size, draft_len, draft_vocab]
1338-
draft_probs_per_request = draft_probs_flat.reshape(
1339-
batch_size, draft_tokens_per_request, vocab_size)
1340-
1341-
# Scatter into draft_probs[slot] for each request in the current batch.
1342-
# spec_metadata.draft_probs is shaped [max_num_requests, max_draft_len,
1343-
# vocab_size]. Different iterations may have different batch
1344-
# compositions, but a given request's data always lives at its
1345-
# py_seq_slot row, so reads at the next iter pick up the right data.
1346+
if spec_metadata.is_all_greedy_sample:
1347+
return self._draft_sampler_greedy(logits, d2t)
1348+
1349+
temperatures = spec_metadata.request_temperatures[:batch_size]
1350+
top_ks = spec_metadata.request_top_ks[:batch_size]
1351+
top_ps = spec_metadata.request_top_ps[:batch_size]
1352+
1353+
if self.seed is None:
1354+
self.seed = torch.tensor([0],
1355+
dtype=torch.int64,
1356+
device=logits.device)
1357+
self.offset = torch.tensor([0],
1358+
dtype=torch.int64,
1359+
device=logits.device)
1360+
self.seed += 1
1361+
self.seed %= (2**31)
1362+
1363+
draft_tokens, probs = sampling_batch_spec_dec_one_model_for_rejection(
1364+
logits,
1365+
temperatures,
1366+
top_ks,
1367+
top_ps,
1368+
seed=self.seed,
1369+
offset=self.offset,
1370+
)
1371+
1372+
# Scatter probs into the slot-indexed buffer (shaped
1373+
# [max_num_requests, max_draft_len, vocab_size]). Each request's data
1374+
# always lands at its stable py_seq_slot row regardless of batch
1375+
# composition shifts across iterations.
13461376
assert spec_metadata.batch_slot_ids is not None, (
13471377
"batch_slot_ids must be populated by "
13481378
"populate_sampling_params_for_one_model before draft probs storage")
13491379
batch_slots = spec_metadata.batch_slot_ids[:batch_size]
1350-
spec_metadata.draft_probs[batch_slots, :draft_tokens_per_request, :
1351-
vocab_size] = draft_probs_per_request
1352-
spec_metadata.draft_probs_last_dim = vocab_size
1353-
spec_metadata.draft_probs_valid = True
1380+
vocab = probs.shape[-1]
1381+
spec_metadata.draft_probs[batch_slots, draft_step, :vocab] = probs
1382+
spec_metadata.draft_probs_last_dim = vocab
1383+
1384+
if d2t is not None:
1385+
draft_tokens = d2t[draft_tokens] + draft_tokens
1386+
1387+
return draft_tokens.type(torch.int32)
13541388

13551389
def _execute_guided_decoder_if_present(self, logits):
13561390
"""Execute guided decoder on target model logits if available."""

tensorrt_llm/_torch/speculative/one_model_sampler.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,20 @@
55
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
66

77
if IS_FLASHINFER_AVAILABLE:
8-
from flashinfer.sampling import chain_speculative_sampling, top_k_top_p_sampling_from_logits
8+
from flashinfer.sampling import (
9+
chain_speculative_sampling,
10+
sampling_from_probs,
11+
top_k_mask_logits,
12+
top_k_top_p_sampling_from_logits,
13+
top_p_renorm_probs,
14+
)
15+
from flashinfer.sampling import softmax as flashinfer_softmax
916
else:
1017
chain_speculative_sampling = None
18+
sampling_from_probs = None
19+
flashinfer_softmax = None
20+
top_k_mask_logits = None
21+
top_p_renorm_probs = None
1122
top_k_top_p_sampling_from_logits = None
1223

1324

@@ -114,9 +125,29 @@ def compute_probs_from_logits(
114125
skip_temperature: bool = False,
115126
) -> torch.Tensor:
116127
"""
117-
Compute probabilities from logits with temperature, top-k, and top-p applied.
128+
Compute filtered + normalized probs from logits (temperature + top_k +
129+
top_p + softmax). Picks the fastest path for the input device:
130+
131+
1. CUDA + flashinfer: ``top_k_mask_logits`` → fused ``softmax+temp`` →
132+
``top_p_renorm_probs`` (all O(N) radix). ``skip_temperature`` ignored.
133+
2. CUDA, no flashinfer: ``compute_probs_from_logits_op`` (sort-based,
134+
O(N log N)).
135+
3. CPU: manual PyTorch fallback.
118136
"""
137+
if logits.is_cuda and IS_FLASHINFER_AVAILABLE:
138+
# Fast path: flashinfer composition (O(N) per row, friendly to small
139+
# batch sizes). skip_temperature is ignored — flashinfer's softmax
140+
# always applies the temperature tensor.
141+
if top_k is not None:
142+
logits = top_k_mask_logits(logits, top_k)
143+
probs = flashinfer_softmax(logits, temperatures)
144+
if top_p is not None:
145+
probs = top_p_renorm_probs(probs, top_p)
146+
return probs
147+
119148
if logits.is_cuda:
149+
# CUDA without flashinfer: fall back to the C++ op (slower sort-based
150+
# top-k path, but works without flashinfer).
120151
return torch.ops.trtllm.compute_probs_from_logits_op(
121152
logits, temperatures, top_k, top_p, skip_temperature
122153
)
@@ -125,7 +156,6 @@ def compute_probs_from_logits(
125156
logits = apply_temperature(logits, temperatures)
126157
logits = apply_top_k_top_p(logits, top_k, top_p)
127158
probs = logits.softmax(dim=-1, dtype=torch.float32)
128-
129159
# Greedy rows should remain exactly one-hot so rejection sampling does not
130160
# spuriously reject numerically-near argmax tokens.
131161
greedy_temp_threshold = 1e-4
@@ -135,6 +165,29 @@ def compute_probs_from_logits(
135165
return torch.where(is_greedy.unsqueeze(1), one_hot, probs)
136166

137167

168+
def sampling_batch_spec_dec_one_model_for_rejection(
169+
logits: torch.Tensor,
170+
temperatures: torch.Tensor,
171+
top_k: torch.Tensor,
172+
top_p: torch.Tensor,
173+
seed: Optional[torch.Tensor] = None,
174+
offset: Optional[torch.Tensor] = None,
175+
) -> Tuple[torch.Tensor, torch.Tensor]:
176+
"""
177+
Rejection-sampling-aware draft sampler: returns BOTH the sampled tokens
178+
AND the prob distribution they were sampled from, so the downstream
179+
rejection-sampling path can reuse the probs without a second softmax +
180+
temp/top_k/top_p pass.
181+
"""
182+
if sampling_from_probs is None:
183+
raise RuntimeError(
184+
"Rejection sampling for one-model speculative decoding requires flashinfer"
185+
)
186+
probs = compute_probs_from_logits(logits, temperatures, top_k, top_p)
187+
tokens = sampling_from_probs(probs, deterministic=True, seed=seed, offset=offset)
188+
return tokens, probs
189+
190+
138191
def rejection_sampling_one_model(
139192
draft_probs: torch.Tensor,
140193
draft_token_ids: torch.Tensor,

0 commit comments

Comments
 (0)