Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5f4824a
[TRTLLM-12669][refactor] Replace allow_advanced_sampling with auto-de…
zhaoyangwang-nvidia May 29, 2026
87300b0
[TRTLLM-12669][feat] Eagle3 one-model draft sampling honors target sa…
zhaoyangwang-nvidia Jun 1, 2026
69a4368
[TRTLLM-12669][feat] Enable rejection sampling by default for Eagle3 …
zhaoyangwang-nvidia Jun 2, 2026
376c2de
[TRTLLM-12669][feat] Slot-index draft_probs and support mixed-batch r…
zhaoyangwang-nvidia Jun 2, 2026
47fa873
[TRTLLM-12669][fix] Pre-capture both greedy and advanced sampling CUD…
zhaoyangwang-nvidia Jun 3, 2026
b54c8a6
[TRTLLM-12669][perf] Reuse draft probs to drop redundant softmax + cu…
zhaoyangwang-nvidia Jun 3, 2026
d40e315
[TRTLLM-12669][perf] Cache d2t target indices in spec metadata
zhaoyangwang-nvidia Jun 4, 2026
6641453
[TRTLLM-12669][chore] Apply CI yapf reformat to interface.py
zhaoyangwang-nvidia Jun 4, 2026
b27a720
[TRTLLM-12669][chore] Address review feedback
zhaoyangwang-nvidia Jun 4, 2026
3f544f9
[TRTLLM-12669][chore] Revert use_rejection_sampling default to False
zhaoyangwang-nvidia Jun 5, 2026
616b446
[TRTLLM-12669][fix] Restore last_tokens_idx dropped during rebase con…
zhaoyangwang-nvidia Jun 5, 2026
5641566
[TRTLLM-12669][fix] Remove stray allow_advanced_sampling arg in mtp_e…
zhaoyangwang-nvidia Jun 5, 2026
511e334
[TRTLLM-12669][fix] Fix draft_decoder AttributeError for MTP Eagle mode
zhaoyangwang-nvidia Jun 8, 2026
1f670b9
[TRTLLM-12669][fix] Fix PARD buffer overflow and CUDA-graph-incompati…
zhaoyangwang-nvidia Jun 9, 2026
e66f8b2
[TRTLLM-12669][fix] disable rejection sampling during CUDA graph capt…
zhaoyangwang-nvidia Jun 9, 2026
4f599c2
[TRTLLM-11508][fix] trim draft token to token_count when use_lm_head_…
zhaoyangwang-nvidia Jun 10, 2026
20a94d2
[TRTLLM-12669][fix] slice padded draft logits before advanced samplin…
zhaoyangwang-nvidia Jun 11, 2026
7bfeba0
[TRTLLM-12669][fix] refresh is_all_greedy_sample before CUDA graph ke…
zhaoyangwang-nvidia Jun 12, 2026
6cef5af
[TRTLLM-12669][fix] keep MTP-Eagle greedy draft sampling TP-aware to …
zhaoyangwang-nvidia Jun 13, 2026
9213129
[TRTLLM-12669][fix] all-gather sharded draft logits before advanced (…
zhaoyangwang-nvidia Jun 14, 2026
b3bdee0
[TRTLLM-12669][fix] only route plain-TP greedy MTP-Eagle draft sampli…
zhaoyangwang-nvidia Jun 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,6 @@ def add_llm_args(parser):
default=False,
action='store_true')
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
parser.add_argument('--allow_advanced_sampling',
default=False,
action='store_true')
parser.add_argument('--eagle3_model_arch',
type=str,
default="llama3",
Expand Down Expand Up @@ -294,7 +291,6 @@ def setup_llm(args, **kwargs):
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
allow_advanced_sampling=args.allow_advanced_sampling,
eagle3_model_arch=args.eagle3_model_arch,
max_total_draft_tokens=args.max_total_draft_tokens)
elif spec_decode_algo == "DFLASH":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ kv_cache_config:
speculative_config:
decoding_type: MTP
max_draft_len: 5
allow_advanced_sampling: true
cuda_graph_config:
max_batch_size: 64
enable_padding: true
Expand Down
23 changes: 17 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..memory_buffer_utils import get_memory_buffers
from ..modules.multi_stream_utils import with_multi_stream
from ..speculative.eagle3 import Eagle3ResourceManager
from ..speculative.interface import SpecMetadata
from ..speculative.spec_sampler_base import SampleStateTensorsSpec
from ..speculative.utils import get_draft_kv_cache_manager
from ..utils import make_weak_ref, piecewise_cuda_graph
Expand All @@ -29,7 +30,7 @@

