Skip to content

Commit d237690

Browse files
[TRTLLM-12669][refactor] Replace allow_advanced_sampling with auto-detected dual-graph dispatch
Remove the static `allow_advanced_sampling` config flag and replace it with a per-step auto-detected `is_all_greedy_sample` boolean on SpecMetadata. The flag is computed in `populate_sampling_params_for_one_model` from the actual temperature/top_k/top_p of every request in the batch. `is_all_greedy_sample` is included in the CUDA graph key so we lazily capture two graph variants (argmax fast-path vs advanced sampling kernel) and dispatch by replaying the right one based on the current batch composition. Both variants stay CUDA-graph-compatible because the dispatch is a host-side decision outside the captured region. Additional optimizations for the all-greedy batch (the common default): - Populate skips per-token list building and 6 H->D copies entirely. - Rejection sampling is bypassed (argmax is equivalent for all-greedy) in both linear and dynamic-tree paths. - _compute_and_store_draft_probs is skipped, saving a softmax pass and draft-probs copy. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 83ec591 commit d237690

13 files changed

Lines changed: 86 additions & 68 deletions

File tree

examples/llm-api/quickstart_advanced.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,6 @@ def add_llm_args(parser):
190190
default=False,
191191
action='store_true')
192192
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
193-
parser.add_argument('--allow_advanced_sampling',
194-
default=False,
195-
action='store_true')
196193
parser.add_argument('--eagle3_model_arch',
197194
type=str,
198195
default="llama3",
@@ -294,7 +291,6 @@ def setup_llm(args, **kwargs):
294291
eagle_choices=args.eagle_choices,
295292
use_dynamic_tree=args.use_dynamic_tree,
296293
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
297-
allow_advanced_sampling=args.allow_advanced_sampling,
298294
eagle3_model_arch=args.eagle3_model_arch,
299295
max_total_draft_tokens=args.max_total_draft_tokens)
300296
elif spec_decode_algo == "DFLASH":

examples/models/core/nemotron/README_nemotron_super_v3.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ kv_cache_config:
144144
speculative_config:
145145
decoding_type: MTP
146146
max_draft_len: 5
147-
allow_advanced_sampling: true
148147
cuda_graph_config:
149148
max_batch_size: 64
150149
enable_padding: true

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# A large prime number used for dummy request IDs to avoid collisions
2828
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1
29-
KeyType: TypeAlias = Tuple[int, int, bool, bool]
29+
KeyType: TypeAlias = Tuple[int, int, bool, bool, bool]
3030

3131

3232
@dataclass
@@ -201,19 +201,28 @@ def get_graph_key(
201201
self,
202202
batch: ScheduledRequests,
203203
new_tensors_device: Optional[SampleStateTensors] = None,
204-
spec_resource_manager: Optional[BaseResourceManager] = None):
204+
spec_resource_manager: Optional[BaseResourceManager] = None,
205+
spec_metadata: Optional[Any] = None):
205206
batch_size = batch.batch_size
206207

207208
# Get the sequence length mode.
208209
short_seq_len_mode = self._get_seq_len_mode(batch, new_tensors_device)
209210

