Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,8 @@ def __init__(
self.engine_cache_task_thread_lock = threading.Lock()
self.engine_cache_tasks = [dict() for _ in range(512)]
self.idx_cache_task_dict = {}
self.pending_layer0_signals = {}
self.pending_layer0_signal_lock = threading.Lock()
self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step
if splitwise_role == "prefill":
consume_signals_thread = threading.Thread(target=self.consume_signals)
Expand Down Expand Up @@ -661,7 +663,28 @@ def _add_cache_task_thread(self):
current_info["status"] = "init"
logger.info(f"Get cache info from D: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
current_id = current_info["current_id"]
with self.engine_cache_task_thread_lock:
self.idx_cache_task_dict[current_id] = current_info
with self.pending_layer0_signal_lock:
recovered_signal = self.pending_layer0_signals.pop(current_id, None)
if recovered_signal is not None:
_, prefilled_token_num = recovered_signal
if prefilled_token_num <= current_info["need_prefill_tokens"]:
recovered_signal_batch = [recovered_signal]
logger.info(
"cache_task_register_recover_layer0_signal: "
f"current_id: {current_id}, "
f"recovered_signal_batch: {recovered_signal_batch}"
)
self.cache_prefilled_engine_ids_queue.put(recovered_signal_batch)
else:
logger.info(
"cache_task_register_drop_layer0_signal: "
f"current_id: {current_id}, "
f"recovered_signal: {recovered_signal}, "
f"need_prefill_tokens: {current_info['need_prefill_tokens']}"
)
else:
logger.info(f"Get cache info from P: {info}")
self.cache_info[info["request_id"]] = info
Expand Down Expand Up @@ -842,9 +865,12 @@ def prefill_layerwise_send_cache_thread(self):
logger.info(
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
)
self.engine_cache_tasks[task["current_id"]] = dict()
current_id = task["current_id"]
self.engine_cache_tasks[current_id] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
del self.idx_cache_task_dict[current_id]
with self.pending_layer0_signal_lock:
self.pending_layer0_signals.pop(current_id, None)
break
except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}")
Expand All @@ -856,32 +882,42 @@ def consume_signals(self):
while True:
try:
get_output_kv_signal(kv_signal_data, self.rank_id, 1) # wait_flag
if not self.cache_info:
time.sleep(0.01)
continue
tasks_count = kv_signal_data[0]
has_cache_info = bool(self.cache_info)
tasks_count = kv_signal_data[0].item()
if tasks_count == -1:
continue
if not has_cache_info:
logger.debug("consume_signals get kv signal before cache info is ready")
layer_id = kv_signal_data[1].item()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
batch_engine_signals = []
ready_engine_signals = []
pending_engine_signals = []
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].item()
chuck_token_offset = kv_signal_data[3 * bi + 3].item()
current_seq_len = kv_signal_data[3 * bi + 4].item()
prefilled_token_num = chuck_token_offset + current_seq_len
self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
chuck_token_offset + current_seq_len
)
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
if layer_id == 0:
logger.info(
f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue"
)
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = prefilled_token_num
if layer_id == 0:
if engine_idx in self.idx_cache_task_dict:
ready_engine_signals.append((engine_idx, prefilled_token_num))
else:
pending_engine_signals.append((engine_idx, prefilled_token_num))
if pending_engine_signals:
with self.pending_layer0_signal_lock:
for engine_idx, prefilled_token_num in pending_engine_signals:
self.pending_layer0_signals[engine_idx] = (engine_idx, prefilled_token_num)
if pending_engine_signals:
logger.debug(f"cache_task_pending_layer0_signal: {pending_engine_signals}")
if ready_engine_signals:
logger.info(
f"Put batch_engine_signals {ready_engine_signals} into cache_prefilled_engine_ids_queue"
)
self.cache_prefilled_engine_ids_queue.put(ready_engine_signals)
except Exception as e:
logger.error(f"Consume signals get exception: {e}, {traceback.format_exc()}")

Expand Down
167 changes: 160 additions & 7 deletions tests/cache_manager/test_cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def error(self, msg):
self.messages.append(("error", msg))


class _QueueRecorder:
def __init__(self):
self.items = []

def put(self, item):
self.items.append(item)


class _DummySignalValue:
def __init__(self, sequence):
self.sequence = list(sequence)
Expand Down Expand Up @@ -380,6 +388,111 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch):
assert messager.cache_info["req-2"]["status"] == "init"


def test_cache_messager_v1_recovers_pending_layer0_signal(monkeypatch):
dummy_queue = _DummyEngineWorkerQueue(
cache_info_sequence=[
[
{
"request_id": "req-pending",
"src_block_ids": [0, 1],
"dest_block_ids": [2],
"current_id": 3,
"need_prefill_tokens": 128,
"transfer_protocol": "rdma",
}
]
]
)
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)

gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
messager = cache_messager.CacheMessagerV1(
splitwise_role="mixed",
transfer_protocol="rdma",
pod_ip="0.0.0.0",
engine_worker_queue_port=9000,
local_data_parallel_id=0,
gpu_cache_kvs=gpu_cache_kvs,
rank=0,
nranks=1,
num_layers=1,
gpu_id=0,
block_size=64,
rdma_port="2222",
)
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
messager.cache_info["req-pending"] = {
"request_id": "req-pending",
"src_block_ids": [0, 1],
"dest_block_ids": [2],
"current_id": 3,
"need_prefill_tokens": 128,
"transfer_protocol": "rdma",
}
messager.pending_layer0_signals[3] = (3, 64)
messager.pending_layer0_signals[4] = (4, 64)

with pytest.raises(SystemExit):
messager._add_cache_task_thread()

assert messager.pending_layer0_signals == {4: (4, 64)}
assert messager.cache_prefilled_engine_ids_queue.items == [[(3, 64)]]


def test_cache_messager_v1_drops_invalid_pending_layer0_signal(monkeypatch):
dummy_queue = _DummyEngineWorkerQueue(
cache_info_sequence=[
[
{
"request_id": "req-pending",
"src_block_ids": [0, 1],
"dest_block_ids": [2],
"current_id": 3,
"need_prefill_tokens": 128,
"transfer_protocol": "rdma",
}
]
]
)
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)

gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
messager = cache_messager.CacheMessagerV1(
splitwise_role="mixed",
transfer_protocol="rdma",
pod_ip="0.0.0.0",
engine_worker_queue_port=9000,
local_data_parallel_id=0,
gpu_cache_kvs=gpu_cache_kvs,
rank=0,
nranks=1,
num_layers=1,
gpu_id=0,
block_size=64,
rdma_port="2222",
)
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
messager.cache_info["req-pending"] = {
"request_id": "req-pending",
"src_block_ids": [0, 1],
"dest_block_ids": [2],
"current_id": 3,
"need_prefill_tokens": 128,
"transfer_protocol": "rdma",
}
messager.pending_layer0_signals[3] = (3, 256)

with pytest.raises(SystemExit):
messager._add_cache_task_thread()

assert messager.pending_layer0_signals == {}
assert messager.cache_prefilled_engine_ids_queue.items == []


def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch):
class _OneShotQueue:
def __init__(self):
Expand Down Expand Up @@ -425,10 +538,12 @@ def get(self):
}
messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64}
messager.cache_info["req-3"] = messager.idx_cache_task_dict[0]
messager.pending_layer0_signals = {0: (0, 64), 1: (1, 64)}
with pytest.raises(SystemExit):
messager.prefill_layerwise_send_cache_thread()
assert dummy_queue.finished_req_payloads
assert dummy_queue.finished_req_payloads[0][0][0] == "req-3"
assert messager.pending_layer0_signals == {1: (1, 64)}


def test_cache_messager_v1_handle_connect_task(monkeypatch):
Expand Down Expand Up @@ -552,13 +667,6 @@ def test_cache_messager_v1_consume_signals(monkeypatch):
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)

class _QueueRecorder:
def __init__(self):
self.items = []

def put(self, item):
self.items.append(item)

counter = {"calls": 0}

def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
Expand Down Expand Up @@ -590,12 +698,57 @@ def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
rdma_port="2222",
)
messager.cache_info["req-4"] = {"request_id": "req-4"}
messager.idx_cache_task_dict[2] = {"request_id": "req-4", "current_id": 2}
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
with pytest.raises(SystemExit):
messager.consume_signals()
assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]]


def test_cache_messager_v1_consume_signals_buffers_early_layer0(monkeypatch):
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)

signals = [(5, 7, 9), (5, 17, 19)]

def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
if not signals:
raise SystemExit
engine_idx, chuck_token_offset, current_seq_len = signals.pop(0)
data = np.full(kv_signal_data.shape, -1, dtype="int32")
data[0] = 1
data[1] = 0
data[2] = engine_idx
data[3] = chuck_token_offset
data[4] = current_seq_len
kv_signal_data.set_value(data)

monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal)
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
messager = cache_messager.CacheMessagerV1(
splitwise_role="mixed",
transfer_protocol="rdma",
pod_ip="0.0.0.0",
engine_worker_queue_port=9000,
local_data_parallel_id=0,
gpu_cache_kvs=gpu_cache_kvs,
rank=0,
nranks=1,
num_layers=1,
gpu_id=0,
block_size=64,
rdma_port="2222",
)
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()

with pytest.raises(SystemExit):
messager.consume_signals()

assert messager.pending_layer0_signals == {5: (5, 36)}
assert messager.cache_prefilled_engine_ids_queue.items == []


def test_main_initializes_cache_and_exits(monkeypatch):
monkeypatch.setattr(cache_messager, "set_device", lambda device: None)
monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None)
Expand Down
Loading