Skip to content

Commit 9cd2fe6

Browse files
committed
[None][feat] Acceptance-rate-based speculation gate
Gate speculative decoding on a rolling per-request acceptance rate. When the moving-average AR over acceptance_rate_window_size falls below acceptance_rate_threshold, speculation is disabled for that request; when it rises back above the threshold, speculation re-engages. - Wire SpeculationGate through the three hot-path call sites (PP, non- overlap, overlap scheduler) and hoist the early-return so the no-op-when-disabled guarantee is visible in the diff. - Rename _record_batch_acceptance_rate -> _update_batch_acceptance_rate and document the overlap-scheduler interaction. - Add the acceptance_rate_window_size / acceptance_rate_threshold knobs to TorchLlmArgs / DecodingBaseConfig. - Add unit tests under tests/unittest/_torch/speculative/test_spec_gate.py. Squashed from: 0e4906a AR based speculation off e099d11 fix CI ef245d7 Rename _record_batch_acceptance_rate and add overlap scheduler comment. 955bd69 [None][perf] Make AR gate no-op explicit at call sites Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent eddaa3a commit 9cd2fe6

4 files changed

Lines changed: 182 additions & 137 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -410,19 +410,16 @@ def __init__(
410410
self.num_fetch_requests = 0
411411
self.shutdown_event = threading.Event()
412412

413-
# Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
414-
spec_config = getattr(self.model_engine, 'spec_config', None)
415-
self.acceptance_window = getattr(
416-
spec_config, 'acceptance_window',
417-
None) if spec_config is not None else None
418-
self.acceptance_length_threshold = getattr(
419-
spec_config, 'acceptance_length_threshold',
420-
None) if spec_config is not None else None
413+
# Rolling true acceptance-rate tracking for permanent speculation
414+
# disable.
421415
self.speculation_permanently_disabled = False
422416
self.speculation_gate = None
423-
if self.acceptance_window and self.acceptance_length_threshold is not None:
424-
self.speculation_gate = SpeculationGate(
425-
self.acceptance_window, self.acceptance_length_threshold)
417+
spec_config = getattr(self.model_engine, 'spec_config', None)
418+
if spec_config is not None:
419+
window = getattr(spec_config, 'acceptance_rate_window_size', None)
420+
threshold = getattr(spec_config, 'acceptance_rate_threshold', None)
421+
if window and threshold is not None:
422+
self.speculation_gate = SpeculationGate(window, threshold)
426423

427424
# response used data
428425
self.response_lock = threading.Lock()
@@ -1706,6 +1703,42 @@ def _update_iter_stats(
17061703

17071704
return stats
17081705

1706+
def _update_batch_acceptance_rate(
1707+
self,
1708+
scheduled_batch: ScheduledRequests,
1709+
sample_state: SampleState,
1710+
iteration_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
1711+
if (self.speculation_gate is None
1712+
or self.speculation_permanently_disabled or self.is_warmup):
1713+
return False, None
1714+
1715+
if (getattr(self.dist, 'has_pp', False)
1716+
and not self.dist.is_last_pp_rank):
1717+
return False, None
1718+
new_tokens_lens = getattr(sample_state.host, 'new_tokens_lens', None)
1719+
if new_tokens_lens is None:
1720+
return False, None
1721+
new_tokens_lens_list = (new_tokens_lens.tolist() if hasattr(
1722+
new_tokens_lens, 'tolist') else list(new_tokens_lens))
1723+
total_draft_tokens = 0
1724+
total_accepted_tokens = 0
1725+
for request in scheduled_batch.generation_requests:
1726+
draft_len = request.num_draft_tokens
1727+
if draft_len <= 0 or request.is_dummy:
1728+
continue
1729+
total_draft_tokens += draft_len
1730+
total_accepted_tokens += request.py_num_accepted_draft_tokens
1731+
1732+
if total_draft_tokens <= 0:
1733+
return False, None
1734+
1735+
acceptance_rate = total_accepted_tokens / total_draft_tokens
1736+
disabled_now, avg = self.speculation_gate.record_acceptance_rate(
1737+
acceptance_rate, sample_id=iteration_id)
1738+
if disabled_now:
1739+
self.speculation_permanently_disabled = True
1740+
return disabled_now, avg
1741+
17091742
def _append_iter_stats(self,
17101743
stats: IterationStats,
17111744
req_stats: Optional[List[RequestStats]] = None,
@@ -2447,6 +2480,11 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
24472480
if executed_batch is not None:
24482481
with torch.cuda.nvtx.range("_handle_executed_batch_pp"):
24492482
self._update_requests(executed_batch.sample_state)
2483+
if self.speculation_gate is not None:
2484+
self._update_batch_acceptance_rate(
2485+
executed_batch.scheduled_requests,
2486+
executed_batch.sample_state,
2487+
iteration_id=self.iter_counter)
24502488

24512489
scheduled_requests = executed_batch.scheduled_requests
24522490
if self.kv_cache_transceiver:
@@ -2514,6 +2552,12 @@ def _handle_dynamic_draft_len(self,
25142552
if not hasattr(self.model_engine, 'max_draft_len'):
25152553
return
25162554

2555+
if self.speculation_permanently_disabled:
2556+
for request in scheduled_batch.generation_requests:
2557+
request.py_draft_tokens = []
2558+
self.model_engine.runtime_draft_len = 0
2559+
return
2560+
25172561
if (self.model_engine.spec_config is not None
25182562
and self.model_engine.spec_config.draft_len_schedule is not None
25192563
and self.model_engine.spec_config.spec_dec_mode.
@@ -2675,7 +2719,6 @@ def _prepare_and_schedule_batch(self):
26752719
# with dummy draft tokens to make the scheduler aware of the fact
26762720
# that speculation is about to happen.
26772721
self._prepare_draft_requests()
2678-
26792722
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
26802723
)
26812724

@@ -3065,6 +3108,11 @@ def _executor_loop(self):
30653108

30663109
self._update_request_states(scheduled_batch)
30673110
self._update_requests(sample_state, self.resource_manager)
3111+
if self.speculation_gate is not None:
3112+
self._update_batch_acceptance_rate(
3113+
scheduled_batch,
3114+
sample_state,
3115+
iteration_id=self.iter_counter)
30683116

30693117
self._send_kv_async(scheduled_batch.all_requests())
30703118
self._flush_pending_transfer_responses()
@@ -3449,6 +3497,13 @@ def _executor_loop_overlap(self):
34493497

34503498
if self.previous_batch is not None and should_process_previous_batch:
34513499
self._update_requests(self.previous_batch.sample_state)
3500+
# Turning off speculative decoding when Acceptance Rate is low.
3501+
# In overlap scheduler path, it will do an extra iter with spec decode on.
3502+
if self.speculation_gate is not None:
3503+
self._update_batch_acceptance_rate(
3504+
self.previous_batch.scheduled_requests,
3505+
self.previous_batch.sample_state,
3506+
iteration_id=self.iter_counter)
34523507

34533508
self._send_kv_async(
34543509
self.previous_batch.scheduled_requests.all_requests())
@@ -5364,31 +5419,6 @@ def _handle_responses(self, emit_first_iter: bool = True):
53645419
new_responses.append((req_id, response))
53655420

53665421
if request_done:
5367-
if (self.drafter is not None and getattr(
5368-
self.model_engine, 'enable_spec_decode', False)
5369-
and not self.speculation_permanently_disabled
5370-
and not request.is_dummy and not self.is_warmup):
5371-
if self.speculation_gate is not None:
5372-
# Response handling runs on multiple PP ranks. Only the last PP rank performs
5373-
# sampling; restrict rolling stat updates to it to avoid overcounting.
5374-
if (not getattr(self.dist, 'has_pp',
5375-
False)) or self.dist.is_last_pp_rank:
5376-
avg_decoded = getattr(
5377-
request, 'avg_decoded_tokens_per_iter', None)
5378-
if avg_decoded is not None:
5379-
disabled_now, _ = self.speculation_gate.record_avg_decoded(
5380-
avg_decoded,
5381-
request_id=getattr(request, 'py_request_id',
5382-
None))
5383-
if disabled_now:
5384-
# disable speculation permanently
5385-
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
5386-
self.speculation_permanently_disabled = True
5387-
else:
5388-
logger.debug(
5389-
f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter"
5390-
)
5391-
53925422
# TODO: Remove PP size == 1 gate for disagg + block reuse with PP > 1.
53935423
force_terminate_for_partial_reuse = (
53945424
self.enable_partial_reuse_for_disagg
Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,82 @@
11
from collections import deque
2-
from typing import Optional, Tuple
2+
from typing import Deque, Optional, Tuple
33

44
from tensorrt_llm.logger import logger
55

66

77
class SpeculationGate:
88
"""
9-
Tracks rolling average of accepted draft tokens per iteration over the last N completed requests.
10-
Permanently disables speculation when average falls below a threshold.
11-
"""
9+
Tracks a rolling average of true acceptance-rate samples over the last N
10+
speculation-enabled decoding iterations.
11+
12+
Permanently disables speculation when the rolling average falls below the
13+
configured threshold.
14+
"""
1215

1316
def __init__(self, window: int, threshold: float):
1417
self.window = window
1518
self.threshold = threshold
16-
self.acceptance_history: Deque[float] = deque()
17-
self.acceptance_sum: float = 0.0
18-
self.num_completed_for_acceptance = 0
19+
self.acceptance_rate_history: Deque[float] = deque()
20+
self.acceptance_rate_sum: float = 0.0
21+
self.num_recorded_samples = 0
1922
self.disabled = False
2023
logger.debug(
2124
f"[SpeculationGate] SpeculationGate initialized with window={self.window}, threshold={self.threshold}"
2225
)
2326

2427
def reset(self) -> None:
25-
self.acceptance_history.clear()
26-
self.acceptance_sum = 0.0
27-
self.num_completed_for_acceptance = 0
28+
self.acceptance_rate_history.clear()
29+
self.acceptance_rate_sum = 0.0
30+
self.num_recorded_samples = 0
2831
self.disabled = False
2932

30-
def record_avg_decoded(
33+
def record_acceptance_rate(
3134
self,
32-
avg_decoded_tokens_per_iter: float,
33-
request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
35+
acceptance_rate: float,
36+
sample_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
3437
"""
35-
Record a completed request's avg_decoded_tokens_per_iter.
36-
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
38+
Record one speculation-enabled iteration's true acceptance rate.
39+
40+
Returns (disabled_now, current_avg_acceptance_rate) where
41+
disabled_now is True only when this call causes permanent disable.
3742
"""
3843
if self.disabled or self.window is None or self.window <= 0 or self.threshold is None:
3944
return False, None
4045

41-
# Extra Guard: if caller passed None, skip updating the rolling stats
42-
if avg_decoded_tokens_per_iter is None:
46+
if acceptance_rate is None:
4347
return False, None
4448

45-
accepted_len = 0.0
46-
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)
49+
acceptance_rate = float(acceptance_rate)
50+
if acceptance_rate < 0.0 or acceptance_rate > 1.0:
51+
raise ValueError("acceptance_rate must be in the range [0.0, 1.0], "
52+
f"got {acceptance_rate}")
4753

48-
# Log per-request completion for debug
49-
if request_id is not None:
50-
logger.debug(
51-
f"[SpeculationGate] Request {request_id} completed: avg_decoded={avg_decoded_tokens_per_iter if avg_decoded_tokens_per_iter is not None else 'None'}, accepted_len={accepted_len:.3f}"
52-
)
54+
if sample_id is not None:
55+
logger.debug(f"[SpeculationGate] Iteration {sample_id} recorded "
56+
f"acceptance_rate={acceptance_rate:.3f}")
5357

5458
# O(1) rolling update
55-
self.acceptance_history.append(accepted_len)
56-
logger.debug(
57-
f"[SpeculationGate] Acceptance history: {self.acceptance_history}")
58-
self.acceptance_sum += accepted_len
59-
if len(self.acceptance_history) > self.window:
60-
removed = self.acceptance_history.popleft()
61-
self.acceptance_sum -= removed
59+
self.acceptance_rate_history.append(acceptance_rate)
60+
logger.debug(f"[SpeculationGate] Acceptance-rate history: "
61+
f"{self.acceptance_rate_history}")
62+
self.acceptance_rate_sum += acceptance_rate
63+
if len(self.acceptance_rate_history) > self.window:
64+
removed = self.acceptance_rate_history.popleft()
65+
self.acceptance_rate_sum -= removed
6266

63-
self.num_completed_for_acceptance += 1
67+
self.num_recorded_samples += 1
6468

65-
if self.num_completed_for_acceptance >= self.window:
66-
avg_accept = self.acceptance_sum / len(self.acceptance_history)
67-
if avg_accept < self.threshold:
69+
if self.num_recorded_samples >= self.window:
70+
avg_acceptance_rate = (self.acceptance_rate_sum /
71+
len(self.acceptance_rate_history))
72+
if avg_acceptance_rate < self.threshold:
6873
self.disabled = True
6974
logger.info(
70-
f"[SpeculationGate] Speculative decoding disabled: rolling acceptance avg {avg_accept:.3f} < threshold {self.threshold} over last {self.window} requests"
71-
)
72-
return True, avg_accept
73-
else:
74-
# speculation is still enabled
75-
return False, avg_accept
75+
"[SpeculationGate] Speculative decoding disabled: "
76+
f"rolling acceptance rate avg "
77+
f"{avg_acceptance_rate:.3f} < threshold "
78+
f"{self.threshold} over last {self.window} iterations")
79+
return True, avg_acceptance_rate
80+
return False, avg_acceptance_rate
7681

7782
return False, None

tensorrt_llm/llmapi/llm_args.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,19 +1291,23 @@ class DecodingBaseConfig(StrictBaseModel):
12911291
load_format: Optional[str] = Field(
12921292
default=None, description="The load format of the speculative model.")
12931293

1294-
acceptance_window: Optional[NonNegativeInt] = Field(
1294+
acceptance_rate_window_size: Optional[NonNegativeInt] = Field(
12951295
default=None,
12961296
description=
1297-
"The rolling average window size (N) for acceptance length across completed requests. "
1297+
"The rolling average window size (N) for acceptance rate across "
1298+
"speculation-enabled decoding iterations. "
12981299
"If not set or set to 0, the feature is disabled. PyTorch backend only."
12991300
)
13001301

1301-
acceptance_length_threshold: Optional[NonNegativeFloat] = Field(
1302+
acceptance_rate_threshold: Optional[float] = Field(
13021303
default=None,
1303-
description=
1304-
"The threshold for average acceptance length; speculation will be disabled permanently once the "
1305-
"rolling average over the last N completed requests (N = acceptance_window) drops below this value. "
1306-
"PyTorch backend only.")
1304+
ge=0.0,
1305+
le=1.0,
1306+
description="The threshold for average true acceptance rate "
1307+
"(accepted_draft_tokens / drafted_tokens); speculation will be "
1308+
"disabled permanently once the rolling average over the last N "
1309+
"speculation-enabled decoding iterations "
1310+
"(N = acceptance_rate_window_size) drops below this value. ")
13071311

13081312
use_rejection_sampling: bool = Field(
13091313
default=False,

0 commit comments

Comments
 (0)