Skip to content

Commit 6e0c1c2

Browse files
[TRTLLM-12669][refactor] Eagle3 sampling: auto-detect greedy fast-path, mixed-batch rejection sampling, draft honors target params (#14745)
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 5fe0a17 commit 6e0c1c2

16 files changed

Lines changed: 772 additions & 268 deletions

File tree

examples/llm-api/quickstart_advanced.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,6 @@ def add_llm_args(parser):
190190
default=False,
191191
action='store_true')
192192
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
193-
parser.add_argument('--allow_advanced_sampling',
194-
default=False,
195-
action='store_true')
196193
parser.add_argument('--eagle3_model_arch',
197194
type=str,
198195
default="llama3",
@@ -294,7 +291,6 @@ def setup_llm(args, **kwargs):
294291
eagle_choices=args.eagle_choices,
295292
use_dynamic_tree=args.use_dynamic_tree,
296293
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
297-
allow_advanced_sampling=args.allow_advanced_sampling,
298294
eagle3_model_arch=args.eagle3_model_arch,
299295
max_total_draft_tokens=args.max_total_draft_tokens)
300296
elif spec_decode_algo == "DFLASH":

examples/models/core/nemotron/README_nemotron_super_v3.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ kv_cache_config:
144144
speculative_config:
145145
decoding_type: MTP
146146
max_draft_len: 5
147-
allow_advanced_sampling: true
148147
cuda_graph_config:
149148
max_batch_size: 64
150149
enable_padding: true

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..memory_buffer_utils import get_memory_buffers
2020
from ..modules.multi_stream_utils import with_multi_stream
2121
from ..speculative.eagle3 import Eagle3ResourceManager
22+
from ..speculative.interface import SpecMetadata
2223
from ..speculative.spec_sampler_base import SampleStateTensorsSpec
2324
from ..speculative.utils import get_draft_kv_cache_manager
2425
from ..utils import make_weak_ref, piecewise_cuda_graph
@@ -30,7 +31,7 @@
3031

3132
# A large prime number used for dummy request IDs to avoid collisions
3233
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1
33-
KeyType: TypeAlias = Tuple[int, int, bool, bool]
34+
KeyType: TypeAlias = Tuple[int, int, bool, bool, bool]
3435

3536

3637
@dataclass
@@ -198,19 +199,28 @@ def get_graph_key(
198199
self,
199200
batch: ScheduledRequests,
200201
new_tensors_device: Optional[SampleStateTensors] = None,
201-
spec_resource_manager: Optional[BaseResourceManager] = None):
202+
spec_resource_manager: Optional[BaseResourceManager] = None,
203+
spec_metadata: Optional[SpecMetadata] = None):
202204
batch_size = batch.batch_size
203205

204206
# Get the sequence length mode.
205207
short_seq_len_mode = self._get_seq_len_mode(batch, new_tensors_device)
206208

