diff --git a/aphrodite/v1/core/kv_cache_coordinator.py b/aphrodite/v1/core/kv_cache_coordinator.py index c3434fa56b..c63fe54552 100644 --- a/aphrodite/v1/core/kv_cache_coordinator.py +++ b/aphrodite/v1/core/kv_cache_coordinator.py @@ -34,6 +34,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -62,6 +63,8 @@ def __init__( self.single_type_managers = tuple( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_group.kv_cache_spec, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, block_pool=self.block_pool, enable_caching=enable_caching, kv_cache_group_id=i, @@ -258,6 +261,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_kv_cache_events: bool, dcp_world_size: int, @@ -268,6 +272,7 @@ def __init__( super().__init__( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, False, enable_kv_cache_events, @@ -301,6 +306,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -312,6 +318,7 @@ def __init__( super().__init__( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, @@ -366,6 +373,7 @@ def __init__( self, kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -377,6 +385,7 @@ def __init__( super().__init__( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, @@ -542,6 +551,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, + max_num_batched_tokens: int, use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, @@ -554,6 +564,7 @@ def get_kv_cache_coordinator( return KVCacheCoordinatorNoPrefixCache( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_kv_cache_events, dcp_world_size=dcp_world_size, @@ -565,6 +576,7 @@ def get_kv_cache_coordinator( return UnitaryKVCacheCoordinator( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, @@ -576,6 +588,7 @@ def get_kv_cache_coordinator( return HybridKVCacheCoordinator( kv_cache_config, max_model_len, + max_num_batched_tokens, use_eagle, enable_caching, enable_kv_cache_events, diff --git a/aphrodite/v1/core/kv_cache_manager.py b/aphrodite/v1/core/kv_cache_manager.py index f8d85823e4..1a627c08b0 100644 --- a/aphrodite/v1/core/kv_cache_manager.py +++ b/aphrodite/v1/core/kv_cache_manager.py @@ -100,6 +100,7 @@ def __init__( kv_cache_config: KVCacheConfig, max_model_len: int, hash_block_size: int, + max_num_batched_tokens: int | None = None, enable_caching: bool = True, use_eagle: bool = False, log_stats: bool = False, @@ -109,6 +110,11 @@ def __init__( metrics_collector: KVCacheMetricsCollector | None = None, ) -> None: self.max_model_len = max_model_len + # When unset, fall back to `max_model_len` so the recycling-aware cap + # collapses to the prior (uncapped) admission behavior. The scheduler + # always supplies the real value at runtime. + if max_num_batched_tokens is None: + max_num_batched_tokens = max_model_len self.enable_caching = enable_caching self.use_eagle = use_eagle @@ -122,6 +128,7 @@ def __init__( self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + max_num_batched_tokens=max_num_batched_tokens, use_eagle=self.use_eagle, enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, diff --git a/aphrodite/v1/core/sched/scheduler.py b/aphrodite/v1/core/sched/scheduler.py index 3d6ba7f80c..94a1235902 100644 --- a/aphrodite/v1/core/sched/scheduler.py +++ b/aphrodite/v1/core/sched/scheduler.py @@ -210,6 +210,7 @@ def __init__( self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, enable_caching=self.cache_config.enable_prefix_caching, use_eagle=self.use_eagle, log_stats=self.log_stats, diff --git a/aphrodite/v1/core/single_type_kv_cache_manager.py b/aphrodite/v1/core/single_type_kv_cache_manager.py index 8b89f7bd07..71c8d2e6dd 100644 --- a/aphrodite/v1/core/single_type_kv_cache_manager.py +++ b/aphrodite/v1/core/single_type_kv_cache_manager.py @@ -41,6 +41,7 @@ def __init__( kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_admission_blocks_per_request: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -48,6 +49,12 @@ def __init__( kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. + max_admission_blocks_per_request: Recycling-aware per-request + block cap used by `get_num_blocks_to_allocate`. Only set for + spec types that recycle blocks across chunks (SWA, + chunked-local); `None` (the default) means no cap, which is + correct for full-attention-style specs that hold every + block until the request finishes. """ self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size @@ -57,6 +64,7 @@ def __init__( self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool self.enable_caching = enable_caching + self._max_admission_blocks_per_request = max_admission_blocks_per_request self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -105,6 +113,17 @@ def get_num_blocks_to_allocate( """ num_required_blocks = cdiv(num_tokens, self.block_size) + if self._max_admission_blocks_per_request is not None: + # Recycling-aware specs (SWA, chunked-local) cap the per-request + # reservation here so admission matches the startup pool sizer + # (`SlidingWindowSpec.max_admission_blocks_per_request` / its + # chunked-local counterpart). `remove_skipped_blocks` runs from + # `allocate_slots` before each chunk's `get_num_blocks_to_allocate`, + # so per-request peak real-held blocks <= this cap, which keeps + # `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`. + # Drift between the two would re-introduce the deadlock from + # issue #39734 or, worse, mid-prefill OOM. + num_required_blocks = min(num_required_blocks, self._max_admission_blocks_per_request) num_req_blocks = len(self.req_to_blocks.get(request_id, ())) if request_id in self.num_cached_block: @@ -1043,7 +1062,20 @@ def __init__( } -def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, **kwargs) -> SingleTypeKVCacheManager: +def get_manager_for_kv_cache_spec( + kv_cache_spec: KVCacheSpec, + max_num_batched_tokens: int, + max_model_len: int, + **kwargs, +) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] + # SlidingWindow / ChunkedLocalAttention managers recycle blocks across + # chunks; the runtime admission cap must match the recycling-aware bound + # the startup pool sizer uses (single source of truth: the spec method). + if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)): + kwargs["max_admission_blocks_per_request"] = kv_cache_spec.max_admission_blocks_per_request( + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + ) manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/aphrodite/v1/kv_cache_interface.py b/aphrodite/v1/kv_cache_interface.py index 39bd4c9a36..455d2b6c9c 100644 --- a/aphrodite/v1/kv_cache_interface.py +++ b/aphrodite/v1/kv_cache_interface.py @@ -332,17 +332,24 @@ def merge(cls, specs: list[Self]) -> Self: class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int - def max_memory_usage_bytes(self, aphrodite_config: AphroditeConfig) -> int: - max_model_len = aphrodite_config.model_config.max_model_len - max_num_batched_tokens = aphrodite_config.scheduler_config.max_num_batched_tokens + def max_admission_blocks_per_request(self, max_num_batched_tokens: int, max_model_len: int) -> int: + """Per-request admission cap, in blocks. - # During chunked prefill, we allocate KV cache for at most - # `self.attention_chunk_size` computed tokens plus the newly scheduled - # tokens. And we won't allocate KV cache for more than `max_model_len` - # tokens. + Single source of truth for both startup pool sizing + (`max_memory_usage_bytes`) and the runtime admission gate, so requests + admitted by startup can also be admitted at runtime. + """ + # During chunked prefill, we hold KV for at most one chunk window. num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, max_model_len) + return cdiv(num_tokens, self.block_size) - return cdiv(num_tokens, self.block_size) * self.page_size_bytes + def max_memory_usage_bytes(self, aphrodite_config: AphroditeConfig) -> int: + max_model_len = aphrodite_config.model_config.max_model_len + max_num_batched_tokens = aphrodite_config.scheduler_config.max_num_batched_tokens + max_blocks = self.max_admission_blocks_per_request( + max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len + ) + return max_blocks * self.page_size_bytes @dataclass(frozen=True, kw_only=True) @@ -358,22 +365,34 @@ def __post_init__(self): def real_page_size_bytes(self) -> int: return self.block_size * self.num_kv_heads * (self.head_size + self.head_size_v) * get_dtype_size(self.dtype) + def max_admission_blocks_per_request(self, max_num_batched_tokens: int, max_model_len: int) -> int: + """Per-request admission cap, in blocks. + + Single source of truth for both startup pool sizing + (`max_memory_usage_bytes`) and the runtime admission gate. Per-request + real-held blocks plateau at this bound because + `SlidingWindowManager.remove_skipped_blocks` runs from `allocate_slots` + before each chunk's `get_num_blocks_to_allocate`. + """ + # During chunked prefill, we hold KV for the last `sliding_window-1` + # computed tokens plus the newly scheduled tokens, and never more + # than `max_model_len`. + num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, max_model_len) + # +1 because the sliding window may not start from the beginning of + # the block. E.g. block size 4 and num_token 4 needs two blocks + # [XXCD][EF] to store the 6-token window [CDEF]. + return cdiv(num_tokens, self.block_size) + 1 + def max_memory_usage_bytes(self, aphrodite_config: AphroditeConfig) -> int: assert aphrodite_config.parallel_config.decode_context_parallel_size == 1, "DCP not support sliding window." max_model_len = aphrodite_config.model_config.max_model_len max_num_batched_tokens = aphrodite_config.scheduler_config.max_num_batched_tokens - # During chunked prefill, we allocate KV cache for the last - # `self.sliding_window-1` computed tokens plus the newly scheduled - # tokens. And we won't allocate KV cache for more than `max_model_len` - # tokens. - num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, max_model_len) + max_blocks = self.max_admission_blocks_per_request( + max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len + ) - # +1 here because the sliding window may not start from the beginning - # of the block. For example, if the block size is 4 and num_token - # is 4, we need two blocks [XXCD] [EF] to store the sliding - # window [CDEF] of 6 tokens. - return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes + return max_blocks * self.page_size_bytes @dataclass(frozen=True, kw_only=True) diff --git a/aphrodite/v1/simple_kv_offload/manager.py b/aphrodite/v1/simple_kv_offload/manager.py index 500b68f746..7f42e2dd1e 100644 --- a/aphrodite/v1/simple_kv_offload/manager.py +++ b/aphrodite/v1/simple_kv_offload/manager.py @@ -107,6 +107,7 @@ def __init__( self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator( kv_cache_config=self.cpu_kv_cache_config, max_model_len=aphrodite_config.model_config.max_model_len, + max_num_batched_tokens=aphrodite_config.scheduler_config.max_num_batched_tokens, use_eagle=False, enable_caching=True, enable_kv_cache_events=self.enable_kv_cache_events,