@@ -328,6 +328,7 @@ def __init__(
328328 self .max_beam_width = max_beam_width
329329 self .max_draft_len = max_draft_len
330330 self .max_total_draft_tokens = max_total_draft_tokens
331+ self .use_spec_decode = self .model_engine .enable_spec_decode
331332 self .llm_args = self .model_engine .llm_args
332333 self .max_stats_len = max (self .llm_args .max_stats_len , 1 )
333334 self .max_num_tokens = self .llm_args .max_num_tokens
@@ -352,19 +353,16 @@ def __init__(
352353 self .num_fetch_requests = 0
353354 self .shutdown_event = threading .Event ()
354355
355- # Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
356- spec_config = getattr (self .model_engine , 'spec_config' , None )
357- self .acceptance_window = getattr (
358- spec_config , 'acceptance_window' ,
359- None ) if spec_config is not None else None
360- self .acceptance_length_threshold = getattr (
361- spec_config , 'acceptance_length_threshold' ,
362- None ) if spec_config is not None else None
356+ # Rolling true acceptance-rate tracking for permanent speculation
357+ # disable.
363358 self .speculation_permanently_disabled = False
364359 self .speculation_gate = None
365- if self .acceptance_window and self .acceptance_length_threshold is not None :
366- self .speculation_gate = SpeculationGate (
367- self .acceptance_window , self .acceptance_length_threshold )
360+ spec_config = getattr (self .model_engine , 'spec_config' , None )
361+ if spec_config is not None :
362+ window = getattr (spec_config , 'acceptance_rate_window_size' , None )
363+ threshold = getattr (spec_config , 'acceptance_rate_threshold' , None )
364+ if window and threshold is not None :
365+ self .speculation_gate = SpeculationGate (window , threshold )
368366
369367 # response used data
370368 self .response_lock = threading .Lock ()
@@ -1185,6 +1183,42 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
11851183 draft_latency_ms ) / float (iter_latency_ms )
11861184 return stats
11871185
1186+ def _record_batch_acceptance_rate (
1187+ self ,
1188+ scheduled_batch : ScheduledRequests ,
1189+ sample_state : SampleState ,
1190+ iteration_id : Optional [int ] = None ) -> Tuple [bool , Optional [float ]]:
1191+ if (self .speculation_gate is None
1192+ or self .speculation_permanently_disabled or self .is_warmup ):
1193+ return False , None
1194+
1195+ if (getattr (self .dist , 'has_pp' , False )
1196+ and not self .dist .is_last_pp_rank ):
1197+ return False , None
1198+ new_tokens_lens = getattr (sample_state .host , 'new_tokens_lens' , None )
1199+ if new_tokens_lens is None :
1200+ return False , None
1201+ new_tokens_lens_list = (new_tokens_lens .tolist () if hasattr (
1202+ new_tokens_lens , 'tolist' ) else list (new_tokens_lens ))
1203+ total_draft_tokens = 0
1204+ total_accepted_tokens = 0
1205+ for request in scheduled_batch .generation_requests :
1206+ draft_len = request .num_draft_tokens
1207+ if draft_len <= 0 or request .is_dummy :
1208+ continue
1209+ total_draft_tokens += draft_len
1210+ total_accepted_tokens += request .py_num_accepted_draft_tokens
1211+
1212+ if total_draft_tokens <= 0 :
1213+ return False , None
1214+
1215+ acceptance_rate = total_accepted_tokens / total_draft_tokens
1216+ disabled_now , avg = self .speculation_gate .record_acceptance_rate (
1217+ acceptance_rate , sample_id = iteration_id )
1218+ if disabled_now :
1219+ self .speculation_permanently_disabled = True
1220+ return disabled_now , avg
1221+
11881222 def _append_iter_stats (self ,
11891223 stats : IterationStats ,
11901224 req_stats : Optional [List [RequestStats ]] = None ):
@@ -1679,6 +1713,10 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
16791713 if executed_batch is not None :
16801714 with torch .cuda .nvtx .range ("_handle_executed_batch_pp" ):
16811715 self ._update_requests (executed_batch .sample_state )
1716+ self ._record_batch_acceptance_rate (
1717+ executed_batch .scheduled_requests ,
1718+ executed_batch .sample_state ,
1719+ iteration_id = self .iter_counter )
16821720
16831721 scheduled_requests = executed_batch .scheduled_requests
16841722 if self .kv_cache_transceiver :
@@ -1745,6 +1783,12 @@ def _handle_dynamic_draft_len(self,
17451783 if not hasattr (self .model_engine , 'max_draft_len' ):
17461784 return
17471785
1786+ if self .speculation_permanently_disabled :
1787+ for request in scheduled_batch .generation_requests :
1788+ request .py_draft_tokens = []
1789+ self .model_engine .runtime_draft_len = 0
1790+ return
1791+
17481792 if (self .model_engine .spec_config is not None
17491793 and self .model_engine .spec_config .draft_len_schedule is not None
17501794 and self .model_engine .spec_config .spec_dec_mode .
@@ -1846,7 +1890,6 @@ def _prepare_and_schedule_batch(self):
18461890 # with dummy draft tokens to make the scheduler aware of the fact
18471891 # that speculation is about to happen.
18481892 self ._prepare_draft_requests ()
1849-
18501893 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
18511894 )
18521895
@@ -2134,6 +2177,10 @@ def _executor_loop(self):
21342177
21352178 self ._update_request_states (scheduled_batch )
21362179 self ._update_requests (sample_state , self .resource_manager )
2180+ self ._record_batch_acceptance_rate (
2181+ scheduled_batch ,
2182+ sample_state ,
2183+ iteration_id = self .iter_counter )
21372184
21382185 self ._send_kv_async (scheduled_batch .all_requests ())
21392186
@@ -2376,6 +2423,10 @@ def _executor_loop_overlap(self):
23762423
23772424 if self .previous_batch is not None and should_process_previous_batch :
23782425 self ._update_requests (self .previous_batch .sample_state )
2426+ self ._record_batch_acceptance_rate (
2427+ self .previous_batch .scheduled_requests ,
2428+ self .previous_batch .sample_state ,
2429+ iteration_id = self .iter_counter )
23792430
23802431 self ._send_kv_async (
23812432 self .previous_batch .scheduled_requests .all_requests ())
@@ -3685,31 +3736,6 @@ def _handle_responses(self):
36853736 new_responses .append ((req_id , response ))
36863737
36873738 if request_done :
3688- if (self .drafter is not None and getattr (
3689- self .model_engine , 'enable_spec_decode' , False )
3690- and not self .speculation_permanently_disabled
3691- and not request .is_dummy and not self .is_warmup ):
3692- if self .speculation_gate is not None :
3693- # Response handling runs on multiple PP ranks. Only the last PP rank performs
3694- # sampling; restrict rolling stat updates to it to avoid overcounting.
3695- if (not getattr (self .dist , 'has_pp' ,
3696- False )) or self .dist .is_last_pp_rank :
3697- avg_decoded = getattr (
3698- request , 'avg_decoded_tokens_per_iter' , None )
3699- if avg_decoded is not None :
3700- disabled_now , _ = self .speculation_gate .record_avg_decoded (
3701- avg_decoded ,
3702- request_id = getattr (request , 'py_request_id' ,
3703- None ))
3704- if disabled_now :
3705- # disable speculation permanently
3706- # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
3707- self .speculation_permanently_disabled = True
3708- else :
3709- logger .debug (
3710- f"Request { request .py_request_id } has no avg_decoded_tokens_per_iter"
3711- )
3712-
37133739 # If partial reuse is enabled, and the KV cache manager is not VSWA, and the PP size is 1,
37143740 # then we need to terminate the request. TODO: Remove this once disagg support from KVCache reuse
37153741 # path is fixed.
0 commit comments