Skip to content

Commit 764edb7

Browse files
[TRTLLM-12669][fix] refresh is_all_greedy_sample before CUDA graph key selection
The one-engine CUDA graph key includes is_all_greedy_sample to dispatch between the argmax fast-path and the advanced-sampling graph variant. The flag was only (re)computed inside populate_sampling_params_for_one_model, which runs in _prepare_inputs AFTER maybe_get_cuda_graph has already built the key. The key therefore used the previous iteration's stale flag, and warmup left it False (from the advanced-sampling capture pass). On the first real decode iteration a greedy batch would then replay the advanced-sampling graph while populate skips filling the sampling/draft_probs buffers, reading uninitialized slot-indexed data. For MTP with num_nextn>=2 this hung the executor (Hang detected on rank 0). Fix: - Extract the greediness detection into _scan_one_model_sampling (single source of truth) and add update_is_all_greedy_sample, called before the graph key is built so the key matches the buffers populate fills. populate now reuses the same scan. - Defensively reset spec_metadata.is_all_greedy_sample to True after CUDA graph warmup so the stale capture-only False does not seed the first iteration. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 4aa80bf commit 764edb7

2 files changed

Lines changed: 66 additions & 14 deletions

File tree

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,13 @@ def _run_capture_pass(force_non_greedy: bool, label: str) -> None:
13311331
_run_capture_pass(force_non_greedy=True, label="advanced sampling")
13321332
# Set the value back to the original value after cuda graph warmups are complete
13331333
self.enable_spec_decode = self.is_spec_decode
1334+
# The advanced-sampling capture pass above leaves is_all_greedy_sample
1335+
# set to False on spec_metadata. Reset it to the default so the first
1336+
# real iteration's graph-key selection is not seeded with this
1337+
# capture-only value. (update_is_all_greedy_sample refreshes it every
1338+
# iteration; this is a defensive guard.)
1339+
if self.spec_metadata is not None:
1340+
self.spec_metadata.is_all_greedy_sample = True
13341341

13351342
def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):
13361343
"""Captures piecewise CUDA graphs for context/prefill steps via torch.compile."""
@@ -4584,6 +4591,17 @@ def forward(self,
45844591
scheduled_requests, resource_manager,
45854592
self.runtime_draft_len) as padded_requests:
45864593

4594+
# Refresh is_all_greedy_sample for the *current* batch BEFORE the
4595+
# CUDA graph key is built below. The key includes this flag to pick
4596+
# the argmax vs advanced-sampling graph variant; populate (inside
4597+
# _prepare_inputs) runs later and fills the matching GPU buffers.
4598+
# Without this pre-scan the key would use the previous iteration's
4599+
# stale value and could replay the advanced graph against
4600+
# unpopulated (greedy) buffers, hanging the run (e.g. MTP nextn>=2).
4601+
if spec_metadata is not None:
4602+
spec_metadata.update_is_all_greedy_sample(
4603+
padded_requests.all_requests())
4604+
45874605
maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
45884606
padded_requests,
45894607
enable_spec_decode=self.enable_spec_decode,

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
if TYPE_CHECKING:
3838
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
39+
from ..pyexecutor.llm_request import LlmRequest
3940

4041
if IS_FLASHINFER_AVAILABLE:
4142
import flashinfer
@@ -574,25 +575,20 @@ def maybe_capture_hidden_states(self, layer_id: int,
574575
model. Use this method to record them. By default, does nothing.
575576
"""
576577

577-
def populate_sampling_params_for_one_model(
578-
self, requests: list["LlmRequest"]) -> None:
579-
"""
580-
Set up topp/topk/temperatures for 1-model sampler.
578+
def _scan_one_model_sampling(
579+
self, requests: list["LlmRequest"]
580+
) -> tuple[list[tuple[float, int, float, int]], list[int]]:
581+
"""Single source of truth for one-engine sampling-param detection.
581582
582-
Scans sampling configs to set skip_*/is_all_greedy_sample flags. When
583-
any request needs sampling, also builds per-token/per-request lists
584-
and copies them to GPU buffers; all-greedy batches skip this entirely.
583+
Scans the batch's sampling configs and sets skip_*/has_greedy_requests/
584+
is_all_greedy_sample (honoring the warmup capture override). Returns
585+
``(per_request_normalized, per_request_slot_ids)`` for buffer
586+
population. Does NOT allocate or fill GPU buffers, so it is safe to call
587+
before the CUDA graph key is built.
585588
"""
586589
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
587590
from tensorrt_llm.sampling_params import SamplingParams
588591

589-
if not self.spec_dec_mode.use_one_engine():
590-
return
591-
592-
if self.temperatures is None:
593-
# Ensures determinism across ranks.
594-
torch.manual_seed(0)
595-
596592
# Need to use a very small value for temperature when disabled to avoid division by 0
597593
DISABLE_TEMP_VAL = 1e-5
598594
# Very large values disable topk.
@@ -708,6 +704,44 @@ def _normalize_request_sampling_params(
708704
for (_, _, _, num_tokens) in per_request_normalized
709705
]
710706

707+
return per_request_normalized, per_request_slot_ids
708+
709+
def update_is_all_greedy_sample(self, requests: list["LlmRequest"]) -> None:
710+
"""Refresh ``is_all_greedy_sample`` for the *current* batch.
711+
712+
Must be called BEFORE the CUDA graph key is built (the key includes
713+
``is_all_greedy_sample`` to choose the argmax vs advanced-sampling graph
714+
variant). ``populate_sampling_params_for_one_model`` runs later, inside
715+
``_prepare_inputs``, and re-derives the same flag while filling the GPU
716+
sampling buffers. Computing the flag here first keeps the selected graph
717+
consistent with the buffers ``populate`` fills; otherwise the key would
718+
use the previous iteration's stale value and could replay the advanced
719+
graph against unpopulated (greedy) buffers, which can hang/corrupt the
720+
run (notably for MTP with num_nextn>=2).
721+
"""
722+
if not self.spec_dec_mode.use_one_engine():
723+
return
724+
self._scan_one_model_sampling(requests)
725+
726+
def populate_sampling_params_for_one_model(
727+
self, requests: list["LlmRequest"]) -> None:
728+
"""
729+
Set up topp/topk/temperatures for 1-model sampler.
730+
731+
Scans sampling configs to set skip_*/is_all_greedy_sample flags. When
732+
any request needs sampling, also builds per-token/per-request lists
733+
and copies them to GPU buffers; all-greedy batches skip this entirely.
734+
"""
735+
if not self.spec_dec_mode.use_one_engine():
736+
return
737+
738+
if self.temperatures is None:
739+
# Ensures determinism across ranks.
740+
torch.manual_seed(0)
741+
742+
per_request_normalized, per_request_slot_ids = (
743+
self._scan_one_model_sampling(requests))
744+
711745
tokens_per_request = (self.max_total_draft_tokens + 1 if
712746
self.is_spec_dec_tree else self.max_draft_len + 1)
713747
# Warmup batches may exceed max_num_requests * tokens_per_request (e.g.

0 commit comments

Comments
 (0)