diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 10901d87c520..8c449283aa2e 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -190,9 +190,6 @@ def add_llm_args(parser): default=False, action='store_true') parser.add_argument('--dynamic_tree_max_topK', type=int, default=None) - parser.add_argument('--allow_advanced_sampling', - default=False, - action='store_true') parser.add_argument('--eagle3_model_arch', type=str, default="llama3", @@ -294,7 +291,6 @@ def setup_llm(args, **kwargs): eagle_choices=args.eagle_choices, use_dynamic_tree=args.use_dynamic_tree, dynamic_tree_max_topK=args.dynamic_tree_max_topK, - allow_advanced_sampling=args.allow_advanced_sampling, eagle3_model_arch=args.eagle3_model_arch, max_total_draft_tokens=args.max_total_draft_tokens) elif spec_decode_algo == "DFLASH": diff --git a/examples/models/core/nemotron/README_nemotron_super_v3.md b/examples/models/core/nemotron/README_nemotron_super_v3.md index e78992359c19..1e59febce82e 100644 --- a/examples/models/core/nemotron/README_nemotron_super_v3.md +++ b/examples/models/core/nemotron/README_nemotron_super_v3.md @@ -144,7 +144,6 @@ kv_cache_config: speculative_config: decoding_type: MTP max_draft_len: 5 - allow_advanced_sampling: true cuda_graph_config: max_batch_size: 64 enable_padding: true diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 69bd91cd2178..f5e0820989a9 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -18,6 +18,7 @@ from ..memory_buffer_utils import get_memory_buffers from ..modules.multi_stream_utils import with_multi_stream from ..speculative.eagle3 import Eagle3ResourceManager +from ..speculative.interface import SpecMetadata from ..speculative.spec_sampler_base import SampleStateTensorsSpec from ..speculative.utils import get_draft_kv_cache_manager from ..utils import make_weak_ref, piecewise_cuda_graph @@ -29,7 +30,7 @@ # A large prime number used for dummy request IDs to avoid collisions CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1 -KeyType: TypeAlias = Tuple[int, int, bool, bool] +KeyType: TypeAlias = Tuple[int, int, bool, bool, bool] @dataclass @@ -197,19 +198,28 @@ def get_graph_key( self, batch: ScheduledRequests, new_tensors_device: Optional[SampleStateTensors] = None, - spec_resource_manager: Optional[BaseResourceManager] = None): + spec_resource_manager: Optional[BaseResourceManager] = None, + spec_metadata: Optional[SpecMetadata] = None): batch_size = batch.batch_size # Get the sequence length mode. short_seq_len_mode = self._get_seq_len_mode(batch, new_tensors_device) + # Spec one-engine sampler has two code paths (argmax fast-path vs + # advanced sampling kernel). Include this in the key so we capture + # both variants and dispatch at replay based on actual batch state. + # Default to True (greedy fast-path) when the metadata doesn't carry + # this field (non-one-engine paths or non-spec batches). + is_all_greedy_sample = bool( + getattr(spec_metadata, "is_all_greedy_sample", True)) + if self.config.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): # If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'. # Because we will pad the input to 'max_draft_len' length for the first draft layer. draft_len = self.config.original_max_draft_len if spec_resource_manager.is_first_draft else 0 key = (batch_size, draft_len, spec_resource_manager.is_first_draft, - short_seq_len_mode) + short_seq_len_mode, is_all_greedy_sample) else: # With dynamic spec decode, the draft length may be zero even when enable_spec_decode is True, # so we need to get the draft length from the batch instead of using enable_spec_decode. @@ -219,7 +229,8 @@ def get_graph_key( draft_len = max(draft_len_list) assert len( set(draft_len_list)) == 1, "All draft lengths must be the same" - key = (batch_size, draft_len, False, short_seq_len_mode) + key = (batch_size, draft_len, False, short_seq_len_mode, + is_all_greedy_sample) return key def __del__(self): @@ -230,7 +241,7 @@ def maybe_get_cuda_graph( batch: ScheduledRequests, enable_spec_decode: bool, attn_metadata: Any, - spec_metadata: Optional[Any] = None, + spec_metadata: Optional[SpecMetadata] = None, draft_tokens_cuda: Optional[torch.Tensor] = None, new_tensors_device: Optional[SampleStateTensors] = None, spec_resource_manager: Optional[BaseResourceManager] = None, @@ -273,7 +284,7 @@ def maybe_get_cuda_graph( # can replay CUDA graphs using the cache. return None, None, None key = self.get_graph_key(batch, new_tensors_device, - spec_resource_manager) + spec_resource_manager, spec_metadata) if key in self.graphs: return self.graph_metadata[key][ diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 152ed8c825fe..686ed4744e2a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -497,7 +497,6 @@ def __init__( sparse_attn_config=self.sparse_attention_config) if self.is_spec_decode: - self.spec_metadata = None update_spec_config_from_model_config(self.spec_config, self.model.config) max_num_draft_tokens = self.max_draft_loop_tokens * self.batch_size @@ -551,6 +550,7 @@ def __init__( # the model engine. self.attn_metadata = None self.encoder_attn_metadata = None + self.spec_metadata = None self.iter_states = {} self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None @@ -1343,33 +1343,70 @@ def _capture_generation_cuda_graphs(self, else: max_seq_len_list = [effective_max_seq_len] - for bs, draft_len in graphs_to_capture: - if bs > self.batch_size: - continue - - for max_seq_len in max_seq_len_list: - warmup_request = self._create_cuda_graph_warmup_request( - resource_manager, bs, draft_len, max_seq_len) - with self._release_batch_context(warmup_request, - resource_manager) as batch: - if batch is None: - # No KV cache space, cannot continue capturing graphs + def _run_capture_pass(force_non_greedy: bool, label: str) -> None: + spec_metadata = self.spec_metadata + if force_non_greedy and spec_metadata is not None: + spec_metadata._force_non_greedy_for_capture = True + # maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample + # to build the graph cache key BEFORE populate runs inside + # _prepare_inputs. Pre-flip it here so the very first capture + # in this pass uses the non-greedy key; populate's override + # below will keep it False on every subsequent iteration. + spec_metadata.is_all_greedy_sample = False + try: + for bs, draft_len in graphs_to_capture: + if bs > self.batch_size: continue - logger.info( - f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}, max_seq_len={max_seq_len}" - ) - self.enable_spec_decode = draft_len > 0 or self.is_draft_model or ( - self.spec_config is not None - and self.spec_config.spec_dec_mode.use_one_engine()) - self._update_draft_inference_state_for_warmup( - batch, draft_len > 0, resource_manager) - self.runtime_draft_len = draft_len - self.forward(batch, - new_tensors_device=None, - resource_manager=resource_manager) - torch.cuda.synchronize() + + for max_seq_len in max_seq_len_list: + warmup_request = self._create_cuda_graph_warmup_request( + resource_manager, bs, draft_len, max_seq_len) + with self._release_batch_context( + warmup_request, resource_manager) as batch: + if batch is None: + # No KV cache space, cannot continue capturing graphs + continue + logger.info( + f"Run generation-only CUDA graph warmup ({label}) " + f"for batch size={bs}, draft_len={draft_len}, " + f"max_seq_len={max_seq_len}") + self.enable_spec_decode = draft_len > 0 or self.is_draft_model or ( + self.spec_config is not None and + self.spec_config.spec_dec_mode.use_one_engine()) + self._update_draft_inference_state_for_warmup( + batch, draft_len > 0, resource_manager) + self.runtime_draft_len = draft_len + self.forward(batch, + new_tensors_device=None, + resource_manager=resource_manager) + torch.cuda.synchronize() + finally: + if force_non_greedy and spec_metadata is not None: + spec_metadata._force_non_greedy_for_capture = False + + # Pass 1: greedy fast-path (dummy requests carry no sampling params, + # so is_all_greedy_sample is naturally True). + _run_capture_pass(force_non_greedy=False, label="greedy") + # Pass 2: advanced sampling variant. Required because on-the-fly capture + # is disabled outside warmup, so any inference batch that contains a + # non-greedy request would otherwise fall back to eager. Only meaningful + # for one-engine spec dec (where is_all_greedy_sample participates in + # the graph key); other paths default to True and would never key into + # this variant. + needs_non_greedy_capture = ( + self.spec_config is not None + and self.spec_config.spec_dec_mode.use_one_engine()) + if needs_non_greedy_capture: + _run_capture_pass(force_non_greedy=True, label="advanced sampling") # Set the value back to the original value after cuda graph warmups are complete self.enable_spec_decode = self.is_spec_decode + # The advanced-sampling capture pass above leaves is_all_greedy_sample + # set to False on spec_metadata. Reset it to the default so the first + # real iteration's graph-key selection is not seeded with this + # capture-only value. (update_is_all_greedy_sample refreshes it every + # iteration; this is a defensive guard.) + if self.spec_metadata is not None: + self.spec_metadata.is_all_greedy_sample = True def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager): """Captures piecewise CUDA graphs for context/prefill steps via torch.compile.""" @@ -4690,6 +4727,17 @@ def forward(self, self.runtime_draft_len) as padded_requests: self._pad_batch_seed_mrope_delta_cache(padded_requests) + # Refresh is_all_greedy_sample for the *current* batch BEFORE the + # CUDA graph key is built below. The key includes this flag to pick + # the argmax vs advanced-sampling graph variant; populate (inside + # _prepare_inputs) runs later and fills the matching GPU buffers. + # Without this pre-scan the key would use the previous iteration's + # stale value and could replay the advanced graph against + # unpopulated (greedy) buffers, hanging the run (e.g. MTP nextn>=2). + if spec_metadata is not None: + spec_metadata.update_is_all_greedy_sample( + padded_requests.all_requests()) + maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph( padded_requests, enable_spec_decode=self.enable_spec_decode, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 25ab939758a9..797f2fd48666 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -406,12 +406,6 @@ def create_py_executor( ) llm_args.disable_overlap_scheduler = True - if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(): - if not spec_config.allow_advanced_sampling: - logger.warning( - f"Falling back to greedy decoding for {spec_config.decoding_type}. If you " - "want to use non-greedy sampling, please set allow_advanced_sampling=True." - ) # Check FLASHINFER compatibility with one-engine speculative decoding if llm_args.attn_backend == "FLASHINFER": raise ValueError( diff --git a/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py b/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py index c25c0c315523..bfb520f2d733 100644 --- a/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py +++ b/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py @@ -237,6 +237,7 @@ def verify_dynamic_tree_rejection_from_logits_out( offset: int | torch.Tensor = 0, d2t: torch.Tensor | None = None, skip_all_sampling_params: bool = False, + top_k_max: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Tree-aware rejection sampling from logits (three CUDA ops). @@ -266,9 +267,13 @@ def verify_dynamic_tree_rejection_from_logits_out( tree_valid = torch.ones(num_gens, dtype=torch.bool, device=candidates.device) tree_valid = tree_valid.contiguous() - if top_k is None: + if top_k_max is not None: + # Pre-computed CPU-side (CUDA-graph-safe): use as-is. + pass + elif top_k is None: top_k_max = 0 else: + # Fallback path (non-CUDA-graph contexts): compute from tensor. enabled_top_k = top_k[(top_k > 0) & (top_k < target_vocab_size)] top_k_max = int(enabled_top_k.max().item()) if enabled_top_k.numel() > 0 else 0 diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 6acef9ed348f..68dac4c5e9f0 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -772,9 +772,9 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, runtime_draft_len = spec_metadata.runtime_draft_len num_gens = batch_size - num_contexts next_draft_tokens = [] - draft_logits_list = [] last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 + position_ids = inputs["position_ids"] with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): for i in range(runtime_draft_len): @@ -862,31 +862,22 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, d2t, draft_step=i) - # Sample the next draft token. - # MTP Eagle: TP-aware sampler; when ADP+LM-head-TP is active - # logits are padded to max_num_requests across TP ranks, so - # the result must be trimmed back to token_count. - # Eagle3: simple greedy sampling; d2t remaps vocab indices when - # the draft model uses a compressed vocabulary. - if self.is_mtp_eagle: - if use_lm_head_tp_in_adp: - mapping_lm_head_tp = draft_model.mtp_layers[ - 0].shared_head.mapping_lm_head_tp - new_draft_token = self.draft_sampler( - logits, mapping_lm_head_tp) - new_draft_token = new_draft_token[:token_count] - else: - new_draft_token = self.draft_sampler(logits) - else: - d2t = getattr(draft_model.model, "d2t", None) - new_draft_token = self._draft_sampler_greedy(logits, d2t) - - # Stash unpadded Eagle3 draft logits for rejection sampling on - # the next iteration. MTP Eagle's logits may be ADP-padded to - # max_num_requests, so we skip them here. - if not self.is_mtp_eagle and spec_metadata.use_rejection_sampling: - draft_logits_list.append(logits.clone()) - + # When ADP+LM-head-TP pads logits to max_num_requests, the + # padded rows are zero-filled placeholders only required so + # every TP rank produces logits of identical shape for the + # LM-head-TP all-gather. Drop them *before* sampling: the + # per-request sampling params (temperatures/top_k/top_p) are + # sized to token_count (== batch_size), so the padded logits + # would otherwise fail to broadcast in apply_temperature. This + # also keeps next_draft_tokens and the draft_probs buffer + # token_count-sized without a post-hoc trim. + if use_lm_head_tp_in_adp: + logits = logits[:token_count] + new_draft_token = self.draft_decoder(logits, + draft_model, + spec_metadata, + batch_size, + draft_step=i) next_draft_tokens.append(new_draft_token) # Update hidden states for the next iteration. @@ -952,11 +943,19 @@ def _forward_linear_draft_loop(self, inputs, attn_metadata, spec_metadata, gen_draft_tokens) next_draft_tokens[num_contexts:] = gen_draft_tokens - if spec_metadata.use_rejection_sampling and draft_logits_list: - d2t_param = getattr(draft_model.model, "d2t", None) - spec_metadata.d2t = d2t_param.data if d2t_param is not None else None - self._compute_and_store_draft_probs(draft_logits_list, - spec_metadata, batch_size) + # Probs were already scattered into the slot-indexed buffer by + # _draft_sampler_advanced_for_rejection on each draft step (non-greedy + # batches only). All-greedy batches skip storage — rejection sampling + # will be bypassed by _can_use_rejection_sampling. Finalize the validity + # flag and d2t for next-iter target-side verification. + if spec_metadata.use_rejection_sampling: + if not spec_metadata.is_all_greedy_sample: + d2t_param = getattr(getattr(draft_model, 'model', None), "d2t", + None) + spec_metadata.d2t = d2t_param.data if d2t_param is not None else None + spec_metadata.draft_probs_valid = True + else: + spec_metadata.draft_probs_valid = False return next_draft_tokens @@ -1174,6 +1173,78 @@ def sample_and_accept_draft_tokens( return self._accept_draft_tokens(logits, draft_tokens, num_contexts, batch_size, spec_metadata) + def draft_decoder( + self, + logits: torch.Tensor, + draft_model: nn.Module, + spec_metadata: Eagle3OneModelSpecMetadata, + batch_size: int, + draft_step: Optional[int] = None, + ): + ''' + Sample draft tokens using the target's per-request sampling params + (temperature/top_k/top_p). + + When rejection sampling is enabled and draft_step is provided, take the + single-pass path that also scatters the draft prob distribution into the + slot-indexed buffer (avoids a redundant softmax later). + + Args: + logits: [batch_size, vocab_size] - Draft model logits. + draft_model: The draft model. + spec_metadata: Carries per-request sampling param tensors. + batch_size: Active requests, used to slice per-request tensors. + draft_step: Current draft step index (0..max_draft_len-1). Required + for the rejection-sampling code path so probs are written to + the correct slice of spec_metadata.draft_probs. + ''' + + d2t = getattr(getattr(draft_model, 'model', None), "d2t", None) + # All-greedy fast path must stay TP-aware. When the draft LM head is + # tensor-parallel (tp_size>1 without attention DP, or LM-head-TP in + # ADP), the draft logits are sharded along the vocab dim. A plain + # per-rank argmax then picks a different token on each rank, which + # desyncs the speculative-decoding control flow across ranks and + # deadlocks the next collective (observed as a generation hang on + # MTP-Eagle + TP). draft_sampler() all-gathers the sharded logits + # before argmax (and falls back to a plain argmax when no TP gather is + # needed). Eagle3 (non-MTP) keeps its d2t-aware argmax. + if spec_metadata.is_all_greedy_sample: + # Only plain tensor parallelism (tp_size>1 without attention DP) + # shards the draft logits over the vocab dim and thus needs + # draft_sampler()'s all-gather argmax. The LM-head-TP-in-ADP case + # already produces full-vocab logits per rank (gathered upstream), + # and the no-TP / Eagle3 cases need nothing, so they take the plain + # d2t-aware argmax. (Routing ADP/LM-head-TP through draft_sampler + # without its mapping_lm_head_tp arg hits the None-mapping branch + # and crashes with 'NoneType has no attribute tp_group'.) + if (self.is_mtp_eagle and self.model_config is not None + and hasattr(self.model_config, 'mapping') + and self.model_config.mapping.tp_size > 1 + and not self.model_config.mapping.enable_attention_dp): + return self.draft_sampler(logits) + return self._draft_sampler_greedy(logits, d2t) + # Non-greedy (advanced) draft sampling has the same TP hazard as the + # greedy path: when the draft LM head is plain tensor-parallel + # (tp_size>1 without attention DP), each rank only holds a vocab shard + # of the draft logits. Random per-rank sampling then draws different + # tokens on different ranks, desyncing the spec-decode control flow and + # deadlocking the next collective. All-gather the shards into the full + # vocab first so every rank samples from the same distribution with the + # shared seed. (Greedy uses draft_sampler()'s lighter max+index gather; + # random sampling needs the full distribution. The LM-head-TP-in-ADP + # case is handled upstream and must not be gathered again here.) + if (self.is_mtp_eagle and self.model_config is not None + and hasattr(self.model_config, 'mapping') + and self.model_config.mapping.tp_size > 1 + and not self.model_config.mapping.enable_attention_dp): + logits = allgather(logits, self.model_config.mapping, dim=-1) + if spec_metadata.use_rejection_sampling and draft_step is not None: + return self._draft_sampler_advanced_for_rejection( + logits, spec_metadata, batch_size, d2t, draft_step) + return self._draft_sampler_advanced(logits, spec_metadata, batch_size, + d2t) + def prepare_1st_drafter_inputs( self, input_ids: torch.LongTensor, diff --git a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py index 47376001166d..e7a6746d7c70 100644 --- a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py +++ b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py @@ -904,6 +904,15 @@ def _sample_and_accept_dynamic_tree( offset=self.offset, d2t=self._d2t, skip_all_sampling_params=skip_all_sampling_params, + # During CUDA graph capture bake top_k_max=0 so the + # full-sort (always-correct) path is captured. Outside + # capture, pass the pre-computed value for the fast + # topk(kMax) path. + top_k_max=( + 0 + if torch.cuda.is_current_stream_capturing() + else getattr(spec_metadata, "top_k_max", None) + ), ) ) @@ -967,7 +976,17 @@ def _can_use_rejection_sampling(self, spec_metadata) -> bool: Returns: True if rejection sampling is enabled and the draft logit buffer is allocated """ - return spec_metadata.use_rejection_sampling and self._draft_depth_logits_cat is not None + # Skip rejection sampling when the whole batch is greedy: argmax is + # equivalent and avoids the rejection kernel cost. + # Also skip during CUDA graph capture/replay: the rejection ops use + # dynamic memory allocation (full-sort fallback) which is incompatible + # with stream capture. + return ( + spec_metadata.use_rejection_sampling + and self._draft_depth_logits_cat is not None + and not spec_metadata.is_all_greedy_sample + and not spec_metadata.is_cuda_graph + ) def _finalize_dynamic_tree_verify_outputs( self, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index c62111f0f511..a6232866cc36 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -36,13 +36,15 @@ if TYPE_CHECKING: from ..pyexecutor.guided_decoder import CapturableGuidedDecoder + from ..pyexecutor.llm_request import LlmRequest if IS_FLASHINFER_AVAILABLE: import flashinfer from .one_model_sampler import (compute_probs_from_logits, rejection_sampling_one_model, - sampling_batch_spec_dec_one_model) + sampling_batch_spec_dec_one_model, + sampling_batch_spec_dec_one_model_for_rejection) # Environment variable name for forcing the number of accepted tokens in speculative decoding FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" @@ -454,8 +456,14 @@ class SpecMetadata: # Always set by model_engine.forward() before any downstream code reads it. runtime_draft_len: int = 0 - # For non-greedy sampling on 1-model. - allow_advanced_sampling: bool = False + # Auto-detected per step from populated sampling params: + # True if every request is greedy (no temp/top_k/top_p) and we can take + # the argmax fast-path. False if any request needs sampling. + # Used as part of the CUDA graph key so we capture two variants + # (greedy fast-path vs advanced sampling) and dispatch at replay. + # Defaults to True so non-one-engine paths (where populate is a no-op) + # never accidentally select the advanced graph variant. + is_all_greedy_sample: bool = True # Whether to use rejection sampling for one-model speculative decoding. use_rejection_sampling: bool = False # Sampling parameters for non-greedy sampling (per-request) @@ -467,6 +475,9 @@ class SpecMetadata: skip_top_k: bool = False skip_top_p: bool = False has_greedy_requests: bool = False + # Pre-computed top_k_max scalar (CPU-side) to avoid CUDA-graph-incompatible + # dynamic boolean tensor indexing inside verify_dynamic_tree_rejection_from_logits_out. + top_k_max: int = 0 # Sampling parameters indexed per request. request_temperatures: Optional[torch.Tensor] = None request_top_ks: Optional[torch.Tensor] = None @@ -475,15 +486,34 @@ class SpecMetadata: use_sampling_params_for_draft_tokens: bool = False # Vocab size used for draft_probs buffer allocation. vocab_size: int = 0 - # Draft probabilities buffer for rejection sampling, stored flat. + # Draft probabilities buffer for rejection sampling, indexed by py_seq_slot + # so per-request data is stable across iterations regardless of batch + # composition shifts (chunking ctx, gen completion, new ctx joining). + # Shape: [max_num_requests, max_draft_len, vocab_size]. draft_probs: Optional[torch.Tensor] = None draft_probs_vocab_size: int = 0 # Whether draft_probs contains valid data. draft_probs_valid: bool = False # Last dimension size of the draft logits/probs stored in draft_probs. draft_probs_last_dim: int = 0 + # Per-request slot ids (py_seq_slot) for the current batch, in batch order. + # Used to scatter draft probs by slot at write time and gather them by slot + # at the next iter's verify. Shape: [max_num_requests], dtype=long. + batch_slot_ids: Optional[torch.Tensor] = None # Draft-to-target vocab offset tensor. d2t: Optional[torch.Tensor] = None + # Pre-allocated scratch for draft probs expanded to the target vocab size. + # Filled with zeros once at prepare(); each rejection iter only overwrites + # the positions selected by d2t (or [:draft_vocab] when there is no d2t), + # so the zeros outside those positions persist across iterations and we + # avoid a per-iter 64 MB zero-fill on the (max_num_requests, max_draft_len, + # vocab_size) tensor. Shape: [max_num_requests, max_draft_len, vocab_size]. + full_draft_probs: Optional[torch.Tensor] = None + # Cached d2t-projected target vocab indices, computed once on first use + # (d2t is a model-static tensor). Replaces the per-iter + # arange + (source + d2t) % vocab_size kernel sequence inside the d2t + # padding step. Shape: [draft_vocab_size], dtype long. + d2t_target_indices: Optional[torch.Tensor] = None def __post_init__(self): pass @@ -494,12 +524,28 @@ def prepare(self): """ if (self.use_rejection_sampling and self.draft_probs is None and self.vocab_size > 0): - buffer_size = (self.max_num_requests * self.max_draft_len * - self.vocab_size) - self.draft_probs = torch.empty(buffer_size, - dtype=torch.float32, - device='cuda') + # 3D [slot, draft_step, vocab] so we can scatter/gather by slot id + # and avoid the brittle "batch position == buffer position" mapping. + self.draft_probs = torch.empty( + (self.max_num_requests, self.max_draft_len, self.vocab_size), + dtype=torch.float32, + device='cuda') self.draft_probs_vocab_size = self.vocab_size + if (self.use_rejection_sampling and self.batch_slot_ids is None + and self.max_num_requests > 0): + self.batch_slot_ids = torch.empty((self.max_num_requests, ), + dtype=torch.long, + device='cuda') + if (self.use_rejection_sampling and self.full_draft_probs is None + and self.vocab_size > 0): + # Zero-fill once. Subsequent iters only overwrite the d2t-mapped + # positions (constant across iters since d2t is model-static), so + # untouched positions stay 0 forever — saves the per-iter 64 MB + # zero-fill in _sample_and_accept_draft_tokens_rejection. + self.full_draft_probs = torch.zeros( + (self.max_num_requests, self.max_draft_len, self.vocab_size), + dtype=torch.float32, + device='cuda') def create_cuda_graph_metadata(self, max_batch_size: int): """ @@ -529,33 +575,20 @@ def maybe_capture_hidden_states(self, layer_id: int, model. Use this method to record them. By default, does nothing. """ - def populate_sampling_params_for_one_model( - self, requests: list["LlmRequest"]) -> None: - """ - Set up topp/topk/temperatures for 1-model sampler. + def _scan_one_model_sampling( + self, requests: list["LlmRequest"] + ) -> tuple[list[tuple[float, int, float, int]], list[int]]: + """Single source of truth for one-engine sampling-param detection. + + Scans the batch's sampling configs and sets skip_*/has_greedy_requests/ + is_all_greedy_sample (honoring the warmup capture override). Returns + ``(per_request_normalized, per_request_slot_ids)`` for buffer + population. Does NOT allocate or fill GPU buffers, so it is safe to call + before the CUDA graph key is built. """ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState from tensorrt_llm.sampling_params import SamplingParams - if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine( - ): - return - - if self.temperatures is None: - # Ensures determinism across ranks. - torch.manual_seed(0) - - temperatures = [] - top_ks = [] - top_ps = [] - request_temperatures = [] - request_top_ks = [] - request_top_ps = [] - top_k_enabled = False - top_p_enabled = False - has_greedy_requests = False - temperature_enabled = False - # Need to use a very small value for temperature when disabled to avoid division by 0 DISABLE_TEMP_VAL = 1e-5 # Very large values disable topk. @@ -601,6 +634,14 @@ def _normalize_request_sampling_params( is_greedy, ) + # Phase 1: collect per-request flags and normalized values. + per_request_normalized: list[tuple[float, int, float, int]] = [] + temperature_enabled = False + top_k_enabled = False + top_p_enabled = False + has_greedy_requests = False + per_request_slot_ids: list[int] = [] + for request in requests: sampling_config = request.sampling_config temp_val = _first_or_none(sampling_config.temperature) @@ -629,19 +670,90 @@ def _normalize_request_sampling_params( top_p_enabled |= use_top_p has_greedy_requests |= is_greedy - request_temperatures.append(temp_val) - request_top_ks.append(tk_val) - request_top_ps.append(tp_val) - temperatures.extend(temp_val for _ in range(num_tokens)) - top_ks.extend(tk_val for _ in range(num_tokens)) - top_ps.extend(tp_val for _ in range(num_tokens)) + per_request_normalized.append( + (temp_val, tk_val, tp_val, num_tokens)) + # py_seq_slot is a stable per-request id used to scatter / gather + # draft probs across iterations. Dummies / unallocated slots fall + # back to 0 (any valid index is fine — the data at that slot will + # be overwritten on the next real iteration before being read). + per_request_slot_ids.append( + request.py_seq_slot if request.py_seq_slot is not None else 0) + + self.skip_temperature = not temperature_enabled + self.skip_top_k = not top_k_enabled + self.skip_top_p = not top_p_enabled + self.has_greedy_requests = has_greedy_requests + # Used in the CUDA graph key to pick the argmax / advanced variant. + self.is_all_greedy_sample = (self.skip_temperature and self.skip_top_k + and self.skip_top_p) + + # Warmup-time override (set via runtime attribute by the model engine): + # force the advanced-sampling code path so the CUDA graph for the + # (is_all_greedy_sample=False) key gets captured. Dummy warmup requests + # carry no sampling params, so the natural detection above always + # returns True; this branch substitutes synthetic non-greedy scalars + # into the per-request data and lets Phase 2 run normally to populate + # the GPU buffers used by the captured kernels. + if getattr(self, '_force_non_greedy_for_capture', False): + self.skip_temperature = False + self.skip_top_k = False + self.skip_top_p = False + self.is_all_greedy_sample = False + per_request_normalized = [ + (0.7, 50, 0.9, num_tokens) + for (_, _, _, num_tokens) in per_request_normalized + ] + + return per_request_normalized, per_request_slot_ids + + def update_is_all_greedy_sample(self, requests: list["LlmRequest"]) -> None: + """Refresh ``is_all_greedy_sample`` for the *current* batch. + + Must be called BEFORE the CUDA graph key is built (the key includes + ``is_all_greedy_sample`` to choose the argmax vs advanced-sampling graph + variant). ``populate_sampling_params_for_one_model`` runs later, inside + ``_prepare_inputs``, and re-derives the same flag while filling the GPU + sampling buffers. Computing the flag here first keeps the selected graph + consistent with the buffers ``populate`` fills; otherwise the key would + use the previous iteration's stale value and could replay the advanced + graph against unpopulated (greedy) buffers, which can hang/corrupt the + run (notably for MTP with num_nextn>=2). + """ + if not self.spec_dec_mode.use_one_engine(): + return + self._scan_one_model_sampling(requests) + + def populate_sampling_params_for_one_model( + self, requests: list["LlmRequest"]) -> None: + """ + Set up topp/topk/temperatures for 1-model sampler. + + Scans sampling configs to set skip_*/is_all_greedy_sample flags. When + any request needs sampling, also builds per-token/per-request lists + and copies them to GPU buffers; all-greedy batches skip this entirely. + """ + if not self.spec_dec_mode.use_one_engine(): + return + + if self.temperatures is None: + # Ensures determinism across ranks. + torch.manual_seed(0) + + per_request_normalized, per_request_slot_ids = ( + self._scan_one_model_sampling(requests)) tokens_per_request = (self.max_total_draft_tokens + 1 if self.is_spec_dec_tree else self.max_draft_len + 1) - required_flat_size = tokens_per_request * self.max_num_requests + # Warmup batches may exceed max_num_requests * tokens_per_request (e.g. + # when CUDA-graph warmup passes use max_batch_size > max_num_requests). + actual_flat_size = sum( + num_tokens for _, _, _, num_tokens in per_request_normalized) + required_flat_size = max(tokens_per_request * self.max_num_requests, + actual_flat_size) if self.temperatures is None or self.temperatures.numel( ) < required_flat_size: + # Allocate once; the captured graph reads from these stable addresses. self.temperatures = torch.ones(required_flat_size, dtype=torch.float32, device='cuda') @@ -661,6 +773,38 @@ def _normalize_request_sampling_params( dtype=torch.float32, device='cuda') + # Always-populate the per-request slot id table when rejection sampling + # is configured: it's tiny (max_num_requests longs) and needed at + # draft-sampler time to scatter draft probs by slot. + if self.use_rejection_sampling and self.batch_slot_ids is not None: + self.batch_slot_ids[:len(per_request_slot_ids)].copy_( + torch.tensor(per_request_slot_ids, + dtype=torch.long, + pin_memory=prefer_pinned()), + non_blocking=True, + ) + + # All-greedy: sampler takes the argmax branch (and rejection sampling + # is also bypassed for all-greedy), so the per-token buffers are never + # read. Skip the heavier H->D copies. + if self.is_all_greedy_sample: + return + + # Phase 2: build per-token / per-request lists and copy to GPU. + temperatures: list[float] = [] + top_ks: list[int] = [] + top_ps: list[float] = [] + request_temperatures: list[float] = [] + request_top_ks: list[int] = [] + request_top_ps: list[float] = [] + for temp_val, tk_val, tp_val, num_tokens in per_request_normalized: + request_temperatures.append(temp_val) + request_top_ks.append(tk_val) + request_top_ps.append(tp_val) + temperatures.extend(temp_val for _ in range(num_tokens)) + top_ks.extend(tk_val for _ in range(num_tokens)) + top_ps.extend(tp_val for _ in range(num_tokens)) + self.temperatures[:len(temperatures)].copy_(torch.tensor( temperatures, dtype=torch.float32, pin_memory=prefer_pinned()), non_blocking=True) @@ -687,10 +831,13 @@ def _normalize_request_sampling_params( pin_memory=prefer_pinned()), non_blocking=True, ) - self.skip_temperature = not temperature_enabled - self.skip_top_k = not top_k_enabled - self.skip_top_p = not top_p_enabled - self.has_greedy_requests = has_greedy_requests + + # Pre-compute top_k_max on the CPU so CUDA-graph capture does not + # encounter boolean-tensor indexing (dynamic size) or .item() calls. + # DISABLE_TOPK_VAL (INT32_MAX) is the sentinel for "top-k disabled". + _disable_topk = torch.iinfo(torch.int32).max + self.top_k_max = max( + (tk for tk in request_top_ks if 0 < tk < _disable_topk), default=0) class SpecWorkerBase(nn.Module, ABC): @@ -1012,109 +1159,176 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts, batch_size, spec_metadata): """ Accept draft tokens with optional rejection sampling support. + + Mixed batches (num_contexts > 0) are supported: context rows take the + first sampled target token via the base logic, and rejection sampling + runs on the gen subset. Draft probs for the gen subset are gathered + from the slot-indexed buffer by `py_seq_slot`. """ - if self._can_use_rejection_sampling(spec_metadata, num_contexts): + num_gens = batch_size - num_contexts + if num_gens > 0 and self._can_use_rejection_sampling(spec_metadata): draft_len = draft_tokens.shape[1] stored_vocab = (spec_metadata.draft_probs_last_dim if spec_metadata.draft_probs_last_dim > 0 else spec_metadata.draft_probs_vocab_size) - draft_probs = spec_metadata.draft_probs[:batch_size * draft_len * - stored_vocab].reshape( - batch_size, draft_len, - stored_vocab) + # Gather the slot rows for the gen subset. The buffer was filled + # at the previous draft step indexed by py_seq_slot, so each gen + # request reads back exactly its own probs, regardless of batch + # composition changes since then. + gen_slot_ids = spec_metadata.batch_slot_ids[num_contexts:batch_size] + draft_probs = spec_metadata.draft_probs[ + gen_slot_ids, :draft_len, :stored_vocab] return self._sample_and_accept_draft_tokens_rejection( - logits, draft_tokens, draft_probs, batch_size, spec_metadata) + logits, draft_tokens, draft_probs, num_contexts, batch_size, + spec_metadata) return self._sample_and_accept_draft_tokens_base( logits, draft_tokens, num_contexts, batch_size, spec_metadata) - def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata, - num_contexts: int) -> bool: + def _can_use_rejection_sampling(self, spec_metadata: SpecMetadata) -> bool: + # Skip rejection sampling when the whole batch is greedy: the accepted + # result is identical to argmax and the base path is cheaper. Mixed + # batches (context + gen) are handled via slot-indexed draft probs and + # are split inside _sample_and_accept_draft_tokens_rejection. return (spec_metadata.use_rejection_sampling - and spec_metadata.draft_probs_valid and num_contexts == 0) + and spec_metadata.draft_probs_valid + and not spec_metadata.is_all_greedy_sample) def _sample_and_accept_draft_tokens_rejection( self, logits: torch.Tensor, draft_tokens: torch.Tensor, draft_probs: torch.Tensor, + num_contexts: int, batch_size: int, spec_metadata, ): """ Rejection-sampling acceptance for one-model speculative decoding. + + Mixed batches are handled by treating the two subsets separately: + - context rows (first `num_contexts`) take the target's sampled first + token; no draft tokens to verify. + - generation rows (`[num_contexts:batch_size]`) run the rejection + sampling kernel on slot-gathered draft probs. + + Per-token sampling-parameter tensors (`temperatures / top_ks / top_ps`) + are laid out as `[ctx (1 each), gen (draft_len+1 each)]`, matching the + logits layout, so slicing is symmetric for both subsets. """ device = logits.device vocab_size = logits.shape[-1] + num_gens = batch_size - num_contexts + runtime_draft_len = draft_tokens.shape[1] if logits.dim() == 1: logits = logits.unsqueeze(0) - runtime_draft_len = draft_tokens.shape[1] - draft_vocab_size = draft_probs.shape[-1] - num_target_tokens = batch_size * (runtime_draft_len + 1) - - temperatures = spec_metadata.temperatures[:num_target_tokens] - # Pass None instead of an all-disabled tensor so the C++ op can short-circuit - # on a host-side check rather than a `.item()` sync, which would break - # CUDA graph capture. - top_ks = None if spec_metadata.skip_top_k else spec_metadata.top_ks[: - num_target_tokens] - top_ps = None if spec_metadata.skip_top_p else spec_metadata.top_ps[: - num_target_tokens] - - target_probs_flat = compute_probs_from_logits(logits.clone(), - temperatures, top_ks, - top_ps) - target_probs = target_probs_flat.reshape(batch_size, - runtime_draft_len + 1, - vocab_size) - - assert draft_probs.shape[1] == runtime_draft_len, ( - f"draft_probs draft length mismatch: {draft_probs.shape[1]} != " - f"{runtime_draft_len}") - d2t = getattr(spec_metadata, "d2t", None) - if draft_vocab_size != vocab_size: - full_draft_probs = torch.zeros( - (batch_size, runtime_draft_len, vocab_size), - dtype=torch.float32, - device=device) - if d2t is not None: - assert d2t.numel() == draft_vocab_size, ( - f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}") - d2t = d2t.to(device=device) - source_indices = torch.arange(draft_vocab_size, - device=device, - dtype=torch.long) - target_indices = (source_indices + d2t) % vocab_size - full_draft_probs[:, :runtime_draft_len, - target_indices] = draft_probs + accepted_tokens = torch.empty((batch_size, runtime_draft_len + 1), + dtype=torch.int, + device=device) + num_accepted_tokens = torch.ones(batch_size, + dtype=torch.int, + device=device) + + # === Context subset: sample target's first token directly === + if num_contexts > 0: + ctx_target_tokens = self._sample_tokens_for_batch( + logits[:num_contexts], spec_metadata, num_contexts, + num_contexts) + accepted_tokens[:num_contexts, 0] = ctx_target_tokens + + # === Generation subset: rejection sampling on the gen slice === + if num_gens > 0: + num_gen_logits = num_gens * (runtime_draft_len + 1) + gen_logits = logits[num_contexts:num_contexts + num_gen_logits] + gen_start = num_contexts + gen_end = num_contexts + num_gen_logits + + temperatures = spec_metadata.temperatures[gen_start:gen_end] + # Pass None instead of an all-disabled tensor so the C++ op can short-circuit + # on a host-side check rather than a `.item()` sync, which would break + # CUDA graph capture. + top_ks = (None if spec_metadata.skip_top_k else + spec_metadata.top_ks[gen_start:gen_end]) + top_ps = (None if spec_metadata.skip_top_p else + spec_metadata.top_ps[gen_start:gen_end]) + + target_probs_flat = compute_probs_from_logits( + gen_logits, temperatures, top_ks, top_ps) + target_probs = target_probs_flat.reshape(num_gens, + runtime_draft_len + 1, + vocab_size) + + draft_vocab_size = draft_probs.shape[-1] + assert draft_probs.shape[0] == num_gens, ( + f"draft_probs batch mismatch: {draft_probs.shape[0]} != " + f"num_gens={num_gens}") + assert draft_probs.shape[1] == runtime_draft_len, ( + f"draft_probs draft length mismatch: {draft_probs.shape[1]} != " + f"{runtime_draft_len}") + d2t = getattr(spec_metadata, "d2t", None) + if draft_vocab_size != vocab_size: + # Use the pre-allocated buffer from spec_metadata.prepare() + # (zero-filled once at init; untouched positions stay 0). + # Falls back to per-iter allocation if the buffer is not + # configured, e.g. when use_rejection_sampling was off at + # prepare() time. + if spec_metadata.full_draft_probs is not None: + full_draft_probs = spec_metadata.full_draft_probs[:num_gens] + else: + full_draft_probs = torch.zeros( + (num_gens, runtime_draft_len, vocab_size), + dtype=torch.float32, + device=device) + if d2t is not None: + assert d2t.numel() == draft_vocab_size, ( + f"d2t size mismatch: {d2t.numel()} != {draft_vocab_size}" + ) + # d2t is model-static; compute target_indices once and + # cache on spec_metadata to skip the arange + add + mod + # kernel sequence on every iter. + target_indices = spec_metadata.d2t_target_indices + if target_indices is None: + source_indices = torch.arange(draft_vocab_size, + device=device, + dtype=torch.long) + target_indices = (source_indices + + d2t.to(device=device)) % vocab_size + spec_metadata.d2t_target_indices = target_indices + full_draft_probs[:, :runtime_draft_len, + target_indices] = draft_probs + else: + assert draft_vocab_size < vocab_size + full_draft_probs[:, :runtime_draft_len, : + draft_vocab_size] = (draft_probs) else: - assert draft_vocab_size < vocab_size - full_draft_probs[:, :runtime_draft_len, :draft_vocab_size] = ( - draft_probs) - else: - full_draft_probs = draft_probs - - full_draft_tokens = draft_tokens.to(torch.int32).contiguous() - - if self.seed is None: - self.seed = torch.tensor([0], dtype=torch.int64, device=device) - if self.offset is None: - self.offset = torch.tensor([0], dtype=torch.int64, device=device) - self.seed += 1 - self.seed %= 2**31 + full_draft_probs = draft_probs + + full_draft_tokens = draft_tokens.to(torch.int32).contiguous() + + if self.seed is None: + self.seed = torch.tensor([0], dtype=torch.int64, device=device) + if self.offset is None: + self.offset = torch.tensor([0], + dtype=torch.int64, + device=device) + self.seed += 1 + self.seed %= 2**31 + + gen_accepted, gen_num_accepted = rejection_sampling_one_model( + draft_probs=full_draft_probs, + draft_token_ids=full_draft_tokens, + target_probs=target_probs, + deterministic=True, + seed=self.seed, + offset=self.offset, + ) - accepted_tokens, num_accepted_tokens = rejection_sampling_one_model( - draft_probs=full_draft_probs, - draft_token_ids=full_draft_tokens, - target_probs=target_probs, - deterministic=True, - seed=self.seed, - offset=self.offset, - ) + accepted_tokens[num_contexts:] = gen_accepted + num_accepted_tokens[num_contexts:] = gen_num_accepted num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, 0, draft_tokens.shape[1]) + num_accepted_tokens, num_contexts, runtime_draft_len) return accepted_tokens, num_accepted_tokens def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None): @@ -1136,47 +1350,127 @@ def _draft_sampler_greedy(self, logits: torch.Tensor, d2t=None): return draft_tokens.type(torch.int32) - def _compute_and_store_draft_probs( + def _draft_sampler_advanced( self, - draft_logits_list: List[torch.Tensor], - spec_metadata: SpecMetadata, + logits: torch.Tensor, + spec_metadata: "SpecMetadata", + batch_size: int, + d2t: Optional[torch.Tensor] = None, + ): + """ + Draft token sampling using per-request sampling parameters from the + target's sampling config. Falls back to argmax when the batch is + all-greedy. + + Args: + logits: [batch_size, vocab_size] - Draft model logits (one row per + request, since each draft step emits one token per request). + spec_metadata: Source of per-request temperatures / top_k / top_p + tensors populated by populate_sampling_params_for_one_model. + batch_size: Number of active requests in the batch. + d2t: Optional dictionary offset tensor for vocab mapping. + + Returns: + draft_tokens: [batch_size] - Sampled draft token ids (int32) + """ + if spec_metadata.is_all_greedy_sample: + return self._draft_sampler_greedy(logits, d2t) + + temperatures = spec_metadata.request_temperatures[:batch_size] + top_ks = spec_metadata.request_top_ks[:batch_size] + top_ps = spec_metadata.request_top_ps[:batch_size] + + if self.use_flashinfer: + top_ks = top_ks.clamp(min=1, max=logits.shape[-1] - 1) + if self.seed is None: + self.seed = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.offset = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.seed += 1 + self.seed %= (2**31) + + draft_tokens = sampling_batch_spec_dec_one_model( + logits, + temperatures, + top_ks, + top_ps, + use_flashinfer=self.use_flashinfer, + seed=self.seed, + offset=self.offset) + + if d2t is not None: + draft_tokens = d2t[draft_tokens] + draft_tokens + + return draft_tokens.type(torch.int32) + + def _draft_sampler_advanced_for_rejection( + self, + logits: torch.Tensor, + spec_metadata: "SpecMetadata", batch_size: int, + d2t: Optional[torch.Tensor] = None, + draft_step: int = 0, ): """ - Compute draft probabilities and store them for next-step rejection sampling. + Rejection-sampling-aware variant of ``_draft_sampler_advanced``. + + Single-pass compute + sample + scatter: computes the per-request prob + distribution once via TRT-LLM's fused ``compute_probs_from_logits`` + (temp + top_k + top_p + softmax + greedy override in one CUDA kernel), + samples the draft token from that distribution, and scatters the same + probs into the slot-indexed ``spec_metadata.draft_probs`` buffer for + next-iter rejection verification. Replaces the previous two-stage path + (flashinfer fused sampling kernel + a redundant softmax pass to store + probs). + + All-greedy batches take the cheaper argmax path — + ``_can_use_rejection_sampling`` will bypass rejection for those anyway. """ - draft_tokens_per_request = len(draft_logits_list) - vocab_size = draft_logits_list[0].shape[-1] - device = draft_logits_list[0].device - - draft_logits = torch.stack(draft_logits_list, dim=0) - draft_logits_flat = draft_logits.transpose(0, 1).reshape(-1, vocab_size) - - num_draft_tokens = batch_size * draft_tokens_per_request - if spec_metadata.request_temperatures is not None: - draft_temps = spec_metadata.request_temperatures[:batch_size].repeat_interleave( - draft_tokens_per_request) - draft_top_ks = ( - spec_metadata.request_top_ks[:batch_size].repeat_interleave( - draft_tokens_per_request) if not spec_metadata.skip_top_k - and spec_metadata.request_top_ks is not None else None) - draft_top_ps = ( - spec_metadata.request_top_ps[:batch_size].repeat_interleave( - draft_tokens_per_request) if not spec_metadata.skip_top_p - and spec_metadata.request_top_ps is not None else None) - else: - draft_temps = torch.ones(num_draft_tokens, device=device) - draft_top_ks = None - draft_top_ps = None - - draft_probs_flat = compute_probs_from_logits(draft_logits_flat, - draft_temps, draft_top_ks, - draft_top_ps) - num_elements = batch_size * draft_tokens_per_request * vocab_size - spec_metadata.draft_probs[:num_elements].copy_( - draft_probs_flat.flatten()) - spec_metadata.draft_probs_last_dim = vocab_size - spec_metadata.draft_probs_valid = True + if spec_metadata.is_all_greedy_sample: + return self._draft_sampler_greedy(logits, d2t) + + temperatures = spec_metadata.request_temperatures[:batch_size] + top_ks = spec_metadata.request_top_ks[:batch_size] + top_ps = spec_metadata.request_top_ps[:batch_size] + + if self.seed is None: + self.seed = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.offset = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.seed += 1 + self.seed %= (2**31) + + draft_tokens, probs = sampling_batch_spec_dec_one_model_for_rejection( + logits, + temperatures, + top_ks, + top_ps, + seed=self.seed, + offset=self.offset, + ) + + # Scatter probs into the slot-indexed buffer (shaped + # [max_num_requests, max_draft_len, vocab_size]). Each request's data + # always lands at its stable py_seq_slot row regardless of batch + # composition shifts across iterations. + assert spec_metadata.batch_slot_ids is not None, ( + "batch_slot_ids must be populated by " + "populate_sampling_params_for_one_model before draft probs storage") + batch_slots = spec_metadata.batch_slot_ids[:batch_size] + vocab = probs.shape[-1] + spec_metadata.draft_probs[batch_slots, draft_step, :vocab] = probs + spec_metadata.draft_probs_last_dim = vocab + + if d2t is not None: + draft_tokens = d2t[draft_tokens] + draft_tokens + + return draft_tokens.type(torch.int32) def _execute_guided_decoder_if_present(self, logits): """Execute guided decoder on target model logits if available.""" @@ -1307,10 +1601,12 @@ def _sample_tokens_for_batch( Returns: sampled_tokens: [num_tokens] - Sampled token ids """ - if spec_metadata.allow_advanced_sampling: - num_gens = batch_size - num_contexts - num_tokens = num_contexts + num_gens * ( - spec_metadata.runtime_draft_len + 1) + if not spec_metadata.is_all_greedy_sample: + # Use logits.shape[0] directly: for PARD under CUDA graph capture + # runtime_draft_len may reflect the PARD-max while the captured + # graph was built for a shorter draft_len, causing a shape mismatch + # in sampling_batch_spec_dec_one_model (which is torch.compiled). + num_tokens = logits.shape[0] temperatures = spec_metadata.temperatures[:num_tokens] top_ks = spec_metadata.top_ks[:num_tokens] diff --git a/tensorrt_llm/_torch/speculative/one_model_sampler.py b/tensorrt_llm/_torch/speculative/one_model_sampler.py index 7e3b06b383cc..6734b5e9f79a 100644 --- a/tensorrt_llm/_torch/speculative/one_model_sampler.py +++ b/tensorrt_llm/_torch/speculative/one_model_sampler.py @@ -5,9 +5,20 @@ from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE if IS_FLASHINFER_AVAILABLE: - from flashinfer.sampling import chain_speculative_sampling, top_k_top_p_sampling_from_logits + from flashinfer.sampling import ( + chain_speculative_sampling, + sampling_from_probs, + top_k_mask_logits, + top_k_top_p_sampling_from_logits, + top_p_renorm_probs, + ) + from flashinfer.sampling import softmax as flashinfer_softmax else: chain_speculative_sampling = None + sampling_from_probs = None + flashinfer_softmax = None + top_k_mask_logits = None + top_p_renorm_probs = None top_k_top_p_sampling_from_logits = None @@ -114,9 +125,29 @@ def compute_probs_from_logits( skip_temperature: bool = False, ) -> torch.Tensor: """ - Compute probabilities from logits with temperature, top-k, and top-p applied. + Compute filtered + normalized probs from logits (temperature + top_k + + top_p + softmax). Picks the fastest path for the input device: + + 1. CUDA + flashinfer: ``top_k_mask_logits`` → fused ``softmax+temp`` → + ``top_p_renorm_probs`` (all O(N) radix). ``skip_temperature`` ignored. + 2. CUDA, no flashinfer: ``compute_probs_from_logits_op`` (sort-based, + O(N log N)). + 3. CPU: manual PyTorch fallback. """ + if logits.is_cuda and IS_FLASHINFER_AVAILABLE: + # Fast path: flashinfer composition (O(N) per row, friendly to small + # batch sizes). skip_temperature is ignored — flashinfer's softmax + # always applies the temperature tensor. + if top_k is not None: + logits = top_k_mask_logits(logits, top_k) + probs = flashinfer_softmax(logits, temperatures) + if top_p is not None: + probs = top_p_renorm_probs(probs, top_p) + return probs + if logits.is_cuda: + # CUDA without flashinfer: fall back to the C++ op (slower sort-based + # top-k path, but works without flashinfer). return torch.ops.trtllm.compute_probs_from_logits_op( logits, temperatures, top_k, top_p, skip_temperature ) @@ -125,7 +156,6 @@ def compute_probs_from_logits( logits = apply_temperature(logits, temperatures) logits = apply_top_k_top_p(logits, top_k, top_p) probs = logits.softmax(dim=-1, dtype=torch.float32) - # Greedy rows should remain exactly one-hot so rejection sampling does not # spuriously reject numerically-near argmax tokens. greedy_temp_threshold = 1e-4 @@ -135,6 +165,29 @@ def compute_probs_from_logits( return torch.where(is_greedy.unsqueeze(1), one_hot, probs) +def sampling_batch_spec_dec_one_model_for_rejection( + logits: torch.Tensor, + temperatures: torch.Tensor, + top_k: torch.Tensor, + top_p: torch.Tensor, + seed: Optional[torch.Tensor] = None, + offset: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Rejection-sampling-aware draft sampler: returns BOTH the sampled tokens + AND the prob distribution they were sampled from, so the downstream + rejection-sampling path can reuse the probs without a second softmax + + temp/top_k/top_p pass. + """ + if sampling_from_probs is None: + raise RuntimeError( + "Rejection sampling for one-model speculative decoding requires flashinfer" + ) + probs = compute_probs_from_logits(logits, temperatures, top_k, top_p) + tokens = sampling_from_probs(probs, deterministic=True, seed=seed, offset=offset) + return tokens, probs + + def rejection_sampling_one_model( draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 9c4284878b06..91f60243834c 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -58,7 +58,6 @@ def get_spec_metadata(spec_config, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, - allow_advanced_sampling=spec_config.allow_advanced_sampling, use_rejection_sampling=use_rejection_sampling, vocab_size=vocab_size, spec_resource_manager=spec_resource_manager, @@ -71,7 +70,6 @@ def get_spec_metadata(spec_config, mtp_num_modules=spec_config.max_draft_len, max_num_requests=max_num_requests, mtp_hidden_states_manager=spec_resource_manager, - allow_advanced_sampling=spec_config.allow_advanced_sampling, ) if spec_config.spec_dec_mode.is_mtp_eagle(): return Eagle3SpecMetadata( @@ -117,7 +115,6 @@ def get_spec_metadata(spec_config, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, layers_to_capture=spec_config.eagle3_layers_to_capture, - allow_advanced_sampling=spec_config.allow_advanced_sampling, use_rejection_sampling=use_rejection_sampling, vocab_size=vocab_size, spec_resource_manager=spec_resource_manager, @@ -130,7 +127,6 @@ def get_spec_metadata(spec_config, max_total_draft_tokens=spec_config.tokens_per_gen_step - 1, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, - allow_advanced_sampling=spec_config.allow_advanced_sampling, spec_resource_manager=spec_resource_manager, ) if spec_config.spec_dec_mode.is_dflash(): @@ -140,7 +136,6 @@ def get_spec_metadata(spec_config, max_total_draft_tokens=spec_config.tokens_per_gen_step - 1, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, - allow_advanced_sampling=spec_config.allow_advanced_sampling, layers_to_capture=target_layer_ids, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, @@ -153,7 +148,6 @@ def get_spec_metadata(spec_config, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, max_num_tokens=max_num_tokens, - allow_advanced_sampling=spec_config.allow_advanced_sampling, ) if spec_config.spec_dec_mode.is_save_hidden_states(): return SaveHiddenStatesSpecMetadata( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6425cb7d2f2d..9b317050a81d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1120,21 +1120,23 @@ class DecodingBaseConfig(StrictBaseModel): "rolling average over the last N completed requests (N = acceptance_window) drops below this value. " "PyTorch backend only.") - allow_advanced_sampling: bool = Field( + use_rejection_sampling: bool = Field( default=False, status="prototype", description= - "If true, allows non-greedy sampling when speculation is used. Only applicable " - "to 1-model code paths; non-greedy sampling is always enabled on 2-model paths." - ) + "If true, enables rejection sampling for one-model speculative decoding " + "paths when the batch contains any non-greedy request. All-greedy batches " + "always take the argmax fast path regardless of this flag. Set to false " + "(default) to use exact-match verification on non-greedy batches. " + "The non-dynamic-tree one-model path requires FlashInfer.") - use_rejection_sampling: bool = Field( + allow_advanced_sampling: bool = Field( default=False, - status="prototype", + status="deprecated", description= - "If true, enables rejection sampling for one-model speculative decoding paths. " - "This is intended for non-greedy sampling configurations on the PyTorch backend. " - "The non-dynamic-tree one-model path requires FlashInfer.") + "DEPRECATED: no-op kept for backward compatibility. Will be removed " + "in a future release. Non-greedy sampling is now auto-detected per " + "request; this flag no longer has any effect.") # If set, drafting is allowed to use chain drafter. _allow_chain_drafter: bool = PrivateAttr(True) @@ -1190,15 +1192,33 @@ def validate_draft_len_schedule_and_sort(cls, v, info): @model_validator(mode='after') def validate_rejection_sampling_config(self): - """Reject SA-enhanced configurations that invalidate rejection sampling.""" + """Disable rejection sampling when SA-enhanced configurations are + active, since SA may override the proposed draft tokens. This is a + silent fallback so the new default (True) does not break sa_config + users. + """ if self.use_rejection_sampling and getattr(self, 'sa_config', None) is not None: - raise ValueError( - "use_rejection_sampling is incompatible with sa_config " - "because SA enhancement may override the proposed draft tokens." - ) + self.use_rejection_sampling = False return self + @model_validator(mode='before') + @classmethod + def _warn_deprecated_allow_advanced_sampling(cls, data): + """Warn when users set the deprecated allow_advanced_sampling flag. + + Non-greedy sampling is now auto-detected per request and always + available, so the flag is a no-op; warn loudly so callers update + their configs before the flag is removed. + """ + if isinstance(data, dict) and 'allow_advanced_sampling' in data: + logger.warning( + "DecodingBaseConfig: 'allow_advanced_sampling' is deprecated " + "and will be removed in a future release. The flag has no " + "effect — non-greedy sampling is now auto-detected per " + "request.") + return data + @model_validator(mode='after') # 1. Validate that max_concurrency and draft_len_schedule are mutually exclusive. # 2. If max_concurrency is set, translate it to the corresponding draft_len_schedule. @@ -4535,12 +4555,17 @@ def validate_speculative_config(self): exclude={"decoding_type"}) self.speculative_config = Eagle3DecodingConfig(**eagle_data) - if self.speculative_config.use_rejection_sampling: - if not isinstance(self.speculative_config, - Eagle3DecodingConfig): - raise ValueError( - "use_rejection_sampling is only supported for " - "PyTorch Eagle3 one-model speculative decoding paths.") + if self.speculative_config.use_rejection_sampling and not isinstance( + self.speculative_config, Eagle3DecodingConfig): + # Rejection sampling is only wired up for Eagle3 one-model paths. + # Silently fall back for other spec types so the new default + # (True) does not break them. + # TODO: extend rejection sampling to the remaining speculative + # decoding paths (MTP / DraftTarget / PARD / DFlash / + # SaveHiddenStates / SA) and unify the dispatch in SpecMetadata + # so new spec algorithms get rejection sampling for free; once + # all paths are covered this whitelist guard can be removed. + self.speculative_config.use_rejection_sampling = False if isinstance(self.speculative_config, PARDDecodingConfig): assert self.speculative_config.max_draft_len > 0, "PARD max_draft_len must be > 0" diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 41e3d93944ea..7a72ffebc88f 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -201,7 +201,6 @@ def test_eagle3_rejection_dynamic_tree_smoke(self, use_dynamic_tree, max_draft_len=4, speculative_model=eagle_model_dir, eagle3_one_model=True, - allow_advanced_sampling=True, use_rejection_sampling=True, ) max_batch_size = 1 @@ -5435,8 +5434,7 @@ def test_eagle3_4gpus(self, v2_kv_cache, moe_backend, one_model, draft_len = 3 spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -5501,8 +5499,7 @@ def test_eagle3_vswa_reuse_4gpus(self, one_model, mocker): draft_len = 3 spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, @@ -5565,8 +5562,7 @@ def test_eagle3_guided_decoding_4gpus(self, one_model, mocker): draft_len = 3 spec_config = Eagle3DecodingConfig(max_draft_len=draft_len, speculative_model=eagle_model_dir, - eagle3_one_model=one_model, - allow_advanced_sampling=True) + eagle3_one_model=one_model) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, diff --git a/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml b/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml index 72a1df6265e7..ed1cf2667329 100644 --- a/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml +++ b/tests/integration/defs/examples/serve/test_configs/Nemotron3_Super_120B_NVFP4.yml @@ -23,4 +23,3 @@ print_iter_log: true speculative_config: decoding_type: MTP num_nextn_predict_layers: 3 - allow_advanced_sampling: true diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index 66d1dd334002..716dc6ef46df 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -526,7 +526,6 @@ def get_model_yaml_config(model_label: str, 'speculative_config': { 'decoding_type': 'MTP', 'num_nextn_predict_layers': 3, - 'allow_advanced_sampling': True, }, } }, diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 8a393f21ad35..29459e4903dc 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -864,7 +864,6 @@ def test_llama_eagle3_rejection_sampling_modes(use_dynamic_tree: bool, max_draft_len=max_draft_len, speculative_model=eagle_model, eagle3_one_model=True, - allow_advanced_sampling=True, use_rejection_sampling=True, ) if use_dynamic_tree: