Skip to content

Commit d599fb5

Browse files
[TRTLLM-12669][feat] Enable rejection sampling by default for Eagle3 one-model
Flip the default of `use_rejection_sampling` from `False` to `True` on DecodingBaseConfig. With the refactor of the all-greedy fast path in place, this is safe: the runtime guard in `_can_use_rejection_sampling` still requires a non-greedy batch, so all-greedy batches keep taking the argmax fast path unchanged. Only batches that already opted into non-greedy sampling now see the rejection sampling acceptance behavior. Benchmark results on Qwen3-235B-A22B + Eagle3 (tp=8) show consistent +6.4% to +9.4% throughput and +3.4 to +4.3 pp acceptance rate across batch sizes 1-16 vs the exact-match baseline. Other Eagle3 deployments see smaller but uniformly positive acceptance-rate gains. Two prior `raise ValueError` paths are converted to silent fallbacks so the new default does not break existing users: - Non-Eagle3 spec configs (PARD, DFlash, MTP, ...) silently disable the flag in TorchLlmArgs post-validation, since rejection sampling is only wired up for Eagle3 one-model paths. - SA-enhanced Eagle3 configs disable the flag in the per-config validator, since SA may override proposed draft tokens. Users who want the prior exact-match behavior can still pass `use_rejection_sampling=False` explicitly. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 83e7903 commit d599fb5

1 file changed

Lines changed: 17 additions & 14 deletions

File tree

tensorrt_llm/llmapi/llm_args.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -897,11 +897,13 @@ class DecodingBaseConfig(StrictBaseModel):
897897
"PyTorch backend only.")
898898

899899
use_rejection_sampling: bool = Field(
900-
default=False,
900+
default=True,
901901
status="prototype",
902902
description=
903-
"If true, enables rejection sampling for one-model speculative decoding paths. "
904-
"This is intended for non-greedy sampling configurations on the PyTorch backend. "
903+
"If true (default), enables rejection sampling for one-model speculative "
904+
"decoding paths when the batch contains any non-greedy request. All-greedy "
905+
"batches always take the argmax fast path regardless of this flag. Set to "
906+
"false to fall back to exact-match verification on non-greedy batches. "
905907
"The non-dynamic-tree one-model path requires FlashInfer.")
906908

907909
# If set, drafting is allowed to use chain drafter.
@@ -958,13 +960,14 @@ def validate_draft_len_schedule_and_sort(cls, v, info):
958960

959961
@model_validator(mode='after')
960962
def validate_rejection_sampling_config(self):
961-
"""Reject SA-enhanced configurations that invalidate rejection sampling."""
963+
"""Disable rejection sampling when SA-enhanced configurations are
964+
active, since SA may override the proposed draft tokens. This is a
965+
silent fallback so the new default (True) does not break sa_config
966+
users.
967+
"""
962968
if self.use_rejection_sampling and getattr(self, 'sa_config',
963969
None) is not None:
964-
raise ValueError(
965-
"use_rejection_sampling is incompatible with sa_config "
966-
"because SA enhancement may override the proposed draft tokens."
967-
)
970+
self.use_rejection_sampling = False
968971
return self
969972

970973
@model_validator(mode='after')
@@ -4140,12 +4143,12 @@ def validate_speculative_config(self):
41404143
exclude={"decoding_type"})
41414144
self.speculative_config = Eagle3DecodingConfig(**eagle_data)
41424145

4143-
if self.speculative_config.use_rejection_sampling:
4144-
if not isinstance(self.speculative_config,
4145-
Eagle3DecodingConfig):
4146-
raise ValueError(
4147-
"use_rejection_sampling is only supported for "
4148-
"PyTorch Eagle3 one-model speculative decoding paths.")
4146+
if self.speculative_config.use_rejection_sampling and not isinstance(
4147+
self.speculative_config, Eagle3DecodingConfig):
4148+
# Rejection sampling is only wired up for Eagle3 one-model paths.
4149+
# Silently fall back for other spec types so the new default
4150+
# (True) does not break them.
4151+
self.speculative_config.use_rejection_sampling = False
41494152

41504153
if isinstance(self.speculative_config, PARDDecodingConfig):
41514154
assert self.speculative_config.max_draft_len > 0, "PARD max_draft_len must be > 0"

0 commit comments

Comments
 (0)