Skip to content

Commit 98f3fc9

Browse files
[RL] [KVCache] let cache transfer managers update key prefix after weight update and add unit tests (#7083)
* [test] add a few unit tests * [feat] update key prefix when model weights are updated * [test] try to fix test_worker_process
1 parent 9f3b3ce commit 98f3fc9

File tree

8 files changed

+636
-11
lines changed

8 files changed

+636
-11
lines changed

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,14 @@ def _handle_resume(self):
10981098
logger.info("✅ Successfully resumed transfer")
10991099
return True
11001100

1101+
def _handle_update_weights(self):
1102+
if self.storage_backend_type is not None:
1103+
self._update_key_prefix()
1104+
logger.info("✅ Successfully updated cache key prefix after weight update")
1105+
else:
1106+
logger.info("💡 Cache storage backend is disabled, skip updating cache key prefix")
1107+
return True
1108+
11011109
def _handle_sleep(self):
11021110
if self.is_sleeping:
11031111
logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!")
@@ -1128,6 +1136,7 @@ def control_task(self, task: ControlRequest):
11281136
handlers = {
11291137
"pause": self._handle_pause,
11301138
"resume": self._handle_resume,
1139+
"update_weights": self._handle_update_weights,
11311140
"sleep": self._handle_sleep,
11321141
"wakeup": self._handle_wakeup,
11331142
}

fastdeploy/engine/common_engine.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,10 +1443,16 @@ def _control_pause(self, control_request: ControlRequest):
14431443
# pause cache transfer
14441444
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
14451445
self.llm_logger.info("Start to pause cache transfer.")
1446-
pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause")
1446+
pause_transfer_request = ControlRequest(
1447+
request_id=f"{control_request.request_id}_pause_transfer", method="pause"
1448+
)
14471449
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
14481450
# Wait for cache_transfer responses
1449-
asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"]))
1451+
asyncio.run(
1452+
self._wait_for_control_responses(
1453+
f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"]
1454+
)
1455+
)
14501456
self.llm_logger.info("Successfully paused cache transfer.")
14511457

14521458
self.resource_manager.cache_manager.reset()
@@ -1473,10 +1479,14 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
14731479
# resume cache transfer
14741480
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
14751481
self.llm_logger.info("Start to resume cache transfer.")
1476-
resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume")
1482+
resume_transfer_request = ControlRequest(
1483+
request_id=f"{control_request.request_id}_resume_transfer", method="resume"
1484+
)
14771485
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
14781486
# Wait for cache_transfer responses
1479-
asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"]))
1487+
asyncio.run(
1488+
self._wait_for_control_responses(resume_transfer_request.request_id, 60, executors=["cache_transfer"])
1489+
)
14801490
self.llm_logger.info("Successfully resumed cache transfer.")
14811491

14821492
self.llm_logger.info("Successfully resumed request generation.")
@@ -1531,6 +1541,19 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d
15311541
if new_version is not None:
15321542
self.cfg.model_config.version = new_version
15331543

1544+
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
1545+
self.llm_logger.info("Start to update cache-transfer metadata after weight update.")
1546+
update_cache_request = ControlRequest(
1547+
request_id=f"{control_request.request_id}_update_weights",
1548+
method="update_weights",
1549+
args=copy.deepcopy(control_request.args),
1550+
)
1551+
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, update_cache_request))
1552+
asyncio.run(
1553+
self._wait_for_control_responses(update_cache_request.request_id, 60, executors=["cache_transfer"])
1554+
)
1555+
self.llm_logger.info("Successfully updated cache-transfer metadata after weight update.")
1556+
15341557
return responses
15351558

15361559
def _control_abort_requests(self, control_req: ControlRequest):

tests/cache_manager/test_cache_transfer_manager.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tempfile
1818
import time
1919
import unittest
20-
from unittest.mock import MagicMock, patch
20+
from unittest.mock import MagicMock, Mock, patch
2121

2222
import paddle
2323

@@ -37,6 +37,7 @@ def enable_torch_proxy(scope=None):
3737
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
3838
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
3939
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
40+
from fastdeploy.engine.request import ControlRequest
4041

4142

4243
# ==========================
@@ -121,6 +122,16 @@ def __init__(self, name, array, dtype, suffix, create=False):
121122
patcher_thread.start()
122123
self.addCleanup(patcher_thread.stop)
123124

125+
# --------------------------
126+
# mock FMQ
127+
# --------------------------
128+
patcher_fmq = patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ")
129+
mock_fmq_cls = patcher_fmq.start()
130+
mock_fmq = MagicMock()
131+
mock_fmq.queue.return_value = MagicMock(name="ctrl_output_queue")
132+
mock_fmq_cls.return_value = mock_fmq
133+
self.addCleanup(patcher_fmq.stop)
134+
124135
# --------------------------
125136
# mock _init_cpu_cache 和 _init_gpu_cache
126137
# --------------------------
@@ -1515,6 +1526,111 @@ def resume_sleep(_):
15151526

15161527
self.assertFalse(self.manager.is_paused)
15171528

