@@ -328,6 +328,7 @@ def __init__(
328328 self .max_beam_width = max_beam_width
329329 self .max_draft_len = max_draft_len
330330 self .max_total_draft_tokens = max_total_draft_tokens
331+ self .use_spec_decode = self .model_engine .enable_spec_decode
331332 self .llm_args = self .model_engine .llm_args
332333 self .max_stats_len = max (self .llm_args .max_stats_len , 1 )
333334 self .max_num_tokens = self .llm_args .max_num_tokens
@@ -352,19 +353,16 @@ def __init__(
352353 self .num_fetch_requests = 0
353354 self .shutdown_event = threading .Event ()
354355
355- # Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
356- spec_config = getattr (self .model_engine , 'spec_config' , None )
357- self .acceptance_window = getattr (
358- spec_config , 'acceptance_window' ,
359- None ) if spec_config is not None else None
360- self .acceptance_length_threshold = getattr (
361- spec_config , 'acceptance_length_threshold' ,
362- None ) if spec_config is not None else None
356+ # Rolling true acceptance-rate tracking for permanent speculation
357+ # disable.
363358 self .speculation_permanently_disabled = False
364359 self .speculation_gate = None
365- if self .acceptance_window and self .acceptance_length_threshold is not None :
366- self .speculation_gate = SpeculationGate (
367- self .acceptance_window , self .acceptance_length_threshold )
360+ spec_config = getattr (self .model_engine , 'spec_config' , None )
361+ if spec_config is not None :
362+ window = getattr (spec_config , 'acceptance_rate_window_size' , None )
363+ threshold = getattr (spec_config , 'acceptance_rate_threshold' , None )
364+ if window and threshold is not None :
365+ self .speculation_gate = SpeculationGate (window , threshold )
368366
369367 # response used data
370368 self .response_lock = threading .Lock ()
@@ -1170,6 +1168,42 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
11701168 draft_latency_ms ) / float (iter_latency_ms )
11711169 return stats
11721170
1171+ def _record_batch_acceptance_rate (
1172+ self ,
1173+ scheduled_batch : ScheduledRequests ,
1174+ sample_state : SampleState ,
1175+ iteration_id : Optional [int ] = None ) -> Tuple [bool , Optional [float ]]:
1176+ if (self .speculation_gate is None
1177+ or self .speculation_permanently_disabled or self .is_warmup ):
1178+ return False , None
1179+
1180+ if (getattr (self .dist , 'has_pp' , False )
1181+ and not self .dist .is_last_pp_rank ):
1182+ return False , None
1183+ new_tokens_lens = getattr (sample_state .host , 'new_tokens_lens' , None )
1184+ if new_tokens_lens is None :
1185+ return False , None
1186+ new_tokens_lens_list = (new_tokens_lens .tolist () if hasattr (
1187+ new_tokens_lens , 'tolist' ) else list (new_tokens_lens ))
1188+ total_draft_tokens = 0
1189+ total_accepted_tokens = 0
1190+ for request in scheduled_batch .generation_requests :
1191+ draft_len = request .num_draft_tokens
1192+ if draft_len <= 0 or request .is_dummy :
1193+ continue
1194+ total_draft_tokens += draft_len
1195+ total_accepted_tokens += request .py_num_accepted_draft_tokens
1196+
1197+ if total_draft_tokens <= 0 :
1198+ return False , None
1199+
1200+ acceptance_rate = total_accepted_tokens / total_draft_tokens
1201+ disabled_now , avg = self .speculation_gate .record_acceptance_rate (
1202+ acceptance_rate , sample_id = iteration_id )
1203+ if disabled_now :
1204+ self .speculation_permanently_disabled = True
1205+ return disabled_now , avg
1206+
11731207 def _append_iter_stats (self ,
11741208 stats : IterationStats ,
11751209 req_stats : Optional [List [RequestStats ]] = None ):
@@ -1664,6 +1698,10 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
16641698 if executed_batch is not None :
16651699 with torch .cuda .nvtx .range ("_handle_executed_batch_pp" ):
16661700 self ._update_requests (executed_batch .sample_state )
1701+ self ._record_batch_acceptance_rate (
1702+ executed_batch .scheduled_requests ,
1703+ executed_batch .sample_state ,
1704+ iteration_id = self .iter_counter )
16671705
16681706 scheduled_requests = executed_batch .scheduled_requests
16691707 if self .kv_cache_transceiver :
@@ -1730,6 +1768,12 @@ def _handle_dynamic_draft_len(self,
17301768 if not hasattr (self .model_engine , 'max_draft_len' ):
17311769 return
17321770
1771+ if self .speculation_permanently_disabled :
1772+ for request in scheduled_batch .generation_requests :
1773+ request .py_draft_tokens = []
1774+ self .model_engine .runtime_draft_len = 0
1775+ return
1776+
17331777 if (self .model_engine .spec_config is not None
17341778 and self .model_engine .spec_config .draft_len_schedule is not None
17351779 and self .model_engine .spec_config .spec_dec_mode .
@@ -1857,7 +1901,6 @@ def _prepare_and_schedule_batch(self):
18571901 # with dummy draft tokens to make the scheduler aware of the fact
18581902 # that speculation is about to happen.
18591903 self ._prepare_draft_requests ()
1860-
18611904 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
18621905 )
18631906
@@ -2069,6 +2112,10 @@ def _executor_loop(self):
20692112
20702113 self ._update_request_states (scheduled_batch )
20712114 self ._update_requests (sample_state , self .resource_manager )
2115+ self ._record_batch_acceptance_rate (
2116+ scheduled_batch ,
2117+ sample_state ,
2118+ iteration_id = self .iter_counter )
20722119
20732120 self ._send_kv_async (scheduled_batch .all_requests ())
20742121
@@ -2340,6 +2387,10 @@ def _executor_loop_overlap(self):
23402387
23412388 if self .previous_batch is not None and should_process_previous_batch :
23422389 self ._update_requests (self .previous_batch .sample_state )
2390+ self ._record_batch_acceptance_rate (
2391+ self .previous_batch .scheduled_requests ,
2392+ self .previous_batch .sample_state ,
2393+ iteration_id = self .iter_counter )
23432394
23442395 self ._send_kv_async (
23452396 self .previous_batch .scheduled_requests .all_requests ())
@@ -3609,31 +3660,6 @@ def _handle_responses(self):
36093660 new_responses .append ((req_id , response ))
36103661
36113662 if request_done :
3612- if (self .drafter is not None and getattr (
3613- self .model_engine , 'enable_spec_decode' , False )
3614- and not self .speculation_permanently_disabled
3615- and not request .is_dummy and not self .is_warmup ):
3616- if self .speculation_gate is not None :
3617- # Response handling runs on multiple PP ranks. Only the last PP rank performs
3618- # sampling; restrict rolling stat updates to it to avoid overcounting.
3619- if (not getattr (self .dist , 'has_pp' ,
3620- False )) or self .dist .is_last_pp_rank :
3621- avg_decoded = getattr (
3622- request , 'avg_decoded_tokens_per_iter' , None )
3623- if avg_decoded is not None :
3624- disabled_now , _ = self .speculation_gate .record_avg_decoded (
3625- avg_decoded ,
3626- request_id = getattr (request , 'py_request_id' ,
3627- None ))
3628- if disabled_now :
3629- # disable speculation permanently
3630- # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
3631- self .speculation_permanently_disabled = True
3632- else :
3633- logger .debug (
3634- f"Request { request .py_request_id } has no avg_decoded_tokens_per_iter"
3635- )
3636-
36373663 # If partial reuse is enabled, and the KV cache manager is not VSWA, and the PP size is 1,
36383664 # then we need to terminate the request. TODO: Remove this once disagg support from KVCache reuse
36393665 # path is fixed.
0 commit comments