209+
# Spec one-engine sampler has two code paths (argmax fast-path vs
210+
# advanced sampling kernel). Include this in the key so we capture
211+
# both variants and dispatch at replay based on actual batch state.
212+
# Default to True (greedy fast-path) when the metadata doesn't carry
213+
# this field (non-one-engine paths or non-spec batches).
214+
is_all_greedy_sample = bool(
215+
getattr(spec_metadata, "is_all_greedy_sample", True))
216+
207217
if self.config.is_draft_model and spec_resource_manager is not None and isinstance(
208218
spec_resource_manager, Eagle3ResourceManager):
209219
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
210220
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
211221
draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0
212222
key = (batch_size, draft_len, spec_resource_manager.is_first_draft,
213-
short_seq_len_mode)
223+
short_seq_len_mode, is_all_greedy_sample)
214224
else:
215225
# With dynamic spec decode, the draft length may be zero even when enable_spec_decode is True,
216226
# so we need to get the draft length from the batch instead of using enable_spec_decode.
@@ -220,7 +230,8 @@ def get_graph_key(
220230
draft_len = max(draft_len_list)
221231
assert len(
222232
set(draft_len_list)) == 1, "All draft lengths must be the same"
223-
key = (batch_size, draft_len, False, short_seq_len_mode)
233+
key = (batch_size, draft_len, False, short_seq_len_mode,
234+
is_all_greedy_sample)
224235
return key
225236

226237
def __del__(self):
@@ -231,7 +242,7 @@ def maybe_get_cuda_graph(
231242
batch: ScheduledRequests,
232243
enable_spec_decode: bool,
233244
attn_metadata: Any,
234-
spec_metadata: Optional[Any] = None,
245+
spec_metadata: Optional[SpecMetadata] = None,
235246
draft_tokens_cuda: Optional[torch.Tensor] = None,
236247
new_tensors_device: Optional[SampleStateTensors] = None,
237248
spec_resource_manager: Optional[BaseResourceManager] = None,
@@ -274,7 +285,7 @@ def maybe_get_cuda_graph(
274285
# can replay CUDA graphs using the cache.
275286
return None, None, None
276287
key = self.get_graph_key(batch, new_tensors_device,
277-
spec_resource_manager)
288+
spec_resource_manager, spec_metadata)
278289

279290
if key in self.graphs:
280291
return self.graph_metadata[key][

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,6 @@ def __init__(
498498
sparse_attention_config=self.sparse_attention_config)
499499

500500
if self.is_spec_decode:
501-
self.spec_metadata = None
502501
update_spec_config_from_model_config(self.spec_config,
503502
self.model.config)
504503
max_num_draft_tokens = self.max_draft_loop_tokens * self.batch_size
@@ -552,6 +551,7 @@ def __init__(
552551
# the model engine.
553552
self.attn_metadata = None
554553
self.encoder_attn_metadata = None
554+
self.spec_metadata = None
555555
self.iter_states = {}
556556
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
557557

@@ -1350,33 +1350,70 @@ def _capture_generation_cuda_graphs(self,
13501350
else:
13511351
max_seq_len_list = [effective_max_seq_len]
13521352

1353-
for bs, draft_len in graphs_to_capture:
1354-
if bs > self.batch_size:
1355-
continue
1356-
1357-
for max_seq_len in max_seq_len_list:
1358-
warmup_request = self._create_cuda_graph_warmup_request(
1359-
resource_manager, bs, draft_len, max_seq_len)
1360-
with self._release_batch_context(warmup_request,
1361-
resource_manager) as batch:
1362-
if batch is None:
1363-
# No KV cache space, cannot continue capturing graphs
1353+
def _run_capture_pass(force_non_greedy: bool, label: str) -> None:
1354+
spec_metadata = self.spec_metadata
1355+
if force_non_greedy and spec_metadata is not None:
1356+
spec_metadata._force_non_greedy_for_capture = True
1357+
# maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample
1358+
# to build the graph cache key BEFORE populate runs inside
1359+
# _prepare_inputs. Pre-flip it here so the very first capture
1360+
# in this pass uses the non-greedy key; populate's override
1361+
# below will keep it False on every subsequent iteration.
1362+
spec_metadata.is_all_greedy_sample = False
1363+
try:
1364+
for bs, draft_len in graphs_to_capture:
1365+
if bs > self.batch_size:
13641366
continue
1365-
logger.info(
1366-
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}"
1367-
)
1368-
self.enable_spec_decode = draft_len > 0 or self.is_draft_model or (
1369-
self.spec_config is not None
1370-
and self.spec_config.spec_dec_mode.use_one_engine())
1371-
self._update_draft_inference_state_for_warmup(
1372-
batch, draft_len > 0, resource_manager)
1373-
self.runtime_draft_len = draft_len
1374-
self.forward(batch,
1375-
new_tensors_device=None,
1376-
resource_manager=resource_manager)
1377-
torch.cuda.synchronize()
1367+
1368+
for max_seq_len in max_seq_len_list:
1369+
warmup_request = self._create_cuda_graph_warmup_request(
1370+
resource_manager, bs, draft_len, max_seq_len)
1371+
with self._release_batch_context(
1372+
warmup_request, resource_manager) as batch:
1373+
if batch is None:
1374+
# No KV cache space, cannot continue capturing graphs
1375+
continue
1376+
logger.info(
1377+
f"Run generation-only CUDA graph warmup ({label}) "
1378+
f"for batch size={bs}, draft_len={draft_len}, "
1379+
f"max_seq_len={max_seq_len}")
1380+
self.enable_spec_decode = draft_len > 0 or self.is_draft_model or (
1381+
self.spec_config is not None and
1382+
self.spec_config.spec_dec_mode.use_one_engine())
1383+
self._update_draft_inference_state_for_warmup(
1384+
batch, draft_len > 0, resource_manager)
1385+
self.runtime_draft_len = draft_len
1386+
self.forward(batch,
1387+
new_tensors_device=None,
1388+
resource_manager=resource_manager)
1389+
torch.cuda.synchronize()
1390+
finally:
1391+
if force_non_greedy and spec_metadata is not None:
1392+
spec_metadata._force_non_greedy_for_capture = False
1393+
1394+
# Pass 1: greedy fast-path (dummy requests carry no sampling params,
1395+
# so is_all_greedy_sample is naturally True).
1396+
_run_capture_pass(force_non_greedy=False, label="greedy")
1397+
# Pass 2: advanced sampling variant. Required because on-the-fly capture
1398+
# is disabled outside warmup, so any inference batch that contains a
1399+
# non-greedy request would otherwise fall back to eager. Only meaningful
1400+
# for one-engine spec dec (where is_all_greedy_sample participates in
1401+
# the graph key); other paths default to True and would never key into
1402+
# this variant.
1403+
needs_non_greedy_capture = (
1404+
self.spec_config is not None
1405+
and self.spec_config.spec_dec_mode.use_one_engine())
1406+
if needs_non_greedy_capture:
1407+
_run_capture_pass(force_non_greedy=True, label="advanced sampling")
13781408
# Set the value back to the original value after cuda graph warmups are complete
13791409
self.enable_spec_decode = self.is_spec_decode
1410+
# The advanced-sampling capture pass above leaves is_all_greedy_sample
1411+
# set to False on spec_metadata. Reset it to the default so the first
1412+
# real iteration's graph-key selection is not seeded with this
1413+
# capture-only value. (update_is_all_greedy_sample refreshes it every
1414+
# iteration; this is a defensive guard.)
1415+
if self.spec_metadata is not None:
1416+
self.spec_metadata.is_all_greedy_sample = True
13801417

13811418
def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager):
13821419
"""Captures piecewise CUDA graphs for context/prefill steps via torch.compile."""
@@ -4887,6 +4924,17 @@ def forward(self,
48874924
self.runtime_draft_len) as padded_requests:
48884925
self._pad_batch_seed_mrope_delta_cache(padded_requests)
48894926

4927+
# Refresh is_all_greedy_sample for the *current* batch BEFORE the
4928+
# CUDA graph key is built below. The key includes this flag to pick
4929+
# the argmax vs advanced-sampling graph variant; populate (inside
4930+
# _prepare_inputs) runs later and fills the matching GPU buffers.
4931+
# Without this pre-scan the key would use the previous iteration's
4932+
# stale value and could replay the advanced graph against
4933+
# unpopulated (greedy) buffers, hanging the run (e.g. MTP nextn>=2).
4934+
if spec_metadata is not None:
4935+
spec_metadata.update_is_all_greedy_sample(
4936+
padded_requests.all_requests())
4937+
48904938
maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
48914939
padded_requests,
48924940
enable_spec_decode=self.enable_spec_decode,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,6 @@ def create_py_executor(
406406
)
407407
llm_args.disable_overlap_scheduler = True
408408

409-
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
410-
if not spec_config.allow_advanced_sampling:
411-
logger.warning(
412-
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
413-
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
414-
)
415409
# Check FLASHINFER compatibility with one-engine speculative decoding
416410
if llm_args.attn_backend == "FLASHINFER":
417411
raise ValueError(

tensorrt_llm/_torch/speculative/dynamic_tree_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def verify_dynamic_tree_rejection_from_logits_out(
237237
offset: int | torch.Tensor = 0,
238238
d2t: torch.Tensor | None = None,
239239
skip_all_sampling_params: bool = False,
240+
top_k_max: int | None = None,
240241
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
241242
"""Tree-aware rejection sampling from logits (three CUDA ops).
242243
@@ -266,9 +267,13 @@ def verify_dynamic_tree_rejection_from_logits_out(
266267
tree_valid = torch.ones(num_gens, dtype=torch.bool, device=candidates.device)
267268
tree_valid = tree_valid.contiguous()
268269

269-
if top_k is None:
270+
if top_k_max is not None:
271+
# Pre-computed CPU-side (CUDA-graph-safe): use as-is.
272+
pass
273+
elif top_k is None:
270274
top_k_max = 0
271275
else:
276+
# Fallback path (non-CUDA-graph contexts): compute from tensor.
272277
enabled_top_k = top_k[(top_k > 0) & (top_k < target_vocab_size)]
273278
top_k_max = int(enabled_top_k.max().item()) if enabled_top_k.numel() > 0 else 0
274279

0 commit comments

Comments
 (0)