Skip to content

Commit 5d18984

Browse files
authored
fix(kvcache): buffer early layer0 signals (#7896)
1 parent e7815be commit 5d18984

2 files changed

Lines changed: 213 additions & 24 deletions

File tree

fastdeploy/cache_manager/cache_messager.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,8 @@ def __init__(
620620
dict() for _ in range(512)
621621
] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}}
622622
self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict}
623+
self.pending_layer0_signals = {}
624+
self.pending_layer0_signal_lock = threading.Lock()
623625
self.cache_prefilled_engine_ids_queue = (
624626
queue.Queue()
625627
) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)]
@@ -663,7 +665,28 @@ def _add_cache_task_thread(self):
663665
current_info["status"] = "init"
664666
logger.info(f"Get cache info and finish add cache task: {current_info}")
665667
self.cache_info[info["request_id"]] = current_info
666-
self.idx_cache_task_dict[current_info["current_id"]] = current_info
668+
current_id = current_info["current_id"]
669+
with self.engine_cache_task_thread_lock:
670+
self.idx_cache_task_dict[current_id] = current_info
671+
with self.pending_layer0_signal_lock:
672+
recovered_signal = self.pending_layer0_signals.pop(current_id, None)
673+
if recovered_signal is not None:
674+
_, prefilled_token_num = recovered_signal
675+
if prefilled_token_num <= current_info["need_prefill_tokens"]:
676+
recovered_signal_batch = [recovered_signal]
677+
logger.info(
678+
"cache_task_register_recover_layer0_signal: "
679+
f"current_id: {current_id}, "
680+
f"recovered_signal_batch: {recovered_signal_batch}"
681+
)
682+
self.cache_prefilled_engine_ids_queue.put(recovered_signal_batch)
683+
else:
684+
logger.info(
685+
"cache_task_register_drop_layer0_signal: "
686+
f"current_id: {current_id}, "
687+
f"recovered_signal: {recovered_signal}, "
688+
f"need_prefill_tokens: {current_info['need_prefill_tokens']}"
689+
)
667690
else:
668691
logger.info(f"Get cache info: {info}")
669692
self.cache_info[info["request_id"]] = info
@@ -842,9 +865,12 @@ def prefill_layerwise_send_cache_thread(self):
842865
logger.info(
843866
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
844867
)
845-
self.engine_cache_tasks[task["current_id"]] = dict()
868+
current_id = task["current_id"]
869+
self.engine_cache_tasks[current_id] = dict()
846870
del self.cache_info[task["request_id"]]
847-
del self.idx_cache_task_dict[task["current_id"]]
871+
del self.idx_cache_task_dict[current_id]
872+
with self.pending_layer0_signal_lock:
873+
self.pending_layer0_signals.pop(current_id, None)
848874
break
849875
except Exception as e:
850876
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
@@ -856,32 +882,42 @@ def consume_signals(self):
856882
while True:
857883
try:
858884
get_output_kv_signal(kv_signal_data, self.rank_id, 1) # wait_flag
859-
if not self.cache_info:
860-
time.sleep(0.01)
861-
continue
862-
tasks_count = kv_signal_data[0]
885+
has_cache_info = bool(self.cache_info)
886+
tasks_count = kv_signal_data[0].item()
863887
if tasks_count == -1:
864888
continue
889+
if not has_cache_info:
890+
logger.debug("consume_signals get kv signal before cache info is ready")
865891
layer_id = kv_signal_data[1].item()
866892
if layer_id == self.num_layers - 1:
867893
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
868-
batch_engine_signals = []
894+
ready_engine_signals = []
895+
pending_engine_signals = []
869896
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
870897
with self.engine_cache_task_thread_lock:
871898
for bi in range(tasks_count):
872899
engine_idx = kv_signal_data[3 * bi + 2].item()
873900
chuck_token_offset = kv_signal_data[3 * bi + 3].item()
874901
current_seq_len = kv_signal_data[3 * bi + 4].item()
902+
prefilled_token_num = chuck_token_offset + current_seq_len
875903
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
876-
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
877-
chuck_token_offset + current_seq_len
878-
)
879-
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
880-
if layer_id == 0:
881-
logger.info(
882-
f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue"
883-
)
884-
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
904+
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = prefilled_token_num
905+
if layer_id == 0:
906+
if engine_idx in self.idx_cache_task_dict:
907+
ready_engine_signals.append((engine_idx, prefilled_token_num))
908+
else:
909+
pending_engine_signals.append((engine_idx, prefilled_token_num))
910+
if pending_engine_signals:
911+
with self.pending_layer0_signal_lock:
912+
for engine_idx, prefilled_token_num in pending_engine_signals:
913+
self.pending_layer0_signals[engine_idx] = (engine_idx, prefilled_token_num)
914+
if pending_engine_signals:
915+
logger.debug(f"cache_task_pending_layer0_signal: {pending_engine_signals}")
916+
if ready_engine_signals:
917+
logger.info(
918+
f"Put batch_engine_signals {ready_engine_signals} into cache_prefilled_engine_ids_queue"
919+
)
920+
self.cache_prefilled_engine_ids_queue.put(ready_engine_signals)
885921
except Exception as e:
886922
logger.error(f"Consume signals get exception: {e}")
887923