# A large prime number used for dummy request IDs to avoid collisions
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1
KeyType: TypeAlias = Tuple[int, int, bool, bool]
KeyType: TypeAlias = Tuple[int, int, bool, bool, bool]


@dataclass
Expand Down Expand Up @@ -197,19 +198,28 @@ def get_graph_key(
self,
batch: ScheduledRequests,
new_tensors_device: Optional[SampleStateTensors] = None,
spec_resource_manager: Optional[BaseResourceManager] = None):
spec_resource_manager: Optional[BaseResourceManager] = None,
spec_metadata: Optional[SpecMetadata] = None):
batch_size = batch.batch_size

# Get the sequence length mode.
short_seq_len_mode = self._get_seq_len_mode(batch, new_tensors_device)

# Spec one-engine sampler has two code paths (argmax fast-path vs
# advanced sampling kernel). Include this in the key so we capture
# both variants and dispatch at replay based on actual batch state.
# Default to True (greedy fast-path) when the metadata doesn't carry
# this field (non-one-engine paths or non-spec batches).
is_all_greedy_sample = bool(
getattr(spec_metadata, "is_all_greedy_sample", True))

if self.config.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0
key = (batch_size, draft_len, spec_resource_manager.is_first_draft,
short_seq_len_mode)
short_seq_len_mode, is_all_greedy_sample)
else:
# With dynamic spec decode, the draft length may be zero even when enable_spec_decode is True,
# so we need to get the draft length from the batch instead of using enable_spec_decode.
Expand All @@ -219,7 +229,8 @@ def get_graph_key(
draft_len = max(draft_len_list)
assert len(
set(draft_len_list)) == 1, "All draft lengths must be the same"
key = (batch_size, draft_len, False, short_seq_len_mode)
key = (batch_size, draft_len, False, short_seq_len_mode,
is_all_greedy_sample)
return key

def __del__(self):
Expand All @@ -230,7 +241,7 @@ def maybe_get_cuda_graph(
batch: ScheduledRequests,
enable_spec_decode: bool,
attn_metadata: Any,
spec_metadata: Optional[Any] = None,
spec_metadata: Optional[SpecMetadata] = None,
draft_tokens_cuda: Optional[torch.Tensor] = None,
new_tensors_device: Optional[SampleStateTensors] = None,
spec_resource_manager: Optional[BaseResourceManager] = None,
Expand Down Expand Up @@ -273,7 +284,7 @@ def maybe_get_cuda_graph(
# can replay CUDA graphs using the cache.
return None, None, None
key = self.get_graph_key(batch, new_tensors_device,
spec_resource_manager)
spec_resource_manager, spec_metadata)

if key in self.graphs:
return self.graph_metadata[key][
Expand Down
98 changes: 73 additions & 25 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ def __init__(
sparse_attn_config=self.sparse_attention_config)

if self.is_spec_decode:
self.spec_metadata = None
update_spec_config_from_model_config(self.spec_config,
self.model.config)
max_num_draft_tokens = self.max_draft_loop_tokens * self.batch_size
Expand Down Expand Up @@ -551,6 +550,7 @@ def __init__(
# the model engine.
self.attn_metadata = None
self.encoder_attn_metadata = None
self.spec_metadata = None
self.iter_states = {}
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None

Expand Down Expand Up @@ -1343,33 +1343,70 @@ def _capture_generation_cuda_graphs(self,
else:
max_seq_len_list = [effective_max_seq_len]

for bs, draft_len in graphs_to_capture:
if bs > self.batch_size:
continue

for max_seq_len in max_seq_len_list:
warmup_request = self._create_cuda_graph_warmup_request(
resource_manager, bs, draft_len, max_seq_len)
with self._release_batch_context(warmup_request,
resource_manager) as batch:
if batch is None:
# No KV cache space, cannot continue capturing graphs
def _run_capture_pass(force_non_greedy: bool, label: str) -> None:
spec_metadata = self.spec_metadata
if force_non_greedy and spec_metadata is not None:
spec_metadata._force_non_greedy_for_capture = True
# maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample
# to build the graph cache key BEFORE populate runs inside
# _prepare_inputs. Pre-flip it here so the very first capture
# in this pass uses the non-greedy key; populate's override
# below will keep it False on every subsequent iteration.
spec_metadata.is_all_greedy_sample = False
try:
for bs, draft_len in graphs_to_capture:
if bs > self.batch_size:
continue
logger.info(
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}"
)
self.enable_spec_decode = draft_len > 0 or self.is_draft_model or (
self.spec_config is not None
and self.spec_config.spec_dec_mode.use_one_engine())
self._update_draft_inference_state_for_warmup(
batch, draft_len > 0, resource_manager)
self.runtime_draft_len = draft_len
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()

for max_seq_len in max_seq_len_list:
warmup_request = self._create_cuda_graph_warmup_request(
resource_manager, bs, draft_len, max_seq_len)
with self._release_batch_context(
warmup_request, resource_manager) as batch:
if batch is None:
# No KV cache space, cannot continue capturing graphs
continue
logger.info(
f"Run generation-only CUDA graph warmup ({label}) "
f"for batch size={bs}, draft_len={draft_len}, "
f"max_seq_len={max_seq_len}")
self.enable_spec_decode = draft_len > 0 or self.is_draft_model or (
self.spec_config is not None and
self.spec_config.spec_dec_mode.use_one_engine())
self._update_draft_inference_state_for_warmup(
batch, draft_len > 0, resource_manager)
self.runtime_draft_len = draft_len
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
finally:
if force_non_greedy and spec_metadata is not None:
spec_metadata._force_non_greedy_for_capture = False

# Pass 1: greedy fast-path (dummy requests carry no sampling params,
# so is_all_greedy_sample is naturally True).
_run_capture_pass(force_non_greedy=False, label="greedy")
# Pass 2: advanced sampling variant. Required because on-the-fly capture
# is disabled outside warmup, so any inference batch that contains a
# non-greedy request would otherwise fall back to eager. Only meaningful
# for one-engine spec dec (where is_all_greedy_sample participates in
# the graph key); other paths default to True and would never key into
# this variant.
needs_non_greedy_capture = (
self.spec_config is not None
and self.spec_config.spec_dec_mode.use_one_engine())
if needs_non_greedy_capture:
_run_capture_pass(force_non_greedy=True, label="advanced sampling")
# Set the value back to the original value after cuda graph warmups are complete
self.enable_spec_decode = self.is_spec_decode
# The advanced-sampling capture pass above leaves is_all_greedy_sample
# set to False on spec_metadata. Reset it to the default so the first
# real iteration's graph-key selection is not seeded with this
# capture-only value. (update_is_all_greedy_sample refreshes it every
# iteration; this is a defensive guard.)
if self.spec_metadata is not None:
self.spec_metadata.is_all_greedy_sample = True

def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):
"""Captures piecewise CUDA graphs for context/prefill steps via torch.compile."""
Expand Down Expand Up @@ -4690,6 +4727,17 @@ def forward(self,
self.runtime_draft_len) as padded_requests:
self._pad_batch_seed_mrope_delta_cache(padded_requests)

# Refresh is_all_greedy_sample for the *current* batch BEFORE the
# CUDA graph key is built below. The key includes this flag to pick
# the argmax vs advanced-sampling graph variant; populate (inside
# _prepare_inputs) runs later and fills the matching GPU buffers.
# Without this pre-scan the key would use the previous iteration's
# stale value and could replay the advanced graph against
# unpopulated (greedy) buffers, hanging the run (e.g. MTP nextn>=2).
if spec_metadata is not None:
spec_metadata.update_is_all_greedy_sample(
padded_requests.all_requests())

maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
padded_requests,
enable_spec_decode=self.enable_spec_decode,
Expand Down
6 changes: 0 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,6 @@ def create_py_executor(
)
llm_args.disable_overlap_scheduler = True

if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
if not spec_config.allow_advanced_sampling:
logger.warning(
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
)
# Check FLASHINFER compatibility with one-engine speculative decoding
if llm_args.attn_backend == "FLASHINFER":
raise ValueError(
Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def verify_dynamic_tree_rejection_from_logits_out(
offset: int | torch.Tensor = 0,
d2t: torch.Tensor | None = None,
skip_all_sampling_params: bool = False,
top_k_max: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Tree-aware rejection sampling from logits (three CUDA ops).

Expand Down Expand Up @@ -266,9 +267,13 @@ def verify_dynamic_tree_rejection_from_logits_out(
tree_valid = torch.ones(num_gens, dtype=torch.bool, device=candidates.device)
tree_valid = tree_valid.contiguous()

if top_k is None:
if top_k_max is not None:
# Pre-computed CPU-side (CUDA-graph-safe): use as-is.
pass
elif top_k is None:
top_k_max = 0
else:
# Fallback path (non-CUDA-graph contexts): compute from tensor.
enabled_top_k = top_k[(top_k > 0) & (top_k < target_vocab_size)]
top_k_max = int(enabled_top_k.max().item()) if enabled_top_k.numel() > 0 else 0

Expand Down
Loading
Loading