Skip to content

Commit d6fd852

Browse files
[TRTLLM-12669][chore] Address review feedback
- llm_args: keep allow_advanced_sampling as a deprecated no-op field with a logger warning when explicitly set, so removing it isn't an abrupt API break - llm_args: add TODO above the Eagle3-only rejection-sampling whitelist to track extending support to MTP / DraftTarget / PARD / DFlash / SaveHiddenStates / SA and unifying the dispatch in SpecMetadata - cuda_graph_runner: type spec_metadata as Optional[SpecMetadata] instead of Optional[Any] - model_engine: always initialize self.spec_metadata = None so the capture-pass can access it directly without a getattr() fallback - eagle3.draft_decoder: drop the dead Optional/_draft_sampler_greedy fallback; spec_metadata and batch_size are always passed by the sole caller Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 53b4415 commit d6fd852

4 files changed

Lines changed: 46 additions & 19 deletions

File tree

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..memory_buffer_utils import get_memory_buffers
2020
from ..modules.multi_stream_utils import with_multi_stream
2121
from ..speculative.eagle3 import Eagle3ResourceManager
22+
from ..speculative.interface import SpecMetadata
2223
from ..speculative.spec_sampler_base import SampleStateTensorsSpec
2324
from ..speculative.utils import get_draft_kv_cache_manager
2425
from ..utils import make_weak_ref, piecewise_cuda_graph
@@ -206,7 +207,7 @@ def get_graph_key(
206207
batch: ScheduledRequests,
207208
new_tensors_device: Optional[SampleStateTensors] = None,
208209
spec_resource_manager: Optional[BaseResourceManager] = None,
209-
spec_metadata: Optional[Any] = None):
210+
spec_metadata: Optional[SpecMetadata] = None):
210211
batch_size = batch.batch_size
211212