1529+
def test_init_control_builds_expected_queue_name(self):
1530+
self.manager.rank = 1
1531+
self.manager.n_ranks = 4
1532+
self.manager.local_data_parallel_id = 2
1533+
self.manager.cache_queue_port = 8899
1534+
1535+
queue = MagicMock(name="ctrl_q")
1536+
fmq = MagicMock()
1537+
fmq.queue.return_value = queue
1538+
1539+
with patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ", return_value=fmq):
1540+
self.manager._init_control()
1541+
1542+
fmq.queue.assert_called_once_with("ctrl_c2e_rank9_8899", "producer")
1543+
self.assertIs(self.manager.ctrl_output_queue, queue)
1544+
1545+
def test_control_task_success_puts_control_response(self):
1546+
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
1547+
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
1548+
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
1549+
self.manager._handle_pause = MagicMock(return_value=True)
1550+
1551+
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
1552+
self.manager.control_task(ControlRequest(request_id="ctrl-1", method="pause"))
1553+
1554+
self.manager._handle_pause.assert_called_once()
1555+
self.manager.cache_task_queue.barrier.wait.assert_called_once()
1556+
self.manager.ctrl_output_queue.put.assert_called_once()
1557+
response = self.manager.ctrl_output_queue.put.call_args.args[0]
1558+
self.assertEqual(response.request_id, "ctrl-1")
1559+
self.assertEqual(response.error_code, 200)
1560+
1561+
def test_control_task_unknown_method_returns_400(self):
1562+
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
1563+
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
1564+
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
1565+
1566+
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
1567+
self.manager.control_task(ControlRequest(request_id="ctrl-2", method="unknown"))
1568+
1569+
response = self.manager.ctrl_output_queue.put.call_args.args[0]
1570+
self.assertEqual(response.error_code, 400)
1571+
self.assertIn("Unknown control method", response.error_message)
1572+
1573+
def test_control_task_exception_returns_500(self):
1574+
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
1575+
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
1576+
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
1577+
1578+
with (
1579+
patch.object(self.manager, "_handle_sleep", side_effect=RuntimeError("boom")),
1580+
patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"),
1581+
):
1582+
self.manager.control_task(ControlRequest(request_id="ctrl-3", method="sleep"))
1583+
1584+
response = self.manager.ctrl_output_queue.put.call_args.args[0]
1585+
self.assertEqual(response.error_code, 500)
1586+
self.assertIn("Failed to execute sleep", response.error_message)
1587+
1588+
def test_handle_resume_updates_key_prefix_for_storage_backend(self):
1589+
self.manager.is_paused = True
1590+
self.manager.storage_backend_type = "mooncake"
1591+
self.manager.resume = MagicMock()
1592+
self.manager._update_key_prefix = MagicMock()
1593+
1594+
result = self.manager._handle_resume()
1595+
1596+
self.assertTrue(result)
1597+
self.manager.resume.assert_called_once()
1598+
self.manager._update_key_prefix.assert_called_once()
1599+
1600+
def test_handle_update_weights_updates_key_prefix_for_storage_backend(self):
1601+
self.manager.storage_backend_type = "mooncake"
1602+
self.manager._update_key_prefix = MagicMock()
1603+
1604+
result = self.manager._handle_update_weights()
1605+
1606+
self.assertTrue(result)
1607+
self.manager._update_key_prefix.assert_called_once()
1608+
1609+
def test_handle_update_weights_skips_without_storage_backend(self):
1610+
self.manager.storage_backend_type = None
1611+
self.manager._update_key_prefix = MagicMock()
1612+
1613+
result = self.manager._handle_update_weights()
1614+
1615+
self.assertTrue(result)
1616+
self.manager._update_key_prefix.assert_not_called()
1617+
1618+
def test_handle_sleep_and_wakeup_are_idempotent(self):
1619+
self.manager.is_sleeping = True
1620+
self.manager._clear_cpu_cache = MagicMock()
1621+
self.manager._clear_gpu_cache = MagicMock()
1622+
self.manager._init_cpu_cache = MagicMock()
1623+
self.manager._init_gpu_cache = MagicMock()
1624+
1625+
self.assertTrue(self.manager._handle_sleep())
1626+
self.manager._clear_cpu_cache.assert_not_called()
1627+
self.manager._clear_gpu_cache.assert_not_called()
1628+
1629+
self.manager.is_sleeping = False
1630+
self.assertTrue(self.manager._handle_wakeup())
1631+
self.manager._init_cpu_cache.assert_not_called()
1632+
self.manager._init_gpu_cache.assert_not_called()
1633+
15181634
def test_submit_task_decrements_inflight_on_task_error(self):
15191635
class DummyPool:
15201636
def submit(self, fn, *args):

tests/engine/test_common_engine.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def enable_torch_proxy(scope=None):
3939

4040
paddle.compat = _PaddleCompat()
4141