tests/cache_manager/test_cache_messager.py

Lines changed: 160 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ def error(self, msg):
124124
self.messages.append(("error", msg))
125125

126126

127+
class _QueueRecorder:
128+
def __init__(self):
129+
self.items = []
130+
131+
def put(self, item):
132+
self.items.append(item)
133+
134+
127135
class _DummySignalValue:
128136
def __init__(self, sequence):
129137
self.sequence = list(sequence)
@@ -390,6 +398,111 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch):
390398
assert messager.cache_info["req-2"]["status"] == "init"
391399

392400

401+
def test_cache_messager_v1_recovers_pending_layer0_signal(monkeypatch):
402+
dummy_queue = _DummyEngineWorkerQueue(
403+
cache_info_sequence=[
404+
[
405+
{
406+
"request_id": "req-pending",
407+
"src_block_ids": [0, 1],
408+
"dest_block_ids": [2],
409+
"current_id": 3,
410+
"need_prefill_tokens": 128,
411+
"transfer_protocol": "rdma",
412+
}
413+
]
414+
]
415+
)
416+
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
417+
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
418+
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
419+
420+
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
421+
messager = cache_messager.CacheMessagerV1(
422+
splitwise_role="mixed",
423+
transfer_protocol="rdma",
424+
pod_ip="0.0.0.0",
425+
engine_worker_queue_port=9000,
426+
local_data_parallel_id=0,
427+
gpu_cache_kvs=gpu_cache_kvs,
428+
rank=0,
429+
nranks=1,
430+
num_layers=1,
431+
gpu_id=0,
432+
block_size=64,
433+
rdma_port="2222",
434+
)
435+
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
436+
messager.cache_info["req-pending"] = {
437+
"request_id": "req-pending",
438+
"src_block_ids": [0, 1],
439+
"dest_block_ids": [2],
440+
"current_id": 3,
441+
"need_prefill_tokens": 128,
442+
"transfer_protocol": "rdma",
443+
}
444+
messager.pending_layer0_signals[3] = (3, 64)
445+
messager.pending_layer0_signals[4] = (4, 64)
446+
447+
with pytest.raises(SystemExit):
448+
messager._add_cache_task_thread()
449+
450+
assert messager.pending_layer0_signals == {4: (4, 64)}
451+
assert messager.cache_prefilled_engine_ids_queue.items == [[(3, 64)]]
452+
453+
454+
def test_cache_messager_v1_drops_invalid_pending_layer0_signal(monkeypatch):
455+
dummy_queue = _DummyEngineWorkerQueue(
456+
cache_info_sequence=[
457+
[
458+
{
459+
"request_id": "req-pending",
460+
"src_block_ids": [0, 1],
461+
"dest_block_ids": [2],
462+
"current_id": 3,
463+
"need_prefill_tokens": 128,
464+
"transfer_protocol": "rdma",
465+
}
466+
]
467+
]
468+
)
469+
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
470+
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
471+
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
472+
473+
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
474+
messager = cache_messager.CacheMessagerV1(
475+
splitwise_role="mixed",
476+
transfer_protocol="rdma",
477+
pod_ip="0.0.0.0",
478+
engine_worker_queue_port=9000,
479+
local_data_parallel_id=0,
480+
gpu_cache_kvs=gpu_cache_kvs,
481+
rank=0,
482+
nranks=1,
483+
num_layers=1,
484+
gpu_id=0,
485+
block_size=64,
486+
rdma_port="2222",
487+
)
488+
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
489+
messager.cache_info["req-pending"] = {
490+
"request_id": "req-pending",
491+
"src_block_ids": [0, 1],
492+
"dest_block_ids": [2],
493+
"current_id": 3,
494+
"need_prefill_tokens": 128,
495+
"transfer_protocol": "rdma",
496+
}
497+
messager.pending_layer0_signals[3] = (3, 256)
498+
499+
with pytest.raises(SystemExit):
500+
messager._add_cache_task_thread()
501+
502+
assert messager.pending_layer0_signals == {}
503+
assert messager.cache_prefilled_engine_ids_queue.items == []
504+
505+
393506
def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch):
394507
class _OneShotQueue:
395508
def __init__(self):
@@ -435,10 +548,12 @@ def get(self):
435548
}
436549
messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64}
437550
messager.cache_info["req-3"] = messager.idx_cache_task_dict[0]
551+
messager.pending_layer0_signals = {0: (0, 64), 1: (1, 64)}
438552
with pytest.raises(SystemExit):
439553
messager.prefill_layerwise_send_cache_thread()
440554
assert dummy_queue.finished_req_payloads
441555
assert dummy_queue.finished_req_payloads[0][0][0] == "req-3"
556+
assert messager.pending_layer0_signals == {1: (1, 64)}
442557

