Skip to content

Commit df6be84

Browse files
[TRTLLM-12669][feat] Slot-index draft_probs and support mixed-batch rejection
The rejection-sampling acceptance path used to gate on `num_contexts == 0`, so any mixed batch (chunked prefill + decode, ramp-up step with new ctx joining, etc.) silently fell back to exact-match verification. The underlying reason was that `spec_metadata.draft_probs` was a flat buffer indexed by batch position at write time: when batch composition shifted across iterations (chunking ctxs polluting the prefix, gen completions leaving holes, new ctxs inserted in the ctx region), the row at buffer index `i` no longer reliably mapped to the request now at batch position `i`. Refactor `draft_probs` to be slot-indexed, matching the convention `next_draft_tokens` already uses on `SampleStateSpec.store`: - `SpecMetadata.draft_probs` is reshaped to `[max_num_requests, max_draft_len, vocab_size]`, addressed by `py_seq_slot`. - A new `SpecMetadata.batch_slot_ids` device tensor carries the current batch's slot ids in batch order; it is populated alongside the other per-request sampling-param tensors in `populate_sampling_params_for_one_model` and is always refreshed when rejection sampling is configured, even for all-greedy batches (it is tiny relative to the per-token buffers we skip). - `_compute_and_store_draft_probs` scatters into `draft_probs[batch_slot_ids, ...]` so each request's data lands at its own stable slot row. - `_accept_draft_tokens` gathers `draft_probs[gen_slot_ids, ...]` for the gen subset of the current batch, so it always reads the per-request data written in the most recent iter that ran the draft loop for that slot. - `_sample_and_accept_draft_tokens_rejection` now accepts `num_contexts` and splits ctx / gen subsets: ctx rows go through `_sample_tokens_for_batch` (no draft tokens to verify); gen rows feed the slot-gathered draft probs into the unchanged `rejection_sampling_one_model` kernel. - `_can_use_rejection_sampling` drops the `num_contexts == 0` constraint now that mixed batches are handled correctly via slot indexing. - When an iter skips the draft-probs store (all-greedy or empty draft loop), `draft_probs_valid` is reset to False so the following iter cannot read stale data if it transitions back to a non-greedy mix. The all-gen case is byte-equivalent to the prior implementation: gen subset gathering by slot id collapses to the same data and the per-token buffer slicing for gen rows is unchanged. CUDA graph capture is unaffected: `scheduler.can_run_cuda_graph` already requires `num_context_requests == 0`, so the new slicing code paths exercise exclusively in eager mode. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent d599fb5 commit df6be84

2 files changed

Lines changed: 187 additions & 86 deletions

File tree

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,11 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata,
767767
spec_metadata.d2t = d2t_param.data if d2t_param is not None else None
768768
self._compute_and_store_draft_probs(draft_logits_list,
769769
spec_metadata, batch_size)
770+
elif spec_metadata.use_rejection_sampling:
771+
# No draft probs were written this iter (all-greedy or empty draft
772+
# loop). Invalidate the buffer so the next iter does not read stale
773+
# data if it transitions back to a non-greedy mix.
774+
spec_metadata.draft_probs_valid = False
770775

771776
return next_draft_tokens
772777

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 182 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,20 @@ class SpecMetadata:
463463
use_sampling_params_for_draft_tokens: bool = False
464464
# Vocab size used for draft_probs buffer allocation.
465465
vocab_size: int = 0
466-
# Draft probabilities buffer for rejection sampling, stored flat.
466+
# Draft probabilities buffer for rejection sampling, indexed by py_seq_slot
467+
# so per-request data is stable across iterations regardless of batch
468+
# composition shifts (chunking ctx, gen completion, new ctx joining).
469+
# Shape: [max_num_requests, max_draft_len, vocab_size].
467470
draft_probs: Optional[torch.Tensor] = None
468471
draft_probs_vocab_size: int = 0
469472
# Whether draft_probs contains valid data.
470473
draft_probs_valid: bool = False
471474
# Last dimension size of the draft logits/probs stored in draft_probs.
472475
draft_probs_last_dim: int = 0
476+
# Per-request slot ids (py_seq_slot) for the current batch, in batch order.
477+
# Used to scatter draft probs by slot at write time and gather them by slot
478+
# at the next iter's verify. Shape: [max_num_requests], dtype=long.
479+
batch_slot_ids: Optional[torch.Tensor] = None
473480
# Draft-to-target vocab offset tensor.
474481
d2t: Optional[torch.Tensor] = None
475482