42+
from fastdeploy.cache_manager.cache_data import CacheStatus
4243
from fastdeploy.engine.args_utils import EngineArgs
4344
from fastdeploy.engine.common_engine import (
4445
EngineService,
@@ -1117,6 +1118,27 @@ def test_control_update_weights_updates_cfg_version(self):
11171118
self.assertEqual(eng.cfg.model_config.version, "new-version")
11181119
self._detach_finalizer(eng)
11191120

1121+
def test_control_update_weights_updates_cache_transfer_metadata(self):
1122+
eng = self._make_mixed_engine()
1123+
eng.is_paused = True
1124+
eng._pause_cond = threading.Condition()
1125+
eng.cfg.cache_config.num_cpu_blocks = 1
1126+
eng._call_worker = Mock(return_value=[{"version": "new-version"}])
1127+
eng.cache_task_queue = Mock(put_transfer_task=Mock())
1128+
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])
1129+
1130+
result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
1131+
1132+
self.assertEqual(result, [{"version": "new-version"}])
1133+
payload = eng.cache_task_queue.put_transfer_task.call_args.args[0]
1134+
self.assertEqual(payload[0], CacheStatus.CTRL)
1135+
self.assertEqual(payload[1].method, "update_weights")
1136+
self.assertIn("update_weights", payload[1].request_id)
1137+
eng._wait_for_control_responses.assert_awaited_once_with(
1138+
payload[1].request_id, 60, executors=["cache_transfer"]
1139+
)
1140+
self._detach_finalizer(eng)
1141+
11201142
def test_control_pause_and_resume_paths(self):
11211143
eng = self._make_mixed_engine()
11221144
eng.is_paused = False

tests/entrypoints/test_engine_client.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import paddle
2626
import pytest
2727

28-
from fastdeploy.engine.request import ControlRequest
28+
from fastdeploy.engine.request import ControlRequest, ControlResponse
2929
from fastdeploy.entrypoints.engine_client import EngineClient
3030
from fastdeploy.inter_communicator import (
3131
KVCacheStatus,
@@ -1882,6 +1882,65 @@ def test_valid_parameters_and_control_timeout(minimal_engine_client):
18821882
assert resp.error_code == 500
18831883

18841884

1885+
def test_run_control_method_uses_send_pyobj_for_mm_requests(minimal_engine_client):
1886+
queue = asyncio.Queue()
1887+
asyncio.run(queue.put(({"request_id": "mm-1", "status": 200, "msg": "ok"},)))
1888+
dealer = Mock(write=Mock())
1889+
minimal_engine_client.enable_mm = True
1890+
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue)))
1891+
1892+
with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0):
1893+
resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="mm-1", method="ping")))
1894+
1895+
assert resp.error_code == 200
1896+
minimal_engine_client.zmq_client.send_pyobj.assert_called_once()
1897+
minimal_engine_client.zmq_client.send_json.assert_not_called()
1898+
1899+
1900+
def test_run_control_method_adds_worker_pid_in_batch_mode(minimal_engine_client):
1901+
queue = asyncio.Queue()
1902+
asyncio.run(queue.put(({"request_id": "batch-1", "status": 200, "msg": "ok"},)))
1903+
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(None, queue)))
1904+
1905+
with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 1):
1906+
resp = asyncio.run(
1907+
minimal_engine_client.run_control_method(ControlRequest(request_id="batch-1", method="ping"))
1908+
)
1909+
1910+
assert resp.error_code == 200
1911+
payload = minimal_engine_client.zmq_client.send_json.call_args.args[0]
1912+
assert payload["zmq_worker_pid"] == minimal_engine_client.worker_pid
1913+
1914+
1915+
def test_run_control_method_generic_exception_returns_error(minimal_engine_client):
1916+
queue = MagicMock()
1917+
queue.get = AsyncMock(side_effect=RuntimeError("queue failed"))
1918+
dealer = Mock(write=Mock())
1919+
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue)))
1920+
1921+
with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0):
1922+
resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="r3", method="m")))
1923+
1924+
assert resp.error_code == 500
1925+
assert "queue failed" in resp.error_message
1926+
1927+
1928+
def test_run_control_method_sync_uses_threadsafe_bridge(minimal_engine_client):
1929+
req = ControlRequest(request_id="sync-1", method="ping")
1930+
future = Mock(result=Mock(return_value=ControlResponse("sync-1", 200, "Success")))
1931+
1932+
minimal_engine_client.run_control_method = AsyncMock(return_value=ControlResponse("sync-1", 200, "Success"))
1933+
1934+
with patch(
1935+
"fastdeploy.entrypoints.engine_client.asyncio.run_coroutine_threadsafe", return_value=future
1936+
) as mock_run:
1937+
resp = minimal_engine_client.run_control_method_sync(req, Mock())
1938+
1939+
assert resp.error_code == 200
1940+
mock_run.assert_called_once()
1941+
mock_run.call_args.args[0].close()
1942+
1943+
18851944
def test_rearrange_and_redundant_branch_matrix(minimal_engine_client):
18861945
cfg = create_mock_fd_config(enable_eplb=True)
18871946
cfg.parallel_config.tensor_parallel_rank = 0

0 commit comments

Comments
 (0)