Skip to content

Commit 5043d01

Browse files
committed
[KVCache][BugFix] fix storage prefetch nodes inserted at wrong radix tree position
## Motivation 三级 KV Cache(Device → Host → Storage)预拉取完成后,第二次 match_prefix 仍然只命中 device 层的 block,storage 预拉取的 host block 无法被找到。 根本原因:`prepare_prefetch_metadata` 调用 `radix_tree.insert` 时未传 `start_node`,导致 8 个新 LOADING_FROM_STORAGE 节点被错误地挂在 radix tree 的 root 节点下(以 storage hash h22 作为 root 直接子节点),而非接在已有 22 节点链末尾(node[21] 的子节点)。`find_prefix` 遍历到 node[21] 时, node[21].children 中不存在 h22,立即停止,始终只返回 22 个节点。 同批次还修复了几个关联问题: - `_match_storage` 只探测 "key" kind,Mooncake LRU 可能单独驱逐 "value" 导致虚假命中,改为同时探测 key + value,两者都存在才算命中 - partial write 时部分 key 写成功、部分失败,改为自动 rollback 已写入的 key,防止 _match_storage 发现半写 block - `prepare_prefetch_metadata` 中只注册真正是 LOADING_FROM_STORAGE 状态的 节点进 prefetch_node_map,避免 insert 复用已有 HOST/DEVICE 节点时触发 spurious "unexpected status" 警告 ## Modifications - `cache_manager.py` - `match_prefix`: 传 `start_node=matched_nodes[-1]` 给 `prepare_prefetch_metadata` - `prepare_prefetch_metadata`: 新增 `start_node` 参数,透传给 `_radix_tree.insert` - `prepare_prefetch_metadata`: 只注册 LOADING_FROM_STORAGE 节点进 prefetch_node_map - `_match_storage`: 同时探测 key + value 两个 kind,均存在才视为命中 - `storage/base.py`: 新增 `batch_exists` / `batch_delete` 默认实现 - `storage/mooncake/connector.py`: Mooncake 实现 `batch_exists` / `batch_delete` - `storage/staging_manager.py`: partial write 自动 rollback - `transfer_manager.py`: prefetch/backup 失败时输出诊断日志 - `tests/cache_manager/v1/test_cache_manager.py`: 添加回归测试 `TestPreparePrefixtMetadataStartNode` ## Usage or Command ```bash # 运行回归测试 source .venv/py310/bin/activate PYTHONPATH=. python -m pytest tests/cache_manager/v1/test_cache_manager.py::TestPreparePrefixtMetadataStartNode -v ```
1 parent 50d588f commit 5043d01

6 files changed

Lines changed: 211 additions & 19 deletions

File tree