@@ -482,12 +489,18 @@ def prepare(self):
482489
"""
483490
if (self.use_rejection_sampling and self.draft_probs is None
484491
and self.vocab_size > 0):
485-
buffer_size = (self.max_num_requests * self.max_draft_len *
486-
self.vocab_size)
487-
self.draft_probs = torch.empty(buffer_size,
488-
dtype=torch.float32,
489-
device='cuda')
492+
# 3D [slot, draft_step, vocab] so we can scatter/gather by slot id
493+
# and avoid the brittle "batch position == buffer position" mapping.
494+
self.draft_probs = torch.empty(
495+
(self.max_num_requests, self.max_draft_len, self.vocab_size),
496+
dtype=torch.float32,
497+
device='cuda')
490498
self.draft_probs_vocab_size = self.vocab_size
499+
if (self.use_rejection_sampling and self.batch_slot_ids is None
500+
and self.max_num_requests > 0):
501+
self.batch_slot_ids = torch.empty((self.max_num_requests, ),
502+
dtype=torch.long,
503+
device='cuda')
491504

492505
def create_cuda_graph_metadata(self, max_batch_size: int):
493506
"""
@@ -587,6 +600,7 @@ def _normalize_request_sampling_params(
587600
top_k_enabled = False
588601
top_p_enabled = False
589602
has_greedy_requests = False
603+
per_request_slot_ids: list[int] = []
590604

591605
for request in requests:
592606
sampling_config = request.sampling_config
@@ -618,6 +632,12 @@ def _normalize_request_sampling_params(
618632

619633
per_request_normalized.append(
620634
(temp_val, tk_val, tp_val, num_tokens))
635+
# py_seq_slot is a stable per-request id used to scatter / gather
636+
# draft probs across iterations. Dummies / unallocated slots fall
637+
# back to 0 (any valid index is fine — the data at that slot will
638+
# be overwritten on the next real iteration before being read).
639+
per_request_slot_ids.append(
640+
request.py_seq_slot if request.py_seq_slot is not None else 0)
621641

622642
self.skip_temperature = not temperature_enabled
623643
self.skip_top_k = not top_k_enabled
@@ -653,9 +673,20 @@ def _normalize_request_sampling_params(
653673
dtype=torch.float32,
654674
device='cuda')
655675

676+
# Always-populate the per-request slot id table when rejection sampling
677+
# is configured: it's tiny (max_num_requests longs) and needed at
678+
# _compute_and_store_draft_probs time to scatter draft probs by slot.
679+
if self.use_rejection_sampling and self.batch_slot_ids is not None:
680+
self.batch_slot_ids[:len(per_request_slot_ids)].copy_(
681+
torch.tensor(per_request_slot_ids,
682+
dtype=torch.long,
683+
pin_memory=prefer_pinned()),
684+
non_blocking=True,
685+
)
686+
656687
# All-greedy: sampler takes the argmax branch (and rejection sampling
657-
# is also bypassed for all-greedy), so the buffers are never read.
658-
# Skip the H->D copies.
688+
# is also bypassed for all-greedy), so the per-token buffers are never
689+
# read. Skip the heavier H->D copies.
659690
if self.is_all_greedy_sample:
660691
return
661692

@@ -1014,112 +1045,162 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts,
10141045
batch_size, spec_metadata):
10151046
"""
10161047
Accept draft tokens with optional rejection sampling support.
1048+
1049+
Mixed batches (num_contexts > 0) are supported: context rows take the
1050+
first sampled target token via the base logic, and rejection sampling
1051+
runs on the gen subset. Draft probs for the gen subset are gathered
1052+
from the slot-indexed buffer by `py_seq_slot`.
10171053
"""
1018-
if self._can_use_rejection_sampling(spec_metadata, num_contexts):
1054+
num_gens = batch_size - num_contexts
1055+
if num_gens > 0 and self._can_use_rejection_sampling(spec_metadata):
10191056
draft_len = draft_tokens.shape[1]
10201057
stored_vocab = (spec_metadata.draft_probs_last_dim
10211058
if spec_metadata.draft_probs_last_dim > 0 else
10221059
spec_metadata.draft_probs_vocab_size)
1023-
draft_probs = spec_metadata.draft_probs[:batch_size * draft_len *
1024-
stored_vocab].reshape(
1025-
batch_size, draft_len,
1026-
stored_vocab)
1060+
# Gather the slot rows for the gen subset. The buffer was filled
1061+
# at the previous draft step indexed by py_seq_slot, so each gen
1062+
# request reads back exactly its own probs, regardless of batch
1063+
# composition changes since then.
1064+
gen_slot_ids = spec_metadata.batch_slot_ids[num_contexts:batch_size]
1065+
draft_probs = spec_metadata.draft_probs[
1066+
gen_slot_ids, :draft_len, :stored_vocab]
10271067
return self._sample_and_accept_draft_tokens_rejection(
1028-
logits, draft_tokens, draft_probs, batch_size, spec_metadata)
1068+
logits, draft_tokens, draft_probs, num_contexts, batch_size,
1069+
spec_metadata)
10291070
return self._sample_and_accept_draft_tokens_base(
10301071
logits, draft_tokens, num_contexts, batch_size, spec_metadata)
10311072