443558

444559
def test_cache_messager_v1_handle_connect_task(monkeypatch):
@@ -562,13 +677,6 @@ def test_cache_messager_v1_consume_signals(monkeypatch):
562677
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
563678
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
564679

565-
class _QueueRecorder:
566-
def __init__(self):
567-
self.items = []
568-
569-
def put(self, item):
570-
self.items.append(item)
571-
572680
counter = {"calls": 0}
573681

574682
def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
@@ -600,12 +708,57 @@ def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
600708
rdma_port="2222",
601709
)
602710
messager.cache_info["req-4"] = {"request_id": "req-4"}
711+
messager.idx_cache_task_dict[2] = {"request_id": "req-4", "current_id": 2}
603712
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
604713
with pytest.raises(SystemExit):
605714
messager.consume_signals()
606715
assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]]
607716

608717

718+
def test_cache_messager_v1_consume_signals_buffers_early_layer0(monkeypatch):
719+
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
720+
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
721+
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
722+
723+
signals = [(5, 7, 9), (5, 17, 19)]
724+
725+
def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
726+
if not signals:
727+
raise SystemExit
728+
engine_idx, chuck_token_offset, current_seq_len = signals.pop(0)
729+
data = np.full(kv_signal_data.shape, -1, dtype="int32")
730+
data[0] = 1
731+
data[1] = 0
732+
data[2] = engine_idx
733+
data[3] = chuck_token_offset
734+
data[4] = current_seq_len
735+
kv_signal_data.set_value(data)
736+
737+
monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal)
738+
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
739+
messager = cache_messager.CacheMessagerV1(
740+
splitwise_role="mixed",
741+
transfer_protocol="rdma",
742+
pod_ip="0.0.0.0",
743+
engine_worker_queue_port=9000,
744+
local_data_parallel_id=0,
745+
gpu_cache_kvs=gpu_cache_kvs,
746+
rank=0,
747+
nranks=1,
748+
num_layers=1,
749+
gpu_id=0,
750+
block_size=64,
751+
rdma_port="2222",
752+
)
753+
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
754+
755+
with pytest.raises(SystemExit):
756+
messager.consume_signals()
757+
758+
assert messager.pending_layer0_signals == {5: (5, 36)}
759+
assert messager.cache_prefilled_engine_ids_queue.items == []
760+
761+
609762
def test_main_initializes_cache_and_exits(monkeypatch):
610763
monkeypatch.setattr(cache_messager, "set_device", lambda device: None)
611764
monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None)

0 commit comments

Comments
 (0)