Skip to content

Commit af51fc4

Browse files
authored
[PD Disaggregation] Write the cache of preempted req to storage and refine PD Disaggregation (#7107)
* Write the cache of preempted req to storage * up * fix
1 parent 3651113 commit af51fc4

5 files changed

Lines changed: 35 additions & 19 deletions

File tree

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def read_storage_task(self, task: ReadStorageTask):
796796
try:
797797
valid_gpu_block_ids = self._run_read_storage(
798798
task.task_id,
799-
task.token_ids[: match_block_num * self.block_size],
799+
task.token_ids[: match_block_num * self.block_size] if task.token_ids else None,
800800
task.start_read_block_idx,
801801
k_cache_keys,
802802
v_cache_keys,

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,12 @@ def update_cache_config(self, cache_config):
462462
main_process_metrics.free_gpu_block_num.set(self.num_gpu_blocks)
463463
main_process_metrics.available_gpu_resource.set(1.0)
464464

465-
def can_allocate_gpu_blocks(self, num_blocks: int):
465+
def can_allocate_gpu_blocks(self, num_blocks: int, try_free_gpu_blocks: bool = True):
466466
"""
467467
Check if num_blocks gpu blocks can be allocated.
468468
"""
469469
if len(self.gpu_free_block_list) < num_blocks:
470-
if self.cache_config.enable_prefix_caching:
470+
if self.cache_config.enable_prefix_caching and try_free_gpu_blocks:
471471
self.free_block_ids(num_blocks)
472472
if len(self.gpu_free_block_list) < num_blocks:
473473
return False
@@ -814,7 +814,7 @@ def request_match_blocks(self, task: Request, block_size, *args):
814814
# 2. prepare cpu cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
815815
gpu_recv_block_ids = []
816816
match_cpu_blocks_num = len(match_cpu_block_ids)
817-
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
817+
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num, try_free_gpu_blocks=False):
818818
if match_cpu_blocks_num > 0:
819819
logger.debug(
820820
f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache"
@@ -845,7 +845,7 @@ def request_match_blocks(self, task: Request, block_size, *args):
845845
match_storage_block_ids = []
846846

847847
if self.kvcache_storage_backend and no_match_token_num >= block_size:
848-
if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num):
848+
if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num, try_free_gpu_blocks=False):
849849
raise Exception(
850850
"request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache"
851851
)
@@ -881,7 +881,7 @@ def request_match_blocks(self, task: Request, block_size, *args):
881881
read_storage_task = ReadStorageTask(
882882
task_id=req_id,
883883
keys=no_match_block_keys,
884-
token_ids=input_token_ids,
884+
token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None,
885885
gpu_block_ids=gpu_recv_storage_block_ids,
886886
start_read_block_idx=match_token_num // block_size,
887887
)
@@ -1141,8 +1141,11 @@ def write_cache_to_storage(self, request: Request):
11411141
token_ids = request.prompt_token_ids
11421142
if isinstance(token_ids, np.ndarray):
11431143
token_ids = token_ids.tolist()
1144+
11441145
if self.config.cache_config.enable_output_caching:
1145-
token_ids += request.output_token_ids
1146+
input_token_ids = token_ids + request.output_token_ids
1147+
else:
1148+
input_token_ids = token_ids
11461149

11471150
req_id = request.request_id
11481151
keys = []
@@ -1159,7 +1162,7 @@ def write_cache_to_storage(self, request: Request):
11591162
write_storage_task = WriteStorageTask(
11601163
task_id=req_id,
11611164
keys=keys,
1162-
token_ids=token_ids,
1165+
token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None,
11631166
gpu_block_ids=gpu_block_ids,
11641167
)
11651168
logger.debug(f"issue write storage task: {write_storage_task}")
@@ -1193,16 +1196,18 @@ def write_cache_to_storage_decode(self, request: Request):
11931196
token_ids = list(token_ids)
11941197

11951198
if self.config.cache_config.enable_output_caching:
1196-
token_ids = token_ids + request.output_token_ids
1199+
input_token_ids = token_ids + request.output_token_ids
1200+
else:
1201+
input_token_ids = token_ids
11971202

11981203
# 2. Calculate cache keys using chained hash (consistent with P instance)
11991204
keys = []
12001205
prefix_block_key = [] # Initial is empty list
12011206
block_size = self.config.cache_config.block_size
12021207
mm_idx = 0 # Multimodal index for tracking position in mm_inputs
12031208

1204-
for i in range(0, len(token_ids), block_size):
1205-
block_token_ids = token_ids[i : i + block_size]
1209+
for i in range(0, len(input_token_ids), block_size):
1210+
block_token_ids = input_token_ids[i : i + block_size]
12061211
if len(block_token_ids) < block_size:
12071212
break # Do not cache incomplete block
12081213