1032-
def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata,
1033-
num_contexts: int) -> bool:
1034-
# Skip rejection sampling when the whole batch is greedy: the
1035-
# accepted result is identical to argmax and the base path is cheaper.
1073+
def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata) -> bool:
1074+
# Skip rejection sampling when the whole batch is greedy: the accepted
1075+
# result is identical to argmax and the base path is cheaper. Mixed
1076+
# batches (context + gen) are handled via slot-indexed draft probs and
1077+
# are split inside _sample_and_accept_draft_tokens_rejection.
10361078
return (spec_metadata.use_rejection_sampling
1037-
and spec_metadata.draft_probs_valid and num_contexts == 0
1079+
and spec_metadata.draft_probs_valid
10381080
and not spec_metadata.is_all_greedy_sample)
10391081

10401082
def _sample_and_accept_draft_tokens_rejection(
10411083
self,
10421084
logits: torch.Tensor,
10431085
draft_tokens: torch.Tensor,
10441086
draft_probs: torch.Tensor,
1087+
num_contexts: int,
10451088
batch_size: int,
10461089
spec_metadata,
10471090
):
10481091
"""
10491092
Rejection-sampling acceptance for one-model speculative decoding.
1093+
1094+
Mixed batches are handled by treating the two subsets separately:
1095+
- context rows (first `num_contexts`) take the target's sampled first
1096+
token; no draft tokens to verify.
1097+
- generation rows (`[num_contexts:batch_size]`) run the rejection
1098+
sampling kernel on slot-gathered draft probs.
1099+
1100+
Per-token sampling-parameter tensors (`temperatures / top_ks / top_ps`)
1101+
are laid out as `[ctx (1 each), gen (draft_len+1 each)]`, matching the
1102+
logits layout, so slicing is symmetric for both subsets.
10501103
"""
10511104
device = logits.device
10521105
vocab_size = logits.shape[-1]
1106+
num_gens = batch_size - num_contexts
1107+
runtime_draft_len = draft_tokens.shape[1]
10531108

10541109
if logits.dim() == 1:
10551110
logits = logits.unsqueeze(0)
10561111

