Skip to content

Commit e1551e4

Browse files
committed
[BugFix][KVCache] fix cupy device id caching and pickle for _match_result
## Motivation 修复两个 bug: 1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。 2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。 ## Modifications - `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。 - `request.py`: - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。 - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。 ## Usage or Command
1 parent 4b43eb7 commit e1551e4

2 files changed

Lines changed: 24 additions & 15 deletions

File tree

fastdeploy/cache_manager/v1/transfer_manager.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def __init__(
9292
# They run in parallel without waiting for each other
9393
# Using cupy to avoid affecting Paddle's internal stream state
9494
if _HAS_CUPY and paddle.is_compiled_with_cuda():
95-
cupy_current_device = cp.cuda.runtime.getDevice()
95+
self._cupy_device_id = cp.cuda.runtime.getDevice()
9696
logger.info(
9797
f"[TransferManager] Creating streams: local_rank={self._local_rank}, device_id={self._device_id}, "
98-
f"cupy_current_device={cupy_current_device}"
98+
f"cupy_device_id={self._cupy_device_id}"
9999
)
100-
with cp.cuda.Device(self._device_id):
100+
with cp.cuda.Device(self._cupy_device_id):
101101
self._input_stream = cp.cuda.Stream(non_blocking=False)
102102
self._output_stream = cp.cuda.Stream(non_blocking=False)
103103
logger.info(
@@ -447,12 +447,11 @@ def _swap_all_layers_async(
447447

448448
stream = self._output_stream if mode == 0 else self._input_stream
449449
try:
450-
cupy_current_device = cp.cuda.runtime.getDevice()
451450
logger.debug(
452451
f"[TransferManager] _swap_all_layers_async: local_rank={self._local_rank}, device_id={self._device_id}, "
453-
f"cupy_current_device={cupy_current_device}, stream_device={stream.device_id}, mode={mode}"
452+
f"cupy_device_id={self._cupy_device_id}, stream_device={stream.device_id}, mode={mode}"
454453
)
455-
with cp.cuda.Device(self._device_id):
454+
with cp.cuda.Device(self._cupy_device_id):
456455
with stream:
457456
swap_cache_all_layers(
458457
self._device_key_caches,
@@ -534,7 +533,7 @@ def _swap_single_layer_async(
534533
return False
535534

536535
try:
537-
with cp.cuda.Device(self._device_id):
536+
with cp.cuda.Device(self._cupy_device_id):
538537
with stream:
539538
swap_cache_per_layer_async(
540539
key_cache,
@@ -640,7 +639,7 @@ def record_input_stream_event(self) -> Any:
640639
if not _HAS_CUPY or self._input_stream is None:
641640
return None
642641
try:
643-
with cp.cuda.Device(self._device_id):
642+
with cp.cuda.Device(self._cupy_device_id):
644643
event = cp.cuda.Event()
645644
with self._input_stream:
646645
event.record()

fastdeploy/engine/request.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,20 +453,30 @@ def __getstate__(self):
453453
Custom getstate method for pickle support.
454454
Handles unpicklable attributes by filtering them from __dict__.
455455
"""
456-
# Create a filtered dictionary without problematic attributes
456+
# Attributes that cannot or need not be pickled for cross-process transfer.
457+
# _block_hasher: closure/callable, not picklable.
458+
# _match_result: contains BlockNode tree with parent<->children circular
459+
# references, which causes RecursionError during pickling.
460+
# async_process_futures: asyncio futures, not picklable.
461+
_SKIP_KEYS = {"_block_hasher", "_match_result"}
457462
filtered_dict = {}
458463
for key, value in self.__dict__.items():
459-
# Skip attributes that are known to contain unpicklable objects
460-
if key == "async_process_futures":
461-
filtered_dict[key] = []
462-
elif key == "_block_hasher":
463-
# Skip _block_hasher (closure function, cannot be pickled)
464+
if key in _SKIP_KEYS:
464465
continue
466+
elif key == "async_process_futures":
467+
filtered_dict[key] = []
465468
else:
466469
filtered_dict[key] = value
467-
468470
return filtered_dict
469471

472+
def __setstate__(self, state):
473+
self.__dict__.update(state)
474+
# Restore fields that were excluded from pickling with safe defaults.
475+
if "_block_hasher" not in self.__dict__:
476+
self._block_hasher = None
477+
if "_match_result" not in self.__dict__:
478+
self._match_result = None
479+
470480
def __eq__(self, other):
471481
"""
472482
EQ operator.

0 commit comments

Comments
 (0)