211+
# Spec one-engine sampler has two code paths (argmax fast-path vs
212+
# advanced sampling kernel). Include this in the key so we capture
213+
# both variants and dispatch at replay based on actual batch state.
214+
# Default to True (greedy fast-path) when the metadata doesn't carry
215+
# this field (non-one-engine paths or non-spec batches).
216+
is_all_greedy_sample = bool(
217+
getattr(spec_metadata, "is_all_greedy_sample", True))
218+
210219
if self.config.is_draft_model and spec_resource_manager is not None and isinstance(
211220
spec_resource_manager, Eagle3ResourceManager):
212221
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
213222
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
214223
draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0
215224
key = (batch_size, draft_len, spec_resource_manager.is_first_draft,
216-
short_seq_len_mode)
225+
short_seq_len_mode, is_all_greedy_sample)
217226
else:
218227
# With dynamic spec decode, the draft length may be zero even when enable_spec_decode is True,
219228
# so we need to get the draft length from the batch instead of using enable_spec_decode.
@@ -223,7 +232,8 @@ def get_graph_key(
223232
draft_len = max(draft_len_list)
224233
assert len(
225234
set(draft_len_list)) == 1, "All draft lengths must be the same"
226-
key = (batch_size, draft_len, False, short_seq_len_mode)
235+
key = (batch_size, draft_len, False, short_seq_len_mode,
236+
is_all_greedy_sample)
227237
return key
228238

229239
def __del__(self):
@@ -268,7 +278,7 @@ def maybe_get_cuda_graph(
268278
if not self.enabled or not can_run_cuda_graph:
269279
return None, None, None
270280
key = self.get_graph_key(batch, new_tensors_device,
271-
spec_resource_manager)
281+
spec_resource_manager, spec_metadata)
272282

273283
if key in self.graphs:
274284
return self.graph_metadata[key][

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,6 @@ def create_py_executor(
406406
)
407407
llm_args.disable_overlap_scheduler = True
408408

409-
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
410-
if not spec_config.allow_advanced_sampling:
411-
logger.warning(
412-
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
413-
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
414-
)
415409
# Check FLASHINFER compatibility with one-engine speculative decoding
416410
if llm_args.attn_backend == "FLASHINFER":
417411
raise ValueError(

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,10 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata,
758758
gen_draft_tokens)
759759
next_draft_tokens[num_contexts:] = gen_draft_tokens
760760

761-
if spec_metadata.use_rejection_sampling and draft_logits_list:
761+
# Skip when the whole batch is greedy: _can_use_rejection_sampling will
762+
# bypass the rejection path anyway, so computing draft probs is wasted.
763+
if (spec_metadata.use_rejection_sampling and draft_logits_list
764+
and not spec_metadata.is_all_greedy_sample):
762765
d2t_param = getattr(draft_model.model, "d2t", None)
763766
spec_metadata.d2t = d2t_param.data if d2t_param is not None else None
764767
self._compute_and_store_draft_probs(draft_logits_list,

tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,13 @@ def _can_use_rejection_sampling(self, spec_metadata) -> bool:
950950
Returns:
951951
True if rejection sampling is enabled and the draft logit buffer is allocated
952952
"""
953-
return spec_metadata.use_rejection_sampling and self._draft_depth_logits_cat is not None
953+
# Skip rejection sampling when the whole batch is greedy: argmax is
954+
# equivalent and avoids the rejection kernel cost.
955+
return (
956+
spec_metadata.use_rejection_sampling
957+
and self._draft_depth_logits_cat is not None
958+
and not spec_metadata.is_all_greedy_sample
959+
)
954960

955961
def _finalize_dynamic_tree_verify_outputs(
956962
self,

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,14 @@ class SpecMetadata:
436436
# Always set by model_engine.forward() before any downstream code reads it.
437437
runtime_draft_len: int = 0
438438

439-
# For non-greedy sampling on 1-model.
440-
allow_advanced_sampling: bool = False
439+
# Auto-detected per step from populated sampling params:
440+
# True if every request is greedy (no temp/top_k/top_p) and we can take
441+
# the argmax fast-path. False if any request needs sampling.
442+
# Used as part of the CUDA graph key so we capture two variants
443+
# (greedy fast-path vs advanced sampling) and dispatch at replay.
444+
# Defaults to True so non-one-engine paths (where populate is a no-op)
445+
# never accidentally select the advanced graph variant.
446+
is_all_greedy_sample: bool = True
441447
# Whether to use rejection sampling for one-model speculative decoding.
442448
use_rejection_sampling: bool = False
443449
# Sampling parameters for non-greedy sampling (per-request)
@@ -515,29 +521,21 @@ def populate_sampling_params_for_one_model(
515521
self, requests: list["LlmRequest"]) -> None:
516522
"""
517523
Set up topp/topk/temperatures for 1-model sampler.
524+
525+
Scans sampling configs to set skip_*/is_all_greedy_sample flags. When
526+
any request needs sampling, also builds per-token/per-request lists
527+
and copies them to GPU buffers; all-greedy batches skip this entirely.
518528
"""
519529
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
520530
from tensorrt_llm.sampling_params import SamplingParams
521531

522-
if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine(
523-
):
532+
if not self.spec_dec_mode.use_one_engine():
524533
return
525534

526535
if self.temperatures is None:
527536
# Ensures determinism across ranks.
528537
torch.manual_seed(0)
529538

530-
temperatures = []
531-
top_ks = []
532-
top_ps = []
533-
request_temperatures = []
534-
request_top_ks = []
535-
request_top_ps = []
536-
top_k_enabled = False
537-
top_p_enabled = False
538-
has_greedy_requests = False
539-
temperature_enabled = False
540-
541539
# Need to use a very small value for temperature when disabled to avoid division by 0
542540
DISABLE_TEMP_VAL = 1e-5
543541
# Very large values disable topk.
@@ -583,6 +581,13 @@ def _normalize_request_sampling_params(
583581
is_greedy,
584582
)
585583

584+
# Phase 1: collect per-request flags and normalized values.
585+
per_request_normalized: list[tuple[float, int, float, int]] = []
586+
temperature_enabled = False
587+
top_k_enabled = False
588+
top_p_enabled = False
589+
has_greedy_requests = False
590+
586591
for request in requests:
587592
sampling_config = request.sampling_config
588593
temp_val = _first_or_none(sampling_config.temperature)
@@ -611,19 +616,24 @@ def _normalize_request_sampling_params(
611616
top_p_enabled |= use_top_p
612617
has_greedy_requests |= is_greedy
613618

614-
request_temperatures.append(temp_val)
615-
request_top_ks.append(tk_val)
616-
request_top_ps.append(tp_val)
617-
temperatures.extend(temp_val for _ in range(num_tokens))
618-
top_ks.extend(tk_val for _ in range(num_tokens))
619-
top_ps.extend(tp_val for _ in range(num_tokens))
619+
per_request_normalized.append(
620+
(temp_val, tk_val, tp_val, num_tokens))
621+
622+
self.skip_temperature = not temperature_enabled
623+
self.skip_top_k = not top_k_enabled
624+
self.skip_top_p = not top_p_enabled
625+
self.has_greedy_requests = has_greedy_requests
626+
# Used in the CUDA graph key to pick the argmax / advanced variant.
627+
self.is_all_greedy_sample = (self.skip_temperature and self.skip_top_k
628+
and self.skip_top_p)
620629

621630
tokens_per_request = (self.max_total_draft_tokens + 1 if
622631
self.is_spec_dec_tree else self.max_draft_len + 1)
623632
required_flat_size = tokens_per_request * self.max_num_requests
624633

625634
if self.temperatures is None or self.temperatures.numel(
626635
) < required_flat_size:
636+
# Allocate once; the captured graph reads from these stable addresses.
627637
self.temperatures = torch.ones(required_flat_size,
628638
dtype=torch.float32,
629639
device='cuda')
@@ -643,6 +653,27 @@ def _normalize_request_sampling_params(
643653
dtype=torch.float32,
644654
device='cuda')
645655

656+
# All-greedy: sampler takes the argmax branch (and rejection sampling
657+
# is also bypassed for all-greedy), so the buffers are never read.
658+
# Skip the H->D copies.
659+
if self.is_all_greedy_sample:
660+
return
661+
662+
# Phase 2: build per-token / per-request lists and copy to GPU.
663+
temperatures: list[float] = []
664+
top_ks: list[int] = []
665+
top_ps: list[float] = []
666+
request_temperatures: list[float] = []
667+
request_top_ks: list[int] = []
668+
request_top_ps: list[float] = []
669+
for temp_val, tk_val, tp_val, num_tokens in per_request_normalized:
670+
request_temperatures.append(temp_val)
671+
request_top_ks.append(tk_val)
672+
request_top_ps.append(tp_val)
673+
temperatures.extend(temp_val for _ in range(num_tokens))
674+
top_ks.extend(tk_val for _ in range(num_tokens))
675+
top_ps.extend(tp_val for _ in range(num_tokens))
676+
646677
self.temperatures[:len(temperatures)].copy_(torch.tensor(
647678
temperatures, dtype=torch.float32, pin_memory=prefer_pinned()),
648679
non_blocking=True)
@@ -669,10 +700,6 @@ def _normalize_request_sampling_params(
669700
pin_memory=prefer_pinned()),
670701
non_blocking=True,
671702
)
672-
self.skip_temperature = not temperature_enabled
673-
self.skip_top_k = not top_k_enabled
674-
self.skip_top_p = not top_p_enabled
675-
self.has_greedy_requests = has_greedy_requests
676703

677704

678705
class SpecWorkerBase(nn.Module, ABC):
@@ -1004,8 +1031,11 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts,
10041031

10051032
def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata,
10061033
num_contexts: int) -> bool:
1034+
# Skip rejection sampling when the whole batch is greedy: the
1035+
# accepted result is identical to argmax and the base path is cheaper.
10071036
return (spec_metadata.use_rejection_sampling
1008-
and spec_metadata.draft_probs_valid and num_contexts == 0)
1037+
and spec_metadata.draft_probs_valid and num_contexts == 0
1038+
and not spec_metadata.is_all_greedy_sample)
10091039

10101040
def _sample_and_accept_draft_tokens_rejection(
10111041
self,
@@ -1282,7 +1312,7 @@ def _sample_tokens_for_batch(
12821312
Returns:
12831313
sampled_tokens: [num_tokens] - Sampled token ids
12841314
"""
1285-
if spec_metadata.allow_advanced_sampling:
1315+
if not spec_metadata.is_all_greedy_sample:
12861316
num_gens = batch_size - num_contexts
12871317
num_tokens = num_contexts + num_gens * (
12881318
spec_metadata.runtime_draft_len + 1)

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def get_spec_metadata(spec_config,
5151
mtp_num_modules=spec_config.max_draft_len,
5252
max_num_requests=max_num_requests,
5353
mtp_hidden_states_manager=spec_resource_manager,
54-
allow_advanced_sampling=spec_config.allow_advanced_sampling,
5554
)
5655
if spec_config.spec_dec_mode.is_mtp_eagle():
5756
return Eagle3SpecMetadata(
@@ -97,7 +96,6 @@ def get_spec_metadata(spec_config,
9796
hidden_size=model_config.hidden_size,
9897
max_num_tokens=max_num_tokens,
9998
layers_to_capture=spec_config.eagle3_layers_to_capture,
100-
allow_advanced_sampling=spec_config.allow_advanced_sampling,
10199
use_rejection_sampling=use_rejection_sampling,
102100
vocab_size=vocab_size,
103101
spec_resource_manager=spec_resource_manager,
@@ -110,7 +108,6 @@ def get_spec_metadata(spec_config,
110108
max_total_draft_tokens=spec_config.tokens_per_gen_step - 1,
111109
spec_dec_mode=spec_config.spec_dec_mode,
112110
max_num_requests=max_num_requests,
113-
allow_advanced_sampling=spec_config.allow_advanced_sampling,
114111
spec_resource_manager=spec_resource_manager,
115112
)
116113
if spec_config.spec_dec_mode.is_dflash():
@@ -120,7 +117,6 @@ def get_spec_metadata(spec_config,
120117
max_total_draft_tokens=spec_config.tokens_per_gen_step - 1,
121118
spec_dec_mode=spec_config.spec_dec_mode,
122119
max_num_requests=max_num_requests,
123-
allow_advanced_sampling=spec_config.allow_advanced_sampling,
124120
layers_to_capture=target_layer_ids,
125121
hidden_size=model_config.hidden_size,
126122
max_num_tokens=max_num_tokens,
@@ -133,7 +129,6 @@ def get_spec_metadata(spec_config,
133129
spec_dec_mode=spec_config.spec_dec_mode,
134130
max_num_requests=max_num_requests,
135131
max_num_tokens=max_num_tokens,
136-
allow_advanced_sampling=spec_config.allow_advanced_sampling,
137132
)
138133
if spec_config.spec_dec_mode.is_save_hidden_states():
139134
return SaveHiddenStatesSpecMetadata(

tensorrt_llm/llmapi/llm_args.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -896,14 +896,6 @@ class DecodingBaseConfig(StrictBaseModel):
896896
"rolling average over the last N completed requests (N = acceptance_window) drops below this value. "
897897
"PyTorch backend only.")
898898

899-
allow_advanced_sampling: bool = Field(
900-
default=False,
901-
status="prototype",
902-
description=
903-
"If true, allows non-greedy sampling when speculation is used. Only applicable "
904-
"to 1-model code paths; non-greedy sampling is always enabled on 2-model paths."
905-
)
906-
907899
use_rejection_sampling: bool = Field(
908900
default=False,
909901
status="prototype",

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree,
201201
max_draft_len=4,
202202
speculative_model=eagle_model_dir,
203203
eagle3_one_model=True,
204-
allow_advanced_sampling=True,
205204
use_rejection_sampling=True,
206205
)
207206
max_batch_size = 1
@@ -5819,8 +5818,7 @@ def test_eagle3_4gpus(self, v2_kv_cache, moe_backend, one_model,
58195818
draft_len = 3
58205819
spec_config = Eagle3DecodingConfig(max_draft_len=draft_len,
58215820
speculative_model=eagle_model_dir,
5822-
eagle3_one_model=one_model,
5823-
allow_advanced_sampling=True)
5821+
eagle3_one_model=one_model)
58245822

58255823
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
58265824
llm = LLM(self.MODEL_PATH,
@@ -5885,8 +5883,7 @@ def test_eagle3_vswa_reuse_4gpus(self, one_model, mocker):
58855883
draft_len = 3
58865884
spec_config = Eagle3DecodingConfig(max_draft_len=draft_len,
58875885
speculative_model=eagle_model_dir,
5888-
eagle3_one_model=one_model,
5889-
allow_advanced_sampling=True)
5886+
eagle3_one_model=one_model)
58905887

58915888
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
58925889
llm = LLM(self.MODEL_PATH,
@@ -5949,8 +5946,7 @@ def test_eagle3_guided_decoding_4gpus(self, one_model, mocker):
59495946
draft_len = 3
59505947
spec_config = Eagle3DecodingConfig(max_draft_len=draft_len,
59515948
speculative_model=eagle_model_dir,
5952-
eagle3_one_model=one_model,
5953-
allow_advanced_sampling=True)
5949+
eagle3_one_model=one_model)
59545950

59555951
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
59565952
llm = LLM(self.MODEL_PATH,

0 commit comments

Comments
 (0)