212213
# Get the sequence length mode.
@@ -248,7 +249,7 @@ def maybe_get_cuda_graph(
248249
batch: ScheduledRequests,
249250
enable_spec_decode: bool,
250251
attn_metadata: Any,
251-
spec_metadata: Optional[Any] = None,
252+
spec_metadata: Optional[SpecMetadata] = None,
252253
draft_tokens_cuda: Optional[torch.Tensor] = None,
253254
new_tensors_device: Optional[SampleStateTensors] = None,
254255
spec_resource_manager: Optional[BaseResourceManager] = None,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,6 @@ def __init__(
484484
sparse_attn_config=self.sparse_attention_config)
485485

486486
if self.is_spec_decode:
487-
self.spec_metadata = None
488487
update_spec_config_from_model_config(self.spec_config,
489488
self.model.config)
490489
max_num_draft_tokens = self.max_draft_loop_tokens * self.batch_size
@@ -538,6 +537,7 @@ def __init__(
538537
# the model engine.
539538
self.attn_metadata = None
540539
self.encoder_attn_metadata = None
540+
self.spec_metadata = None
541541
self.iter_states = {}
542542
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
543543

@@ -1214,7 +1214,7 @@ def _capture_generation_cuda_graphs(self,
12141214
max_seq_len_list = [effective_max_seq_len]
12151215

12161216
def _run_capture_pass(force_non_greedy: bool, label: str) -> None:
1217-
spec_metadata = getattr(self, 'spec_metadata', None)
1217+
spec_metadata = self.spec_metadata
12181218
if force_non_greedy and spec_metadata is not None:
12191219
spec_metadata._force_non_greedy_for_capture = True
12201220
# maybe_get_cuda_graph reads spec_metadata.is_all_greedy_sample

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -819,14 +819,13 @@ def draft_decoder(
819819
self,
820820
logits: torch.Tensor,
821821
draft_model: nn.Module,
822-
spec_metadata: Optional[Eagle3OneModelSpecMetadata] = None,
823-
batch_size: Optional[int] = None,
822+
spec_metadata: Eagle3OneModelSpecMetadata,
823+
batch_size: int,
824824
draft_step: Optional[int] = None,
825825
):
826826
'''
827-
Sample draft tokens. When spec_metadata + batch_size are provided, use
828-
the target's per-request sampling params (temperature/top_k/top_p);
829-
otherwise fall back to argmax.
827+
Sample draft tokens using the target's per-request sampling params
828+
(temperature/top_k/top_p).
830829
831830
When rejection sampling is enabled and draft_step is provided, take the
832831
single-pass path that also scatters the draft prob distribution into the
@@ -835,23 +834,20 @@ def draft_decoder(
835834
Args:
836835
logits: [batch_size, vocab_size] - Draft model logits.
837836
draft_model: The draft model.
838-
spec_metadata: Carries per-request sampling param tensors. When
839-
None, sampling is forced greedy.
837+
spec_metadata: Carries per-request sampling param tensors.
840838
batch_size: Active requests, used to slice per-request tensors.
841839
draft_step: Current draft step index (0..max_draft_len-1). Required
842840
for the rejection-sampling code path so probs are written to
843841
the correct slice of spec_metadata.draft_probs.
844842
'''
845843

846844
d2t = getattr(draft_model.model, "d2t", None)
847-
if spec_metadata is not None and batch_size is not None:
848-
if (spec_metadata.use_rejection_sampling and draft_step is not None
849-
and not spec_metadata.is_all_greedy_sample):
850-
return self._draft_sampler_advanced_for_rejection(
851-
logits, spec_metadata, batch_size, d2t, draft_step)
852-
return self._draft_sampler_advanced(logits, spec_metadata,
853-
batch_size, d2t)
854-
return self._draft_sampler_greedy(logits, d2t)
845+
if (spec_metadata.use_rejection_sampling and draft_step is not None
846+
and not spec_metadata.is_all_greedy_sample):
847+
return self._draft_sampler_advanced_for_rejection(
848+
logits, spec_metadata, batch_size, d2t, draft_step)
849+
return self._draft_sampler_advanced(logits, spec_metadata, batch_size,
850+
d2t)
855851

856852
def prepare_1st_drafter_inputs(
857853
self,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,14 @@ class DecodingBaseConfig(StrictBaseModel):
10661066
"false to fall back to exact-match verification on non-greedy batches. "
10671067
"The non-dynamic-tree one-model path requires FlashInfer.")
10681068

1069+
allow_advanced_sampling: bool = Field(
1070+
default=False,
1071+
status="deprecated",
1072+
description=
1073+
"DEPRECATED: no-op kept for backward compatibility. Will be removed "
1074+
"in a future release. Non-greedy sampling is now auto-detected per "
1075+
"request; this flag no longer has any effect.")
1076+
10691077
# If set, drafting is allowed to use chain drafter.
10701078
_allow_chain_drafter: bool = PrivateAttr(True)
10711079
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
@@ -1130,6 +1138,23 @@ def validate_rejection_sampling_config(self):
11301138
self.use_rejection_sampling = False
11311139
return self
11321140

1141+
@model_validator(mode='before')
1142+
@classmethod
1143+
def _warn_deprecated_allow_advanced_sampling(cls, data):
1144+
"""Warn when users set the deprecated allow_advanced_sampling flag.
1145+
1146+
Non-greedy sampling is now auto-detected per request and always
1147+
available, so the flag is a no-op; warn loudly so callers update
1148+
their configs before the flag is removed.
1149+
"""
1150+
if isinstance(data, dict) and 'allow_advanced_sampling' in data:
1151+
logger.warning(
1152+
"DecodingBaseConfig: 'allow_advanced_sampling' is deprecated "
1153+
"and will be removed in a future release. The flag has no "
1154+
"effect — non-greedy sampling is now auto-detected per "
1155+
"request.")
1156+
return data
1157+
11331158
@model_validator(mode='after')
11341159
# 1. Validate that max_concurrency and draft_len_schedule are mutually exclusive.
11351160
# 2. If max_concurrency is set, translate it to the corresponding draft_len_schedule.
@@ -4423,6 +4448,11 @@ def validate_speculative_config(self):
44234448
# Rejection sampling is only wired up for Eagle3 one-model paths.
44244449
# Silently fall back for other spec types so the new default
44254450
# (True) does not break them.
4451+
# TODO: extend rejection sampling to the remaining speculative
4452+
# decoding paths (MTP / DraftTarget / PARD / DFlash /
4453+
# SaveHiddenStates / SA) and unify the dispatch in SpecMetadata
4454+
# so new spec algorithms get rejection sampling for free; once
4455+
# all paths are covered this whitelist guard can be removed.
44264456
self.speculative_config.use_rejection_sampling = False
44274457

44284458
if isinstance(self.speculative_config, PARDDecodingConfig):

0 commit comments

Comments
 (0)