Skip to content

Commit 775bae6

Browse files
[TRTLLM-12669][chore] Address review feedback
- llm_args: keep allow_advanced_sampling as a deprecated no-op field with a logger warning when explicitly set, so removing it isn't an abrupt API break - llm_args: add TODO above the Eagle3-only rejection-sampling whitelist to track extending support to MTP / DraftTarget / PARD / DFlash / SaveHiddenStates / SA and unifying the dispatch in SpecMetadata - cuda_graph_runner: type spec_metadata as Optional[SpecMetadata] instead of Optional[Any] - model_engine: always initialize self.spec_metadata = None so the capture-pass can access it directly without a getattr() fallback - eagle3.draft_decoder: drop the dead Optional/_draft_sampler_greedy fallback; spec_metadata and batch_size are always passed by the sole caller Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent df3dc51 commit 775bae6

4 files changed

Lines changed: 46 additions & 19 deletions

File tree

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..memory_buffer_utils import get_memory_buffers
1616
from ..modules.multi_stream_utils import with_multi_stream
1717
from ..speculative.eagle3 import Eagle3ResourceManager
18+
from ..speculative.interface import SpecMetadata
1819
from ..speculative.spec_sampler_base import SampleStateTensorsSpec
1920
from ..speculative.utils import get_draft_kv_cache_manager
2021
from ..utils import make_weak_ref, piecewise_cuda_graph
@@ -202,7 +203,7 @@ def get_graph_key(
202203
batch: ScheduledRequests,
203204
new_tensors_device: Optional[SampleStateTensors] = None,
204205
spec_resource_manager: Optional[BaseResourceManager] = None,
205-
spec_metadata: Optional[Any] = None):
206+
spec_metadata: Optional[SpecMetadata] = None):
206207
batch_size = batch.batch_size
207208

