Skip to content

Commit 0927708

Browse files
authored
feat: support kv offload storing with multiple KV groups (#1644)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 3b0472d commit 0927708

1 file changed

Lines changed: 78 additions & 44 deletions

File tree

  • aphrodite/distributed/kv_transfer/kv_connector/v1/offloading

aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py

Lines changed: 78 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def update_block_id_groups(self, new_block_id_groups: tuple[list[int], ...] | No
101101
for group_state, new_blocks in zip(self.group_states, new_block_id_groups):
102102
group_state.block_ids.extend(new_blocks)
103103

104+
def advance_stored_idx(self, num_offloadable_tokens: int) -> None:
105+
for group_config, group_state in zip(self.config.kv_group_configs, self.group_states):
106+
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
107+
group_state.next_stored_block_idx = num_blocks
108+
104109

105110
class OffloadingConnectorScheduler:
106111
"""Implementation of Scheduler side methods"""
@@ -333,16 +338,14 @@ def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_
333338
if self._blocks_being_loaded is not None:
334339
self._blocks_being_loaded.update(req_blocks_being_loaded)
335340

336-
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
337-
# Below assertion will be removed once this function supports HMA
338-
assert len(self.config.kv_group_configs) == 1
339-
group_config = self.config.kv_group_configs[0]
340-
341+
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput) -> dict[ReqId, TransferSpec]:
342+
block_size_factor = self.config.block_size_factor
341343
reqs_to_store: dict[ReqId, TransferSpec] = {}
342344
# iterate over both new and cached requests
343345
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
344346
req_status = self._req_status[req_id]
345347
req_status.update_offload_keys()
348+
req = req_status.req
346349

347350
if preempted:
348351
for group_state in req_status.group_states:
@@ -351,64 +354,95 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
351354
if new_block_id_groups:
352355
req_status.update_block_id_groups(new_block_id_groups)
353356

354-
# Below assertion will be removed once this function supports HMA
355-
assert len(req_status.group_states) == 1
356-
group_state = req_status.group_states[0]
357-
358-
block_ids = group_state.block_ids
359-
360-
req = req_status.req
361-
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
362-
expected_tokens = req.num_computed_tokens + new_tokens
357+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
358+
num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
363359
# with async scheduling, some tokens may be missing
364-
total_tokens = min(expected_tokens, req.num_tokens)
365-
num_blocks = total_tokens // group_config.offloaded_block_size
366-
start_block_idx = group_state.next_stored_block_idx
367-
num_new_blocks = num_blocks - start_block_idx
368-
369-
if num_new_blocks <= 0:
360+
num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens)
361+
362+
# Filter out blocks skipped due to sliding window attention / SSM
363+
new_offload_keys: list[OffloadKey] = []
364+
for group_config, group_state in zip(self.config.kv_group_configs, req_status.group_states):
365+
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
366+
start_block_idx = group_state.next_stored_block_idx
367+
if num_blocks <= start_block_idx:
368+
continue
369+
offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
370+
# For each block to offload, take the last corresponding GPU block.
371+
# e.g. if block size factor is 3 and GPU block IDs are
372+
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
373+
# We will use these GPU blocks to determine if the block needs
374+
# offloading, or (if the GPU block ID is 0) this block should
375+
# be skipped due to sliding window attention / SSM.
376+
# We know that if a block is skipped, then all the previous blocks
377+
# are skipped as well. This is why we take the last of each block.
378+
offload_block_ids = group_state.block_ids[
379+
start_block_idx * block_size_factor + block_size_factor - 1 : num_blocks
380+
* block_size_factor : block_size_factor
381+
]
382+
assert len(offload_keys) == len(offload_block_ids)
383+
384+
for offload_key, block_id in zip(offload_keys, offload_block_ids):
385+
if block_id != 0:
386+
new_offload_keys.append(offload_key)
387+
388+
if not new_offload_keys:
389+
req_status.advance_stored_idx(num_offloadable_tokens)
370390
continue
371391

372-
num_gpu_blocks = num_blocks * self.config.block_size_factor
373-
assert len(req.block_hashes) >= num_gpu_blocks
374-
375-
new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
376392
store_output = self.manager.prepare_store(new_offload_keys, req_status.req_context)
377393
if store_output is None:
378-
logger.warning("Request %s: cannot store %s blocks", req_id, num_new_blocks)
394+
logger.warning("Request %s: cannot store blocks", req_id)
379395
continue
380396

381-
group_state.next_stored_block_idx = num_blocks
382-
383397
if not store_output.keys_to_store:
398+
req_status.advance_stored_idx(num_offloadable_tokens)
384399
continue
385-
keys_to_store = set(store_output.keys_to_store)
386400

387-
self.manager.touch(group_state.offload_keys[:num_blocks])
401+
for group_state in req_status.group_states:
402+
self.manager.touch(group_state.offload_keys)
388403

389-
dst_spec = store_output.store_spec
404+
keys_to_store = set(store_output.keys_to_store)
405+
406+
group_sizes: list[int] = []
407+
block_indices: list[int] = []
390408
src_block_ids: list[int] = []
391-
for idx, key in enumerate(new_offload_keys):
392-
if key not in keys_to_store:
393-
continue
394-
offloaded_block_idx = start_block_idx + idx
395-
gpu_block_idx = offloaded_block_idx * self.config.block_size_factor
396-
for i in range(self.config.block_size_factor):
397-
src_block_ids.append(block_ids[gpu_block_idx + i])
398-
src_spec = GPULoadStoreSpec(
399-
src_block_ids,
400-
group_sizes=(len(src_block_ids),),
401-
block_indices=(0,),
402-
)
409+
for group_config, group_state in zip(self.config.kv_group_configs, req_status.group_states):
410+
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
411+
start_block_idx = group_state.next_stored_block_idx
412+
block_ids = group_state.block_ids
413+
num_group_blocks = 0
414+
start_gpu_block_idx: int | None = None
415+
for idx, offload_key in enumerate(group_state.offload_keys[start_block_idx:num_blocks]):
416+
if offload_key not in keys_to_store:
417+
continue
418+
419+
offloaded_block_idx = start_block_idx + idx
420+
gpu_block_idx = offloaded_block_idx * block_size_factor
421+
num_group_blocks += block_size_factor
422+
for i in range(block_size_factor):
423+
block_id = block_ids[gpu_block_idx + i]
424+
if block_id == 0:
425+
# skipped blocks cannot appear after non-skipped blocks
426+
assert start_gpu_block_idx is None
427+
continue
428+
elif start_gpu_block_idx is None:
429+
start_gpu_block_idx = gpu_block_idx + i
430+
src_block_ids.append(block_id)
431+
group_sizes.append(num_group_blocks)
432+
block_indices.append(start_gpu_block_idx or 0)
433+
group_state.next_stored_block_idx = num_blocks
434+
435+
src_spec = GPULoadStoreSpec(src_block_ids, group_sizes=group_sizes, block_indices=block_indices)
436+
dst_spec = store_output.store_spec
403437

404438
reqs_to_store[req_id] = (src_spec, dst_spec)
405439
self._reqs_being_stored[req_id] |= keys_to_store
406440

407441
logger.debug(
408-
"Request %s offloading %s blocks starting from block #%d",
442+
"Request %s offloading %s blocks upto %d tokens",
409443
req_id,
410444
len(keys_to_store),
411-
start_block_idx,
445+
num_offloadable_tokens,
412446
)
413447

414448
return reqs_to_store

0 commit comments

Comments
 (0)