Skip to content

Commit dc37614

Browse files
committed
issue/340 - fix(scheduler): account for current prefill batch block reservations
1 parent f38eb57 commit dc37614

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

python/infinilm/llm/scheduler.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
8787
scheduled_requests = []
8888
is_prefill = False
8989
current_num_batched_tokens = 0
90+
current_prefill_extra_blocks = 0
9091

9192
# Process Waiting queue (prefill phase)
9293
while (
@@ -149,6 +150,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
149150
if not self.can_accept_request(
150151
req,
151152
num_local_computed_tokens,
153+
current_prefill_extra_blocks,
152154
):
153155
logger.warning(
154156
"Insufficient KV cache blocks for request %s, deferring.",
@@ -216,6 +218,7 @@ def schedule(self) -> Optional[SchedulerOutput]:
216218
) // self.block_size
217219
continue
218220

221+
current_prefill_extra_blocks += self._get_prefill_extra_blocks(req)
219222
scheduled_requests.append(req)
220223

221224
num_tokens_this_step = req.get_prompt_length() - req.num_local_cached_tokens
@@ -388,7 +391,10 @@ def complete_requests(self, requests: List[InferenceRequest]):
388391
self.running_queue.sync_q.put(req)
389392

390393
def can_accept_request(
391-
self, request: InferenceRequest, num_local_computed_tokens: int
394+
self,
395+
request: InferenceRequest,
396+
num_local_computed_tokens: int,
397+
current_prefill_extra_blocks: int = 0,
392398
) -> bool:
393399
total_required_blocks = 0
394400

@@ -415,9 +421,18 @@ def can_accept_request(
415421
# hold prompt blocks but will also need decode blocks once promoted.
416422
total_required_blocks += self.pending_kv_decode_blocks
417423

424+
# Include decode headroom for requests accepted earlier in this batch.
425+
total_required_blocks += current_prefill_extra_blocks
426+
418427
# Compare with total usable blocks in cache manager
419428
return total_required_blocks <= self.cache_manager.get_total_usable_blocks()
420429

430+
def _get_prefill_extra_blocks(self, request: InferenceRequest) -> int:
431+
total_length = request.get_prompt_length()
432+
total_length += request.sampling_params.max_tokens
433+
total_required_blocks = (total_length + self.block_size - 1) // self.block_size
434+
return max(total_required_blocks - len(request.block_table), 0)
435+
421436
def update_from_output(self, model_output):
422437
if self.connector is None or model_output.kv_connector_output is None:
423438
return

0 commit comments

Comments
 (0)