Skip to content

Commit 3d662e5

Browse files
committed
AR based speculation off
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent 2335ef8 commit 3d662e5

4 files changed

Lines changed: 177 additions & 140 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def __init__(
328328
self.max_beam_width = max_beam_width
329329
self.max_draft_len = max_draft_len
330330
self.max_total_draft_tokens = max_total_draft_tokens
331+
self.use_spec_decode = self.model_engine.enable_spec_decode
331332
self.llm_args = self.model_engine.llm_args
332333
self.max_stats_len = max(self.llm_args.max_stats_len, 1)
333334
self.max_num_tokens = self.llm_args.max_num_tokens
@@ -352,19 +353,16 @@ def __init__(
352353
self.num_fetch_requests = 0
353354
self.shutdown_event = threading.Event()
354355

355-
# Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
356-
spec_config = getattr(self.model_engine, 'spec_config', None)
357-
self.acceptance_window = getattr(
358-
spec_config, 'acceptance_window',
359-
None) if spec_config is not None else None
360-
self.acceptance_length_threshold = getattr(
361-
spec_config, 'acceptance_length_threshold',
362-
None) if spec_config is not None else None
356+
# Rolling true acceptance-rate tracking for permanent speculation
357+
# disable.
363358
self.speculation_permanently_disabled = False
364359
self.speculation_gate = None
365-
if self.acceptance_window and self.acceptance_length_threshold is not None:
366-
self.speculation_gate = SpeculationGate(
367-
self.acceptance_window, self.acceptance_length_threshold)
360+
spec_config = getattr(self.model_engine, 'spec_config', None)
361+
if spec_config is not None:
362+
window = getattr(spec_config, 'acceptance_rate_window_size', None)
363+
threshold = getattr(spec_config, 'acceptance_rate_threshold', None)
364+
if window and threshold is not None:
365+
self.speculation_gate = SpeculationGate(window, threshold)
368366

369367
# response used data
370368
self.response_lock = threading.Lock()
@@ -1185,6 +1183,42 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
11851183
draft_latency_ms) / float(iter_latency_ms)
11861184
return stats
11871185

1186+
def _record_batch_acceptance_rate(
1187+
self,
1188+
scheduled_batch: ScheduledRequests,
1189+
sample_state: SampleState,
1190+
iteration_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
1191+
if (self.speculation_gate is None
1192+
or self.speculation_permanently_disabled or self.is_warmup):
1193+
return False, None
1194+
1195+
if (getattr(self.dist, 'has_pp', False)
1196+
and not self.dist.is_last_pp_rank):
1197+
return False, None
1198+
new_tokens_lens = getattr(sample_state.host, 'new_tokens_lens', None)
1199+
if new_tokens_lens is None:
1200+
return False, None
1201+
new_tokens_lens_list = (new_tokens_lens.tolist() if hasattr(
1202+
new_tokens_lens, 'tolist') else list(new_tokens_lens))
1203+
total_draft_tokens = 0
1204+
total_accepted_tokens = 0
1205+
for request in scheduled_batch.generation_requests:
1206+
draft_len = request.num_draft_tokens
1207+
if draft_len <= 0 or request.is_dummy:
1208+
continue
1209+
total_draft_tokens += draft_len
1210+
total_accepted_tokens += request.py_num_accepted_draft_tokens
1211+
1212+
if total_draft_tokens <= 0:
1213+
return False, None
1214+
1215+
acceptance_rate = total_accepted_tokens / total_draft_tokens
1216+
disabled_now, avg = self.speculation_gate.record_acceptance_rate(
1217+
acceptance_rate, sample_id=iteration_id)
1218+
if disabled_now:
1219+
self.speculation_permanently_disabled = True
1220+
return disabled_now, avg
1221+
11881222
def _append_iter_stats(self,
11891223
stats: IterationStats,
11901224
req_stats: Optional[List[RequestStats]] = None):
@@ -1679,6 +1713,10 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
16791713
if executed_batch is not None:
16801714
with torch.cuda.nvtx.range("_handle_executed_batch_pp"):
16811715
self._update_requests(executed_batch.sample_state)
1716+
self._record_batch_acceptance_rate(
1717+
executed_batch.scheduled_requests,
1718+
executed_batch.sample_state,
1719+
iteration_id=self.iter_counter)
16821720

16831721
scheduled_requests = executed_batch.scheduled_requests
16841722
if self.kv_cache_transceiver:
@@ -1745,6 +1783,12 @@ def _handle_dynamic_draft_len(self,
17451783
if not hasattr(self.model_engine, 'max_draft_len'):
17461784
return
17471785

1786+
if self.speculation_permanently_disabled:
1787+
for request in scheduled_batch.generation_requests:
1788+
request.py_draft_tokens = []
1789+
self.model_engine.runtime_draft_len = 0
1790+
return
1791+
17481792
if (self.model_engine.spec_config is not None
17491793
and self.model_engine.spec_config.draft_len_schedule is not None
17501794
and self.model_engine.spec_config.spec_dec_mode.
@@ -1846,7 +1890,6 @@ def _prepare_and_schedule_batch(self):
18461890
# with dummy draft tokens to make the scheduler aware of the fact
18471891
# that speculation is about to happen.
18481892
self._prepare_draft_requests()
1849-
18501893
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
18511894
)
18521895

@@ -2134,6 +2177,10 @@ def _executor_loop(self):
21342177

21352178
self._update_request_states(scheduled_batch)
21362179
self._update_requests(sample_state, self.resource_manager)
2180+
self._record_batch_acceptance_rate(
2181+
scheduled_batch,
2182+
sample_state,
2183+
iteration_id=self.iter_counter)
21372184

21382185
self._send_kv_async(scheduled_batch.all_requests())
21392186

@@ -2376,6 +2423,10 @@ def _executor_loop_overlap(self):
23762423

23772424
if self.previous_batch is not None and should_process_previous_batch:
23782425
self._update_requests(self.previous_batch.sample_state)
2426+
self._record_batch_acceptance_rate(
2427+
self.previous_batch.scheduled_requests,
2428+
self.previous_batch.sample_state,
2429+
iteration_id=self.iter_counter)
23792430

23802431
self._send_kv_async(
23812432
self.previous_batch.scheduled_requests.all_requests())
@@ -3685,31 +3736,6 @@ def _handle_responses(self):
36853736
new_responses.append((req_id, response))
36863737

36873738
if request_done:
3688-
if (self.drafter is not None and getattr(
3689-
self.model_engine, 'enable_spec_decode', False)
3690-
and not self.speculation_permanently_disabled
3691-
and not request.is_dummy and not self.is_warmup):
3692-
if self.speculation_gate is not None:
3693-
# Response handling runs on multiple PP ranks. Only the last PP rank performs
3694-
# sampling; restrict rolling stat updates to it to avoid overcounting.
3695-
if (not getattr(self.dist, 'has_pp',
3696-
False)) or self.dist.is_last_pp_rank:
3697-
avg_decoded = getattr(
3698-
request, 'avg_decoded_tokens_per_iter', None)
3699-
if avg_decoded is not None:
3700-
disabled_now, _ = self.speculation_gate.record_avg_decoded(
3701-
avg_decoded,
3702-
request_id=getattr(request, 'py_request_id',
3703-
None))
3704-
if disabled_now:
3705-
# disable speculation permanently
3706-
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
3707-
self.speculation_permanently_disabled = True
3708-
else:
3709-
logger.debug(
3710-
f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter"
3711-
)
3712-
37133739
# If partial reuse is enabled, and the KV cache manager is not VSWA, and the PP size is 1,
37143740
# then we need to terminate the request. TODO: Remove this once disagg support from KVCache reuse
37153741
# path is fixed.
Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,82 @@
11
from collections import deque
2-
from typing import Optional, Tuple
2+
from typing import Deque, Optional, Tuple
33

44
from tensorrt_llm.logger import logger
55

66

77
class SpeculationGate:
88
"""
9-
Tracks rolling average of accepted draft tokens per iteration over the last N completed requests.
10-
Permanently disables speculation when average falls below a threshold.
11-
"""
9+
Tracks a rolling average of true acceptance-rate samples over the last N
10+
speculation-enabled decoding iterations.
11+
12+
Permanently disables speculation when the rolling average falls below the
13+
configured threshold.
14+
"""
1215

1316
def __init__(self, window: int, threshold: float):
1417
self.window = window
1518
self.threshold = threshold
16-
self.acceptance_history: Deque[float] = deque()
17-
self.acceptance_sum: float = 0.0
18-
self.num_completed_for_acceptance = 0
19+
self.acceptance_rate_history: Deque[float] = deque()
20+
self.acceptance_rate_sum: float = 0.0
21+
self.num_recorded_samples = 0
1922
self.disabled = False
2023
logger.debug(
2124
f"[SpeculationGate] SpeculationGate initialized with window={self.window}, threshold={self.threshold}"
2225
)
2326

2427
def reset(self) -> None:
25-
self.acceptance_history.clear()
26-
self.acceptance_sum = 0.0
27-
self.num_completed_for_acceptance = 0
28+
self.acceptance_rate_history.clear()
29+
self.acceptance_rate_sum = 0.0
30+
self.num_recorded_samples = 0
2831
self.disabled = False
2932

30-
def record_avg_decoded(
33+
def record_acceptance_rate(
3134
self,
32-
avg_decoded_tokens_per_iter: float,
33-
request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
35+
acceptance_rate: float,
36+
sample_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
3437
"""
35-
Record a completed request's avg_decoded_tokens_per_iter.
36-
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
38+
Record one speculation-enabled iteration's true acceptance rate.
39+
40+
Returns (disabled_now, current_avg_acceptance_rate) where
41+
disabled_now is True only when this call causes permanent disable.
3742
"""
3843
if self.disabled or self.window is None or self.window <= 0 or self.threshold is None:
3944
return False, None
4045

41-
# Extra Guard: if caller passed None, skip updating the rolling stats
42-
if avg_decoded_tokens_per_iter is None:
46+
if acceptance_rate is None:
4347
return False, None
4448

45-
accepted_len = 0.0
46-
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)
49+
acceptance_rate = float(acceptance_rate)
50+
if acceptance_rate < 0.0 or acceptance_rate > 1.0:
51+
raise ValueError("acceptance_rate must be in the range [0.0, 1.0], "
52+
f"got {acceptance_rate}")
4753

48-
# Log per-request completion for debug
49-
if request_id is not None:
50-
logger.debug(
51-
f"[SpeculationGate] Request {request_id} completed: avg_decoded={avg_decoded_tokens_per_iter if avg_decoded_tokens_per_iter is not None else 'None'}, accepted_len={accepted_len:.3f}"
52-
)
54+
if sample_id is not None:
55+
logger.debug(f"[SpeculationGate] Iteration {sample_id} recorded "
56+
f"acceptance_rate={acceptance_rate:.3f}")
5357

5458
# O(1) rolling update
55-
self.acceptance_history.append(accepted_len)
56-
logger.debug(
57-
f"[SpeculationGate] Acceptance history: {self.acceptance_history}")
58-
self.acceptance_sum += accepted_len
59-
if len(self.acceptance_history) > self.window:
60-
removed = self.acceptance_history.popleft()
61-
self.acceptance_sum -= removed
59+
self.acceptance_rate_history.append(acceptance_rate)
60+
logger.debug(f"[SpeculationGate] Acceptance-rate history: "
61+
f"{self.acceptance_rate_history}")
62+
self.acceptance_rate_sum += acceptance_rate
63+
if len(self.acceptance_rate_history) > self.window:
64+
removed = self.acceptance_rate_history.popleft()
65+
self.acceptance_rate_sum -= removed
6266

63-
self.num_completed_for_acceptance += 1
67+
self.num_recorded_samples += 1
6468

65-
if self.num_completed_for_acceptance >= self.window:
66-
avg_accept = self.acceptance_sum / len(self.acceptance_history)
67-
if avg_accept < self.threshold:
69+
if self.num_recorded_samples >= self.window:
70+
avg_acceptance_rate = (self.acceptance_rate_sum /
71+
len(self.acceptance_rate_history))
72+
if avg_acceptance_rate < self.threshold:
6873
self.disabled = True
6974
logger.info(
70-
f"[SpeculationGate] Speculative decoding disabled: rolling acceptance avg {avg_accept:.3f} < threshold {self.threshold} over last {self.window} requests"
71-
)
72-
return True, avg_accept
73-
else:
74-
# speculation is still enabled
75-
return False, avg_accept
75+
"[SpeculationGate] Speculative decoding disabled: "
76+
f"rolling acceptance rate avg "
77+
f"{avg_acceptance_rate:.3f} < threshold "
78+
f"{self.threshold} over last {self.window} iterations")
79+
return True, avg_acceptance_rate
80+
return False, avg_acceptance_rate
7681

7782
return False, None

tensorrt_llm/llmapi/llm_args.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -814,19 +814,23 @@ class DecodingBaseConfig(StrictBaseModel):
814814
load_format: Optional[str] = Field(
815815
default=None, description="The load format of the speculative model.")
816816

817-
acceptance_window: Optional[NonNegativeInt] = Field(
817+
acceptance_rate_window_size: Optional[NonNegativeInt] = Field(
818818
default=None,
819819
description=
820-
"The rolling average window size (N) for acceptance length across completed requests. "
820+
"The rolling average window size (N) for acceptance rate across "
821+
"speculation-enabled decoding iterations. "
821822
"If not set or set to 0, the feature is disabled. PyTorch backend only."
822823
)
823824

824-
acceptance_length_threshold: Optional[NonNegativeFloat] = Field(
825+
acceptance_rate_threshold: Optional[float] = Field(
825826
default=None,
826-
description=
827-
"The threshold for average acceptance length; speculation will be disabled permanently once the "
828-
"rolling average over the last N completed requests (N = acceptance_window) drops below this value. "
829-
"PyTorch backend only.")
827+
ge=0.0,
828+
le=1.0,
829+
description="The threshold for average true acceptance rate "
830+
"(accepted_draft_tokens / drafted_tokens); speculation will be "
831+
"disabled permanently once the rolling average over the last N "
832+
"speculation-enabled decoding iterations "
833+
"(N = acceptance_rate_window_size) drops below this value. ")
830834

831835
allow_advanced_sampling: bool = Field(
832836
default=False,

0 commit comments

Comments
 (0)