Skip to content

Commit eaa50b7

Browse files
committed
[None][feat] Acceptance-rate-based speculation gate
Gate speculative decoding on a rolling per-request acceptance rate. When the moving-average AR over acceptance_rate_window_size falls below acceptance_rate_threshold, speculation is disabled for that request; when it rises back above the threshold, speculation re-engages. - Wire SpeculationGate through the three hot-path call sites (PP, non- overlap, overlap scheduler) and hoist the early-return so the no-op-when-disabled guarantee is visible in the diff. - Rename _record_batch_acceptance_rate -> _update_batch_acceptance_rate and document the overlap-scheduler interaction. - Add the acceptance_rate_window_size / acceptance_rate_threshold knobs to TorchLlmArgs / DecodingBaseConfig. - Add unit tests under tests/unittest/_torch/speculative/test_spec_gate.py. Squashed from: 0e4906a AR based speculation off e099d11 fix CI ef245d7 Rename _record_batch_acceptance_rate and add overlap scheduler comment. 955bd69 [None][perf] Make AR gate no-op explicit at call sites Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent b03b78f commit eaa50b7

4 files changed

Lines changed: 182 additions & 137 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -399,19 +399,16 @@ def __init__(
399399
self.num_fetch_requests = 0
400400
self.shutdown_event = threading.Event()
401401

402-
# Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
403-
spec_config = getattr(self.model_engine, 'spec_config', None)
404-
self.acceptance_window = getattr(
405-
spec_config, 'acceptance_window',
406-
None) if spec_config is not None else None
407-
self.acceptance_length_threshold = getattr(
408-
spec_config, 'acceptance_length_threshold',
409-
None) if spec_config is not None else None
402+
# Rolling true acceptance-rate tracking for permanent speculation
403+
# disable.
410404
self.speculation_permanently_disabled = False
411405
self.speculation_gate = None
412-
if self.acceptance_window and self.acceptance_length_threshold is not None:
413-
self.speculation_gate = SpeculationGate(
414-
self.acceptance_window, self.acceptance_length_threshold)
406+
spec_config = getattr(self.model_engine, 'spec_config', None)
407+
if spec_config is not None:
408+
window = getattr(spec_config, 'acceptance_rate_window_size', None)
409+
threshold = getattr(spec_config, 'acceptance_rate_threshold', None)
410+
if window and threshold is not None:
411+
self.speculation_gate = SpeculationGate(window, threshold)
415412

416413
# response used data
417414
self.response_lock = threading.Lock()
@@ -1666,6 +1663,42 @@ def _update_iter_stats(
16661663

16671664
return stats
16681665

1666+
def _update_batch_acceptance_rate(
1667+
self,
1668+
scheduled_batch: ScheduledRequests,
1669+
sample_state: SampleState,
1670+
iteration_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
1671+
if (self.speculation_gate is None
1672+
or self.speculation_permanently_disabled or self.is_warmup):
1673+
return False, None
1674+
1675+
if (getattr(self.dist, 'has_pp', False)
1676+
and not self.dist.is_last_pp_rank):
1677+
return False, None
1678+
new_tokens_lens = getattr(sample_state.host, 'new_tokens_lens', None)
1679+
if new_tokens_lens is None:
1680+
return False, None
1681+
new_tokens_lens_list = (new_tokens_lens.tolist() if hasattr(
1682+
new_tokens_lens, 'tolist') else list(new_tokens_lens))
1683+
total_draft_tokens = 0
1684+
total_accepted_tokens = 0
1685+
for request in scheduled_batch.generation_requests:
1686+
draft_len = request.num_draft_tokens
1687+
if draft_len <= 0 or request.is_dummy:
1688+
continue
1689+
total_draft_tokens += draft_len
1690+
total_accepted_tokens += request.py_num_accepted_draft_tokens
1691+
1692+
if total_draft_tokens <= 0:
1693+
return False, None
1694+
1695+
acceptance_rate = total_accepted_tokens / total_draft_tokens
1696+
disabled_now, avg = self.speculation_gate.record_acceptance_rate(
1697+
acceptance_rate, sample_id=iteration_id)
1698+
if disabled_now:
1699+
self.speculation_permanently_disabled = True
1700+
return disabled_now, avg
1701+
16691702
def _append_iter_stats(self,
16701703
stats: IterationStats,
16711704
req_stats: Optional[List[RequestStats]] = None,
@@ -2406,6 +2439,11 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
24062439
if executed_batch is not None:
24072440
with torch.cuda.nvtx.range("_handle_executed_batch_pp"):
24082441
self._update_requests(executed_batch.sample_state)
2442+
if self.speculation_gate is not None:
2443+
self._update_batch_acceptance_rate(
2444+
executed_batch.scheduled_requests,
2445+
executed_batch.sample_state,
2446+
iteration_id=self.iter_counter)
24092447

24102448
scheduled_requests = executed_batch.scheduled_requests
24112449
if self.kv_cache_transceiver:
@@ -2473,6 +2511,12 @@ def _handle_dynamic_draft_len(self,
24732511
if not hasattr(self.model_engine, 'max_draft_len'):
24742512
return
24752513

2514+
if self.speculation_permanently_disabled:
2515+
for request in scheduled_batch.generation_requests:
2516+
request.py_draft_tokens = []
2517+
self.model_engine.runtime_draft_len = 0
2518+
return
2519+
24762520
if (self.model_engine.spec_config is not None
24772521
and self.model_engine.spec_config.draft_len_schedule is not None
24782522
and self.model_engine.spec_config.spec_dec_mode.
@@ -2608,7 +2652,6 @@ def _prepare_and_schedule_batch(self):
26082652
# with dummy draft tokens to make the scheduler aware of the fact
26092653
# that speculation is about to happen.
26102654
self._prepare_draft_requests()
2611-
26122655
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
26132656
)
26142657

@@ -2986,6 +3029,11 @@ def _executor_loop(self):
29863029

29873030
self._update_request_states(scheduled_batch)
29883031
self._update_requests(sample_state, self.resource_manager)
3032+
if self.speculation_gate is not None:
3033+
self._update_batch_acceptance_rate(
3034+
scheduled_batch,
3035+
sample_state,
3036+
iteration_id=self.iter_counter)
29893037

29903038
self._send_kv_async(scheduled_batch.all_requests())
29913039
self._flush_pending_transfer_responses()
@@ -3368,6 +3416,13 @@ def _executor_loop_overlap(self):
33683416

33693417
if self.previous_batch is not None and should_process_previous_batch:
33703418
self._update_requests(self.previous_batch.sample_state)
3419+
# Turning off speculative decoding when Acceptance Rate is low.
3420+
# In overlap scheduler path, it will do an extra iter with spec decode on.
3421+
if self.speculation_gate is not None:
3422+
self._update_batch_acceptance_rate(
3423+
self.previous_batch.scheduled_requests,
3424+
self.previous_batch.sample_state,
3425+
iteration_id=self.iter_counter)
33713426

33723427
self._send_kv_async(
33733428
self.previous_batch.scheduled_requests.all_requests())
@@ -4993,31 +5048,6 @@ def _handle_responses(self, emit_first_iter: bool = True):
49935048
new_responses.append((req_id, response))
49945049

49955050
if request_done:
4996-
if (self.drafter is not None and getattr(
4997-
self.model_engine, 'enable_spec_decode', False)
4998-
and not self.speculation_permanently_disabled
4999-
and not request.is_dummy and not self.is_warmup):
5000-
if self.speculation_gate is not None:
5001-
# Response handling runs on multiple PP ranks. Only the last PP rank performs
5002-
# sampling; restrict rolling stat updates to it to avoid overcounting.
5003-
if (not getattr(self.dist, 'has_pp',
5004-
False)) or self.dist.is_last_pp_rank:
5005-
avg_decoded = getattr(
5006-
request, 'avg_decoded_tokens_per_iter', None)
5007-
if avg_decoded is not None:
5008-
disabled_now, _ = self.speculation_gate.record_avg_decoded(
5009-
avg_decoded,
5010-
request_id=getattr(request, 'py_request_id',
5011-
None))
5012-
if disabled_now:
5013-
# disable speculation permanently
5014-
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
5015-
self.speculation_permanently_disabled = True
5016-
else:
5017-
logger.debug(
5018-
f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter"
5019-
)
5020-
50215051
# TODO: Remove PP size == 1 gate for disagg + block reuse with PP > 1.
50225052
force_terminate_for_partial_reuse = (
50235053
self.enable_partial_reuse_for_disagg
Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,82 @@
11
from collections import deque
2-
from typing import Optional, Tuple
2+
from typing import Deque, Optional, Tuple
33

44
from tensorrt_llm.logger import logger
55

66

77
class SpeculationGate:
88
"""
9-
Tracks rolling average of accepted draft tokens per iteration over the last N completed requests.
10-
Permanently disables speculation when average falls below a threshold.
11-
"""
9+
Tracks a rolling average of true acceptance-rate samples over the last N
10+
speculation-enabled decoding iterations.
11+
12+
Permanently disables speculation when the rolling average falls below the
13+
configured threshold.
14+
"""
1215

1316
def __init__(self, window: int, threshold: float):
1417
self.window = window
1518
self.threshold = threshold
16-
self.acceptance_history: Deque[float] = deque()
17-
self.acceptance_sum: float = 0.0
18-
self.num_completed_for_acceptance = 0
19+
self.acceptance_rate_history: Deque[float] = deque()
20+
self.acceptance_rate_sum: float = 0.0
21+
self.num_recorded_samples = 0
1922
self.disabled = False
2023
logger.debug(
2124
f"[SpeculationGate] SpeculationGate initialized with window={self.window}, threshold={self.threshold}"
2225
)
2326

2427
def reset(self) -> None:
25-
self.acceptance_history.clear()
26-
self.acceptance_sum = 0.0
27-
self.num_completed_for_acceptance = 0
28+
self.acceptance_rate_history.clear()
29+
self.acceptance_rate_sum = 0.0
30+
self.num_recorded_samples = 0
2831
self.disabled = False
2932

30-
def record_avg_decoded(
33+
def record_acceptance_rate(
3134
self,
32-
avg_decoded_tokens_per_iter: float,
33-
request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
35+
acceptance_rate: float,
36+
sample_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
3437
"""
35-
Record a completed request's avg_decoded_tokens_per_iter.
36-
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
38+
Record one speculation-enabled iteration's true acceptance rate.
39+
40+
Returns (disabled_now, current_avg_acceptance_rate) where
41+
disabled_now is True only when this call causes permanent disable.
3742
"""
3843
if self.disabled or self.window is None or self.window <= 0 or self.threshold is None:
3944
return False, None
4045

41-
# Extra Guard: if caller passed None, skip updating the rolling stats
42-
if avg_decoded_tokens_per_iter is None:
46+
if acceptance_rate is None:
4347
return False, None
4448

45-
accepted_len = 0.0
46-
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)
49+
acceptance_rate = float(acceptance_rate)
50+
if acceptance_rate < 0.0 or acceptance_rate > 1.0:
51+
raise ValueError("acceptance_rate must be in the range [0.0, 1.0], "
52+
f"got {acceptance_rate}")
4753

48-
# Log per-request completion for debug
49-
if request_id is not None:
50-
logger.debug(
51-
f"[SpeculationGate] Request {request_id} completed: avg_decoded={avg_decoded_tokens_per_iter if avg_decoded_tokens_per_iter is not None else 'None'}, accepted_len={accepted_len:.3f}"
52-
)
54+
if sample_id is not None:
55+
logger.debug(f"[SpeculationGate] Iteration {sample_id} recorded "
56+
f"acceptance_rate={acceptance_rate:.3f}")
5357

5458
# O(1) rolling update
55-
self.acceptance_history.append(accepted_len)
56-
logger.debug(
57-
f"[SpeculationGate] Acceptance history: {self.acceptance_history}")
58-
self.acceptance_sum += accepted_len
59-
if len(self.acceptance_history) > self.window:
60-
removed = self.acceptance_history.popleft()
61-
self.acceptance_sum -= removed
59+
self.acceptance_rate_history.append(acceptance_rate)
60+
logger.debug(f"[SpeculationGate] Acceptance-rate history: "
61+
f"{self.acceptance_rate_history}")
62+
self.acceptance_rate_sum += acceptance_rate
63+
if len(self.acceptance_rate_history) > self.window:
64+
removed = self.acceptance_rate_history.popleft()
65+
self.acceptance_rate_sum -= removed
6266

63-
self.num_completed_for_acceptance += 1
67+
self.num_recorded_samples += 1
6468

65-
if self.num_completed_for_acceptance >= self.window:
66-
avg_accept = self.acceptance_sum / len(self.acceptance_history)
67-
if avg_accept < self.threshold:
69+
if self.num_recorded_samples >= self.window:
70+
avg_acceptance_rate = (self.acceptance_rate_sum /
71+
len(self.acceptance_rate_history))
72+
if avg_acceptance_rate < self.threshold:
6873
self.disabled = True
6974
logger.info(
70-
f"[SpeculationGate] Speculative decoding disabled: rolling acceptance avg {avg_accept:.3f} < threshold {self.threshold} over last {self.window} requests"
71-
)
72-
return True, avg_accept
73-
else:
74-
# speculation is still enabled
75-
return False, avg_accept
75+
"[SpeculationGate] Speculative decoding disabled: "
76+
f"rolling acceptance rate avg "
77+
f"{avg_acceptance_rate:.3f} < threshold "
78+
f"{self.threshold} over last {self.window} iterations")
79+
return True, avg_acceptance_rate
80+
return False, avg_acceptance_rate
7681

7782
return False, None

tensorrt_llm/llmapi/llm_args.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,19 +1057,23 @@ class DecodingBaseConfig(StrictBaseModel):
10571057
load_format: Optional[str] = Field(
10581058
default=None, description="The load format of the speculative model.")
10591059

1060-
acceptance_window: Optional[NonNegativeInt] = Field(
1060+
acceptance_rate_window_size: Optional[NonNegativeInt] = Field(
10611061
default=None,
10621062
description=
1063-
"The rolling average window size (N) for acceptance length across completed requests. "
1063+
"The rolling average window size (N) for acceptance rate across "
1064+
"speculation-enabled decoding iterations. "
10641065
"If not set or set to 0, the feature is disabled. PyTorch backend only."
10651066
)
10661067

1067-
acceptance_length_threshold: Optional[NonNegativeFloat] = Field(
1068+
acceptance_rate_threshold: Optional[float] = Field(
10681069
default=None,
1069-
description=
1070-
"The threshold for average acceptance length; speculation will be disabled permanently once the "
1071-
"rolling average over the last N completed requests (N = acceptance_window) drops below this value. "
1072-
"PyTorch backend only.")
1070+
ge=0.0,
1071+
le=1.0,
1072+
description="The threshold for average true acceptance rate "
1073+
"(accepted_draft_tokens / drafted_tokens); speculation will be "
1074+
"disabled permanently once the rolling average over the last N "
1075+
"speculation-enabled decoding iterations "
1076+
"(N = acceptance_rate_window_size) drops below this value. ")
10731077

10741078
allow_advanced_sampling: bool = Field(
10751079
default=False,

0 commit comments

Comments
 (0)