@@ -332,17 +332,24 @@ def merge(cls, specs: list[Self]) -> Self:
332332class ChunkedLocalAttentionSpec (AttentionSpec ):
333333 attention_chunk_size : int
334334
335- def max_memory_usage_bytes (self , aphrodite_config : AphroditeConfig ) -> int :
336- max_model_len = aphrodite_config .model_config .max_model_len
337- max_num_batched_tokens = aphrodite_config .scheduler_config .max_num_batched_tokens
335+ def max_admission_blocks_per_request (self , max_num_batched_tokens : int , max_model_len : int ) -> int :
336+ """Per-request admission cap, in blocks.
338337
339- # During chunked prefill, we allocate KV cache for at most
340- # `self.attention_chunk_size` computed tokens plus the newly scheduled
341- # tokens. And we won't allocate KV cache for more than `max_model_len`
342- # tokens.
338+ Single source of truth for both startup pool sizing
339+ (`max_memory_usage_bytes`) and the runtime admission gate, so requests
340+ admitted by startup can also be admitted at runtime.
341+ """
342+ # During chunked prefill, we hold KV for at most one chunk window.
343343 num_tokens = min (self .attention_chunk_size + max_num_batched_tokens , max_model_len )
344+ return cdiv (num_tokens , self .block_size )
344345
345- return cdiv (num_tokens , self .block_size ) * self .page_size_bytes
346+ def max_memory_usage_bytes (self , aphrodite_config : AphroditeConfig ) -> int :
347+ max_model_len = aphrodite_config .model_config .max_model_len
348+ max_num_batched_tokens = aphrodite_config .scheduler_config .max_num_batched_tokens
349+ max_blocks = self .max_admission_blocks_per_request (
350+ max_num_batched_tokens = max_num_batched_tokens , max_model_len = max_model_len
351+ )
352+ return max_blocks * self .page_size_bytes
346353
347354
348355@dataclass (frozen = True , kw_only = True )
@@ -358,22 +365,34 @@ def __post_init__(self):
358365 def real_page_size_bytes (self ) -> int :
359366 return self .block_size * self .num_kv_heads * (self .head_size + self .head_size_v ) * get_dtype_size (self .dtype )
360367
368+ def max_admission_blocks_per_request (self , max_num_batched_tokens : int , max_model_len : int ) -> int :
369+ """Per-request admission cap, in blocks.
370+
371+ Single source of truth for both startup pool sizing
372+ (`max_memory_usage_bytes`) and the runtime admission gate. Per-request
373+ real-held blocks plateau at this bound because
374+ `SlidingWindowManager.remove_skipped_blocks` runs from `allocate_slots`
375+ before each chunk's `get_num_blocks_to_allocate`.
376+ """
377+ # During chunked prefill, we hold KV for the last `sliding_window-1`
378+ # computed tokens plus the newly scheduled tokens, and never more
379+ # than `max_model_len`.
380+ num_tokens = min (self .sliding_window - 1 + max_num_batched_tokens , max_model_len )
381+ # +1 because the sliding window may not start from the beginning of
382+ # the block. E.g. block size 4 and num_token 4 needs two blocks
383+ # [XXCD][EF] to store the 6-token window [CDEF].
384+ return cdiv (num_tokens , self .block_size ) + 1
385+
361386 def max_memory_usage_bytes (self , aphrodite_config : AphroditeConfig ) -> int :
362387 assert aphrodite_config .parallel_config .decode_context_parallel_size == 1 , "DCP not support sliding window."
363388 max_model_len = aphrodite_config .model_config .max_model_len
364389 max_num_batched_tokens = aphrodite_config .scheduler_config .max_num_batched_tokens
365390
366- # During chunked prefill, we allocate KV cache for the last
367- # `self.sliding_window-1` computed tokens plus the newly scheduled
368- # tokens. And we won't allocate KV cache for more than `max_model_len`
369- # tokens.
370- num_tokens = min (self .sliding_window - 1 + max_num_batched_tokens , max_model_len )
391+ max_blocks = self .max_admission_blocks_per_request (
392+ max_num_batched_tokens = max_num_batched_tokens , max_model_len = max_model_len
393+ )
371394
372- # +1 here because the sliding window may not start from the beginning
373- # of the block. For example, if the block size is 4 and num_token
374- # is 4, we need two blocks [XXCD] [EF] to store the sliding
375- # window [CDEF] of 6 tokens.
376- return (cdiv (num_tokens , self .block_size ) + 1 ) * self .page_size_bytes
395+ return max_blocks * self .page_size_bytes
377396
378397
379398@dataclass (frozen = True , kw_only = True )
0 commit comments