@@ -436,8 +436,14 @@ class SpecMetadata:
436436 # Always set by model_engine.forward() before any downstream code reads it.
437437 runtime_draft_len : int = 0
438438
439- # For non-greedy sampling on 1-model.
440- allow_advanced_sampling : bool = False
439+ # Auto-detected per step from populated sampling params:
440+ # True if every request is greedy (no temp/top_k/top_p) and we can take
441+ # the argmax fast-path. False if any request needs sampling.
442+ # Used as part of the CUDA graph key so we capture two variants
443+ # (greedy fast-path vs advanced sampling) and dispatch at replay.
444+ # Defaults to True so non-one-engine paths (where populate is a no-op)
445+ # never accidentally select the advanced graph variant.
446+ is_all_greedy_sample : bool = True
441447 # Whether to use rejection sampling for one-model speculative decoding.
442448 use_rejection_sampling : bool = False
443449 # Sampling parameters for non-greedy sampling (per-request)
@@ -515,29 +521,21 @@ def populate_sampling_params_for_one_model(
515521 self , requests : list ["LlmRequest" ]) -> None :
516522 """
517523 Set up topp/topk/temperatures for 1-model sampler.
524+
525+ Scans sampling configs to set skip_*/is_all_greedy_sample flags. When
526+ any request needs sampling, also builds per-token/per-request lists
527+ and copies them to GPU buffers; all-greedy batches skip this entirely.
518528 """
519529 from tensorrt_llm ._torch .pyexecutor .llm_request import LlmRequestState
520530 from tensorrt_llm .sampling_params import SamplingParams
521531
522- if not self .allow_advanced_sampling or not self .spec_dec_mode .use_one_engine (
523- ):
532+ if not self .spec_dec_mode .use_one_engine ():
524533 return
525534
526535 if self .temperatures is None :
527536 # Ensures determinism across ranks.
528537 torch .manual_seed (0 )
529538
530- temperatures = []
531- top_ks = []
532- top_ps = []
533- request_temperatures = []
534- request_top_ks = []
535- request_top_ps = []
536- top_k_enabled = False
537- top_p_enabled = False
538- has_greedy_requests = False
539- temperature_enabled = False
540-
541539 # Need to use a very small value for temperature when disabled to avoid division by 0
542540 DISABLE_TEMP_VAL = 1e-5
543541 # Very large values disable topk.
@@ -583,6 +581,13 @@ def _normalize_request_sampling_params(
583581 is_greedy ,
584582 )
585583
584+ # Phase 1: collect per-request flags and normalized values.
585+ per_request_normalized : list [tuple [float , int , float , int ]] = []
586+ temperature_enabled = False
587+ top_k_enabled = False
588+ top_p_enabled = False
589+ has_greedy_requests = False
590+
586591 for request in requests :
587592 sampling_config = request .sampling_config
588593 temp_val = _first_or_none (sampling_config .temperature )
@@ -611,19 +616,24 @@ def _normalize_request_sampling_params(
611616 top_p_enabled |= use_top_p
612617 has_greedy_requests |= is_greedy
613618
614- request_temperatures .append (temp_val )
615- request_top_ks .append (tk_val )
616- request_top_ps .append (tp_val )
617- temperatures .extend (temp_val for _ in range (num_tokens ))
618- top_ks .extend (tk_val for _ in range (num_tokens ))
619- top_ps .extend (tp_val for _ in range (num_tokens ))
619+ per_request_normalized .append (
620+ (temp_val , tk_val , tp_val , num_tokens ))
621+
622+ self .skip_temperature = not temperature_enabled
623+ self .skip_top_k = not top_k_enabled
624+ self .skip_top_p = not top_p_enabled
625+ self .has_greedy_requests = has_greedy_requests
626+ # Used in the CUDA graph key to pick the argmax / advanced variant.
627+ self .is_all_greedy_sample = (self .skip_temperature and self .skip_top_k
628+ and self .skip_top_p )
620629
621630 tokens_per_request = (self .max_total_draft_tokens + 1 if
622631 self .is_spec_dec_tree else self .max_draft_len + 1 )
623632 required_flat_size = tokens_per_request * self .max_num_requests
624633
625634 if self .temperatures is None or self .temperatures .numel (
626635 ) < required_flat_size :
636+ # Allocate once; the captured graph reads from these stable addresses.
627637 self .temperatures = torch .ones (required_flat_size ,
628638 dtype = torch .float32 ,
629639 device = 'cuda' )
@@ -643,6 +653,27 @@ def _normalize_request_sampling_params(
643653 dtype = torch .float32 ,
644654 device = 'cuda' )
645655
656+ # All-greedy: sampler takes the argmax branch (and rejection sampling
657+ # is also bypassed for all-greedy), so the buffers are never read.
658+ # Skip the H->D copies.
659+ if self .is_all_greedy_sample :
660+ return
661+
662+ # Phase 2: build per-token / per-request lists and copy to GPU.
663+ temperatures : list [float ] = []
664+ top_ks : list [int ] = []
665+ top_ps : list [float ] = []
666+ request_temperatures : list [float ] = []
667+ request_top_ks : list [int ] = []
668+ request_top_ps : list [float ] = []
669+ for temp_val , tk_val , tp_val , num_tokens in per_request_normalized :
670+ request_temperatures .append (temp_val )
671+ request_top_ks .append (tk_val )
672+ request_top_ps .append (tp_val )
673+ temperatures .extend (temp_val for _ in range (num_tokens ))
674+ top_ks .extend (tk_val for _ in range (num_tokens ))
675+ top_ps .extend (tp_val for _ in range (num_tokens ))
676+
646677 self .temperatures [:len (temperatures )].copy_ (torch .tensor (
647678 temperatures , dtype = torch .float32 , pin_memory = prefer_pinned ()),
648679 non_blocking = True )
@@ -669,10 +700,6 @@ def _normalize_request_sampling_params(
669700 pin_memory = prefer_pinned ()),
670701 non_blocking = True ,
671702 )
672- self .skip_temperature = not temperature_enabled
673- self .skip_top_k = not top_k_enabled
674- self .skip_top_p = not top_p_enabled
675- self .has_greedy_requests = has_greedy_requests
676703
677704
678705class SpecWorkerBase (nn .Module , ABC ):
@@ -1004,8 +1031,11 @@ def _accept_draft_tokens(self, logits, draft_tokens, num_contexts,
10041031
10051032 def _can_use_rejection_sampling (self , spec_metadata : SpecMetadata ,
10061033 num_contexts : int ) -> bool :
1034+ # Skip rejection sampling when the whole batch is greedy: the
1035+ # accepted result is identical to argmax and the base path is cheaper.
10071036 return (spec_metadata .use_rejection_sampling
1008- and spec_metadata .draft_probs_valid and num_contexts == 0 )
1037+ and spec_metadata .draft_probs_valid and num_contexts == 0
1038+ and not spec_metadata .is_all_greedy_sample )
10091039
10101040 def _sample_and_accept_draft_tokens_rejection (
10111041 self ,
@@ -1282,7 +1312,7 @@ def _sample_tokens_for_batch(
12821312 Returns:
12831313 sampled_tokens: [num_tokens] - Sampled token ids
12841314 """
1285- if spec_metadata .allow_advanced_sampling :
1315+ if not spec_metadata .is_all_greedy_sample :
12861316 num_gens = batch_size - num_contexts
12871317 num_tokens = num_contexts + num_gens * (
12881318 spec_metadata .runtime_draft_len + 1 )
0 commit comments