@@ -87,6 +87,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
8787 scheduled_requests = []
8888 is_prefill = False
8989 current_num_batched_tokens = 0
90+ current_prefill_extra_blocks = 0
9091
9192 # Process Waiting queue (prefill phase)
9293 while (
@@ -149,6 +150,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
149150 if not self .can_accept_request (
150151 req ,
151152 num_local_computed_tokens ,
153+ current_prefill_extra_blocks ,
152154 ):
153155 logger .warning (
154156 "Insufficient KV cache blocks for request %s, deferring." ,
@@ -216,6 +218,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
216218 ) // self .block_size
217219 continue
218220
221+ current_prefill_extra_blocks += self ._get_prefill_extra_blocks (req )
219222 scheduled_requests .append (req )
220223
221224 num_tokens_this_step = req .get_prompt_length () - req .num_local_cached_tokens
@@ -388,7 +391,10 @@ def complete_requests(self, requests: List[InferenceRequest]):
388391 self .running_queue .sync_q .put (req )
389392
390393 def can_accept_request (
391- self , request : InferenceRequest , num_local_computed_tokens : int
394+ self ,
395+ request : InferenceRequest ,
396+ num_local_computed_tokens : int ,
397+ current_prefill_extra_blocks : int = 0 ,
392398 ) -> bool :
393399 total_required_blocks = 0
394400
@@ -415,9 +421,18 @@ def can_accept_request(
415421 # hold prompt blocks but will also need decode blocks once promoted.
416422 total_required_blocks += self .pending_kv_decode_blocks
417423
424+ # Include decode headroom for requests accepted earlier in this batch.
425+ total_required_blocks += current_prefill_extra_blocks
426+
418427 # Compare with total usable blocks in cache manager
419428 return total_required_blocks <= self .cache_manager .get_total_usable_blocks ()
420429
430+ def _get_prefill_extra_blocks (self , request : InferenceRequest ) -> int :
431+ total_length = request .get_prompt_length ()
432+ total_length += request .sampling_params .max_tokens
433+ total_required_blocks = (total_length + self .block_size - 1 ) // self .block_size
434+ return max (total_required_blocks - len (request .block_table ), 0 )
435+
421436 def update_from_output (self , model_output ):
422437 if self .connector is None or model_output .kv_connector_output is None :
423438 return
0 commit comments