Skip to content

Commit 42769da

Browse files
authored
fix: cap SWA/chunked-local runtime admission to startup pool-sizing bound (#1659)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent f48068d commit 42769da

6 files changed

Lines changed: 92 additions & 19 deletions

File tree

aphrodite/v1/core/kv_cache_coordinator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
self,
3535
kv_cache_config: KVCacheConfig,
3636
max_model_len: int,
37+
max_num_batched_tokens: int,
3738
use_eagle: bool,
3839
enable_caching: bool,
3940
enable_kv_cache_events: bool,
@@ -62,6 +63,8 @@ def __init__(
6263
self.single_type_managers = tuple(
6364
get_manager_for_kv_cache_spec(
6465
kv_cache_spec=kv_cache_group.kv_cache_spec,
66+
max_num_batched_tokens=max_num_batched_tokens,
67+
max_model_len=max_model_len,
6568
block_pool=self.block_pool,
6669
enable_caching=enable_caching,
6770
kv_cache_group_id=i,
@@ -258,6 +261,7 @@ def __init__(
258261
self,
259262
kv_cache_config: KVCacheConfig,
260263
max_model_len: int,
264+
max_num_batched_tokens: int,
261265
use_eagle: bool,
262266
enable_kv_cache_events: bool,
263267
dcp_world_size: int,
@@ -268,6 +272,7 @@ def __init__(
268272
super().__init__(
269273
kv_cache_config,
270274
max_model_len,
275+
max_num_batched_tokens,
271276
use_eagle,
272277
False,
273278
enable_kv_cache_events,
@@ -301,6 +306,7 @@ def __init__(
301306
self,
302307
kv_cache_config: KVCacheConfig,
303308
max_model_len: int,
309+
max_num_batched_tokens: int,
304310
use_eagle: bool,
305311
enable_caching: bool,
306312
enable_kv_cache_events: bool,
@@ -312,6 +318,7 @@ def __init__(
312318
super().__init__(
313319
kv_cache_config,
314320
max_model_len,
321+
max_num_batched_tokens,
315322
use_eagle,
316323
enable_caching,
317324
enable_kv_cache_events,
@@ -366,6 +373,7 @@ def __init__(
366373
self,
367374
kv_cache_config: KVCacheConfig,
368375
max_model_len: int,
376+
max_num_batched_tokens: int,
369377
use_eagle: bool,
370378
enable_caching: bool,
371379
enable_kv_cache_events: bool,
@@ -377,6 +385,7 @@ def __init__(
377385
super().__init__(
378386
kv_cache_config,
379387
max_model_len,
388+
max_num_batched_tokens,
380389
use_eagle,
381390
enable_caching,
382391
enable_kv_cache_events,
@@ -542,6 +551,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
542551
def get_kv_cache_coordinator(
543552
kv_cache_config: KVCacheConfig,
544553
max_model_len: int,
554+
max_num_batched_tokens: int,
545555
use_eagle: bool,
546556
enable_caching: bool,
547557
enable_kv_cache_events: bool,
@@ -554,6 +564,7 @@ def get_kv_cache_coordinator(
554564
return KVCacheCoordinatorNoPrefixCache(
555565
kv_cache_config,
556566
max_model_len,
567+
max_num_batched_tokens,
557568
use_eagle,
558569
enable_kv_cache_events,
559570
dcp_world_size=dcp_world_size,
@@ -565,6 +576,7 @@ def get_kv_cache_coordinator(
565576
return UnitaryKVCacheCoordinator(
566577
kv_cache_config,
567578
max_model_len,
579+
max_num_batched_tokens,
568580
use_eagle,
569581
enable_caching,
570582
enable_kv_cache_events,
@@ -576,6 +588,7 @@ def get_kv_cache_coordinator(
576588
return HybridKVCacheCoordinator(
577589
kv_cache_config,
578590
max_model_len,
591+
max_num_batched_tokens,
579592
use_eagle,
580593
enable_caching,
581594
enable_kv_cache_events,

aphrodite/v1/core/kv_cache_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
kv_cache_config: KVCacheConfig,
101101
max_model_len: int,
102102
hash_block_size: int,
103+
max_num_batched_tokens: int | None = None,
103104
enable_caching: bool = True,
104105
use_eagle: bool = False,
105106
log_stats: bool = False,
@@ -109,6 +110,11 @@ def __init__(
109110
metrics_collector: KVCacheMetricsCollector | None = None,
110111
) -> None:
111112
self.max_model_len = max_model_len
113+
# When unset, fall back to `max_model_len` so the recycling-aware cap
114+
# collapses to the prior (uncapped) admission behavior. The scheduler
115+
# always supplies the real value at runtime.
116+
if max_num_batched_tokens is None:
117+
max_num_batched_tokens = max_model_len
112118

113119
self.enable_caching = enable_caching
114120
self.use_eagle = use_eagle
@@ -122,6 +128,7 @@ def __init__(
122128
self.coordinator = get_kv_cache_coordinator(
123129
kv_cache_config=kv_cache_config,
124130
max_model_len=self.max_model_len,
131+
max_num_batched_tokens=max_num_batched_tokens,
125132
use_eagle=self.use_eagle,
126133
enable_caching=self.enable_caching,
127134
enable_kv_cache_events=enable_kv_cache_events,

aphrodite/v1/core/sched/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
self.kv_cache_manager = KVCacheManager(
211211
kv_cache_config=kv_cache_config,
212212
max_model_len=self.max_model_len,
213+
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
213214
enable_caching=self.cache_config.enable_prefix_caching,
214215
use_eagle=self.use_eagle,
215216
log_stats=self.log_stats,

aphrodite/v1/core/single_type_kv_cache_manager.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,20 @@ def __init__(
4141
kv_cache_group_id: int,
4242
dcp_world_size: int = 1,
4343
pcp_world_size: int = 1,
44+
max_admission_blocks_per_request: int | None = None,
4445
) -> None:
4546
"""
4647
Initializes the SingleTypeKVCacheManager.
4748
Args:
4849
kv_cache_spec: The kv_cache_spec for this manager.
4950
block_pool: The block pool.
5051
kv_cache_group_id: The id of the kv cache group of this manager.
52+
max_admission_blocks_per_request: Recycling-aware per-request
53+
block cap used by `get_num_blocks_to_allocate`. Only set for
54+
spec types that recycle blocks across chunks (SWA,
55+
chunked-local); `None` (the default) means no cap, which is
56+
correct for full-attention-style specs that hold every
57+
block until the request finishes.
5158
"""
5259
self.block_size = kv_cache_spec.block_size
5360
self.dcp_world_size = dcp_world_size
@@ -57,6 +64,7 @@ def __init__(
5764
self.kv_cache_spec = kv_cache_spec
5865
self.block_pool = block_pool
5966
self.enable_caching = enable_caching
67+
self._max_admission_blocks_per_request = max_admission_blocks_per_request
6068
self.new_block_ids: list[int] = []
6169

6270
# Mapping from request ID to blocks to track the blocks allocated
@@ -105,6 +113,17 @@ def get_num_blocks_to_allocate(
105113
"""
106114

107115
num_required_blocks = cdiv(num_tokens, self.block_size)
116+
if self._max_admission_blocks_per_request is not None:
117+
# Recycling-aware specs (SWA, chunked-local) cap the per-request
118+
# reservation here so admission matches the startup pool sizer
119+
# (`SlidingWindowSpec.max_admission_blocks_per_request` / its
120+
# chunked-local counterpart). `remove_skipped_blocks` runs from
121+
# `allocate_slots` before each chunk's `get_num_blocks_to_allocate`,
122+
# so per-request peak real-held blocks <= this cap, which keeps
123+
# `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`.
124+
# Drift between the two would re-introduce the deadlock from
125+
# issue #39734 or, worse, mid-prefill OOM.
126+
num_required_blocks = min(num_required_blocks, self._max_admission_blocks_per_request)
108127
num_req_blocks = len(self.req_to_blocks.get(request_id, ()))
109128

110129
if request_id in self.num_cached_block:
@@ -1043,7 +1062,20 @@ def __init__(
10431062
}
10441063

10451064

1046-
def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, **kwargs) -> SingleTypeKVCacheManager:
1065+
def get_manager_for_kv_cache_spec(
1066+
kv_cache_spec: KVCacheSpec,
1067+
max_num_batched_tokens: int,
1068+
max_model_len: int,
1069+
**kwargs,
1070+
) -> SingleTypeKVCacheManager:
10471071
manager_class = spec_manager_map[type(kv_cache_spec)]
1072+
# SlidingWindow / ChunkedLocalAttention managers recycle blocks across
1073+
# chunks; the runtime admission cap must match the recycling-aware bound
1074+
# the startup pool sizer uses (single source of truth: the spec method).
1075+
if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)):
1076+
kwargs["max_admission_blocks_per_request"] = kv_cache_spec.max_admission_blocks_per_request(
1077+
max_num_batched_tokens=max_num_batched_tokens,
1078+
max_model_len=max_model_len,
1079+
)
10481080
manager = manager_class(kv_cache_spec, **kwargs)
10491081
return manager

aphrodite/v1/kv_cache_interface.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -332,17 +332,24 @@ def merge(cls, specs: list[Self]) -> Self:
332332
class ChunkedLocalAttentionSpec(AttentionSpec):
333333
attention_chunk_size: int
334334

335-
def max_memory_usage_bytes(self, aphrodite_config: AphroditeConfig) -> int:
336-
max_model_len = aphrodite_config.model_config.max_model_len
337-
max_num_batched_tokens = aphrodite_config.scheduler_config.max_num_batched_tokens
335+
def max_admission_blocks_per_request(self, max_num_batched_tokens: int, max_model_len: int) -> int:
336+
"""Per-request admission cap, in blocks.
338337
339-
# During chunked prefill, we allocate KV cache for at most
340-
# `self.attention_chunk_size` computed tokens plus the newly scheduled
341-
# tokens. And we won't allocate KV cache for more than `max_model_len`
342-
# tokens.
338+
Single source of truth for both startup pool sizing
339+
(`max_memory_usage_bytes`) and the runtime admission gate, so requests
340+
admitted by startup can also be admitted at runtime.
341+
"""
342+
# During chunked prefill, we hold KV for at most one chunk window.
343343
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, max_model_len)
344+
return cdiv(num_tokens, self.block_size)
344345

345-
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
346+
def max_memory_usage_bytes(self, aphrodite_config: AphroditeConfig) -> int:
347+
max_model_len = aphrodite_config.model_config.max_model_len
348+
max_num_batched_tokens = aphrodite_config.scheduler_config.max_num_batched_tokens
349+
max_blocks = self.max_admission_blocks_per_request(
350+
max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len
351+
)
352+
return max_blocks * self.page_size_bytes
346353

347354

348355
@dataclass(frozen=True, kw_only=True)
@@ -358,22 +365,34 @@ def __post_init__(self):
358365
def real_page_size_bytes(self) -> int:
359366
return self.block_size * self.num_kv_heads * (self.head_size + self.head_size_v) * get_dtype_size(self.dtype)
360367

368+
def max_admission_blocks_per_request(self, max_num_batched_tokens: int, max_model_len: int) -> int:
369+
"""Per-request admission cap, in blocks.
370+
371+
Single source of truth for both startup pool sizing
372+
(`max_memory_usage_bytes`) and the runtime admission gate. Per-request
373+
real-held blocks plateau at this bound because
374+
`SlidingWindowManager.remove_skipped_blocks` runs from `allocate_slots`
375+
before each chunk's `get_num_blocks_to_allocate`.
376+
"""
377+
# During chunked prefill, we hold KV for the last `sliding_window-1`
378+
# computed tokens plus the newly scheduled tokens, and never more
379+
# than `max_model_len`.
380+
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, max_model_len)
381+
# +1 because the sliding window may not start from the beginning of
382+
# the block. E.g. block size 4 and num_token 4 needs two blocks
383+
# [XXCD][EF] to store the 6-token window [CDEF].
384+
return cdiv(num_tokens, self.block_size) + 1
385+
361386
def max_memory_usage_bytes(self, aphrodite_config: AphroditeConfig) -> int:
362387
assert aphrodite_config.parallel_config.decode_context_parallel_size == 1, "DCP not support sliding window."
363388
max_model_len = aphrodite_config.model_config.max_model_len
364389
max_num_batched_tokens = aphrodite_config.scheduler_config.max_num_batched_tokens
365390

366-
# During chunked prefill, we allocate KV cache for the last
367-
# `self.sliding_window-1` computed tokens plus the newly scheduled
368-
# tokens. And we won't allocate KV cache for more than `max_model_len`
369-
# tokens.
370-
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, max_model_len)
391+
max_blocks = self.max_admission_blocks_per_request(
392+
max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len
393+
)
371394

372-
# +1 here because the sliding window may not start from the beginning
373-
# of the block. For example, if the block size is 4 and num_token
374-
# is 4, we need two blocks [XXCD] [EF] to store the sliding
375-
# window [CDEF] of 6 tokens.
376-
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
395+
return max_blocks * self.page_size_bytes
377396

378397

379398
@dataclass(frozen=True, kw_only=True)

aphrodite/v1/simple_kv_offload/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
108108
kv_cache_config=self.cpu_kv_cache_config,
109109
max_model_len=aphrodite_config.model_config.max_model_len,
110+
max_num_batched_tokens=aphrodite_config.scheduler_config.max_num_batched_tokens,
110111
use_eagle=False,
111112
enable_caching=True,
112113
enable_kv_cache_events=self.enable_kv_cache_events,

0 commit comments

Comments
 (0)