@@ -1236,7 +1241,7 @@ def write_cache_to_storage_decode(self, request: Request):
12361241
write_storage_task = WriteStorageTask(
12371242
task_id=req_id,
12381243
keys=keys,
1239-
token_ids=token_ids,
1244+
token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None,
12401245
gpu_block_ids=gpu_block_ids,
12411246
)
12421247

@@ -2166,7 +2171,7 @@ def recv_data_transfer_result(self):
21662171
event_type = data[0]
21672172

21682173
if event_type.value == CacheStatus.STORAGE2GPU.value:
2169-
logger.info(f"recv_data_transfer_result: {data}")
2174+
logger.debug(f"recv_data_transfer_result: {data}")
21702175
task_id, hash_keys, block_ids = data[1:]
21712176
if task_id not in self.storage_prefetch_block_ids:
21722177
self.storage_prefetch_block_ids[task_id] = []
@@ -2177,7 +2182,7 @@ def recv_data_transfer_result(self):
21772182
if task_id in self.task_prefetch_event:
21782183
self.task_prefetch_event[task_id].set()
21792184
elif event_type.value == CacheStatus.GPU2STORAGE.value:
2180-
logger.info(f"recv_data_transfer_result: {data}")
2185+
logger.debug(f"recv_data_transfer_result: {data}")
21812186
task_id, hash_keys, block_ids = data[1:]
21822187
if task_id in self.task_write_back_event:
21832188
self.task_write_back_event[task_id].set()

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def _fetch_request():
910910
self.split_connector.send_splitwise_tasks([task], task.idx)
911911
status, msg = self.split_connector.check_decode_allocated(task)
912912
if not status:
913-
self.llm_logger.error(
913+
self.llm_logger.warning(
914914
f"D failed to allocate resource for request {task.request_id}, try again."
915915
)
916916
time.sleep(0.05)

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,15 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
367367
del self.requests[preempted_req.request_id]
368368
if preempted_req.request_id in self.req_dict:
369369
del self.req_dict[preempted_req.request_id]
370+
if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST:
371+
if self.config.cache_config.kvcache_storage_backend:
372+
self.cache_manager.write_cache_to_storage_decode(preempted_req)
370373
self._free_blocks(preempted_req)
371374
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
372375
else:
376+
if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST:
377+
if self.config.cache_config.kvcache_storage_backend:
378+
self.cache_manager.write_cache_to_storage(preempted_req)
373379
self._free_blocks(preempted_req)
374380
preempted_req.num_cached_blocks = 0
375381
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
@@ -399,7 +405,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
399405
self.can_relax_prefill_strategy = False
400406
return can_schedule
401407

402-
def _get_can_schedule_prefill_threshold_block(self, request, num_chunk_new_block):
408+
def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block):
403409
if self.can_relax_prefill_strategy:
404410
can_schedule_block_num_threshold = num_chunk_new_block
405411
else:
@@ -987,7 +993,7 @@ def _allocate_decode_and_extend():
987993
continue
988994
num_new_block = self.get_new_block_nums(request, num_new_tokens)
989995
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
990-
request, num_new_block
996+
num_new_block
991997
)
992998
# Allocate blocks to prefill
993999
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
@@ -1052,7 +1058,7 @@ def _allocate_decode_and_extend():
10521058
continue
10531059
num_new_block = self.get_new_block_nums(request, num_new_tokens)
10541060
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
1055-
request, num_new_block
1061+
num_new_block
10561062
)
10571063
# Allocate blocks to prefill
10581064
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
@@ -1392,7 +1398,8 @@ def preallocate_resource_in_d(self, request: Request):
13921398
return False
13931399
if self.available_batch() == 0:
13941400
return False
1395-
if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
1401+
total_need_blocks = self._get_can_schedule_prefill_threshold_block(need_prealloc_prefill_blocks)
1402+
if not self.cache_manager.can_allocate_gpu_blocks(total_need_blocks):
13961403
return False
13971404

13981405
request.block_tables = self.cache_manager.allocate_gpu_blocks(

fastdeploy/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ def _validate_split_kv_size(value: int) -> int:
252252
# When v1 is enabled, the legacy /clear_load_weight and /update_model_weight
253253
# will adopt this new communication pattern.
254254
"FD_ENABLE_V1_UPDATE_WEIGHTS": lambda: bool(int(os.getenv("FD_ENABLE_V1_UPDATE_WEIGHTS", "0"))),
255+
# Whether to save the cache of output token for preempted request to storage.
256+
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
257+
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
258+
),
255259
}
256260

257261

0 commit comments

Comments
 (0)