Skip to content

Commit 485792d

Browse files
[TRTLLM-12669][fix] Pre-capture both greedy and advanced sampling CUDA graphs during warmup
On-the-fly CUDA graph capture is disabled outside the warmup window (allow_capture context manager) because it can resize the shared cuda_graph_workspace tensor and invalidate addresses baked into previously captured graphs. As a result, the (is_all_greedy_sample=False) graph key introduced for one-engine spec dec was never captured: warmup only ran dummy requests with greedy sampling params, so inference batches with temperature / top_k / top_p fell back to eager. Fix: run the warmup capture loop twice for one-engine spec dec. The first pass captures the greedy fast-path (existing behavior). The second pass flips spec_metadata.is_all_greedy_sample to False before forward so maybe_get_cuda_graph computes the non-greedy key, and sets a runtime attribute that populate_sampling_params_for_one_model honors to override the dummy-request-derived greedy detection and substitute synthetic non-greedy values into the per-request buffers. Other paths are unaffected: non-one-engine spec dec and non-spec dec default is_all_greedy_sample to True, so the second pass is skipped. End-to-end (qwen3_8b_eagle3, bs=32, T=0.7/top_k=50/top_p=0.9): rej_off baseline: TPS=3713.73 rej_on (before fix): TPS=3854.01 (+3.8%; non-greedy ran eager) rej_on (after fix): TPS=6013.58 (+62.0%; non-greedy uses graph) Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent df6be84 commit 485792d

2 files changed

Lines changed: 71 additions & 24 deletions

File tree

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,31 +1071,61 @@ def _capture_generation_cuda_graphs(self,
10711071
else:
10721072
max_seq_len_list = [effective_max_seq_len]
10731073

1074-
for bs, draft_len in graphs_to_capture:
1075-
if bs > self.batch_size:
1076-
continue
1077-
1078-
for max_seq_len in max_seq_len_list:
1079-
warmup_request = self._create_cuda_graph_warmup_request(
1080-
resource_manager, bs, draft_len, max_seq_len)
1081-
with self._release_batch_context(warmup_request,
1082-
resource_manager) as batch:
1083-
if batch is None:
1084-
# No KV cache space, cannot continue capturing graphs
1074+
def _run_capture_pass(force_non_greedy: bool, label: str) -> None:
1075+
spec_metadata = getattr(self, 'spec_metadata', None)
1076+
if force_non_greedy and spec_metadata is not None:
1077+
spec_metadata._force_non_greedy_for_capture = True
1078+
# maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample
1079+
# to build the graph cache key BEFORE populate runs inside
1080+
# _prepare_inputs. Pre-flip it here so the very first capture
1081+
# in this pass uses the non-greedy key; populate's override
1082+
# below will keep it False on every subsequent iteration.
1083+
spec_metadata.is_all_greedy_sample = False
1084+
try:
1085+
for bs, draft_len in graphs_to_capture:
1086+
if bs > self.batch_size:
10851087
continue
1086-
logger.info(
1087-
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}"
1088-
)
1089-
self.enable_spec_decode = draft_len > 0 or self.is_draft_model or (
1090-
self.spec_config is not None
1091-
and self.spec_config.spec_dec_mode.use_one_engine())
1092-
self._update_draft_inference_state_for_warmup(
1093-
batch, draft_len > 0, resource_manager)
1094-
self.runtime_draft_len = draft_len
1095-
self.forward(batch,
1096-
new_tensors_device=None,
1097-
resource_manager=resource_manager)
1098-
torch.cuda.synchronize()
1088+
1089+
for max_seq_len in max_seq_len_list:
1090+
warmup_request = self._create_cuda_graph_warmup_request(
1091+
resource_manager, bs, draft_len, max_seq_len)
1092+
with self._release_batch_context(
1093+
warmup_request, resource_manager) as batch:
1094+
if batch is None:
1095+
# No KV cache space, cannot continue capturing graphs
1096+
continue
1097+
logger.info(
1098+
f"Run generation-only CUDA graph warmup ({label}) "
1099+
f"for batch size={bs}, draft_len={draft_len}, "
1100+
f"max_seq_len={max_seq_len}")
1101+
self.enable_spec_decode = draft_len > 0 or self.is_draft_model or (
1102+
self.spec_config is not None and
1103+
self.spec_config.spec_dec_mode.use_one_engine())
1104+
self._update_draft_inference_state_for_warmup(
1105+
batch, draft_len > 0, resource_manager)
1106+
self.runtime_draft_len = draft_len
1107+
self.forward(batch,
1108+
new_tensors_device=None,
1109+
resource_manager=resource_manager)
1110+
torch.cuda.synchronize()
1111+
finally:
1112+
if force_non_greedy and spec_metadata is not None:
1113+
spec_metadata._force_non_greedy_for_capture = False
1114+
1115+
# Pass 1: greedy fast-path (dummy requests carry no sampling params,
1116+
# so is_all_greedy_sample is naturally True).
1117+
_run_capture_pass(force_non_greedy=False, label="greedy")
1118+
# Pass 2: advanced sampling variant. Required because on-the-fly capture
1119+
# is disabled outside warmup, so any inference batch that contains a
1120+
# non-greedy request would otherwise fall back to eager. Only meaningful
1121+
# for one-engine spec dec (where is_all_greedy_sample participates in
1122+
# the graph key); other paths default to True and would never key into
1123+
# this variant.
1124+
needs_non_greedy_capture = (
1125+
self.spec_config is not None
1126+
and self.spec_config.spec_dec_mode.use_one_engine())
1127+
if needs_non_greedy_capture:
1128+
_run_capture_pass(force_non_greedy=True, label="advanced sampling")
10991129
# Set the value back to the original value after cuda graph warmups are complete
11001130
self.enable_spec_decode = self.is_spec_decode
11011131

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,23 @@ def _normalize_request_sampling_params(
647647
self.is_all_greedy_sample = (self.skip_temperature and self.skip_top_k
648648
and self.skip_top_p)
649649

650+
# Warmup-time override (set via runtime attribute by the model engine):
651+
# force the advanced-sampling code path so the CUDA graph for the
652+
# (is_all_greedy_sample=False) key gets captured. Dummy warmup requests
653+
# carry no sampling params, so the natural detection above always
654+
# returns True; this branch substitutes synthetic non-greedy scalars
655+
# into the per-request data and lets Phase 2 run normally to populate
656+
# the GPU buffers used by the captured kernels.
657+
if getattr(self, '_force_non_greedy_for_capture', False):
658+
self.skip_temperature = False
659+
self.skip_top_k = False
660+
self.skip_top_p = False
661+
self.is_all_greedy_sample = False
662+
per_request_normalized = [
663+
(0.7, 50, 0.9, num_tokens)
664+
for (_, _, _, num_tokens) in per_request_normalized
665+
]
666+
650667
tokens_per_request = (self.max_total_draft_tokens + 1 if
651668
self.is_spec_dec_tree else self.max_draft_len + 1)
652669
required_flat_size = tokens_per_request * self.max_num_requests

0 commit comments

Comments
 (0)