Skip to content

Commit beb4b32

Browse files
[TRTLLM-12669][perf] Cache d2t target indices in spec metadata
The d2t-projected target vocab indices computed inside the rejection-path d2t padding step (arange(draft_vocab) + (source + d2t.to(device)) % vocab_size) were being rebuilt every iteration even though the d2t tensor is model-static. Cache the result on SpecMetadataBase.d2t_target_indices on first use and reuse it on subsequent iterations. Profile breakdown (llama70b bs=32, CUDA graph off) showed accept_draft.rejection.d2t_padding at 88 us/iter — the second-largest rejection-path step after compute target_probs (127 us). The index sequence costs ~10-20 us of that (3-4 kernels: arange + d2t H2D copy + add + mod); the rest is the slot-indexed scatter into full_draft_probs which is already pre-allocated. Verified on llama70b bs=32 over 3 rounds (mean ± stdev): Before: rej_on vs rej_off gap ≈ -10.0% (single-run baseline) After : rej_on vs rej_off gap = -8.71% ± 0.9% (3-round mean) Net within-run improvement ≈ +1.3%. qwen235b unchanged (already positive). Output accuracy verified across 22 (model, bs, mode) configurations: all 1760 outputs terminate normally (EOT or max_tokens), no regressions. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent e173cbf commit beb4b32

1 file changed

Lines changed: 28 additions & 15 deletions

File tree

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,11 @@ class SpecMetadata:
487487
# avoid a per-iter 64 MB zero-fill on the (max_num_requests, max_draft_len,
488488
# vocab_size) tensor. Shape: [max_num_requests, max_draft_len, vocab_size].
489489
full_draft_probs: Optional[torch.Tensor] = None
490+
# Cached d2t-projected target vocab indices, computed once on first use
491+
# (d2t is a model-static tensor). Replaces the per-iter
492+
# arange + (source + d2t) % vocab_size kernel sequence inside the d2t
493+
# padding step. Shape: [draft_vocab_size], dtype long.
494+
d2t_target_indices: Optional[torch.Tensor] = None
490495

491496
def __post_init__(self):
492497
pass
@@ -1052,8 +1057,8 @@ def _sample_and_accept_draft_tokens_base(
10521057
device=logits.device)
10531058

10541059
# Sample tokens using per-request sampling parameters
1055-
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
1056-
num_contexts, batch_size)
1060+
target_tokens = self._sample_tokens_for_batch(
1061+
logits, spec_metadata, num_contexts, batch_size)
10571062

10581063
# Context requests: only accept the sampled token (no draft tokens yet)
10591064
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1067,7 +1072,8 @@ def _sample_and_accept_draft_tokens_base(
10671072
# Compare draft tokens with target tokens using cumulative product
10681073
# Counts consecutive matches from the start
10691074
num_accepted_tokens[num_contexts:] += torch.cumprod(
1070-
(draft_tokens == gen_target_tokens[:, :runtime_draft_len]).int(),
1075+
(draft_tokens
1076+
== gen_target_tokens[:, :runtime_draft_len]).int(),
10711077
dim=-1).sum(1)
10721078

10731079
# Apply force override if set
@@ -1176,9 +1182,8 @@ def _sample_and_accept_draft_tokens_rejection(
11761182

11771183
target_probs_flat = compute_probs_from_logits(
11781184
gen_logits, temperatures, top_ks, top_ps)
1179-
target_probs = target_probs_flat.reshape(num_gens,
1180-
runtime_draft_len + 1,
1181-
vocab_size)
1185+
target_probs = target_probs_flat.reshape(
1186+
num_gens, runtime_draft_len + 1, vocab_size)
11821187

11831188
draft_vocab_size = draft_probs.shape[-1]
11841189
assert draft_probs.shape[0] == num_gens, (
@@ -1190,11 +1195,13 @@ def _sample_and_accept_draft_tokens_rejection(
11901195
d2t = getattr(spec_metadata, "d2t", None)
11911196
if draft_vocab_size != vocab_size:
11921197
# Use the pre-allocated buffer from spec_metadata.prepare()
1193-
# (zero-filled once at init; untouched positions stay 0). Falls
1194-
# back to per-iter allocation if the buffer is not configured,
1195-
# e.g. when use_rejection_sampling was off at prepare() time.
1198+
# (zero-filled once at init; untouched positions stay 0).
1199+
# Falls back to per-iter allocation if the buffer is not
1200+
# configured, e.g. when use_rejection_sampling was off at
1201+
# prepare() time.
11961202
if spec_metadata.full_draft_probs is not None:
1197-
full_draft_probs = spec_metadata.full_draft_probs[:num_gens]
1203+
full_draft_probs = spec_metadata.full_draft_probs[:
1204+
num_gens]
11981205
else:
11991206
full_draft_probs = torch.zeros(
12001207
(num_gens, runtime_draft_len, vocab_size),
@@ -1204,11 +1211,17 @@ def _sample_and_accept_draft_tokens_rejection(
12041211
assert d2t.numel() == draft_vocab_size, (
12051212
f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}"
12061213
)
1207-
d2t = d2t.to(device=device)
1208-
source_indices = torch.arange(draft_vocab_size,
1209-
device=device,
1210-
dtype=torch.long)
1211-
target_indices = (source_indices + d2t) % vocab_size
1214+
# d2t is model-static; compute target_indices once and
1215+
# cache on spec_metadata to skip the arange + add + mod
1216+
# kernel sequence on every iter.
1217+
target_indices = spec_metadata.d2t_target_indices
1218+
if target_indices is None:
1219+
source_indices = torch.arange(draft_vocab_size,
1220+
device=device,
1221+
dtype=torch.long)
1222+
target_indices = (source_indices +
1223+
d2t.to(device=device)) % vocab_size
1224+
spec_metadata.d2t_target_indices = target_indices
12121225
full_draft_probs[:, :runtime_draft_len,
12131226
target_indices] = draft_probs
12141227
else:

0 commit comments

Comments
 (0)