1057-
runtime_draft_len = draft_tokens.shape[1]
1058-
draft_vocab_size = draft_probs.shape[-1]
1059-
num_target_tokens = batch_size * (runtime_draft_len + 1)
1060-
1061-
temperatures = spec_metadata.temperatures[:num_target_tokens]
1062-
# Pass None instead of an all-disabled tensor so the C++ op can short-circuit
1063-
# on a host-side check rather than a `.item<bool>()` sync, which would break
1064-
# CUDA graph capture.
1065-
top_ks = None if spec_metadata.skip_top_k else spec_metadata.top_ks[:
1066-
num_target_tokens]
1067-
top_ps = None if spec_metadata.skip_top_p else spec_metadata.top_ps[:
1068-
num_target_tokens]
1069-
1070-
target_probs_flat = compute_probs_from_logits(logits.clone(),
1071-
temperatures, top_ks,
1072-
top_ps)
1073-
target_probs = target_probs_flat.reshape(batch_size,
1074-
runtime_draft_len + 1,
1075-
vocab_size)
1076-
1077-
assert draft_probs.shape[1] == runtime_draft_len, (
1078-
f"draft_probs draft length mismatch: {draft_probs.shape[1]} != "
1079-
f"{runtime_draft_len}")
1080-
d2t = getattr(spec_metadata, "d2t", None)
1081-
if draft_vocab_size != vocab_size:
1082-
full_draft_probs = torch.zeros(
1083-
(batch_size, runtime_draft_len, vocab_size),
1084-
dtype=torch.float32,
1085-
device=device)
1086-
if d2t is not None:
1087-
assert d2t.numel() == draft_vocab_size, (
1088-
f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}")
1089-
d2t = d2t.to(device=device)
1090-
source_indices = torch.arange(draft_vocab_size,
1091-
device=device,
1092-
dtype=torch.long)
1093-
target_indices = (source_indices + d2t) % vocab_size
1094-
full_draft_probs[:, :runtime_draft_len,
1095-
target_indices] = draft_probs
1112+
accepted_tokens = torch.empty((batch_size, runtime_draft_len + 1),
1113+
dtype=torch.int,
1114+
device=device)
1115+
num_accepted_tokens = torch.ones(batch_size,
1116+
dtype=torch.int,
1117+
device=device)
1118+
1119+
# === Context subset: sample target's first token directly ===
1120+
if num_contexts > 0:
1121+
ctx_target_tokens = self._sample_tokens_for_batch(
1122+
logits[:num_contexts], spec_metadata, num_contexts,
1123+
num_contexts)
1124+
accepted_tokens[:num_contexts, 0] = ctx_target_tokens
1125+
1126+
# === Generation subset: rejection sampling on the gen slice ===
1127+
if num_gens > 0:
1128+
num_gen_logits = num_gens * (runtime_draft_len + 1)
1129+
gen_logits = logits[num_contexts:num_contexts + num_gen_logits]
1130+
gen_start = num_contexts
1131+
gen_end = num_contexts + num_gen_logits
1132+
1133+
temperatures = spec_metadata.temperatures[gen_start:gen_end]
1134+
# Pass None instead of an all-disabled tensor so the C++ op can short-circuit
1135+
# on a host-side check rather than a `.item<bool>()` sync, which would break
1136+
# CUDA graph capture.
1137+
top_ks = (None if spec_metadata.skip_top_k else
1138+
spec_metadata.top_ks[gen_start:gen_end])
1139+
top_ps = (None if spec_metadata.skip_top_p else
1140+
spec_metadata.top_ps[gen_start:gen_end])
1141+
1142+
target_probs_flat = compute_probs_from_logits(
1143+
gen_logits.clone(), temperatures, top_ks, top_ps)
1144+
target_probs = target_probs_flat.reshape(num_gens,
1145+
runtime_draft_len + 1,
1146+
vocab_size)
1147+
1148+
draft_vocab_size = draft_probs.shape[-1]
1149+
assert draft_probs.shape[0] == num_gens, (
1150+
f"draft_probs batch mismatch: {draft_probs.shape[0]} != "
1151+
f"num_gens={num_gens}")
1152+
assert draft_probs.shape[1] == runtime_draft_len, (
1153+
f"draft_probs draft length mismatch: {draft_probs.shape[1]} != "
1154+
f"{runtime_draft_len}")
1155+
d2t = getattr(spec_metadata, "d2t", None)
1156+
if draft_vocab_size != vocab_size:
1157+
full_draft_probs = torch.zeros(
1158+
(num_gens, runtime_draft_len, vocab_size),
1159+
dtype=torch.float32,
1160+
device=device)
1161+
if d2t is not None:
1162+
assert d2t.numel() == draft_vocab_size, (
1163+
f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}"
1164+
)
1165+
d2t = d2t.to(device=device)
1166+
source_indices = torch.arange(draft_vocab_size,
1167+
device=device,
1168+
dtype=torch.long)
1169+
target_indices = (source_indices + d2t) % vocab_size
1170+
full_draft_probs[:, :runtime_draft_len,
1171+
target_indices] = draft_probs
1172+
else:
1173+
assert draft_vocab_size < vocab_size
1174+
full_draft_probs[:, :runtime_draft_len, :
1175+
draft_vocab_size] = (draft_probs)
10961176
else:
1097-
assert draft_vocab_size < vocab_size
1098-
full_draft_probs[:, :runtime_draft_len, :draft_vocab_size] = (
1099-
draft_probs)
1100-
else:
1101-
full_draft_probs = draft_probs
1102-
1103-
full_draft_tokens = draft_tokens.to(torch.int32).contiguous()
1104-
1105-
if self.seed is None:
1106-
self.seed = torch.tensor([0], dtype=torch.int64, device=device)
1107-
if self.offset is None:
1108-
self.offset = torch.tensor([0], dtype=torch.int64, device=device)
1109-
self.seed += 1
1110-
self.seed %= 2**31
1111-
1112-
accepted_tokens, num_accepted_tokens = rejection_sampling_one_model(
1113-
draft_probs=full_draft_probs,
1114-
draft_token_ids=full_draft_tokens,
1115-
target_probs=target_probs,
1116-
deterministic=True,
1117-
seed=self.seed,
1118-
offset=self.offset,
1119-
)
1177+
full_draft_probs = draft_probs
1178+
1179+
full_draft_tokens = draft_tokens.to(torch.int32).contiguous()
1180+
1181+
if self.seed is None:
1182+
self.seed = torch.tensor([0], dtype=torch.int64, device=device)
1183+
if self.offset is None:
1184+
self.offset = torch.tensor([0],
1185+
dtype=torch.int64,
1186+
device=device)
1187+
self.seed += 1
1188+
self.seed %= 2**31
1189+
1190+
gen_accepted, gen_num_accepted = rejection_sampling_one_model(
1191+
draft_probs=full_draft_probs,
1192+
draft_token_ids=full_draft_tokens,
1193+
target_probs=target_probs,
1194+
deterministic=True,
1195+
seed=self.seed,
1196+
offset=self.offset,
1197+
)
1198+
1199+
accepted_tokens[num_contexts:] = gen_accepted
1200+
num_accepted_tokens[num_contexts:] = gen_num_accepted
11201201