208209
# Get the sequence length mode.
@@ -244,7 +245,7 @@ def maybe_get_cuda_graph(
244245
batch: ScheduledRequests,
245246
enable_spec_decode: bool,
246247
attn_metadata: Any,
247-
spec_metadata: Optional[Any] = None,
248+
spec_metadata: Optional[SpecMetadata] = None,
248249
draft_tokens_cuda: Optional[torch.Tensor] = None,
249250
new_tensors_device: Optional[SampleStateTensors] = None,
250251
spec_resource_manager: Optional[BaseResourceManager] = None,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def __init__(
411411
sparse_attn_config=self.sparse_attention_config)
412412

413413
if self.is_spec_decode:
414-
self.spec_metadata = None
415414
update_spec_config_from_model_config(self.spec_config,
416415
self.model.config)
417416
max_num_draft_tokens = self.max_draft_loop_tokens * self.batch_size
@@ -464,6 +463,7 @@ def __init__(
464463
# NOTE: This can be simplified by decoupling the model config loading and
465464
# the model engine.
466465
self.attn_metadata = None
466+
self.spec_metadata = None
467467
self.iter_states = {}
468468
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
469469

@@ -1072,7 +1072,7 @@ def _capture_generation_cuda_graphs(self,
10721072
max_seq_len_list = [effective_max_seq_len]
10731073

10741074
def _run_capture_pass(force_non_greedy: bool, label: str) -> None:
1075-
spec_metadata = getattr(self, 'spec_metadata', None)
1075+
spec_metadata = self.spec_metadata
10761076
if force_non_greedy and spec_metadata is not None:
10771077
spec_metadata._force_non_greedy_for_capture = True
10781078
# maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -798,14 +798,13 @@ def draft_decoder(
798798
self,
799799
logits: torch.Tensor,
800800
draft_model: nn.Module,
801-
spec_metadata: Optional[Eagle3OneModelSpecMetadata] = None,
802-
batch_size: Optional[int] = None,
801+
spec_metadata: Eagle3OneModelSpecMetadata,
802+
batch_size: int,
803803
draft_step: Optional[int] = None,
804804
):
805805
'''
806-
Sample draft tokens. When spec_metadata + batch_size are provided, use
807-
the target's per-request sampling params (temperature/top_k/top_p);
808-
otherwise fall back to argmax.
806+
Sample draft tokens using the target's per-request sampling params
807+
(temperature/top_k/top_p).
809808
810809
When rejection sampling is enabled and draft_step is provided, take the
811810
single-pass path that also scatters the draft prob distribution into the
@@ -814,23 +813,20 @@ def draft_decoder(
814813
Args:
815814
logits: [batch_size, vocab_size] - Draft model logits.
816815
draft_model: The draft model.
817-
spec_metadata: Carries per-request sampling param tensors. When
818-
None, sampling is forced greedy.
816+
spec_metadata: Carries per-request sampling param tensors.
819817
batch_size: Active requests, used to slice per-request tensors.
820818
draft_step: Current draft step index (0..max_draft_len-1). Required
821819
for the rejection-sampling code path so probs are written to
822820
the correct slice of spec_metadata.draft_probs.
823821
'''
824822

825823
d2t = getattr(draft_model.model, "d2t", None)
826-
if spec_metadata is not None and batch_size is not None:
827-
if (spec_metadata.use_rejection_sampling and draft_step is not None
828-
and not spec_metadata.is_all_greedy_sample):
829-
return self._draft_sampler_advanced_for_rejection(
830-
logits, spec_metadata, batch_size, d2t, draft_step)
831-
return self._draft_sampler_advanced(logits, spec_metadata,
832-
batch_size, d2t)
833-
return self._draft_sampler_greedy(logits, d2t)
824+
if (spec_metadata.use_rejection_sampling and draft_step is not None
825+
and not spec_metadata.is_all_greedy_sample):
826+
return self._draft_sampler_advanced_for_rejection(
827+
logits, spec_metadata, batch_size, d2t, draft_step)
828+
return self._draft_sampler_advanced(logits, spec_metadata, batch_size,
829+
d2t)
834830

835831
def prepare_1st_drafter_inputs(
836832
self,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,14 @@ class DecodingBaseConfig(StrictBaseModel):
906906
"false to fall back to exact-match verification on non-greedy batches. "
907907
"The non-dynamic-tree one-model path requires FlashInfer.")
908908

909+
allow_advanced_sampling: bool = Field(
910+
default=False,
911+
status="deprecated",
912+
description=
913+
"DEPRECATED: no-op kept for backward compatibility. Will be removed "
914+
"in a future release. Non-greedy sampling is now auto-detected per "
915+
"request; this flag no longer has any effect.")
916+
909917
# If set, drafting is allowed to use chain drafter.
910918
_allow_chain_drafter: bool = PrivateAttr(True)
911919
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
@@ -970,6 +978,23 @@ def validate_rejection_sampling_config(self):
970978
self.use_rejection_sampling = False
971979
return self
972980

981+
@model_validator(mode='before')
982+
@classmethod
983+
def _warn_deprecated_allow_advanced_sampling(cls, data):
984+
"""Warn when users set the deprecated allow_advanced_sampling flag.
985+
986+
Non-greedy sampling is now auto-detected per request and always
987+
available, so the flag is a no-op; warn loudly so callers update
988+
their configs before the flag is removed.
989+
"""
990+
if isinstance(data, dict) and 'allow_advanced_sampling' in data:
991+
logger.warning(
992+
"DecodingBaseConfig: 'allow_advanced_sampling' is deprecated "
993+
"and will be removed in a future release. The flag has no "
994+
"effect — non-greedy sampling is now auto-detected per "
995+
"request.")
996+
return data
997+
973998
@model_validator(mode='after')
974999
# 1. Validate that max_concurrency and draft_len_schedule are mutually exclusive.
9751000
# 2. If max_concurrency is set, translate it to the corresponding draft_len_schedule.
@@ -4148,6 +4173,11 @@ def validate_speculative_config(self):
41484173
# Rejection sampling is only wired up for Eagle3 one-model paths.
41494174
# Silently fall back for other spec types so the new default
41504175
# (True) does not break them.
4176+
# TODO: extend rejection sampling to the remaining speculative
4177+
# decoding paths (MTP / DraftTarget / PARD / DFlash /
4178+
# SaveHiddenStates / SA) and unify the dispatch in SpecMetadata
4179+
# so new spec algorithms get rejection sampling for free; once
4180+
# all paths are covered this whitelist guard can be removed.
41514181
self.speculative_config.use_rejection_sampling = False
41524182

41534183
if isinstance(self.speculative_config, PARDDecodingConfig):

0 commit comments

Comments
 (0)