fastdeploy/cache_manager/v1/cache_manager.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,8 @@ def match_prefix(
531531
# Step 2: Match Storage (if enabled and not skipped)
532532
if not skip_storage and self._storage_scheduler and remaining_hashes:
533533
storage_matches = self._match_storage(remaining_hashes)
534-
result.storage_nodes = self.prepare_prefetch_metadata(storage_matches)
534+
start_node = matched_nodes[-1] if matched_nodes else None
535+
result.storage_nodes = self.prepare_prefetch_metadata(storage_matches, start_node=start_node)
535536

536537
# Step 3: Increment ref count for matched blocks(only scheduling phase)
537538
if skip_storage:
@@ -562,11 +563,13 @@ def _match_storage(self, hash_values: List[str]) -> List[str]:
562563
consecutive prefix of hashes that are all present (prefix semantics
563564
are required because a cache miss in the middle breaks prefetch continuity).
564565
565-
Uses rank=0 key as a probe: if rank 0 has the block, all ranks
566-
are assumed to have it (all ranks write storage synchronously).
566+
Probes both rank=0 "key" and "value" kinds: a block is considered present
567+
only when both exist. This avoids false positives from partial writes where
568+
only one kind was stored, and prevents LRU asymmetry (probing only "key"
569+
would keep it hot while "value" gets evicted by Mooncake).
567570
568571
Storage key format (see cache_utils.storage_key_for_block):
569-
"{hash_value}_0_key"
572+
"{hash_value}_0_key" / "{hash_value}_0_value"
570573
571574
Args:
572575
hash_values: List of block hash values to check, in prefix order.
@@ -584,21 +587,27 @@ def _match_storage(self, hash_values: List[str]) -> List[str]:
584587
logger.warning("_match_storage: storage scheduler disconnected, skipping storage match")
585588
return []
586589

587-
# Build probe keys using rank=0 (same format as storage_key_for_block)
588-
probe_keys = [storage_key_for_block(h, 0, "key") for h in hash_values]
590+
# Probe both key and value kinds for rank=0.
591+
# Interleaved: [h0_key, h0_value, h1_key, h1_value, ...]
592+
probe_keys = []
593+
for h in hash_values:
594+
probe_keys.append(storage_key_for_block(h, 0, "key"))
595+
probe_keys.append(storage_key_for_block(h, 0, "value"))
589596

590-
# batch_exists returns a bool list aligned with probe_keys
591597
exist_flags = self._storage_scheduler.batch_exists(probe_keys)
592598

593-
# Return only the leading consecutive hit run
599+
# A block is present only when both key and value exist.
594600
matched = []
595-
for h, exists in zip(hash_values, exist_flags):
596-
if not exists:
601+
for i, h in enumerate(hash_values):
602+
key_ok = exist_flags[i * 2]
603+
val_ok = exist_flags[i * 2 + 1]
604+
if not (key_ok and val_ok):
597605
break
598606
matched.append(h)
599607

600608
logger.debug(
601-
f"[CacheManager] _match_storage: probing {len(probe_keys)} keys, matched hashes: {len(matched)}"
609+
f"[CacheManager] _match_storage: probing {len(hash_values)} blocks "
610+
f"({len(probe_keys)} keys), matched={len(matched)}"
602611
)
603612
return matched
604613
except Exception:
@@ -1001,6 +1010,7 @@ def drain_pending_prefetches(self) -> List[PendingPrefetch]:
10011010
def prepare_prefetch_metadata(
10021011
self,
10031012
storage_hashes: List[str],
1013+
start_node: Optional["BlockNode"] = None,
10041014
) -> Optional[List["BlockNode"]]:
10051015
"""
10061016
Prepare metadata for storage prefetch operation.
@@ -1010,6 +1020,10 @@ def prepare_prefetch_metadata(
10101020
10111021
Args:
10121022
storage_hashes: List of storage hash values to prefetch
1023+
start_node: Node to start insertion from in the radix tree.
1024+
Must be the last matched node from find_prefix so that
1025+
the new LOADING_FROM_STORAGE nodes are attached as proper
1026+
extensions of the existing prefix chain.
10131027
10141028
Returns:
10151029
List of BlockNode objects if successful, None or empty list otherwise.
@@ -1032,17 +1046,24 @@ def prepare_prefetch_metadata(
10321046

10331047
blocks = list(zip(storage_hashes, host_block_ids))
10341048
prefetch_nodes, wasted_block_ids = self._radix_tree.insert(
1035-
blocks=blocks, cache_status=CacheStatus.LOADING_FROM_STORAGE
1049+
blocks=blocks, cache_status=CacheStatus.LOADING_FROM_STORAGE, start_node=start_node
10361050
)
10371051
# Release any blocks that were wasted due to node reuse
10381052
if wasted_block_ids:
10391053
self._host_pool.release(wasted_block_ids)
10401054

1041-
# Register nodes in prefetch_node_map for fast status update on done
1055+
# Register only truly new LOADING_FROM_STORAGE nodes.
1056+
# insert() reuses existing nodes without updating their status, so nodes
1057+
# that were already HOST/DEVICE must be excluded — they don't need a
1058+
# storage transfer and would trigger a spurious "unexpected status" warning
1059+
# in update_storage_blocks_to_host.
1060+
actual_prefetch_nodes = []
10421061
for node in prefetch_nodes:
1043-
self._prefetch_node_map[node.block_id] = node
1062+
if node.cache_status == CacheStatus.LOADING_FROM_STORAGE:
1063+
self._prefetch_node_map[node.block_id] = node
1064+
actual_prefetch_nodes.append(node)
10441065

1045-
return prefetch_nodes
1066+
return actual_prefetch_nodes
10461067
except Exception as e:
10471068
logger.error(f"prepare_prefetch_metadata error: {e}, {str(traceback.format_exc())}")
10481069
return []

fastdeploy/cache_manager/v1/storage/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,20 @@ def is_connected(self) -> bool:
295295
"""Check if connected to storage."""
296296
return self._connected
297297

298+
def batch_exists(self, keys: List[str]) -> List[bool]:
299+
"""
300+
Batch check key existence. Backends that support it should override.
301+
Default returns False for all keys (conservative: assume missing).
302+
"""
303+
return [False] * len(keys)
304+
305+
def batch_delete(self, keys: List[str]) -> List[bool]:
306+
"""
307+
Delete multiple keys. Backends can override for efficiency.
308+
Default falls back to calling delete() per key.
309+
"""
310+
return [self.delete(k) for k in keys]
311+
298312
def get_stats(self) -> Dict[str, Any]:
299313
"""Get connector statistics."""
300314
return {

fastdeploy/cache_manager/v1/storage/mooncake/connector.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,15 @@ def batch_set(
635635

636636
return final_results
637637

638+
def batch_exists(self, keys: List[str]) -> List[bool]:
639+
"""Batch check key existence."""
640+
if not self._connected or self._base._store is None:
641+
return [False] * len(keys)
642+
if not keys:
643+
return []
644+
results, _ = self._base._batch_exists(keys)
645+
return [r == 1 for r in results]
646+
638647
# ------------------------------------------------------------------
639648
# Delete / clear
640649
# ------------------------------------------------------------------
@@ -661,6 +670,21 @@ def delete(self, key: str, timeout: int = 5) -> bool:
661670
self.logger.error(f"delete({key!r}) timed out after {timeout}s")
662671
return False
663672

673+
def batch_delete(self, keys: List[str]) -> List[bool]:
674+
"""
675+
Delete multiple keys from the store (single attempt, no retry).
676+
677+
Used for cleaning up partial writes where some kinds succeeded
678+
and others failed. Returns per-key success flags.
679+
"""
680+
if not self._connected or self._base._store is None:
681+
return [False] * len(keys)
682+
results = []
683+
for key in keys:
684+
rc = self._base._store.remove(key)
685+
results.append(rc == 0)
686+
return results
687+
664688
def clear(self) -> int:
665689
"""
666690
Remove all objects from the store.

fastdeploy/cache_manager/v1/storage/staging_manager.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,22 @@ def batch_set_block(
298298

299299
results = self._connector.batch_set(flat_keys, flat_ptrs, flat_sizes)
300300

301+
# Track which keys succeeded per block for partial-write cleanup.
302+
block_ok_keys: Dict[int, List[str]] = {}
301303
for flat_idx, ok in enumerate(results):
302-
if not ok:
303-
block_success[flat_index[flat_idx]] = False
304+
bi = flat_index[flat_idx]
305+
if ok:
306+
block_ok_keys.setdefault(bi, []).append(flat_keys[flat_idx])
307+
else:
308+
block_success[bi] = False
309+
310+
# Rollback: if a block failed but some of its keys were written,
311+
# delete those keys so the block appears fully absent in storage.
312+
# This prevents _match_storage from finding a half-written block.
313+
keys_to_rollback = [key for bi, keys in block_ok_keys.items() if not block_success[bi] for key in keys]
314+
if keys_to_rollback:
315+
logger.warning(f"[StagingManager] partial write on {len(keys_to_rollback)} key(s), rolling back")
316+
self._connector.batch_delete(keys_to_rollback)
304317

305318
return block_success
306319

fastdeploy/cache_manager/v1/transfer_manager.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,58 @@ def prefetch_from_storage(
879879
return [False] * len(hash_list)
880880

881881
keys_per_kind, host_ptrs_per_kind = self._build_storage_io_args(hash_list)
882-
return self._staging_manager.batch_get_block(keys_per_kind, host_ptrs_per_kind, cpu_block_list)
882+
results = self._staging_manager.batch_get_block(keys_per_kind, host_ptrs_per_kind, cpu_block_list)
883+
884+
failed_indices = [i for i, ok in enumerate(results) if not ok]
885+
if failed_indices and self._storage_connector is not None:
886+
# For each failed block, check which storage keys are actually missing.
887+
# keys_per_kind maps kind -> [key_for_block_0, key_for_block_1, ...]
888+
probe_keys = []
889+
probe_labels = []
890+
for i in failed_indices:
891+
for kind, keys in keys_per_kind.items():
892+
probe_keys.append(keys[i])
893+
probe_labels.append((i, cpu_block_list[i], hash_list[i], kind))
894+
895+
try:
896+
exist_flags = self._storage_connector.batch_exists(probe_keys)
897+
898+
# Aggregate per-block: collect missing kinds and whether any kind exists
899+
# block_idx -> {missing_kinds, existing_kinds}
900+
block_diag: Dict[int, Dict] = {}
901+
for (bi, cpu_bid, h, kind), ok in zip(probe_labels, exist_flags):
902+
if bi not in block_diag:
903+
block_diag[bi] = {"cpu_bid": cpu_bid, "hash": h, "missing": [], "existing": []}
904+
if ok:
905+
block_diag[bi]["existing"].append(kind)
906+
else:
907+
block_diag[bi]["missing"].append(kind)
908+
909+
# Blocks with at least one missing kind
910+
partial_missing = {bi: v for bi, v in block_diag.items() if v["missing"]}
911+
# Blocks where all kinds exist (pure transfer error)
912+
pure_transfer_err = {bi: v for bi, v in block_diag.items() if not v["missing"]}
913+
914+
if partial_missing:
915+
detail = [
916+
f"cpu_block={v['cpu_bid']} hash={v['hash'][:16]}.. "
917+
f"missing_kinds={v['missing']} existing_kinds={v['existing']}"
918+
for v in partial_missing.values()
919+
]
920+
logger.warning(
921+
f"[TransferManager] prefetch_from_storage: {len(partial_missing)} block(s) have missing keys — "
922+
+ "; ".join(detail)
923+
)
924+
if pure_transfer_err:
925+
detail = [f"cpu_block={v['cpu_bid']} hash={v['hash'][:16]}.." for v in pure_transfer_err.values()]
926+
logger.warning(
927+
f"[TransferManager] prefetch_from_storage: {len(pure_transfer_err)} block(s) keys exist but transfer failed — "
928+
+ ", ".join(detail)
929+
)
930+
except Exception as e:
931+
logger.warning(f"[TransferManager] prefetch_from_storage: failed to probe missing keys: {e}")
932+
933+
return results
883934

884935
def backup_to_storage(
885936
self,
@@ -924,4 +975,13 @@ def backup_to_storage(
924975
return [False] * len(cpu_block_list)
925976

926977
keys_per_kind, host_ptrs_per_kind = self._build_storage_io_args(hash_list)
927-
return self._staging_manager.batch_set_block(keys_per_kind, host_ptrs_per_kind, cpu_block_list)
978+
results = self._staging_manager.batch_set_block(keys_per_kind, host_ptrs_per_kind, cpu_block_list)
979+
980+
failed = [(cpu_block_list[i], hash_list[i]) for i, ok in enumerate(results) if not ok]
981+
if failed:
982+
logger.warning(
983+
f"[TransferManager] backup_to_storage: {len(failed)}/{len(cpu_block_list)} block(s) failed — "
984+
+ ", ".join(f"cpu_block={cb} hash={h[:16]}.." for cb, h in failed)
985+
)
986+
987+
return results

tests/cache_manager/v1/test_cache_manager.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,5 +918,65 @@ def test_issue_returns_none_when_host_cache_disabled(self):
918918
self.assertEqual(cm.get_pending_backup_count(), 0)
919919

920920

921+
class TestPreparePrefixtMetadataStartNode(unittest.TestCase):
922+
"""Regression test for the start_node bug in prepare_prefetch_metadata.
923+
924+
Before the fix, prepare_prefetch_metadata called radix_tree.insert without
925+
start_node, which inserted LOADING_FROM_STORAGE nodes as children of root
926+
(using storage hashes h22..h29 at depth 1) instead of as extensions of the
927+
existing device prefix chain (at depth 22..29). As a result, a subsequent
928+
find_prefix on the full hash list would traverse root → h0 → ... → h21,
929+
then fail to find h22 as a child of node(21), and stop at 22 nodes — never
930+
reaching the HOST nodes even after update_storage_blocks_to_host.
931+
"""
932+
933+
def test_find_prefix_finds_host_blocks_after_prefetch(self):
934+
"""After prepare_prefetch_metadata + update_storage_blocks_to_host,
935+
find_prefix must return all 30 nodes (22 DEVICE + 8 HOST)."""
936+
from fastdeploy.cache_manager.v1.metadata import CacheStatus
937+
938+
cm = create_cache_manager(total_block_num=50, num_cpu_blocks=20)
939+
rt = cm._radix_tree
940+
941+
# Build 30 hashes: 22 for device, 8 for storage
942+
all_hashes = [f"h{i}" for i in range(30)]
943+
device_hashes = all_hashes[:22]
944+
storage_hashes = all_hashes[22:]
945+
946+
# Insert 22 device blocks into the radix tree
947+
device_block_ids = cm._device_pool.allocate(22)
948+
self.assertIsNotNone(device_block_ids)
949+
device_nodes, _ = rt.insert(
950+
blocks=list(zip(device_hashes, device_block_ids)),
951+
cache_status=CacheStatus.DEVICE,
952+
)
953+
self.assertEqual(len(device_nodes), 22)
954+
955+
# The last device node is the correct start_node for the storage insertion
956+
last_device_node = device_nodes[-1]
957+
958+
# prepare_prefetch_metadata should attach storage nodes AFTER the last device node
959+
storage_nodes = cm.prepare_prefetch_metadata(storage_hashes, start_node=last_device_node)
960+
self.assertEqual(len(storage_nodes), 8)
961+
for node in storage_nodes:
962+
self.assertEqual(node.cache_status, CacheStatus.LOADING_FROM_STORAGE)
963+
964+
# Simulate prefetch completion: transition LOADING_FROM_STORAGE → HOST
965+
storage_block_ids = [n.block_id for n in storage_nodes]
966+
for node in storage_nodes:
967+
cm._prefetch_node_map[node.block_id] = node
968+
cm.update_storage_blocks_to_host(storage_block_ids)
969+
for node in storage_nodes:
970+
self.assertEqual(node.cache_status, CacheStatus.HOST)
971+
972+
# Now find_prefix on all 30 hashes must return 30 nodes
973+
found = rt.find_prefix(all_hashes)
974+
self.assertEqual(len(found), 30, f"Expected 30 nodes, got {len(found)}")
975+
device_found = [n for n in found if n.is_on_device()]
976+
host_found = [n for n in found if n.is_on_host()]
977+
self.assertEqual(len(device_found), 22)
978+
self.assertEqual(len(host_found), 8)
979+
980+
921981
if __name__ == "__main__":
922982
unittest.main()

0 commit comments

Comments
 (0)