11211202
num_accepted_tokens = self._apply_force_accepted_tokens(
1122-
num_accepted_tokens, 0, draft_tokens.shape[1])
1203+
num_accepted_tokens, num_contexts, runtime_draft_len)
11231204
return accepted_tokens, num_accepted_tokens
11241205

11251206
def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None):
@@ -1204,7 +1285,10 @@ def _compute_and_store_draft_probs(
12041285
batch_size: int,
12051286
):
12061287
"""
1207-
Compute draft probabilities and store them for next-step rejection sampling.
1288+
Compute draft probabilities and store them for next-step rejection
1289+
sampling. The storage is keyed by py_seq_slot, so the data is robust
1290+
to batch composition shifts across iterations (chunking ctxs, gen
1291+
completion, new ctxs joining).
12081292
"""
12091293
draft_tokens_per_request = len(draft_logits_list)
12101294
vocab_size = draft_logits_list[0].shape[-1]
@@ -1233,9 +1317,21 @@ def _compute_and_store_draft_probs(
12331317
draft_probs_flat = compute_probs_from_logits(draft_logits_flat,
12341318
draft_temps, draft_top_ks,
12351319
draft_top_ps)
1236-
num_elements = batch_size * draft_tokens_per_request * vocab_size
1237-
spec_metadata.draft_probs[:num_elements].copy_(
1238-
draft_probs_flat.flatten())
1320+
# [batch_size, draft_len, draft_vocab]
1321+
draft_probs_per_request = draft_probs_flat.reshape(
1322+
batch_size, draft_tokens_per_request, vocab_size)
1323+
1324+
# Scatter into draft_probs[slot] for each request in the current batch.
1325+
# spec_metadata.draft_probs is shaped [max_num_requests, max_draft_len,
1326+
# vocab_size]. Different iterations may have different batch
1327+
# compositions, but a given request's data always lives at its
1328+
# py_seq_slot row, so reads at the next iter pick up the right data.
1329+
assert spec_metadata.batch_slot_ids is not None, (
1330+
"batch_slot_ids must be populated by "
1331+
"populate_sampling_params_for_one_model before draft probs storage")
1332+
batch_slots = spec_metadata.batch_slot_ids[:batch_size]
1333+
spec_metadata.draft_probs[batch_slots, :draft_tokens_per_request, :
1334+
vocab_size] = draft_probs_per_request
12391335
spec_metadata.draft_probs_last_dim = vocab_size
12401336
spec_metadata.draft_probs_valid = True
12411337

0 commit comments

Comments
 (0)