From 5f4824a7cf39d1b8cce717def07a40fe5f90bb9d Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Fri, 29 May 2026 03:05:34 -0700 Subject: [PATCH 01/21] [TRTLLM-12669][refactor] Replace allow_advanced_sampling with auto-detected dual-graph dispatch Remove the static `allow_advanced_sampling` config flag and replace it with a per-step auto-detected `is_all_greedy_sample` boolean on SpecMetadata. The flag is computed in `populate_sampling_params_for_one_model` from the actual temperature/top_k/top_p of every request in the batch. `is_all_greedy_sample` is included in the CUDA graph key so we lazily capture two graph variants (argmax fast-path vs advanced sampling kernel) and dispatch by replaying the right one based on the current batch composition. Both variants stay CUDA-graph-compatible because the dispatch is a host-side decision outside the captured region. Additional optimizations for the all-greedy batch (the common default): - Populate skips per-token list building and 6 H->D copies entirely. - Rejection sampling is bypassed (argmax is equivalent for all-greedy) in both linear and dynamic-tree paths. - _compute_and_store_draft_probs is skipped, saving a softmax pass and draft-probs copy. Signed-off-by: ZhaoyangWang --- examples/llm-api/quickstart_advanced.py | 4 - .../core/nemotron/README_nemotron_super_v3.md | 1 - .../_torch/pyexecutor/cuda_graph_runner.py | 20 +++-- .../_torch/pyexecutor/py_executor_creator.py | 6 -- tensorrt_llm/_torch/speculative/eagle3.py | 5 +- .../_torch/speculative/eagle3_dynamic_tree.py | 8 +- tensorrt_llm/_torch/speculative/interface.py | 84 +++++++++++++------ tensorrt_llm/_torch/speculative/utils.py | 5 -- tensorrt_llm/llmapi/llm_args.py | 8 -- .../defs/accuracy/test_llm_api_pytorch.py | 10 +-- .../Nemotron3_Super_120B_NVFP4.yml | 1 - .../defs/perf/pytorch_model_config.py | 1 - .../_torch/speculative/test_eagle3.py | 1 - 13 files changed, 86 insertions(+), 68 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 10901d87c520..8c449283aa2e 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -190,9 +190,6 @@ def add_llm_args(parser): default=False, action='store_true') parser.add_argument('--dynamic_tree_max_topK', type=int, default=None) - parser.add_argument('--allow_advanced_sampling', - default=False, - action='store_true') parser.add_argument('--eagle3_model_arch', type=str, default="llama3", @@ -294,7 +291,6 @@ def setup_llm(args, **kwargs): eagle_choices=args.eagle_choices, use_dynamic_tree=args.use_dynamic_tree, dynamic_tree_max_topK=args.dynamic_tree_max_topK, - allow_advanced_sampling=args.allow_advanced_sampling, eagle3_model_arch=args.eagle3_model_arch, max_total_draft_tokens=args.max_total_draft_tokens) elif spec_decode_algo == "DFLASH": diff --git a/examples/models/core/nemotron/README_nemotron_super_v3.md b/examples/models/core/nemotron/README_nemotron_super_v3.md index e78992359c19..1e59febce82e 100644 --- a/examples/models/core/nemotron/README_nemotron_super_v3.md +++ b/examples/models/core/nemotron/README_nemotron_super_v3.md @@ -144,7 +144,6 @@ kv_cache_config: speculative_config: decoding_type: MTP max_draft_len: 5 - allow_advanced_sampling: true cuda_graph_config: max_batch_size: 64 enable_padding: true diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 69bd91cd2178..f69956eaf158 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -29,7 +29,7 @@ # A large prime number used for dummy request IDs to avoid collisions CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1 -KeyType: TypeAlias = Tuple[int, int, bool, bool] +KeyType: TypeAlias = Tuple[int, int, bool, bool, bool] @dataclass @@ -197,19 +197,28 @@ def get_graph_key( self, batch: ScheduledRequests, new_tensors_device: Optional[SampleStateTensors] = None, - spec_resource_manager: Optional[BaseResourceManager] = None): + spec_resource_manager: Optional[BaseResourceManager] = None, + spec_metadata: Optional[Any] = None): batch_size = batch.batch_size # Get the sequence length mode. short_seq_len_mode = self._get_seq_len_mode(batch, new_tensors_device) + # Spec one-engine sampler has two code paths (argmax fast-path vs + # advanced sampling kernel). Include this in the key so we capture + # both variants and dispatch at replay based on actual batch state. + # Default to True (greedy fast-path) when the metadata doesn't carry + # this field (non-one-engine paths or non-spec batches). + is_all_greedy_sample = bool( + getattr(spec_metadata, "is_all_greedy_sample", True)) + if self.config.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): # If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'. # Because we will pad the input to 'max_draft_len' length for the first draft layer. draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0 key = (batch_size, draft_len, spec_resource_manager.is_first_draft, - short_seq_len_mode) + short_seq_len_mode, is_all_greedy_sample) else: # With dynamic spec decode, the draft length may be zero even when enable_spec_decode is True, # so we need to get the draft length from the batch instead of using enable_spec_decode. @@ -219,7 +228,8 @@ def get_graph_key( draft_len = max(draft_len_list) assert len( set(draft_len_list)) == 1, "All draft lengths must be the same" - key = (batch_size, draft_len, False, short_seq_len_mode) + key = (batch_size, draft_len, False, short_seq_len_mode, + is_all_greedy_sample) return key def __del__(self): @@ -273,7 +283,7 @@ def maybe_get_cuda_graph( # can replay CUDA graphs using the cache. return None, None, None key = self.get_graph_key(batch, new_tensors_device, - spec_resource_manager) + spec_resource_manager, spec_metadata) if key in self.graphs: return self.graph_metadata[key][ diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 25ab939758a9..797f2fd48666 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -406,12 +406,6 @@ def create_py_executor( ) llm_args.disable_overlap_scheduler = True - if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(): - if not spec_config.allow_advanced_sampling: - logger.warning( - f"Falling back to greedy decoding for {spec_config.decoding_type}. If you " - "want to use non-greedy sampling, please set allow_advanced_sampling=True." - ) # Check FLASHINFER compatibility with one-engine speculative decoding if llm_args.attn_backend == "FLASHINFER": raise ValueError( diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 6acef9ed348f..c0fb22bfefe7 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -952,7 +952,10 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, gen_draft_tokens) next_draft_tokens[num_contexts:] = gen_draft_tokens - if spec_metadata.use_rejection_sampling and draft_logits_list: + # Skip when the whole batch is greedy: _can_use_rejection_sampling will + # bypass the rejection path anyway, so computing draft probs is wasted. + if (spec_metadata.use_rejection_sampling and draft_logits_list + and not spec_metadata.is_all_greedy_sample): d2t_param = getattr(draft_model.model, "d2t", None) spec_metadata.d2t = d2t_param.data if d2t_param is not None else None self._compute_and_store_draft_probs(draft_logits_list, diff --git a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py index 47376001166d..ddfd62812eb5 100644 --- a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py +++ b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py @@ -967,7 +967,13 @@ def _can_use_rejection_sampling(self, spec_metadata) -> bool: Returns: True if rejection sampling is enabled and the draft logit buffer is allocated """ - return spec_metadata.use_rejection_sampling and self._draft_depth_logits_cat is not None + # Skip rejection sampling when the whole batch is greedy: argmax is + # equivalent and avoids the rejection kernel cost. + return ( + spec_metadata.use_rejection_sampling + and self._draft_depth_logits_cat is not None + and not spec_metadata.is_all_greedy_sample + ) def _finalize_dynamic_tree_verify_outputs( self, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index c62111f0f511..07ddfaae01bc 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -454,8 +454,14 @@ class SpecMetadata: # Always set by model_engine.forward() before any downstream code reads it. runtime_draft_len: int = 0 - # For non-greedy sampling on 1-model. - allow_advanced_sampling: bool = False + # Auto-detected per step from populated sampling params: + # True if every request is greedy (no temp/top_k/top_p) and we can take + # the argmax fast-path. False if any request needs sampling. + # Used as part of the CUDA graph key so we capture two variants + # (greedy fast-path vs advanced sampling) and dispatch at replay. + # Defaults to True so non-one-engine paths (where populate is a no-op) + # never accidentally select the advanced graph variant. + is_all_greedy_sample: bool = True # Whether to use rejection sampling for one-model speculative decoding. use_rejection_sampling: bool = False # Sampling parameters for non-greedy sampling (per-request) @@ -533,29 +539,21 @@ def populate_sampling_params_for_one_model( self, requests: list["LlmRequest"]) -> None: """ Set up topp/topk/temperatures for 1-model sampler. + + Scans sampling configs to set skip_*/is_all_greedy_sample flags. When + any request needs sampling, also builds per-token/per-request lists + and copies them to GPU buffers; all-greedy batches skip this entirely. """ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState from tensorrt_llm.sampling_params import SamplingParams - if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine( - ): + if not self.spec_dec_mode.use_one_engine(): return if self.temperatures is None: # Ensures determinism across ranks. torch.manual_seed(0) - temperatures = [] - top_ks = [] - top_ps = [] - request_temperatures = [] - request_top_ks = [] - request_top_ps = [] - top_k_enabled = False - top_p_enabled = False - has_greedy_requests = False - temperature_enabled = False - # Need to use a very small value for temperature when disabled to avoid division by 0 DISABLE_TEMP_VAL = 1e-5 # Very large values disable topk. @@ -601,6 +599,13 @@ def _normalize_request_sampling_params( is_greedy, ) + # Phase 1: collect per-request flags and normalized values. + per_request_normalized: list[tuple[float, int, float, int]] = [] + temperature_enabled = False + top_k_enabled = False + top_p_enabled = False + has_greedy_requests = False + for request in requests: sampling_config = request.sampling_config temp_val = _first_or_none(sampling_config.temperature) @@ -629,12 +634,16 @@ def _normalize_request_sampling_params( top_p_enabled |= use_top_p has_greedy_requests |= is_greedy - request_temperatures.append(temp_val) - request_top_ks.append(tk_val) - request_top_ps.append(tp_val) - temperatures.extend(temp_val for _ in range(num_tokens)) - top_ks.extend(tk_val for _ in range(num_tokens)) - top_ps.extend(tp_val for _ in range(num_tokens)) + per_request_normalized.append( + (temp_val, tk_val, tp_val, num_tokens)) + + self.skip_temperature = not temperature_enabled + self.skip_top_k = not top_k_enabled + self.skip_top_p = not top_p_enabled + self.has_greedy_requests = has_greedy_requests + # Used in the CUDA graph key to pick the argmax / advanced variant. + self.is_all_greedy_sample = (self.skip_temperature and self.skip_top_k + and self.skip_top_p) tokens_per_request = (self.max_total_draft_tokens + 1 if self.is_spec_dec_tree else self.max_draft_len + 1) @@ -642,6 +651,7 @@ def _normalize_request_sampling_params( if self.temperatures is None or self.temperatures.numel( ) < required_flat_size: + # Allocate once; the captured graph reads from these stable addresses. self.temperatures = torch.ones(required_flat_size, dtype=torch.float32, device='cuda') @@ -661,6 +671,27 @@ def _normalize_request_sampling_params( dtype=torch.float32, device='cuda') + # All-greedy: sampler takes the argmax branch (and rejection sampling + # is also bypassed for all-greedy), so the buffers are never read. + # Skip the H->D copies. + if self.is_all_greedy_sample: + return + + # Phase 2: build per-token / per-request lists and copy to GPU. + temperatures: list[float] = [] + top_ks: list[int] = [] + top_ps: list[float] = [] + request_temperatures: list[float] = [] + request_top_ks: list[int] = [] + request_top_ps: list[float] = [] + for temp_val, tk_val, tp_val, num_tokens in per_request_normalized: + request_temperatures.append(temp_val) + request_top_ks.append(tk_val) + request_top_ps.append(tp_val) + temperatures.extend(temp_val for _ in range(num_tokens)) + top_ks.extend(tk_val for _ in range(num_tokens)) + top_ps.extend(tp_val for _ in range(num_tokens)) + self.temperatures[:len(temperatures)].copy_(torch.tensor( temperatures, dtype=torch.float32, pin_memory=prefer_pinned()), non_blocking=True) @@ -687,10 +718,6 @@ def _normalize_request_sampling_params( pin_memory=prefer_pinned()), non_blocking=True, ) - self.skip_temperature = not temperature_enabled - self.skip_top_k = not top_k_enabled - self.skip_top_p = not top_p_enabled - self.has_greedy_requests = has_greedy_requests class SpecWorkerBase(nn.Module, ABC): @@ -1029,8 +1056,11 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts, def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata, num_contexts: int) -> bool: + # Skip rejection sampling when the whole batch is greedy: the + # accepted result is identical to argmax and the base path is cheaper. return (spec_metadata.use_rejection_sampling - and spec_metadata.draft_probs_valid and num_contexts == 0) + and spec_metadata.draft_probs_valid and num_contexts == 0 + and not spec_metadata.is_all_greedy_sample) def _sample_and_accept_draft_tokens_rejection( self, @@ -1307,7 +1337,7 @@ def _sample_tokens_for_batch( Returns: sampled_tokens: [num_tokens] - Sampled token ids """ - if spec_metadata.allow_advanced_sampling: + if not spec_metadata.is_all_greedy_sample: num_gens = batch_size - num_contexts num_tokens = num_contexts + num_gens * ( spec_metadata.runtime_draft_len + 1) diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 9c4284878b06..8f132bbcdb22 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -71,7 +71,6 @@ def get_spec_metadata(spec_config, mtp_num_modules=spec_config.max_draft_len, max_num_requests=max_num_requests, mtp_hidden_states_manager=spec_resource_manager, - allow_advanced_sampling=spec_config.allow_advanced_sampling, ) if spec_config.spec_dec_mode.is_mtp_eagle(): return Eagle3SpecMetadata( @@ -117,7 +116,6 @@ def get_spec_metadata(spec_config, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, layers_to_capture=spec_config.eagle3_layers_to_capture, - allow_advanced_sampling=spec_config.allow_advanced_sampling, use_rejection_sampling=use_rejection_sampling, vocab_size=vocab_size, spec_resource_manager=spec_resource_manager, @@ -130,7 +128,6 @@ def get_spec_metadata(spec_config, max_total_draft_tokens=spec_config.tokens_per_gen_step - 1, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, - allow_advanced_sampling=spec_config.allow_advanced_sampling, spec_resource_manager=spec_resource_manager, ) if spec_config.spec_dec_mode.is_dflash(): @@ -140,7 +137,6 @@ def get_spec_metadata(spec_config, max_total_draft_tokens=spec_config.tokens_per_gen_step - 1, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, - allow_advanced_sampling=spec_config.allow_advanced_sampling, layers_to_capture=target_layer_ids, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, @@ -153,7 +149,6 @@ def get_spec_metadata(spec_config, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, max_num_tokens=max_num_tokens, - allow_advanced_sampling=spec_config.allow_advanced_sampling, ) if spec_config.spec_dec_mode.is_save_hidden_states(): return SaveHiddenStatesSpecMetadata( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6425cb7d2f2d..6541dbfcd5ed 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1120,14 +1120,6 @@ class DecodingBaseConfig(StrictBaseModel): "rolling average over the last N completed requests (N = acceptance_window) drops below this value. " "PyTorch backend only.") - allow_advanced_sampling: bool = Field( - default=False, - status="prototype", - description= - "If true, allows non-greedy sampling when speculation is used. Only applicable " - "to 1-model code paths; non-greedy sampling is always enabled on 2-model paths." - ) - use_rejection_sampling: bool = Field( default=False, status="prototype", diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 41e3d93944ea..7a72ffebc88f 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -201,7 +201,6 @@ def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree, max_draft_len=4, speculative_model=eagle_model_dir, eagle3_one_model=True, - allow_advanced_sampling=True, use_rejection_sampling=True, ) max_batch_size = 1 @@ -5435,8 +5434,7 @@ def test_eagle3_4gpus(self, v2_kv_cache, moe_backend, one_model, draft_len = 3 spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -5501,8 +5499,7 @@ def test_eagle3_vswa_reuse_4gpus(self, one_model, mocker): draft_len = 3 spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -5565,8 +5562,7 @@ def test_eagle3_guided_decoding_4gpus(self, one_model, mocker): draft_len = 3 spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, diff --git a/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml b/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml index 72a1df6265e7..ed1cf2667329 100644 --- a/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml +++ b/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml @@ -23,4 +23,3 @@ print_iter_log: true speculative_config: decoding_type: MTP num_nextn_predict_layers: 3 - allow_advanced_sampling: true diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index 66d1dd334002..716dc6ef46df 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -526,7 +526,6 @@ def get_model_yaml_config(model_label: str, 'speculative_config': { 'decoding_type': 'MTP', 'num_nextn_predict_layers': 3, - 'allow_advanced_sampling': True, }, } }, diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 8a393f21ad35..29459e4903dc 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -864,7 +864,6 @@ def test_llama_eagle3_rejection_sampling_modes(use_dynamic_tree: bool, max_draft_len=max_draft_len, speculative_model=eagle_model, eagle3_one_model=True, - allow_advanced_sampling=True, use_rejection_sampling=True, ) if use_dynamic_tree: From 87300b02d9558f245996a413086e9982631e5f42 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Mon, 1 Jun 2026 02:44:49 -0700 Subject: [PATCH 02/21] [TRTLLM-12669][feat] Eagle3 one-model draft sampling honors target sampling params The Eagle3 one-model draft loop was hardcoded to greedy argmax, even when the target sampler used non-greedy (temperature/top_k/top_p) parameters. This made rejection sampling math degenerate: with draft forced to argmax, p_draft is a one-hot on the argmax token, so p_target / p_draft is zero everywhere else and acceptance is biased. Route the linear draft loop through a new `_draft_sampler_advanced` on `SpecWorkerBase` that reuses the per-request sampling tensors already populated by `populate_sampling_params_for_one_model` (request_temperatures / request_top_ks / request_top_ps) and the shared seed/offset used by the target sampler. When the batch is all-greedy (`is_all_greedy_sample=True`) it short-circuits to `_draft_sampler_greedy`, matching the target sampler's argmax fast-path so existing CUDA-graph variants are unaffected. Other spec modes (Draft-Target, MTP, Pard) are untouched and keep their current argmax draft sampling. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 29 ++++++++++ tensorrt_llm/_torch/speculative/interface.py | 56 ++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index c0fb22bfefe7..7c022a0585bd 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -887,6 +887,8 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, if not self.is_mtp_eagle and spec_metadata.use_rejection_sampling: draft_logits_list.append(logits.clone()) + new_draft_token = self.draft_decoder(logits, draft_model, + spec_metadata, batch_size) next_draft_tokens.append(new_draft_token) # Update hidden states for the next iteration. @@ -1177,6 +1179,33 @@ def sample_and_accept_draft_tokens( return self._accept_draft_tokens(logits, draft_tokens, num_contexts, batch_size, spec_metadata) + def draft_decoder( + self, + logits: torch.Tensor, + draft_model: nn.Module, + spec_metadata: Optional[Eagle3OneModelSpecMetadata] = None, + batch_size: Optional[int] = None, + ): + ''' + Sample draft tokens. When spec_metadata + batch_size are provided, use + the target's per-request sampling params (temperature/top_k/top_p); + otherwise fall back to argmax. + + Args: + logits: [batch_size, vocab_size] - Draft model logits. + draft_model: The draft model. + spec_metadata: Carries per-request sampling param tensors. When + None, sampling is forced greedy. + batch_size: Active requests, used to slice per-request tensors. + ''' + + d2t = getattr(draft_model.model, "d2t", None) + if spec_metadata is not None and batch_size is not None: + return self._draft_sampler_advanced(logits, spec_metadata, + batch_size, d2t) + return self._draft_sampler_greedy(logits, d2t) + + def prepare_1st_drafter_inputs( self, input_ids: torch.LongTensor, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 07ddfaae01bc..2388959ae51c 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -1166,6 +1166,62 @@ def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None): return draft_tokens.type(torch.int32) + def _draft_sampler_advanced( + self, + logits: torch.Tensor, + spec_metadata: "SpecMetadata", + batch_size: int, + d2t: Optional[torch.Tensor] = None, + ): + """ + Draft token sampling using per-request sampling parameters from the + target's sampling config. Falls back to argmax when the batch is + all-greedy. + + Args: + logits: [batch_size, vocab_size] - Draft model logits (one row per + request, since each draft step emits one token per request). + spec_metadata: Source of per-request temperatures / top_k / top_p + tensors populated by populate_sampling_params_for_one_model. + batch_size: Number of active requests in the batch. + d2t: Optional dictionary offset tensor for vocab mapping. + + Returns: + draft_tokens: [batch_size] - Sampled draft token ids (int32) + """ + if spec_metadata.is_all_greedy_sample: + return self._draft_sampler_greedy(logits, d2t) + + temperatures = spec_metadata.request_temperatures[:batch_size] + top_ks = spec_metadata.request_top_ks[:batch_size] + top_ps = spec_metadata.request_top_ps[:batch_size] + + if self.use_flashinfer: + top_ks = top_ks.clamp(min=1, max=logits.shape[-1] - 1) + if self.seed is None: + self.seed = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.offset = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.seed += 1 + self.seed %= (2**31) + + draft_tokens = sampling_batch_spec_dec_one_model( + logits, + temperatures, + top_ks, + top_ps, + use_flashinfer=self.use_flashinfer, + seed=self.seed, + offset=self.offset) + + if d2t is not None: + draft_tokens = d2t[draft_tokens] + draft_tokens + + return draft_tokens.type(torch.int32) + def _compute_and_store_draft_probs( self, draft_logits_list: List[torch.Tensor], From 69a43681424f8d4c9064dc6150e9e046b9088a46 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Tue, 2 Jun 2026 03:06:42 -0700 Subject: [PATCH 03/21] [TRTLLM-12669][feat] Enable rejection sampling by default for Eagle3 one-model Flip the default of `use_rejection_sampling` from `False` to `True` on DecodingBaseConfig. With the refactor of the all-greedy fast path in place, this is safe: the runtime guard in `_can_use_rejection_sampling` still requires a non-greedy batch, so all-greedy batches keep taking the argmax fast path unchanged. Only batches that already opted into non-greedy sampling now see the rejection sampling acceptance behavior. Benchmark results on Qwen3-235B-A22B + Eagle3 (tp=8) show consistent +6.4% to +9.4% throughput and +3.4 to +4.3 pp acceptance rate across batch sizes 1-16 vs the exact-match baseline. Other Eagle3 deployments see smaller but uniformly positive acceptance-rate gains. Two prior `raise ValueError` paths are converted to silent fallbacks so the new default does not break existing users: - Non-Eagle3 spec configs (PARD, DFlash, MTP, ...) silently disable the flag in TorchLlmArgs post-validation, since rejection sampling is only wired up for Eagle3 one-model paths. - SA-enhanced Eagle3 configs disable the flag in the per-config validator, since SA may override proposed draft tokens. Users who want the prior exact-match behavior can still pass `use_rejection_sampling=False` explicitly. Signed-off-by: ZhaoyangWang --- tensorrt_llm/llmapi/llm_args.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6541dbfcd5ed..521e4a438c5f 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1121,11 +1121,13 @@ class DecodingBaseConfig(StrictBaseModel): "PyTorch backend only.") use_rejection_sampling: bool = Field( - default=False, + default=True, status="prototype", description= - "If true, enables rejection sampling for one-model speculative decoding paths. " - "This is intended for non-greedy sampling configurations on the PyTorch backend. " + "If true (default), enables rejection sampling for one-model speculative " + "decoding paths when the batch contains any non-greedy request. All-greedy " + "batches always take the argmax fast path regardless of this flag. Set to " + "false to fall back to exact-match verification on non-greedy batches. " "The non-dynamic-tree one-model path requires FlashInfer.") # If set, drafting is allowed to use chain drafter. @@ -1182,13 +1184,14 @@ def validate_draft_len_schedule_and_sort(cls, v, info): @model_validator(mode='after') def validate_rejection_sampling_config(self): - """Reject SA-enhanced configurations that invalidate rejection sampling.""" + """Disable rejection sampling when SA-enhanced configurations are + active, since SA may override the proposed draft tokens. This is a + silent fallback so the new default (True) does not break sa_config + users. + """ if self.use_rejection_sampling and getattr(self, 'sa_config', None) is not None: - raise ValueError( - "use_rejection_sampling is incompatible with sa_config " - "because SA enhancement may override the proposed draft tokens." - ) + self.use_rejection_sampling = False return self @model_validator(mode='after') @@ -4527,12 +4530,12 @@ def validate_speculative_config(self): exclude={"decoding_type"}) self.speculative_config = Eagle3DecodingConfig(**eagle_data) - if self.speculative_config.use_rejection_sampling: - if not isinstance(self.speculative_config, - Eagle3DecodingConfig): - raise ValueError( - "use_rejection_sampling is only supported for " - "PyTorch Eagle3 one-model speculative decoding paths.") + if self.speculative_config.use_rejection_sampling and not isinstance( + self.speculative_config, Eagle3DecodingConfig): + # Rejection sampling is only wired up for Eagle3 one-model paths. + # Silently fall back for other spec types so the new default + # (True) does not break them. + self.speculative_config.use_rejection_sampling = False if isinstance(self.speculative_config, PARDDecodingConfig): assert self.speculative_config.max_draft_len > 0, "PARD max_draft_len must be > 0" From 376c2de1e1f5a376c1468b171467c0875d685b13 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Tue, 2 Jun 2026 05:41:33 -0700 Subject: [PATCH 04/21] [TRTLLM-12669][feat] Slot-index draft_probs and support mixed-batch rejection The rejection-sampling acceptance path used to gate on `num_contexts == 0`, so any mixed batch (chunked prefill + decode, ramp-up step with new ctx joining, etc.) silently fell back to exact-match verification. The underlying reason was that `spec_metadata.draft_probs` was a flat buffer indexed by batch position at write time: when batch composition shifted across iterations (chunking ctxs polluting the prefix, gen completions leaving holes, new ctxs inserted in the ctx region), the row at buffer index `i` no longer reliably mapped to the request now at batch position `i`. Refactor `draft_probs` to be slot-indexed, matching the convention `next_draft_tokens` already uses on `SampleStateSpec.store`: - `SpecMetadata.draft_probs` is reshaped to `[max_num_requests, max_draft_len, vocab_size]`, addressed by `py_seq_slot`. - A new `SpecMetadata.batch_slot_ids` device tensor carries the current batch's slot ids in batch order; it is populated alongside the other per-request sampling-param tensors in `populate_sampling_params_for_one_model` and is always refreshed when rejection sampling is configured, even for all-greedy batches (it is tiny relative to the per-token buffers we skip). - `_compute_and_store_draft_probs` scatters into `draft_probs[batch_slot_ids, ...]` so each request's data lands at its own stable slot row. - `_accept_draft_tokens` gathers `draft_probs[gen_slot_ids, ...]` for the gen subset of the current batch, so it always reads the per-request data written in the most recent iter that ran the draft loop for that slot. - `_sample_and_accept_draft_tokens_rejection` now accepts `num_contexts` and splits ctx / gen subsets: ctx rows go through `_sample_tokens_for_batch` (no draft tokens to verify); gen rows feed the slot-gathered draft probs into the unchanged `rejection_sampling_one_model` kernel. - `_can_use_rejection_sampling` drops the `num_contexts == 0` constraint now that mixed batches are handled correctly via slot indexing. - When an iter skips the draft-probs store (all-greedy or empty draft loop), `draft_probs_valid` is reset to False so the following iter cannot read stale data if it transitions back to a non-greedy mix. The all-gen case is byte-equivalent to the prior implementation: gen subset gathering by slot id collapses to the same data and the per-token buffer slicing for gen rows is unchanged. CUDA graph capture is unaffected: `scheduler.can_run_cuda_graph` already requires `num_context_requests == 0`, so the new slicing code paths exercise exclusively in eager mode. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 5 + tensorrt_llm/_torch/speculative/interface.py | 268 +++++++++++++------ 2 files changed, 187 insertions(+), 86 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 7c022a0585bd..c1fd0fac3bf3 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -962,6 +962,11 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, spec_metadata.d2t = d2t_param.data if d2t_param is not None else None self._compute_and_store_draft_probs(draft_logits_list, spec_metadata, batch_size) + elif spec_metadata.use_rejection_sampling: + # No draft probs were written this iter (all-greedy or empty draft + # loop). Invalidate the buffer so the next iter does not read stale + # data if it transitions back to a non-greedy mix. + spec_metadata.draft_probs_valid = False return next_draft_tokens diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 2388959ae51c..65471331e1a1 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -481,13 +481,20 @@ class SpecMetadata: use_sampling_params_for_draft_tokens: bool = False # Vocab size used for draft_probs buffer allocation. vocab_size: int = 0 - # Draft probabilities buffer for rejection sampling, stored flat. + # Draft probabilities buffer for rejection sampling, indexed by py_seq_slot + # so per-request data is stable across iterations regardless of batch + # composition shifts (chunking ctx, gen completion, new ctx joining). + # Shape: [max_num_requests, max_draft_len, vocab_size]. draft_probs: Optional[torch.Tensor] = None draft_probs_vocab_size: int = 0 # Whether draft_probs contains valid data. draft_probs_valid: bool = False # Last dimension size of the draft logits/probs stored in draft_probs. draft_probs_last_dim: int = 0 + # Per-request slot ids (py_seq_slot) for the current batch, in batch order. + # Used to scatter draft probs by slot at write time and gather them by slot + # at the next iter's verify. Shape: [max_num_requests], dtype=long. + batch_slot_ids: Optional[torch.Tensor] = None # Draft-to-target vocab offset tensor. d2t: Optional[torch.Tensor] = None @@ -500,12 +507,18 @@ def prepare(self): """ if (self.use_rejection_sampling and self.draft_probs is None and self.vocab_size > 0): - buffer_size = (self.max_num_requests * self.max_draft_len * - self.vocab_size) - self.draft_probs = torch.empty(buffer_size, - dtype=torch.float32, - device='cuda') + # 3D [slot, draft_step, vocab] so we can scatter/gather by slot id + # and avoid the brittle "batch position == buffer position" mapping. + self.draft_probs = torch.empty( + (self.max_num_requests, self.max_draft_len, self.vocab_size), + dtype=torch.float32, + device='cuda') self.draft_probs_vocab_size = self.vocab_size + if (self.use_rejection_sampling and self.batch_slot_ids is None + and self.max_num_requests > 0): + self.batch_slot_ids = torch.empty((self.max_num_requests, ), + dtype=torch.long, + device='cuda') def create_cuda_graph_metadata(self, max_batch_size: int): """ @@ -605,6 +618,7 @@ def _normalize_request_sampling_params( top_k_enabled = False top_p_enabled = False has_greedy_requests = False + per_request_slot_ids: list[int] = [] for request in requests: sampling_config = request.sampling_config @@ -636,6 +650,12 @@ def _normalize_request_sampling_params( per_request_normalized.append( (temp_val, tk_val, tp_val, num_tokens)) + # py_seq_slot is a stable per-request id used to scatter / gather + # draft probs across iterations. Dummies / unallocated slots fall + # back to 0 (any valid index is fine — the data at that slot will + # be overwritten on the next real iteration before being read). + per_request_slot_ids.append( + request.py_seq_slot if request.py_seq_slot is not None else 0) self.skip_temperature = not temperature_enabled self.skip_top_k = not top_k_enabled @@ -671,9 +691,20 @@ def _normalize_request_sampling_params( dtype=torch.float32, device='cuda') + # Always-populate the per-request slot id table when rejection sampling + # is configured: it's tiny (max_num_requests longs) and needed at + # _compute_and_store_draft_probs time to scatter draft probs by slot. + if self.use_rejection_sampling and self.batch_slot_ids is not None: + self.batch_slot_ids[:len(per_request_slot_ids)].copy_( + torch.tensor(per_request_slot_ids, + dtype=torch.long, + pin_memory=prefer_pinned()), + non_blocking=True, + ) + # All-greedy: sampler takes the argmax branch (and rejection sampling - # is also bypassed for all-greedy), so the buffers are never read. - # Skip the H->D copies. + # is also bypassed for all-greedy), so the per-token buffers are never + # read. Skip the heavier H->D copies. if self.is_all_greedy_sample: return @@ -1039,27 +1070,38 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts, batch_size, spec_metadata): """ Accept draft tokens with optional rejection sampling support. + + Mixed batches (num_contexts > 0) are supported: context rows take the + first sampled target token via the base logic, and rejection sampling + runs on the gen subset. Draft probs for the gen subset are gathered + from the slot-indexed buffer by `py_seq_slot`. """ - if self._can_use_rejection_sampling(spec_metadata, num_contexts): + num_gens = batch_size - num_contexts + if num_gens > 0 and self._can_use_rejection_sampling(spec_metadata): draft_len = draft_tokens.shape[1] stored_vocab = (spec_metadata.draft_probs_last_dim if spec_metadata.draft_probs_last_dim > 0 else spec_metadata.draft_probs_vocab_size) - draft_probs = spec_metadata.draft_probs[:batch_size * draft_len * - stored_vocab].reshape( - batch_size, draft_len, - stored_vocab) + # Gather the slot rows for the gen subset. The buffer was filled + # at the previous draft step indexed by py_seq_slot, so each gen + # request reads back exactly its own probs, regardless of batch + # composition changes since then. + gen_slot_ids = spec_metadata.batch_slot_ids[num_contexts:batch_size] + draft_probs = spec_metadata.draft_probs[ + gen_slot_ids, :draft_len, :stored_vocab] return self._sample_and_accept_draft_tokens_rejection( - logits, draft_tokens, draft_probs, batch_size, spec_metadata) + logits, draft_tokens, draft_probs, num_contexts, batch_size, + spec_metadata) return self._sample_and_accept_draft_tokens_base( logits, draft_tokens, num_contexts, batch_size, spec_metadata) - def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata, - num_contexts: int) -> bool: - # Skip rejection sampling when the whole batch is greedy: the - # accepted result is identical to argmax and the base path is cheaper. + def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata) -> bool: + # Skip rejection sampling when the whole batch is greedy: the accepted + # result is identical to argmax and the base path is cheaper. Mixed + # batches (context + gen) are handled via slot-indexed draft probs and + # are split inside _sample_and_accept_draft_tokens_rejection. return (spec_metadata.use_rejection_sampling - and spec_metadata.draft_probs_valid and num_contexts == 0 + and spec_metadata.draft_probs_valid and not spec_metadata.is_all_greedy_sample) def _sample_and_accept_draft_tokens_rejection( @@ -1067,84 +1109,123 @@ def _sample_and_accept_draft_tokens_rejection( logits: torch.Tensor, draft_tokens: torch.Tensor, draft_probs: torch.Tensor, + num_contexts: int, batch_size: int, spec_metadata, ): """ Rejection-sampling acceptance for one-model speculative decoding. + + Mixed batches are handled by treating the two subsets separately: + - context rows (first `num_contexts`) take the target's sampled first + token; no draft tokens to verify. + - generation rows (`[num_contexts:batch_size]`) run the rejection + sampling kernel on slot-gathered draft probs. + + Per-token sampling-parameter tensors (`temperatures / top_ks / top_ps`) + are laid out as `[ctx (1 each), gen (draft_len+1 each)]`, matching the + logits layout, so slicing is symmetric for both subsets. """ device = logits.device vocab_size = logits.shape[-1] + num_gens = batch_size - num_contexts + runtime_draft_len = draft_tokens.shape[1] if logits.dim() == 1: logits = logits.unsqueeze(0) - runtime_draft_len = draft_tokens.shape[1] - draft_vocab_size = draft_probs.shape[-1] - num_target_tokens = batch_size * (runtime_draft_len + 1) - - temperatures = spec_metadata.temperatures[:num_target_tokens] - # Pass None instead of an all-disabled tensor so the C++ op can short-circuit - # on a host-side check rather than a `.item()` sync, which would break - # CUDA graph capture. - top_ks = None if spec_metadata.skip_top_k else spec_metadata.top_ks[: - num_target_tokens] - top_ps = None if spec_metadata.skip_top_p else spec_metadata.top_ps[: - num_target_tokens] - - target_probs_flat = compute_probs_from_logits(logits.clone(), - temperatures, top_ks, - top_ps) - target_probs = target_probs_flat.reshape(batch_size, - runtime_draft_len + 1, - vocab_size) - - assert draft_probs.shape[1] == runtime_draft_len, ( - f"draft_probs draft length mismatch: {draft_probs.shape[1]} != " - f"{runtime_draft_len}") - d2t = getattr(spec_metadata, "d2t", None) - if draft_vocab_size != vocab_size: - full_draft_probs = torch.zeros( - (batch_size, runtime_draft_len, vocab_size), - dtype=torch.float32, - device=device) - if d2t is not None: - assert d2t.numel() == draft_vocab_size, ( - f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}") - d2t = d2t.to(device=device) - source_indices = torch.arange(draft_vocab_size, - device=device, - dtype=torch.long) - target_indices = (source_indices + d2t) % vocab_size - full_draft_probs[:, :runtime_draft_len, - target_indices] = draft_probs + accepted_tokens = torch.empty((batch_size, runtime_draft_len + 1), + dtype=torch.int, + device=device) + num_accepted_tokens = torch.ones(batch_size, + dtype=torch.int, + device=device) + + # === Context subset: sample target's first token directly === + if num_contexts > 0: + ctx_target_tokens = self._sample_tokens_for_batch( + logits[:num_contexts], spec_metadata, num_contexts, + num_contexts) + accepted_tokens[:num_contexts, 0] = ctx_target_tokens + + # === Generation subset: rejection sampling on the gen slice === + if num_gens > 0: + num_gen_logits = num_gens * (runtime_draft_len + 1) + gen_logits = logits[num_contexts:num_contexts + num_gen_logits] + gen_start = num_contexts + gen_end = num_contexts + num_gen_logits + + temperatures = spec_metadata.temperatures[gen_start:gen_end] + # Pass None instead of an all-disabled tensor so the C++ op can short-circuit + # on a host-side check rather than a `.item()` sync, which would break + # CUDA graph capture. + top_ks = (None if spec_metadata.skip_top_k else + spec_metadata.top_ks[gen_start:gen_end]) + top_ps = (None if spec_metadata.skip_top_p else + spec_metadata.top_ps[gen_start:gen_end]) + + target_probs_flat = compute_probs_from_logits( + gen_logits.clone(), temperatures, top_ks, top_ps) + target_probs = target_probs_flat.reshape(num_gens, + runtime_draft_len + 1, + vocab_size) + + draft_vocab_size = draft_probs.shape[-1] + assert draft_probs.shape[0] == num_gens, ( + f"draft_probs batch mismatch: {draft_probs.shape[0]} != " + f"num_gens={num_gens}") + assert draft_probs.shape[1] == runtime_draft_len, ( + f"draft_probs draft length mismatch: {draft_probs.shape[1]} != " + f"{runtime_draft_len}") + d2t = getattr(spec_metadata, "d2t", None) + if draft_vocab_size != vocab_size: + full_draft_probs = torch.zeros( + (num_gens, runtime_draft_len, vocab_size), + dtype=torch.float32, + device=device) + if d2t is not None: + assert d2t.numel() == draft_vocab_size, ( + f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}" + ) + d2t = d2t.to(device=device) + source_indices = torch.arange(draft_vocab_size, + device=device, + dtype=torch.long) + target_indices = (source_indices + d2t) % vocab_size + full_draft_probs[:, :runtime_draft_len, + target_indices] = draft_probs + else: + assert draft_vocab_size < vocab_size + full_draft_probs[:, :runtime_draft_len, : + draft_vocab_size] = (draft_probs) else: - assert draft_vocab_size < vocab_size - full_draft_probs[:, :runtime_draft_len, :draft_vocab_size] = ( - draft_probs) - else: - full_draft_probs = draft_probs - - full_draft_tokens = draft_tokens.to(torch.int32).contiguous() - - if self.seed is None: - self.seed = torch.tensor([0], dtype=torch.int64, device=device) - if self.offset is None: - self.offset = torch.tensor([0], dtype=torch.int64, device=device) - self.seed += 1 - self.seed %= 2**31 - - accepted_tokens, num_accepted_tokens = rejection_sampling_one_model( - draft_probs=full_draft_probs, - draft_token_ids=full_draft_tokens, - target_probs=target_probs, - deterministic=True, - seed=self.seed, - offset=self.offset, - ) + full_draft_probs = draft_probs + + full_draft_tokens = draft_tokens.to(torch.int32).contiguous() + + if self.seed is None: + self.seed = torch.tensor([0], dtype=torch.int64, device=device) + if self.offset is None: + self.offset = torch.tensor([0], + dtype=torch.int64, + device=device) + self.seed += 1 + self.seed %= 2**31 + + gen_accepted, gen_num_accepted = rejection_sampling_one_model( + draft_probs=full_draft_probs, + draft_token_ids=full_draft_tokens, + target_probs=target_probs, + deterministic=True, + seed=self.seed, + offset=self.offset, + ) + + accepted_tokens[num_contexts:] = gen_accepted + num_accepted_tokens[num_contexts:] = gen_num_accepted num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, 0, draft_tokens.shape[1]) + num_accepted_tokens, num_contexts, runtime_draft_len) return accepted_tokens, num_accepted_tokens def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None): @@ -1229,7 +1310,10 @@ def _compute_and_store_draft_probs( batch_size: int, ): """ - Compute draft probabilities and store them for next-step rejection sampling. + Compute draft probabilities and store them for next-step rejection + sampling. The storage is keyed by py_seq_slot, so the data is robust + to batch composition shifts across iterations (chunking ctxs, gen + completion, new ctxs joining). """ draft_tokens_per_request = len(draft_logits_list) vocab_size = draft_logits_list[0].shape[-1] @@ -1258,9 +1342,21 @@ def _compute_and_store_draft_probs( draft_probs_flat = compute_probs_from_logits(draft_logits_flat, draft_temps, draft_top_ks, draft_top_ps) - num_elements = batch_size * draft_tokens_per_request * vocab_size - spec_metadata.draft_probs[:num_elements].copy_( - draft_probs_flat.flatten()) + # [batch_size, draft_len, draft_vocab] + draft_probs_per_request = draft_probs_flat.reshape( + batch_size, draft_tokens_per_request, vocab_size) + + # Scatter into draft_probs[slot] for each request in the current batch. + # spec_metadata.draft_probs is shaped [max_num_requests, max_draft_len, + # vocab_size]. Different iterations may have different batch + # compositions, but a given request's data always lives at its + # py_seq_slot row, so reads at the next iter pick up the right data. + assert spec_metadata.batch_slot_ids is not None, ( + "batch_slot_ids must be populated by " + "populate_sampling_params_for_one_model before draft probs storage") + batch_slots = spec_metadata.batch_slot_ids[:batch_size] + spec_metadata.draft_probs[batch_slots, :draft_tokens_per_request, : + vocab_size] = draft_probs_per_request spec_metadata.draft_probs_last_dim = vocab_size spec_metadata.draft_probs_valid = True From 47fa873d7fcac8191d076cb7056e9130a416b02e Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Wed, 3 Jun 2026 00:44:19 -0700 Subject: [PATCH 05/21] [TRTLLM-12669][fix] Pre-capture both greedy and advanced sampling CUDA graphs during warmup On-the-fly CUDA graph capture is disabled outside the warmup window (allow_capture context manager) because it can resize the shared cuda_graph_workspace tensor and invalidate addresses baked into previously captured graphs. As a result, the (is_all_greedy_sample=False) graph key introduced for one-engine spec dec was never captured: warmup only ran dummy requests with greedy sampling params, so inference batches with temperature / top_k / top_p fell back to eager. Fix: run the warmup capture loop twice for one-engine spec dec. The first pass captures the greedy fast-path (existing behavior). The second pass flips spec_metadata.is_all_greedy_sample to False before forward so maybe_get_cuda_graph computes the non-greedy key, and sets a runtime attribute that populate_sampling_params_for_one_model honors to override the dummy-request-derived greedy detection and substitute synthetic non-greedy values into the per-request buffers. Other paths are unaffected: non-one-engine spec dec and non-spec dec default is_all_greedy_sample to True, so the second pass is skipped. End-to-end (qwen3_8b_eagle3, bs=32, T=0.7/top_k=50/top_p=0.9): rej_off baseline: TPS=3713.73 rej_on (before fix): TPS=3854.01 (+3.8%; non-greedy ran eager) rej_on (after fix): TPS=6013.58 (+62.0%; non-greedy uses graph) Signed-off-by: ZhaoyangWang --- .../_torch/pyexecutor/model_engine.py | 78 +++++++++++++------ tensorrt_llm/_torch/speculative/interface.py | 17 ++++ 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 152ed8c825fe..be7cbe485252 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1343,31 +1343,61 @@ def _capture_generation_cuda_graphs(self, else: max_seq_len_list = [effective_max_seq_len] - for bs, draft_len in graphs_to_capture: - if bs > self.batch_size: - continue - - for max_seq_len in max_seq_len_list: - warmup_request = self._create_cuda_graph_warmup_request( - resource_manager, bs, draft_len, max_seq_len) - with self._release_batch_context(warmup_request, - resource_manager) as batch: - if batch is None: - # No KV cache space, cannot continue capturing graphs + def _run_capture_pass(force_non_greedy: bool, label: str) -> None: + spec_metadata = getattr(self, 'spec_metadata', None) + if force_non_greedy and spec_metadata is not None: + spec_metadata._force_non_greedy_for_capture = True + # maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample + # to build the graph cache key BEFORE populate runs inside + # _prepare_inputs. Pre-flip it here so the very first capture + # in this pass uses the non-greedy key; populate's override + # below will keep it False on every subsequent iteration. + spec_metadata.is_all_greedy_sample = False + try: + for bs, draft_len in graphs_to_capture: + if bs > self.batch_size: continue - logger.info( - f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}" - ) - self.enable_spec_decode = draft_len > 0 or self.is_draft_model or ( - self.spec_config is not None - and self.spec_config.spec_dec_mode.use_one_engine()) - self._update_draft_inference_state_for_warmup( - batch, draft_len > 0, resource_manager) - self.runtime_draft_len = draft_len - self.forward(batch, - new_tensors_device=None, - resource_manager=resource_manager) - torch.cuda.synchronize() + + for max_seq_len in max_seq_len_list: + warmup_request = self._create_cuda_graph_warmup_request( + resource_manager, bs, draft_len, max_seq_len) + with self._release_batch_context( + warmup_request, resource_manager) as batch: + if batch is None: + # No KV cache space, cannot continue capturing graphs + continue + logger.info( + f"Run generation-only CUDA graph warmup ({label}) " + f"for batch size={bs}, draft_len={draft_len}, " + f"max_seq_len={max_seq_len}") + self.enable_spec_decode = draft_len > 0 or self.is_draft_model or ( + self.spec_config is not None and + self.spec_config.spec_dec_mode.use_one_engine()) + self._update_draft_inference_state_for_warmup( + batch, draft_len > 0, resource_manager) + self.runtime_draft_len = draft_len + self.forward(batch, + new_tensors_device=None, + resource_manager=resource_manager) + torch.cuda.synchronize() + finally: + if force_non_greedy and spec_metadata is not None: + spec_metadata._force_non_greedy_for_capture = False + + # Pass 1: greedy fast-path (dummy requests carry no sampling params, + # so is_all_greedy_sample is naturally True). + _run_capture_pass(force_non_greedy=False, label="greedy") + # Pass 2: advanced sampling variant. Required because on-the-fly capture + # is disabled outside warmup, so any inference batch that contains a + # non-greedy request would otherwise fall back to eager. Only meaningful + # for one-engine spec dec (where is_all_greedy_sample participates in + # the graph key); other paths default to True and would never key into + # this variant. + needs_non_greedy_capture = ( + self.spec_config is not None + and self.spec_config.spec_dec_mode.use_one_engine()) + if needs_non_greedy_capture: + _run_capture_pass(force_non_greedy=True, label="advanced sampling") # Set the value back to the original value after cuda graph warmups are complete self.enable_spec_decode = self.is_spec_decode diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 65471331e1a1..ad56e7a62dfd 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -665,6 +665,23 @@ def _normalize_request_sampling_params( self.is_all_greedy_sample = (self.skip_temperature and self.skip_top_k and self.skip_top_p) + # Warmup-time override (set via runtime attribute by the model engine): + # force the advanced-sampling code path so the CUDA graph for the + # (is_all_greedy_sample=False) key gets captured. Dummy warmup requests + # carry no sampling params, so the natural detection above always + # returns True; this branch substitutes synthetic non-greedy scalars + # into the per-request data and lets Phase 2 run normally to populate + # the GPU buffers used by the captured kernels. + if getattr(self, '_force_non_greedy_for_capture', False): + self.skip_temperature = False + self.skip_top_k = False + self.skip_top_p = False + self.is_all_greedy_sample = False + per_request_normalized = [ + (0.7, 50, 0.9, num_tokens) + for (_, _, _, num_tokens) in per_request_normalized + ] + tokens_per_request = (self.max_total_draft_tokens + 1 if self.is_spec_dec_tree else self.max_draft_len + 1) required_flat_size = tokens_per_request * self.max_num_requests From b54c8a6b7a7949cdb36800d8cbb341fc1bbea3ae Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Wed, 3 Jun 2026 06:06:43 -0700 Subject: [PATCH 06/21] [TRTLLM-12669][perf] Reuse draft probs to drop redundant softmax + cut rejection-path overhead MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors the rejection-sampling draft path to compute the filtered + normalized prob distribution exactly once per draft step, and folds three independent optimizations into one PR-coherent change: 1. Single-pass compute_probs + sample on draft side _draft_sampler_advanced_for_rejection now calls a new sampling_batch_spec_dec_one_model_for_rejection which returns both the sampled token AND the probs in one go. The probs are scattered into the slot-indexed draft_probs buffer immediately, so the previous separate _compute_and_store_draft_probs path (which redundantly re-ran temperature + top_k + top_p + softmax on the cloned logits) is gone. 2. Faster compute_probs_from_logits via flashinfer fast path compute_probs_from_logits now composes flashinfer's radix-based O(N) kernels (top_k_mask_logits → fused softmax+temp → top_p_renorm_probs) when CUDA + flashinfer are available. The previous C++ op path triggered torch.sort fallback (O(N log N) per row) due to a hard-coded kMax=0, which severely under-utilized SMs at small batch sizes. C++ op and PyTorch CPU paths are retained as fallbacks. 3. Pre-allocated full_draft_probs buffer The (max_num_requests, max_draft_len, vocab_size) scratch used to pad draft probs to target vocab is now zero-filled once at prepare() and reused across iters, saving ~25 us/iter of 64 MB zero-fill. Only allocated when use_rejection_sampling=True. The eagle3 draft loop is simplified accordingly: it no longer accumulates a draft_logits_list or invokes _compute_and_store_draft_probs after the loop; per-step scatter happens inside _draft_sampler_advanced_for_rejection keyed on the (already-required) draft_step index. Net effect on llama70b bs=32 (T=0.7/top_k=50/top_p=0.9, MT-bench 2000): ΔTPS recovered from -32% (post-refactor with sort fallback) and -12% (pre-refactor with double softmax) to ~-5% (flashinfer fast path). The remaining gap is fundamental: llama70b's Eagle3 draft already tracks the target closely (AR uplift only +2%), so the inherent rejection sampling overhead (chain_speculative_sampling kernel + target_probs + d2t padding ≈ ~340 us/iter ≈ 1.5%) is not fully offset by the small AR gain. qwen8b/qwen235b with ΔAR +9%~+14% remain solidly net positive. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 73 ++++----- tensorrt_llm/_torch/speculative/interface.py | 142 +++++++++++------- .../_torch/speculative/one_model_sampler.py | 59 +++++++- 3 files changed, 174 insertions(+), 100 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index c1fd0fac3bf3..e72ce751cb33 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -772,9 +772,7 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, runtime_draft_len = spec_metadata.runtime_draft_len num_gens = batch_size - num_contexts next_draft_tokens = [] - draft_logits_list = [] - last_tokens_idx = torch.cumsum( - attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 + position_ids = inputs["position_ids"] with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): for i in range(runtime_draft_len): @@ -862,33 +860,11 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, d2t, draft_step=i) - # Sample the next draft token. - # MTP Eagle: TP-aware sampler; when ADP+LM-head-TP is active - # logits are padded to max_num_requests across TP ranks, so - # the result must be trimmed back to token_count. - # Eagle3: simple greedy sampling; d2t remaps vocab indices when - # the draft model uses a compressed vocabulary. - if self.is_mtp_eagle: - if use_lm_head_tp_in_adp: - mapping_lm_head_tp = draft_model.mtp_layers[ - 0].shared_head.mapping_lm_head_tp - new_draft_token = self.draft_sampler( - logits, mapping_lm_head_tp) - new_draft_token = new_draft_token[:token_count] - else: - new_draft_token = self.draft_sampler(logits) - else: - d2t = getattr(draft_model.model, "d2t", None) - new_draft_token = self._draft_sampler_greedy(logits, d2t) - - # Stash unpadded Eagle3 draft logits for rejection sampling on - # the next iteration. MTP Eagle's logits may be ADP-padded to - # max_num_requests, so we skip them here. - if not self.is_mtp_eagle and spec_metadata.use_rejection_sampling: - draft_logits_list.append(logits.clone()) - - new_draft_token = self.draft_decoder(logits, draft_model, - spec_metadata, batch_size) + new_draft_token = self.draft_decoder(logits, + draft_model, + spec_metadata, + batch_size, + draft_step=i) next_draft_tokens.append(new_draft_token) # Update hidden states for the next iteration. @@ -954,19 +930,18 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, gen_draft_tokens) next_draft_tokens[num_contexts:] = gen_draft_tokens - # Skip when the whole batch is greedy: _can_use_rejection_sampling will - # bypass the rejection path anyway, so computing draft probs is wasted. - if (spec_metadata.use_rejection_sampling and draft_logits_list - and not spec_metadata.is_all_greedy_sample): - d2t_param = getattr(draft_model.model, "d2t", None) - spec_metadata.d2t = d2t_param.data if d2t_param is not None else None - self._compute_and_store_draft_probs(draft_logits_list, - spec_metadata, batch_size) - elif spec_metadata.use_rejection_sampling: - # No draft probs were written this iter (all-greedy or empty draft - # loop). Invalidate the buffer so the next iter does not read stale - # data if it transitions back to a non-greedy mix. - spec_metadata.draft_probs_valid = False + # Probs were already scattered into the slot-indexed buffer by + # _draft_sampler_advanced_for_rejection on each draft step (non-greedy + # batches only). All-greedy batches skip storage — rejection sampling + # will be bypassed by _can_use_rejection_sampling. Finalize the validity + # flag and d2t for next-iter target-side verification. + if spec_metadata.use_rejection_sampling: + if not spec_metadata.is_all_greedy_sample: + d2t_param = getattr(draft_model.model, "d2t", None) + spec_metadata.d2t = d2t_param.data if d2t_param is not None else None + spec_metadata.draft_probs_valid = True + else: + spec_metadata.draft_probs_valid = False return next_draft_tokens @@ -1190,22 +1165,34 @@ def draft_decoder( draft_model: nn.Module, spec_metadata: Optional[Eagle3OneModelSpecMetadata] = None, batch_size: Optional[int] = None, + draft_step: Optional[int] = None, ): ''' Sample draft tokens. When spec_metadata + batch_size are provided, use the target's per-request sampling params (temperature/top_k/top_p); otherwise fall back to argmax. + When rejection sampling is enabled and draft_step is provided, take the + single-pass path that also scatters the draft prob distribution into the + slot-indexed buffer (avoids a redundant softmax later). + Args: logits: [batch_size, vocab_size] - Draft model logits. draft_model: The draft model. spec_metadata: Carries per-request sampling param tensors. When None, sampling is forced greedy. batch_size: Active requests, used to slice per-request tensors. + draft_step: Current draft step index (0..max_draft_len-1). Required + for the rejection-sampling code path so probs are written to + the correct slice of spec_metadata.draft_probs. ''' d2t = getattr(draft_model.model, "d2t", None) if spec_metadata is not None and batch_size is not None: + if (spec_metadata.use_rejection_sampling and draft_step is not None + and not spec_metadata.is_all_greedy_sample): + return self._draft_sampler_advanced_for_rejection( + logits, spec_metadata, batch_size, d2t, draft_step) return self._draft_sampler_advanced(logits, spec_metadata, batch_size, d2t) return self._draft_sampler_greedy(logits, d2t) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index ad56e7a62dfd..2206b9e2dd1b 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -42,7 +42,8 @@ from .one_model_sampler import (compute_probs_from_logits, rejection_sampling_one_model, - sampling_batch_spec_dec_one_model) + sampling_batch_spec_dec_one_model, + sampling_batch_spec_dec_one_model_for_rejection) # Environment variable name for forcing the number of accepted tokens in speculative decoding FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" @@ -497,6 +498,13 @@ class SpecMetadata: batch_slot_ids: Optional[torch.Tensor] = None # Draft-to-target vocab offset tensor. d2t: Optional[torch.Tensor] = None + # Pre-allocated scratch for draft probs expanded to the target vocab size. + # Filled with zeros once at prepare(); each rejection iter only overwrites + # the positions selected by d2t (or [:draft_vocab] when there is no d2t), + # so the zeros outside those positions persist across iterations and we + # avoid a per-iter 64 MB zero-fill on the (max_num_requests, max_draft_len, + # vocab_size) tensor. Shape: [max_num_requests, max_draft_len, vocab_size]. + full_draft_probs: Optional[torch.Tensor] = None def __post_init__(self): pass @@ -519,6 +527,16 @@ def prepare(self): self.batch_slot_ids = torch.empty((self.max_num_requests, ), dtype=torch.long, device='cuda') + if (self.use_rejection_sampling and self.full_draft_probs is None + and self.vocab_size > 0): + # Zero-fill once. Subsequent iters only overwrite the d2t-mapped + # positions (constant across iters since d2t is model-static), so + # untouched positions stay 0 forever — saves the per-iter 64 MB + # zero-fill in _sample_and_accept_draft_tokens_rejection. + self.full_draft_probs = torch.zeros( + (self.max_num_requests, self.max_draft_len, self.vocab_size), + dtype=torch.float32, + device='cuda') def create_cuda_graph_metadata(self, max_batch_size: int): """ @@ -710,7 +728,7 @@ def _normalize_request_sampling_params( # Always-populate the per-request slot id table when rejection sampling # is configured: it's tiny (max_num_requests longs) and needed at - # _compute_and_store_draft_probs time to scatter draft probs by slot. + # draft-sampler time to scatter draft probs by slot. if self.use_rejection_sampling and self.batch_slot_ids is not None: self.batch_slot_ids[:len(per_request_slot_ids)].copy_( torch.tensor(per_request_slot_ids, @@ -1182,7 +1200,7 @@ def _sample_and_accept_draft_tokens_rejection( spec_metadata.top_ps[gen_start:gen_end]) target_probs_flat = compute_probs_from_logits( - gen_logits.clone(), temperatures, top_ks, top_ps) + gen_logits, temperatures, top_ks, top_ps) target_probs = target_probs_flat.reshape(num_gens, runtime_draft_len + 1, vocab_size) @@ -1196,10 +1214,17 @@ def _sample_and_accept_draft_tokens_rejection( f"{runtime_draft_len}") d2t = getattr(spec_metadata, "d2t", None) if draft_vocab_size != vocab_size: - full_draft_probs = torch.zeros( - (num_gens, runtime_draft_len, vocab_size), - dtype=torch.float32, - device=device) + # Use the pre-allocated buffer from spec_metadata.prepare() + # (zero-filled once at init; untouched positions stay 0). Falls + # back to per-iter allocation if the buffer is not configured, + # e.g. when use_rejection_sampling was off at prepare() time. + if spec_metadata.full_draft_probs is not None: + full_draft_probs = spec_metadata.full_draft_probs[:num_gens] + else: + full_draft_probs = torch.zeros( + (num_gens, runtime_draft_len, vocab_size), + dtype=torch.float32, + device=device) if d2t is not None: assert d2t.numel() == draft_vocab_size, ( f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}" @@ -1320,62 +1345,71 @@ def _draft_sampler_advanced( return draft_tokens.type(torch.int32) - def _compute_and_store_draft_probs( + def _draft_sampler_advanced_for_rejection( self, - draft_logits_list: List[torch.Tensor], - spec_metadata: SpecMetadata, + logits: torch.Tensor, + spec_metadata: "SpecMetadata", batch_size: int, + d2t: Optional[torch.Tensor] = None, + draft_step: int = 0, ): """ - Compute draft probabilities and store them for next-step rejection - sampling. The storage is keyed by py_seq_slot, so the data is robust - to batch composition shifts across iterations (chunking ctxs, gen - completion, new ctxs joining). + Rejection-sampling-aware variant of ``_draft_sampler_advanced``. + + Single-pass compute + sample + scatter: computes the per-request prob + distribution once via TRT-LLM's fused ``compute_probs_from_logits`` + (temp + top_k + top_p + softmax + greedy override in one CUDA kernel), + samples the draft token from that distribution, and scatters the same + probs into the slot-indexed ``spec_metadata.draft_probs`` buffer for + next-iter rejection verification. Replaces the previous two-stage path + (flashinfer fused sampling kernel + a redundant softmax pass to store + probs). + + All-greedy batches take the cheaper argmax path — + ``_can_use_rejection_sampling`` will bypass rejection for those anyway. """ - draft_tokens_per_request = len(draft_logits_list) - vocab_size = draft_logits_list[0].shape[-1] - device = draft_logits_list[0].device - - draft_logits = torch.stack(draft_logits_list, dim=0) - draft_logits_flat = draft_logits.transpose(0, 1).reshape(-1, vocab_size) - - num_draft_tokens = batch_size * draft_tokens_per_request - if spec_metadata.request_temperatures is not None: - draft_temps = spec_metadata.request_temperatures[:batch_size].repeat_interleave( - draft_tokens_per_request) - draft_top_ks = ( - spec_metadata.request_top_ks[:batch_size].repeat_interleave( - draft_tokens_per_request) if not spec_metadata.skip_top_k - and spec_metadata.request_top_ks is not None else None) - draft_top_ps = ( - spec_metadata.request_top_ps[:batch_size].repeat_interleave( - draft_tokens_per_request) if not spec_metadata.skip_top_p - and spec_metadata.request_top_ps is not None else None) - else: - draft_temps = torch.ones(num_draft_tokens, device=device) - draft_top_ks = None - draft_top_ps = None - - draft_probs_flat = compute_probs_from_logits(draft_logits_flat, - draft_temps, draft_top_ks, - draft_top_ps) - # [batch_size, draft_len, draft_vocab] - draft_probs_per_request = draft_probs_flat.reshape( - batch_size, draft_tokens_per_request, vocab_size) - - # Scatter into draft_probs[slot] for each request in the current batch. - # spec_metadata.draft_probs is shaped [max_num_requests, max_draft_len, - # vocab_size]. Different iterations may have different batch - # compositions, but a given request's data always lives at its - # py_seq_slot row, so reads at the next iter pick up the right data. + if spec_metadata.is_all_greedy_sample: + return self._draft_sampler_greedy(logits, d2t) + + temperatures = spec_metadata.request_temperatures[:batch_size] + top_ks = spec_metadata.request_top_ks[:batch_size] + top_ps = spec_metadata.request_top_ps[:batch_size] + + if self.seed is None: + self.seed = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.offset = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.seed += 1 + self.seed %= (2**31) + + draft_tokens, probs = sampling_batch_spec_dec_one_model_for_rejection( + logits, + temperatures, + top_ks, + top_ps, + seed=self.seed, + offset=self.offset, + ) + + # Scatter probs into the slot-indexed buffer (shaped + # [max_num_requests, max_draft_len, vocab_size]). Each request's data + # always lands at its stable py_seq_slot row regardless of batch + # composition shifts across iterations. assert spec_metadata.batch_slot_ids is not None, ( "batch_slot_ids must be populated by " "populate_sampling_params_for_one_model before draft probs storage") batch_slots = spec_metadata.batch_slot_ids[:batch_size] - spec_metadata.draft_probs[batch_slots, :draft_tokens_per_request, : - vocab_size] = draft_probs_per_request - spec_metadata.draft_probs_last_dim = vocab_size - spec_metadata.draft_probs_valid = True + vocab = probs.shape[-1] + spec_metadata.draft_probs[batch_slots, draft_step, :vocab] = probs + spec_metadata.draft_probs_last_dim = vocab + + if d2t is not None: + draft_tokens = d2t[draft_tokens] + draft_tokens + + return draft_tokens.type(torch.int32) def _execute_guided_decoder_if_present(self, logits): """Execute guided decoder on target model logits if available.""" diff --git a/tensorrt_llm/_torch/speculative/one_model_sampler.py b/tensorrt_llm/_torch/speculative/one_model_sampler.py index 7e3b06b383cc..6734b5e9f79a 100644 --- a/tensorrt_llm/_torch/speculative/one_model_sampler.py +++ b/tensorrt_llm/_torch/speculative/one_model_sampler.py @@ -5,9 +5,20 @@ from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE if IS_FLASHINFER_AVAILABLE: - from flashinfer.sampling import chain_speculative_sampling, top_k_top_p_sampling_from_logits + from flashinfer.sampling import ( + chain_speculative_sampling, + sampling_from_probs, + top_k_mask_logits, + top_k_top_p_sampling_from_logits, + top_p_renorm_probs, + ) + from flashinfer.sampling import softmax as flashinfer_softmax else: chain_speculative_sampling = None + sampling_from_probs = None + flashinfer_softmax = None + top_k_mask_logits = None + top_p_renorm_probs = None top_k_top_p_sampling_from_logits = None @@ -114,9 +125,29 @@ def compute_probs_from_logits( skip_temperature: bool = False, ) -> torch.Tensor: """ - Compute probabilities from logits with temperature, top-k, and top-p applied. + Compute filtered + normalized probs from logits (temperature + top_k + + top_p + softmax). Picks the fastest path for the input device: + + 1. CUDA + flashinfer: ``top_k_mask_logits`` → fused ``softmax+temp`` → + ``top_p_renorm_probs`` (all O(N) radix). ``skip_temperature`` ignored. + 2. CUDA, no flashinfer: ``compute_probs_from_logits_op`` (sort-based, + O(N log N)). + 3. CPU: manual PyTorch fallback. """ + if logits.is_cuda and IS_FLASHINFER_AVAILABLE: + # Fast path: flashinfer composition (O(N) per row, friendly to small + # batch sizes). skip_temperature is ignored — flashinfer's softmax + # always applies the temperature tensor. + if top_k is not None: + logits = top_k_mask_logits(logits, top_k) + probs = flashinfer_softmax(logits, temperatures) + if top_p is not None: + probs = top_p_renorm_probs(probs, top_p) + return probs + if logits.is_cuda: + # CUDA without flashinfer: fall back to the C++ op (slower sort-based + # top-k path, but works without flashinfer). return torch.ops.trtllm.compute_probs_from_logits_op( logits, temperatures, top_k, top_p, skip_temperature ) @@ -125,7 +156,6 @@ def compute_probs_from_logits( logits = apply_temperature(logits, temperatures) logits = apply_top_k_top_p(logits, top_k, top_p) probs = logits.softmax(dim=-1, dtype=torch.float32) - # Greedy rows should remain exactly one-hot so rejection sampling does not # spuriously reject numerically-near argmax tokens. greedy_temp_threshold = 1e-4 @@ -135,6 +165,29 @@ def compute_probs_from_logits( return torch.where(is_greedy.unsqueeze(1), one_hot, probs) +def sampling_batch_spec_dec_one_model_for_rejection( + logits: torch.Tensor, + temperatures: torch.Tensor, + top_k: torch.Tensor, + top_p: torch.Tensor, + seed: Optional[torch.Tensor] = None, + offset: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Rejection-sampling-aware draft sampler: returns BOTH the sampled tokens + AND the prob distribution they were sampled from, so the downstream + rejection-sampling path can reuse the probs without a second softmax + + temp/top_k/top_p pass. + """ + if sampling_from_probs is None: + raise RuntimeError( + "Rejection sampling for one-model speculative decoding requires flashinfer" + ) + probs = compute_probs_from_logits(logits, temperatures, top_k, top_p) + tokens = sampling_from_probs(probs, deterministic=True, seed=seed, offset=offset) + return tokens, probs + + def rejection_sampling_one_model( draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, From d40e31574443befe22ed8251b4a5d5ce1066680b Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Wed, 3 Jun 2026 18:42:20 -0700 Subject: [PATCH 07/21] [TRTLLM-12669][perf] Cache d2t target indices in spec metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The d2t-projected target vocab indices computed inside the rejection-path d2t padding step (arange(draft_vocab) + (source + d2t.to(device)) % vocab_size) were being rebuilt every iteration even though the d2t tensor is model-static. Cache the result on SpecMetadataBase.d2t_target_indices on first use and reuse it on subsequent iterations. Profile breakdown (llama70b bs=32, CUDA graph off) showed accept_draft.rejection.d2t_padding at 88 us/iter — the second-largest rejection-path step after compute target_probs (127 us). The index sequence costs ~10-20 us of that (3-4 kernels: arange + d2t H2D copy + add + mod); the rest is the slot-indexed scatter into full_draft_probs which is already pre-allocated. Verified on llama70b bs=32 over 3 rounds (mean ± stdev): Before: rej_on vs rej_off gap ≈ -10.0% (single-run baseline) After : rej_on vs rej_off gap = -8.71% ± 0.9% (3-round mean) Net within-run improvement ≈ +1.3%. qwen235b unchanged (already positive). Output accuracy verified across 22 (model, bs, mode) configurations: all 1760 outputs terminate normally (EOT or max_tokens), no regressions. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/interface.py | 43 +++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 2206b9e2dd1b..e0e89409bddb 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -505,6 +505,11 @@ class SpecMetadata: # avoid a per-iter 64 MB zero-fill on the (max_num_requests, max_draft_len, # vocab_size) tensor. Shape: [max_num_requests, max_draft_len, vocab_size]. full_draft_probs: Optional[torch.Tensor] = None + # Cached d2t-projected target vocab indices, computed once on first use + # (d2t is a model-static tensor). Replaces the per-iter + # arange + (source + d2t) % vocab_size kernel sequence inside the d2t + # padding step. Shape: [draft_vocab_size], dtype long. + d2t_target_indices: Optional[torch.Tensor] = None def __post_init__(self): pass @@ -1077,8 +1082,8 @@ def _sample_and_accept_draft_tokens_base( device=logits.device) # Sample tokens using per-request sampling parameters - target_tokens = self._sample_tokens_for_batch(logits, spec_metadata, - num_contexts, batch_size) + target_tokens = self._sample_tokens_for_batch( + logits, spec_metadata, num_contexts, batch_size) # Context requests: only accept the sampled token (no draft tokens yet) accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] @@ -1092,7 +1097,8 @@ def _sample_and_accept_draft_tokens_base( # Compare draft tokens with target tokens using cumulative product # Counts consecutive matches from the start num_accepted_tokens[num_contexts:] += torch.cumprod( - (draft_tokens == gen_target_tokens[:, :runtime_draft_len]).int(), + (draft_tokens + == gen_target_tokens[:, :runtime_draft_len]).int(), dim=-1).sum(1) # Apply force override if set @@ -1201,9 +1207,8 @@ def _sample_and_accept_draft_tokens_rejection( target_probs_flat = compute_probs_from_logits( gen_logits, temperatures, top_ks, top_ps) - target_probs = target_probs_flat.reshape(num_gens, - runtime_draft_len + 1, - vocab_size) + target_probs = target_probs_flat.reshape( + num_gens, runtime_draft_len + 1, vocab_size) draft_vocab_size = draft_probs.shape[-1] assert draft_probs.shape[0] == num_gens, ( @@ -1215,11 +1220,13 @@ def _sample_and_accept_draft_tokens_rejection( d2t = getattr(spec_metadata, "d2t", None) if draft_vocab_size != vocab_size: # Use the pre-allocated buffer from spec_metadata.prepare() - # (zero-filled once at init; untouched positions stay 0). Falls - # back to per-iter allocation if the buffer is not configured, - # e.g. when use_rejection_sampling was off at prepare() time. + # (zero-filled once at init; untouched positions stay 0). + # Falls back to per-iter allocation if the buffer is not + # configured, e.g. when use_rejection_sampling was off at + # prepare() time. if spec_metadata.full_draft_probs is not None: - full_draft_probs = spec_metadata.full_draft_probs[:num_gens] + full_draft_probs = spec_metadata.full_draft_probs[: + num_gens] else: full_draft_probs = torch.zeros( (num_gens, runtime_draft_len, vocab_size), @@ -1229,11 +1236,17 @@ def _sample_and_accept_draft_tokens_rejection( assert d2t.numel() == draft_vocab_size, ( f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}" ) - d2t = d2t.to(device=device) - source_indices = torch.arange(draft_vocab_size, - device=device, - dtype=torch.long) - target_indices = (source_indices + d2t) % vocab_size + # d2t is model-static; compute target_indices once and + # cache on spec_metadata to skip the arange + add + mod + # kernel sequence on every iter. + target_indices = spec_metadata.d2t_target_indices + if target_indices is None: + source_indices = torch.arange(draft_vocab_size, + device=device, + dtype=torch.long) + target_indices = (source_indices + + d2t.to(device=device)) % vocab_size + spec_metadata.d2t_target_indices = target_indices full_draft_probs[:, :runtime_draft_len, target_indices] = draft_probs else: From 6641453acc4c9003d5a6ed8117563c272c46d40b Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Wed, 3 Jun 2026 18:56:18 -0700 Subject: [PATCH 08/21] [TRTLLM-12669][chore] Apply CI yapf reformat to interface.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI yapf hook reformatted a few line wraps in interface.py — apply locally to keep CI green. No functional change. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/interface.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index e0e89409bddb..d91c72374638 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -1082,8 +1082,8 @@ def _sample_and_accept_draft_tokens_base( device=logits.device) # Sample tokens using per-request sampling parameters - target_tokens = self._sample_tokens_for_batch( - logits, spec_metadata, num_contexts, batch_size) + target_tokens = self._sample_tokens_for_batch(logits, spec_metadata, + num_contexts, batch_size) # Context requests: only accept the sampled token (no draft tokens yet) accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] @@ -1097,8 +1097,7 @@ def _sample_and_accept_draft_tokens_base( # Compare draft tokens with target tokens using cumulative product # Counts consecutive matches from the start num_accepted_tokens[num_contexts:] += torch.cumprod( - (draft_tokens - == gen_target_tokens[:, :runtime_draft_len]).int(), + (draft_tokens == gen_target_tokens[:, :runtime_draft_len]).int(), dim=-1).sum(1) # Apply force override if set @@ -1207,8 +1206,9 @@ def _sample_and_accept_draft_tokens_rejection( target_probs_flat = compute_probs_from_logits( gen_logits, temperatures, top_ks, top_ps) - target_probs = target_probs_flat.reshape( - num_gens, runtime_draft_len + 1, vocab_size) + target_probs = target_probs_flat.reshape(num_gens, + runtime_draft_len + 1, + vocab_size) draft_vocab_size = draft_probs.shape[-1] assert draft_probs.shape[0] == num_gens, ( @@ -1225,8 +1225,7 @@ def _sample_and_accept_draft_tokens_rejection( # configured, e.g. when use_rejection_sampling was off at # prepare() time. if spec_metadata.full_draft_probs is not None: - full_draft_probs = spec_metadata.full_draft_probs[: - num_gens] + full_draft_probs = spec_metadata.full_draft_probs[:num_gens] else: full_draft_probs = torch.zeros( (num_gens, runtime_draft_len, vocab_size), From b27a7201b5822ca0935f6e96527a10d1552062c5 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Wed, 3 Jun 2026 19:27:58 -0700 Subject: [PATCH 09/21] [TRTLLM-12669][chore] Address review feedback - llm_args: keep allow_advanced_sampling as a deprecated no-op field with a logger warning when explicitly set, so removing it isn't an abrupt API break - llm_args: add TODO above the Eagle3-only rejection-sampling whitelist to track extending support to MTP / DraftTarget / PARD / DFlash / SaveHiddenStates / SA and unifying the dispatch in SpecMetadata - cuda_graph_runner: type spec_metadata as Optional[SpecMetadata] instead of Optional[Any] - model_engine: always initialize self.spec_metadata = None so the capture-pass can access it directly without a getattr() fallback - eagle3.draft_decoder: drop the dead Optional/_draft_sampler_greedy fallback; spec_metadata and batch_size are always passed by the sole caller Signed-off-by: ZhaoyangWang --- .../_torch/pyexecutor/cuda_graph_runner.py | 5 ++-- .../_torch/pyexecutor/model_engine.py | 4 +-- tensorrt_llm/_torch/speculative/eagle3.py | 26 +++++++--------- tensorrt_llm/llmapi/llm_args.py | 30 +++++++++++++++++++ 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index f69956eaf158..f5e0820989a9 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -18,6 +18,7 @@ from ..memory_buffer_utils import get_memory_buffers from ..modules.multi_stream_utils import with_multi_stream from ..speculative.eagle3 import Eagle3ResourceManager +from ..speculative.interface import SpecMetadata from ..speculative.spec_sampler_base import SampleStateTensorsSpec from ..speculative.utils import get_draft_kv_cache_manager from ..utils import make_weak_ref, piecewise_cuda_graph @@ -198,7 +199,7 @@ def get_graph_key( batch: ScheduledRequests, new_tensors_device: Optional[SampleStateTensors] = None, spec_resource_manager: Optional[BaseResourceManager] = None, - spec_metadata: Optional[Any] = None): + spec_metadata: Optional[SpecMetadata] = None): batch_size = batch.batch_size # Get the sequence length mode. @@ -240,7 +241,7 @@ def maybe_get_cuda_graph( batch: ScheduledRequests, enable_spec_decode: bool, attn_metadata: Any, - spec_metadata: Optional[Any] = None, + spec_metadata: Optional[SpecMetadata] = None, draft_tokens_cuda: Optional[torch.Tensor] = None, new_tensors_device: Optional[SampleStateTensors] = None, spec_resource_manager: Optional[BaseResourceManager] = None, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index be7cbe485252..b4df20e60eaa 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -497,7 +497,6 @@ def __init__( sparse_attn_config=self.sparse_attention_config) if self.is_spec_decode: - self.spec_metadata = None update_spec_config_from_model_config(self.spec_config, self.model.config) max_num_draft_tokens = self.max_draft_loop_tokens * self.batch_size @@ -551,6 +550,7 @@ def __init__( # the model engine. self.attn_metadata = None self.encoder_attn_metadata = None + self.spec_metadata = None self.iter_states = {} self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None @@ -1344,7 +1344,7 @@ def _capture_generation_cuda_graphs(self, max_seq_len_list = [effective_max_seq_len] def _run_capture_pass(force_non_greedy: bool, label: str) -> None: - spec_metadata = getattr(self, 'spec_metadata', None) + spec_metadata = self.spec_metadata if force_non_greedy and spec_metadata is not None: spec_metadata._force_non_greedy_for_capture = True # maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index e72ce751cb33..c0af6092ddc9 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1163,14 +1163,13 @@ def draft_decoder( self, logits: torch.Tensor, draft_model: nn.Module, - spec_metadata: Optional[Eagle3OneModelSpecMetadata] = None, - batch_size: Optional[int] = None, + spec_metadata: Eagle3OneModelSpecMetadata, + batch_size: int, draft_step: Optional[int] = None, ): ''' - Sample draft tokens. When spec_metadata + batch_size are provided, use - the target's per-request sampling params (temperature/top_k/top_p); - otherwise fall back to argmax. + Sample draft tokens using the target's per-request sampling params + (temperature/top_k/top_p). When rejection sampling is enabled and draft_step is provided, take the single-pass path that also scatters the draft prob distribution into the @@ -1179,8 +1178,7 @@ def draft_decoder( Args: logits: [batch_size, vocab_size] - Draft model logits. draft_model: The draft model. - spec_metadata: Carries per-request sampling param tensors. When - None, sampling is forced greedy. + spec_metadata: Carries per-request sampling param tensors. batch_size: Active requests, used to slice per-request tensors. draft_step: Current draft step index (0..max_draft_len-1). Required for the rejection-sampling code path so probs are written to @@ -1188,14 +1186,12 @@ def draft_decoder( ''' d2t = getattr(draft_model.model, "d2t", None) - if spec_metadata is not None and batch_size is not None: - if (spec_metadata.use_rejection_sampling and draft_step is not None - and not spec_metadata.is_all_greedy_sample): - return self._draft_sampler_advanced_for_rejection( - logits, spec_metadata, batch_size, d2t, draft_step) - return self._draft_sampler_advanced(logits, spec_metadata, - batch_size, d2t) - return self._draft_sampler_greedy(logits, d2t) + if (spec_metadata.use_rejection_sampling and draft_step is not None + and not spec_metadata.is_all_greedy_sample): + return self._draft_sampler_advanced_for_rejection( + logits, spec_metadata, batch_size, d2t, draft_step) + return self._draft_sampler_advanced(logits, spec_metadata, batch_size, + d2t) def prepare_1st_drafter_inputs( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 521e4a438c5f..fcf9865e4944 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1130,6 +1130,14 @@ class DecodingBaseConfig(StrictBaseModel): "false to fall back to exact-match verification on non-greedy batches. " "The non-dynamic-tree one-model path requires FlashInfer.") + allow_advanced_sampling: bool = Field( + default=False, + status="deprecated", + description= + "DEPRECATED: no-op kept for backward compatibility. Will be removed " + "in a future release. Non-greedy sampling is now auto-detected per " + "request; this flag no longer has any effect.") + # If set, drafting is allowed to use chain drafter. _allow_chain_drafter: bool = PrivateAttr(True) # If set, drafting uses greedy sampling, irrespective of sampling parameters. @@ -1194,6 +1202,23 @@ def validate_rejection_sampling_config(self): self.use_rejection_sampling = False return self + @model_validator(mode='before') + @classmethod + def _warn_deprecated_allow_advanced_sampling(cls, data): + """Warn when users set the deprecated allow_advanced_sampling flag. + + Non-greedy sampling is now auto-detected per request and always + available, so the flag is a no-op; warn loudly so callers update + their configs before the flag is removed. + """ + if isinstance(data, dict) and 'allow_advanced_sampling' in data: + logger.warning( + "DecodingBaseConfig: 'allow_advanced_sampling' is deprecated " + "and will be removed in a future release. The flag has no " + "effect — non-greedy sampling is now auto-detected per " + "request.") + return data + @model_validator(mode='after') # 1. Validate that max_concurrency and draft_len_schedule are mutually exclusive. # 2. If max_concurrency is set, translate it to the corresponding draft_len_schedule. @@ -4535,6 +4560,11 @@ def validate_speculative_config(self): # Rejection sampling is only wired up for Eagle3 one-model paths. # Silently fall back for other spec types so the new default # (True) does not break them. + # TODO: extend rejection sampling to the remaining speculative + # decoding paths (MTP / DraftTarget / PARD / DFlash / + # SaveHiddenStates / SA) and unify the dispatch in SpecMetadata + # so new spec algorithms get rejection sampling for free; once + # all paths are covered this whitelist guard can be removed. self.speculative_config.use_rejection_sampling = False if isinstance(self.speculative_config, PARDDecodingConfig): From 3f544f9376936e20980d7c919efa1fc84543aede Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Thu, 4 Jun 2026 19:19:16 -0700 Subject: [PATCH 10/21] [TRTLLM-12669][chore] Revert use_rejection_sampling default to False Flip the default back to False so existing deployments are not silently opted into rejection sampling behavior. Users who want the acceptance-rate gains must pass use_rejection_sampling=True explicitly. Signed-off-by: ZhaoyangWang --- tensorrt_llm/llmapi/llm_args.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index fcf9865e4944..9b317050a81d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1121,13 +1121,13 @@ class DecodingBaseConfig(StrictBaseModel): "PyTorch backend only.") use_rejection_sampling: bool = Field( - default=True, + default=False, status="prototype", description= - "If true (default), enables rejection sampling for one-model speculative " - "decoding paths when the batch contains any non-greedy request. All-greedy " - "batches always take the argmax fast path regardless of this flag. Set to " - "false to fall back to exact-match verification on non-greedy batches. " + "If true, enables rejection sampling for one-model speculative decoding " + "paths when the batch contains any non-greedy request. All-greedy batches " + "always take the argmax fast path regardless of this flag. Set to false " + "(default) to use exact-match verification on non-greedy batches. " "The non-dynamic-tree one-model path requires FlashInfer.") allow_advanced_sampling: bool = Field( From 616b446a77e6994b86df892b37fa024fcb33016d Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Thu, 4 Jun 2026 19:22:42 -0700 Subject: [PATCH 11/21] [TRTLLM-12669][fix] Restore last_tokens_idx dropped during rebase conflict resolution The rebase conflict resolution accidentally removed the last_tokens_idx initialisation that gather_ids depends on at draft step 0. Re-add the cumsum computation. Also revert use_rejection_sampling default to False so existing deployments are not silently opted into rejection sampling. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index c0af6092ddc9..0002d2931d98 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -772,6 +772,8 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, runtime_draft_len = spec_metadata.runtime_draft_len num_gens = batch_size - num_contexts next_draft_tokens = [] + last_tokens_idx = torch.cumsum( + attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 position_ids = inputs["position_ids"] with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): @@ -1193,7 +1195,6 @@ def draft_decoder( return self._draft_sampler_advanced(logits, spec_metadata, batch_size, d2t) - def prepare_1st_drafter_inputs( self, input_ids: torch.LongTensor, From 5641566abdc47902a521a064c620f049a2dd588e Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Fri, 5 Jun 2026 00:02:41 -0700 Subject: [PATCH 12/21] [TRTLLM-12669][fix] Remove stray allow_advanced_sampling arg in mtp_eagle_one_model path Commit 25c8c9fc3e removed allow_advanced_sampling from all Eagle3OneModelSpecMetadata constructor calls in utils.py except the is_mtp_eagle_one_model() branch, causing a TypeError at runtime when MTPDecodingConfig is used (e.g. DeepSeek-R1 throughput_mtp test on GB200). Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 8f132bbcdb22..91f60243834c 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -58,7 +58,6 @@ def get_spec_metadata(spec_config, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, - allow_advanced_sampling=spec_config.allow_advanced_sampling, use_rejection_sampling=use_rejection_sampling, vocab_size=vocab_size, spec_resource_manager=spec_resource_manager, From 511e3347fcbc1f4eb0c1b9196961d8a36e822051 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Mon, 8 Jun 2026 05:10:26 -0700 Subject: [PATCH 13/21] [TRTLLM-12669][fix] Fix draft_decoder AttributeError for MTP Eagle mode 'MTPForCausalLM' does not store its constructor's 'model' argument as self.model, so getattr(draft_model.model, "d2t", None) raised AttributeError when draft_decoder was called in MTP Eagle mode. Use nested getattr to safely return None when draft_model has no 'model' attribute (MTP Eagle never uses a compressed vocabulary so d2t is always None for that mode anyway). Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 0002d2931d98..fffc5c8384c4 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -939,7 +939,8 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, # flag and d2t for next-iter target-side verification. if spec_metadata.use_rejection_sampling: if not spec_metadata.is_all_greedy_sample: - d2t_param = getattr(draft_model.model, "d2t", None) + d2t_param = getattr(getattr(draft_model, 'model', None), "d2t", + None) spec_metadata.d2t = d2t_param.data if d2t_param is not None else None spec_metadata.draft_probs_valid = True else: @@ -1187,7 +1188,7 @@ def draft_decoder( the correct slice of spec_metadata.draft_probs. ''' - d2t = getattr(draft_model.model, "d2t", None) + d2t = getattr(getattr(draft_model, 'model', None), "d2t", None) if (spec_metadata.use_rejection_sampling and draft_step is not None and not spec_metadata.is_all_greedy_sample): return self._draft_sampler_advanced_for_rejection( From 1f670b99d38499e6179d766bcef76615b52ba40b Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Mon, 8 Jun 2026 17:40:38 -0700 Subject: [PATCH 14/21] [TRTLLM-12669][fix] Fix PARD buffer overflow and CUDA-graph-incompatible top_k_max Two bugs exposed by the new forced non-greedy CUDA graph capture pass: 1. SpecMetadata.populate_sampling_params_for_one_model: buffer size is tokens_per_request * max_num_requests, but warmup batches can have more total tokens when batch_size > max_num_requests. Fix by using max(static_required, actual_flat_size) for buffer allocation. 2. Eagle3 dynamic tree rejection: verify_dynamic_tree_rejection_from_logits_out computed top_k_max via boolean tensor indexing + .item(), both CUDA-graph-incompatible. Fix by: - Pre-computing top_k_max CPU-side in populate_sampling_params_for_one_model - Passing top_k_max=0 during stream capture (forces full-sort path, always correct) and the pre-computed value during eager execution - Adding top_k_max optional param to verify_dynamic_tree_rejection_from_logits_out Signed-off-by: ZhaoyangWang --- .../_torch/speculative/dynamic_tree_ops.py | 7 ++++++- .../_torch/speculative/eagle3_dynamic_tree.py | 9 +++++++++ tensorrt_llm/_torch/speculative/interface.py | 17 ++++++++++++++++- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py b/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py index c25c0c315523..bfb520f2d733 100644 --- a/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py +++ b/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py @@ -237,6 +237,7 @@ def verify_dynamic_tree_rejection_from_logits_out( offset: int | torch.Tensor = 0, d2t: torch.Tensor | None = None, skip_all_sampling_params: bool = False, + top_k_max: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Tree-aware rejection sampling from logits (three CUDA ops). @@ -266,9 +267,13 @@ def verify_dynamic_tree_rejection_from_logits_out( tree_valid = torch.ones(num_gens, dtype=torch.bool, device=candidates.device) tree_valid = tree_valid.contiguous() - if top_k is None: + if top_k_max is not None: + # Pre-computed CPU-side (CUDA-graph-safe): use as-is. + pass + elif top_k is None: top_k_max = 0 else: + # Fallback path (non-CUDA-graph contexts): compute from tensor. enabled_top_k = top_k[(top_k > 0) & (top_k < target_vocab_size)] top_k_max = int(enabled_top_k.max().item()) if enabled_top_k.numel() > 0 else 0 diff --git a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py index ddfd62812eb5..65d2871be24b 100644 --- a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py +++ b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py @@ -904,6 +904,15 @@ def _sample_and_accept_dynamic_tree( offset=self.offset, d2t=self._d2t, skip_all_sampling_params=skip_all_sampling_params, + # During CUDA graph capture bake top_k_max=0 so the + # full-sort (always-correct) path is captured. Outside + # capture, pass the pre-computed value for the fast + # topk(kMax) path. + top_k_max=( + 0 + if torch.cuda.is_current_stream_capturing() + else getattr(spec_metadata, "top_k_max", None) + ), ) ) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index d91c72374638..4446b6e2bad9 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -474,6 +474,9 @@ class SpecMetadata: skip_top_k: bool = False skip_top_p: bool = False has_greedy_requests: bool = False + # Pre-computed top_k_max scalar (CPU-side) to avoid CUDA-graph-incompatible + # dynamic boolean tensor indexing inside verify_dynamic_tree_rejection_from_logits_out. + top_k_max: int = 0 # Sampling parameters indexed per request. request_temperatures: Optional[torch.Tensor] = None request_top_ks: Optional[torch.Tensor] = None @@ -707,7 +710,12 @@ def _normalize_request_sampling_params( tokens_per_request = (self.max_total_draft_tokens + 1 if self.is_spec_dec_tree else self.max_draft_len + 1) - required_flat_size = tokens_per_request * self.max_num_requests + # Warmup batches may exceed max_num_requests * tokens_per_request (e.g. + # when CUDA-graph warmup passes use max_batch_size > max_num_requests). + actual_flat_size = sum( + num_tokens for _, _, _, num_tokens in per_request_normalized) + required_flat_size = max(tokens_per_request * self.max_num_requests, + actual_flat_size) if self.temperatures is None or self.temperatures.numel( ) < required_flat_size: @@ -790,6 +798,13 @@ def _normalize_request_sampling_params( non_blocking=True, ) + # Pre-compute top_k_max on the CPU so CUDA-graph capture does not + # encounter boolean-tensor indexing (dynamic size) or .item() calls. + # DISABLE_TOPK_VAL (INT32_MAX) is the sentinel for "top-k disabled". + _disable_topk = torch.iinfo(torch.int32).max + self.top_k_max = max( + (tk for tk in request_top_ks if 0 < tk < _disable_topk), default=0) + class SpecWorkerBase(nn.Module, ABC): """ From e66f8b213deb02f80268b04630f5a54ab1ff1f21 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Mon, 8 Jun 2026 22:54:58 -0700 Subject: [PATCH 15/21] [TRTLLM-12669][fix] disable rejection sampling during CUDA graph capture and fix PARD num_tokens shape mismatch - eagle3_dynamic_tree.py: _can_use_rejection_sampling now returns False when spec_metadata.is_cuda_graph is True. The rejection ops (compute_draft_probs_for_dynamic_tree_rejection_op) use a full-sort fallback with dynamic allocation that is incompatible with CUDA stream capture, causing cudaErrorStreamCaptureUnsupported. - interface.py: _sample_tokens_for_batch now derives num_tokens from logits.shape[0] instead of computing it from runtime_draft_len. For PARD under CUDA graph capture runtime_draft_len can be the PARD-max while the graph was built for a shorter draft_len, causing a shape mismatch in the torch.compiled sampling_batch_spec_dec_one_model. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py | 4 ++++ tensorrt_llm/_torch/speculative/interface.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py index 65d2871be24b..e7a6746d7c70 100644 --- a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py +++ b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py @@ -978,10 +978,14 @@ def _can_use_rejection_sampling(self, spec_metadata) -> bool: """ # Skip rejection sampling when the whole batch is greedy: argmax is # equivalent and avoids the rejection kernel cost. + # Also skip during CUDA graph capture/replay: the rejection ops use + # dynamic memory allocation (full-sort fallback) which is incompatible + # with stream capture. return ( spec_metadata.use_rejection_sampling and self._draft_depth_logits_cat is not None and not spec_metadata.is_all_greedy_sample + and not spec_metadata.is_cuda_graph ) def _finalize_dynamic_tree_verify_outputs( diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 4446b6e2bad9..c950730e6923 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -1568,9 +1568,11 @@ def _sample_tokens_for_batch( sampled_tokens: [num_tokens] - Sampled token ids """ if not spec_metadata.is_all_greedy_sample: - num_gens = batch_size - num_contexts - num_tokens = num_contexts + num_gens * ( - spec_metadata.runtime_draft_len + 1) + # Use logits.shape[0] directly: for PARD under CUDA graph capture + # runtime_draft_len may reflect the PARD-max while the captured + # graph was built for a shorter draft_len, causing a shape mismatch + # in sampling_batch_spec_dec_one_model (which is torch.compiled). + num_tokens = logits.shape[0] temperatures = spec_metadata.temperatures[:num_tokens] top_ks = spec_metadata.top_ks[:num_tokens] From 4f599c23ca081bc797ae2bc5ce2e4993b772fa8d Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Wed, 10 Jun 2026 04:34:00 -0700 Subject: [PATCH 16/21] [TRTLLM-11508][fix] trim draft token to token_count when use_lm_head_tp_in_adp Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index fffc5c8384c4..261bc1dd7d3c 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -867,6 +867,11 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, spec_metadata, batch_size, draft_step=i) + # When ADP+LM-head-TP pads logits to max_num_requests, the + # sampler returns max_num_requests tokens; trim back to the + # actual token_count so next_draft_tokens has the right shape. + if use_lm_head_tp_in_adp: + new_draft_token = new_draft_token[:token_count] next_draft_tokens.append(new_draft_token) # Update hidden states for the next iteration. From 20a94d24544eef18d3e0ab12374673fa4434929d Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Thu, 11 Jun 2026 03:17:41 -0700 Subject: [PATCH 17/21] [TRTLLM-12669][fix] slice padded draft logits before advanced sampling in ADP+LM-head-TP In the MTP-Eagle ADP + LM-head-TP path, draft logits are zero-padded to max_num_requests so every TP rank produces an identically-shaped tensor for the LM-head-TP all-gather. The refactored draft sampler applies per-request temperature/top_k/top_p tensors sized to token_count (== batch_size), so the padded logits ([max_num_requests, vocab]) failed to broadcast against the [batch_size, 1] temperature in apply_temperature, crashing torch.compile fake tensor tracing during executor worker init. Drop the padded rows before sampling (logits = logits[:token_count]) instead of trimming the sampled tokens afterwards. This keeps logits, next_draft_tokens and the draft_probs buffer token_count-sized and lets the per-request sampling params broadcast correctly. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 261bc1dd7d3c..a43b7d019426 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -862,16 +862,22 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, d2t, draft_step=i) + # When ADP+LM-head-TP pads logits to max_num_requests, the + # padded rows are zero-filled placeholders only required so + # every TP rank produces logits of identical shape for the + # LM-head-TP all-gather. Drop them *before* sampling: the + # per-request sampling params (temperatures/top_k/top_p) are + # sized to token_count (== batch_size), so the padded logits + # would otherwise fail to broadcast in apply_temperature. This + # also keeps next_draft_tokens and the draft_probs buffer + # token_count-sized without a post-hoc trim. + if use_lm_head_tp_in_adp: + logits = logits[:token_count] new_draft_token = self.draft_decoder(logits, draft_model, spec_metadata, batch_size, draft_step=i) - # When ADP+LM-head-TP pads logits to max_num_requests, the - # sampler returns max_num_requests tokens; trim back to the - # actual token_count so next_draft_tokens has the right shape. - if use_lm_head_tp_in_adp: - new_draft_token = new_draft_token[:token_count] next_draft_tokens.append(new_draft_token) # Update hidden states for the next iteration. From 7bfeba0bf70f747db740e294979fd3219837734f Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Fri, 12 Jun 2026 02:26:11 -0700 Subject: [PATCH 18/21] [TRTLLM-12669][fix] refresh is_all_greedy_sample before CUDA graph key selection The one-engine CUDA graph key includes is_all_greedy_sample to dispatch between the argmax fast-path and the advanced-sampling graph variant. The flag was only (re)computed inside populate_sampling_params_for_one_model, which runs in _prepare_inputs AFTER maybe_get_cuda_graph has already built the key. The key therefore used the previous iteration's stale flag, and warmup left it False (from the advanced-sampling capture pass). On the first real decode iteration a greedy batch would then replay the advanced-sampling graph while populate skips filling the sampling/draft_probs buffers, reading uninitialized slot-indexed data. For MTP with num_nextn>=2 this hung the executor (Hang detected on rank 0). Fix: - Extract the greediness detection into _scan_one_model_sampling (single source of truth) and add update_is_all_greedy_sample, called before the graph key is built so the key matches the buffers populate fills. populate now reuses the same scan. - Defensively reset spec_metadata.is_all_greedy_sample to True after CUDA graph warmup so the stale capture-only False does not seed the first iteration. Signed-off-by: ZhaoyangWang --- .../_torch/pyexecutor/model_engine.py | 18 ++++++ tensorrt_llm/_torch/speculative/interface.py | 62 ++++++++++++++----- 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index b4df20e60eaa..686ed4744e2a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1400,6 +1400,13 @@ def _run_capture_pass(force_non_greedy: bool, label: str) -> None: _run_capture_pass(force_non_greedy=True, label="advanced sampling") # Set the value back to the original value after cuda graph warmups are complete self.enable_spec_decode = self.is_spec_decode + # The advanced-sampling capture pass above leaves is_all_greedy_sample + # set to False on spec_metadata. Reset it to the default so the first + # real iteration's graph-key selection is not seeded with this + # capture-only value. (update_is_all_greedy_sample refreshes it every + # iteration; this is a defensive guard.) + if self.spec_metadata is not None: + self.spec_metadata.is_all_greedy_sample = True def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager): """Captures piecewise CUDA graphs for context/prefill steps via torch.compile.""" @@ -4720,6 +4727,17 @@ def forward(self, self.runtime_draft_len) as padded_requests: self._pad_batch_seed_mrope_delta_cache(padded_requests) + # Refresh is_all_greedy_sample for the *current* batch BEFORE the + # CUDA graph key is built below. The key includes this flag to pick + # the argmax vs advanced-sampling graph variant; populate (inside + # _prepare_inputs) runs later and fills the matching GPU buffers. + # Without this pre-scan the key would use the previous iteration's + # stale value and could replay the advanced graph against + # unpopulated (greedy) buffers, hanging the run (e.g. MTP nextn>=2). + if spec_metadata is not None: + spec_metadata.update_is_all_greedy_sample( + padded_requests.all_requests()) + maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph( padded_requests, enable_spec_decode=self.enable_spec_decode, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index c950730e6923..a6232866cc36 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from ..pyexecutor.guided_decoder import CapturableGuidedDecoder + from ..pyexecutor.llm_request import LlmRequest if IS_FLASHINFER_AVAILABLE: import flashinfer @@ -574,25 +575,20 @@ def maybe_capture_hidden_states(self, layer_id: int, model. Use this method to record them. By default, does nothing. """ - def populate_sampling_params_for_one_model( - self, requests: list["LlmRequest"]) -> None: - """ - Set up topp/topk/temperatures for 1-model sampler. + def _scan_one_model_sampling( + self, requests: list["LlmRequest"] + ) -> tuple[list[tuple[float, int, float, int]], list[int]]: + """Single source of truth for one-engine sampling-param detection. - Scans sampling configs to set skip_*/is_all_greedy_sample flags. When - any request needs sampling, also builds per-token/per-request lists - and copies them to GPU buffers; all-greedy batches skip this entirely. + Scans the batch's sampling configs and sets skip_*/has_greedy_requests/ + is_all_greedy_sample (honoring the warmup capture override). Returns + ``(per_request_normalized, per_request_slot_ids)`` for buffer + population. Does NOT allocate or fill GPU buffers, so it is safe to call + before the CUDA graph key is built. """ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState from tensorrt_llm.sampling_params import SamplingParams - if not self.spec_dec_mode.use_one_engine(): - return - - if self.temperatures is None: - # Ensures determinism across ranks. - torch.manual_seed(0) - # Need to use a very small value for temperature when disabled to avoid division by 0 DISABLE_TEMP_VAL = 1e-5 # Very large values disable topk. @@ -708,6 +704,44 @@ def _normalize_request_sampling_params( for (_, _, _, num_tokens) in per_request_normalized ] + return per_request_normalized, per_request_slot_ids + + def update_is_all_greedy_sample(self, requests: list["LlmRequest"]) -> None: + """Refresh ``is_all_greedy_sample`` for the *current* batch. + + Must be called BEFORE the CUDA graph key is built (the key includes + ``is_all_greedy_sample`` to choose the argmax vs advanced-sampling graph + variant). ``populate_sampling_params_for_one_model`` runs later, inside + ``_prepare_inputs``, and re-derives the same flag while filling the GPU + sampling buffers. Computing the flag here first keeps the selected graph + consistent with the buffers ``populate`` fills; otherwise the key would + use the previous iteration's stale value and could replay the advanced + graph against unpopulated (greedy) buffers, which can hang/corrupt the + run (notably for MTP with num_nextn>=2). + """ + if not self.spec_dec_mode.use_one_engine(): + return + self._scan_one_model_sampling(requests) + + def populate_sampling_params_for_one_model( + self, requests: list["LlmRequest"]) -> None: + """ + Set up topp/topk/temperatures for 1-model sampler. + + Scans sampling configs to set skip_*/is_all_greedy_sample flags. When + any request needs sampling, also builds per-token/per-request lists + and copies them to GPU buffers; all-greedy batches skip this entirely. + """ + if not self.spec_dec_mode.use_one_engine(): + return + + if self.temperatures is None: + # Ensures determinism across ranks. + torch.manual_seed(0) + + per_request_normalized, per_request_slot_ids = ( + self._scan_one_model_sampling(requests)) + tokens_per_request = (self.max_total_draft_tokens + 1 if self.is_spec_dec_tree else self.max_draft_len + 1) # Warmup batches may exceed max_num_requests * tokens_per_request (e.g. From 6cef5afe65b1c24bdbdadd8ef795788027f81052 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Sat, 13 Jun 2026 10:01:18 -0700 Subject: [PATCH 19/21] [TRTLLM-12669][fix] keep MTP-Eagle greedy draft sampling TP-aware to avoid multi-GPU hang draft_decoder routed the all-greedy fast path to _draft_sampler_greedy, a plain torch.argmax. For MTP-Eagle with a tensor-parallel draft LM head (tp_size>1 without attention DP, or LM-head-TP in ADP) the draft logits are sharded along the vocab dim, so a per-rank argmax selects a different token on each rank. The ranks then desync on the speculative-decoding control flow and the next collective deadlocks, observed as a generation hang on rank 0 (e.g. DeepSeek-V3-Lite tp4 + mtp_nextn>=2 + cuda_graph + torch_compile). Restore the TP-aware path: for the all-greedy case, MTP-Eagle now uses draft_sampler(), which all-gathers the sharded draft logits before argmax (and falls back to a plain argmax when no TP gather is needed). Eagle3 (non-MTP) keeps its d2t-aware argmax. This matches the pre-refactor behavior. Root-caused and verified by local reproduction (DeepSeek-V3-Lite, tp4, mtp_nextn=2, cuda_graph, torch_compile): baseline passes, the refactor hangs, and this fix restores passing. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index a43b7d019426..acf02fee409b 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1200,8 +1200,20 @@ def draft_decoder( ''' d2t = getattr(getattr(draft_model, 'model', None), "d2t", None) - if (spec_metadata.use_rejection_sampling and draft_step is not None - and not spec_metadata.is_all_greedy_sample): + # All-greedy fast path must stay TP-aware. When the draft LM head is + # tensor-parallel (tp_size>1 without attention DP, or LM-head-TP in + # ADP), the draft logits are sharded along the vocab dim. A plain + # per-rank argmax then picks a different token on each rank, which + # desyncs the speculative-decoding control flow across ranks and + # deadlocks the next collective (observed as a generation hang on + # MTP-Eagle + TP). draft_sampler() all-gathers the sharded logits + # before argmax (and falls back to a plain argmax when no TP gather is + # needed). Eagle3 (non-MTP) keeps its d2t-aware argmax. + if spec_metadata.is_all_greedy_sample: + if self.is_mtp_eagle: + return self.draft_sampler(logits) + return self._draft_sampler_greedy(logits, d2t) + if spec_metadata.use_rejection_sampling and draft_step is not None: return self._draft_sampler_advanced_for_rejection( logits, spec_metadata, batch_size, d2t, draft_step) return self._draft_sampler_advanced(logits, spec_metadata, batch_size, From 92131291c951cf84ad48d815316d3b8f853285ba Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Sat, 13 Jun 2026 21:16:28 -0700 Subject: [PATCH 20/21] [TRTLLM-12669][fix] all-gather sharded draft logits before advanced (non-greedy) sampling under TP The non-greedy draft sampling path (_draft_sampler_advanced) has the same multi-GPU hazard as the greedy path that was just fixed. With a plain tensor-parallel draft LM head (tp_size>1 without attention DP) each rank only holds a vocab shard of the draft logits, so per-rank random sampling draws a different token on each rank, desyncs the speculative-decoding control flow and deadlocks the next collective (generation hang). Greedy could be repaired with draft_sampler()'s lightweight max+index all-gather, but random sampling needs the full distribution, so all-gather the sharded draft logits into the full vocab before advanced sampling. Every rank then samples from the same distribution with the shared seed. The LM-head-TP-in-ADP path is gathered upstream and is intentionally excluded. Verified by local reproduction (DeepSeek-V3-Lite, tp4, mtp_nextn=2, cuda_graph, torch_compile, non-greedy temperature/top_k/top_p): hangs without this gather, passes with it. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index acf02fee409b..537572d00f60 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1213,6 +1213,21 @@ def draft_decoder( if self.is_mtp_eagle: return self.draft_sampler(logits) return self._draft_sampler_greedy(logits, d2t) + # Non-greedy (advanced) draft sampling has the same TP hazard as the + # greedy path: when the draft LM head is plain tensor-parallel + # (tp_size>1 without attention DP), each rank only holds a vocab shard + # of the draft logits. Random per-rank sampling then draws different + # tokens on different ranks, desyncing the spec-decode control flow and + # deadlocking the next collective. All-gather the shards into the full + # vocab first so every rank samples from the same distribution with the + # shared seed. (Greedy uses draft_sampler()'s lighter max+index gather; + # random sampling needs the full distribution. The LM-head-TP-in-ADP + # case is handled upstream and must not be gathered again here.) + if (self.is_mtp_eagle and self.model_config is not None + and hasattr(self.model_config, 'mapping') + and self.model_config.mapping.tp_size > 1 + and not self.model_config.mapping.enable_attention_dp): + logits = allgather(logits, self.model_config.mapping, dim=-1) if spec_metadata.use_rejection_sampling and draft_step is not None: return self._draft_sampler_advanced_for_rejection( logits, spec_metadata, batch_size, d2t, draft_step) From b3bdee0cbcd008816ea3270230877da28d767f78 Mon Sep 17 00:00:00 2001 From: ZhaoyangWang Date: Sun, 14 Jun 2026 09:51:45 -0700 Subject: [PATCH 21/21] [TRTLLM-12669][fix] only route plain-TP greedy MTP-Eagle draft sampling through draft_sampler The previous fix routed every greedy MTP-Eagle draft step through draft_sampler(), but that call does not forward mapping_lm_head_tp. For the LM-head-TP-in-ADP configuration draft_sampler() then takes its ADP branch with a None mapping and crashes during warmup with "'NoneType' object has no attribute 'tp_group'" (Executor worker returned error), e.g. DeepSeek-R1 nvfp4 latency_adp_lmtp_tp4. Only plain tensor parallelism (tp_size>1 without attention DP) shards the draft logits over the vocab dim and needs draft_sampler()'s all-gather argmax. The LM-head-TP-in-ADP case already yields full-vocab logits per rank (gathered upstream) and the no-TP / Eagle3 cases need nothing, so all of those take the plain d2t-aware argmax (_draft_sampler_greedy), restoring the pre-regression behavior for ADP while keeping the plain-TP hang fix. Signed-off-by: ZhaoyangWang --- tensorrt_llm/_torch/speculative/eagle3.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 537572d00f60..68dac4c5e9f0 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1210,7 +1210,18 @@ def draft_decoder( # before argmax (and falls back to a plain argmax when no TP gather is # needed). Eagle3 (non-MTP) keeps its d2t-aware argmax. if spec_metadata.is_all_greedy_sample: - if self.is_mtp_eagle: + # Only plain tensor parallelism (tp_size>1 without attention DP) + # shards the draft logits over the vocab dim and thus needs + # draft_sampler()'s all-gather argmax. The LM-head-TP-in-ADP case + # already produces full-vocab logits per rank (gathered upstream), + # and the no-TP / Eagle3 cases need nothing, so they take the plain + # d2t-aware argmax. (Routing ADP/LM-head-TP through draft_sampler + # without its mapping_lm_head_tp arg hits the None-mapping branch + # and crashes with 'NoneType has no attribute tp_group'.) + if (self.is_mtp_eagle and self.model_config is not None + and hasattr(self.model_config, 'mapping') + and self.model_config.mapping.tp_size > 1 + and not self.model_config.mapping.enable_attention_dp): return self.draft_sampler(logits) return self._draft_sampler_greedy(logits, d2t) # Non-greedy (advanced) draft sampling has the same TP hazard as the