Skip to content

Commit 271d85c

Browse files
kevincheng2claude
andcommitted
feat: refactor storage prefetch - 3-phase architecture
Refactor _prefetch_storage_cache into three decoupled phases: - Phase 1 (preprocess thread): CacheManager.prefetch_storage() does matching + enqueue - Phase 2 (schedule thread): drain pending list, attach to batch_request for dispatch - Phase 3 (receiver thread): zmq.Poller receives done msgs, stores results Worker side: extract prefetch tasks from batch_request, execute via thread pool, send completion via ZMQ PUSH. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6151c15 commit 271d85c

7 files changed

Lines changed: 445 additions & 174 deletions

File tree

fastdeploy/cache_manager/v1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CacheStatus,
2626
MatchResult,
2727
PDTransferMetadata,
28+
PendingPrefetch,
2829
StorageConfig,
2930
StorageMetadata,
3031
StorageType,
@@ -61,6 +62,7 @@
6162
"AsyncTaskHandler",
6263
"MatchResult",
6364
"StorageMetadata",
65+
"PendingPrefetch",
6466
"PDTransferMetadata",
6567
"StorageConfig",
6668
"StorageType",

fastdeploy/cache_manager/v1/cache_manager.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,15 @@
3030
from .base import KVCacheBase
3131
from .block_pool import DeviceBlockPool, HostBlockPool
3232
from .cache_utils import storage_key_for_block
33-
from .metadata import BlockNode, CacheLevel, CacheStatus, CacheSwapMetadata, MatchResult
33+
from .metadata import (
34+
BlockNode,
35+
CacheLevel,
36+
CacheStatus,
37+
CacheSwapMetadata,
38+
MatchResult,
39+
PendingPrefetch,
40+
StorageMetadata,
41+
)
3442
from .radix_tree import RadixTree
3543
from .storage import create_storage_scheduler
3644

@@ -111,6 +119,10 @@ def __init__(
111119
# used to quickly update status to HOST once prefetch completes.
112120
self._prefetch_node_map: Dict[int, BlockNode] = {}
113121

122+
# Pending prefetch queue: tasks waiting to be dispatched by scheduler
123+
self._pending_prefetch_list: List[PendingPrefetch] = []
124+
self._pending_prefetch_lock = threading.Lock()
125+
114126
# Storage scheduler (create using factory method if backend is configured)
115127
self._storage_scheduler = create_storage_scheduler(self.cache_config)
116128

@@ -504,10 +516,11 @@ def match_prefix(
504516
# Split matched_nodes into device blocks and host blocks
505517
if self.enable_host_cache:
506518
for node in matched_nodes:
507-
if node.is_on_device():
508-
result.device_nodes.append(node)
509-
elif node.is_on_host():
510-
result.host_nodes.append(node)
519+
pass
520+
# if node.is_on_device():
521+
# result.device_nodes.append(node)
522+
# elif node.is_on_host():
523+
# result.host_nodes.append(node)
511524
else:
512525
result.device_nodes = matched_nodes
513526

@@ -968,6 +981,61 @@ def load_from_host(self, block_indices: List[int]) -> bool:
968981

969982
# ============ Prefetch Methods ============
970983

984+
def prefetch_storage(self, request: "Request") -> bool:
985+
"""
986+
Execute storage matching and enqueue prefetch info for later dispatch.
987+
988+
Called from the preprocess thread. Does match_prefix(skip_storage=False)
989+
to probe storage, allocate host blocks, and enqueue PendingPrefetch
990+
into the pending list. The scheduler will drain and dispatch later.
991+
992+
Args:
993+
request: The request to prefetch cache for.
994+
995+
Returns:
996+
True if storage blocks were matched and enqueued, False otherwise.
997+
"""
998+
if not self.enable_prefix_caching:
999+
return False
1000+
1001+
self.match_prefix(request, skip_storage=False)
1002+
match_result = request.match_result
1003+
request.match_result = None
1004+
1005+
if match_result is None or match_result.matched_storage_nums == 0:
1006+
return False
1007+
1008+
storage_nodes = match_result.storage_nodes
1009+
host_block_ids = [node.block_id for node in storage_nodes]
1010+
hash_values = [node.hash_value for node in storage_nodes]
1011+
1012+
metadata = StorageMetadata(
1013+
hash_values=hash_values,
1014+
block_ids=host_block_ids,
1015+
direction="load",
1016+
)
1017+
1018+
pending = PendingPrefetch(
1019+
request_id=request.request_id,
1020+
metadata=metadata,
1021+
host_block_ids=host_block_ids,
1022+
)
1023+
with self._pending_prefetch_lock:
1024+
self._pending_prefetch_list.append(pending)
1025+
1026+
logger.info(
1027+
f"[Debug][StoragePrefetch] request_id={request.request_id} "
1028+
f"storage_matched={match_result.matched_storage_nums} blocks, enqueued for dispatch"
1029+
)
1030+
return True
1031+
1032+
def drain_pending_prefetches(self) -> List[PendingPrefetch]:
1033+
"""Atomically drain all pending prefetch tasks for scheduler dispatch."""
1034+
with self._pending_prefetch_lock:
1035+
items = self._pending_prefetch_list
1036+
self._pending_prefetch_list = []
1037+
return items
1038+
9711039
def prepare_prefetch_metadata(
9721040
self,
9731041
storage_hashes: List[str],

fastdeploy/cache_manager/v1/metadata.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,23 @@ class StorageMetadata:
409409
extra_params: Dict[str, Any] = field(default_factory=dict)
410410

411411

412+
@dataclass
413+
class PendingPrefetch:
414+
"""
415+
Represents a pending storage prefetch task enqueued by CacheManager,
416+
waiting to be dispatched to workers by the scheduler.
417+
418+
Attributes:
419+
request_id: The request that triggered this prefetch.
420+
metadata: StorageMetadata with hash_values and block_ids for the transfer.
421+
host_block_ids: Pre-allocated host block IDs (for cleanup on failure).
422+
"""
423+
424+
request_id: str = ""
425+
metadata: "StorageMetadata" = field(default_factory=lambda: StorageMetadata())
426+
host_block_ids: List[int] = field(default_factory=list)
427+
428+
412429
@dataclass
413430
class PDTransferMetadata:
414431
"""

fastdeploy/engine/common_engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,12 @@ def _fetch_request():
11681168
task.metrics.decode_inference_start_time = time.time()
11691169
elif not task.has_been_preempted_before:
11701170
task.metrics.inference_start_time = time.time()
1171+
if batch_request.storage_prefetch_tasks:
1172+
self.llm_logger.info(
1173+
f"[Debug][StoragePrefetch][Dispatch] put_tasks with "
1174+
f"{len(batch_request.storage_prefetch_tasks)} prefetch tasks, "
1175+
f"{len(batch_request.requests)} inference requests"
1176+
)
11711177
self.engine_worker_queue.put_tasks((batch_request, self.resource_manager.real_bsz))
11721178
else:
11731179
# When there are no actual tasks to schedule, send an empty task batch to EP workers.

fastdeploy/engine/request.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from typing_extensions import TypeVar
3535

3636
from fastdeploy import envs
37-
from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata
37+
from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata, PendingPrefetch
3838
from fastdeploy.engine.pooling_params import PoolingParams
3939
from fastdeploy.engine.sampling_params import SamplingParams
4040
from fastdeploy.entrypoints.openai.protocol import (
@@ -618,6 +618,7 @@ def __init__(self):
618618

619619
self.cache_swap_metadata: Optional[CacheSwapMetadata] = None
620620
self.cache_evict_metadata: Optional[CacheSwapMetadata] = None
621+
self.storage_prefetch_tasks: Optional[List[PendingPrefetch]] = None
621622

622623
def add_request(self, request):
623624
if hasattr(request, "cache_swap_metadata") and request.cache_swap_metadata:
@@ -659,9 +660,17 @@ def append_evict_metadata(self, metadata: List[CacheSwapMetadata]):
659660
hash_values=meta.hash_values,
660661
)
661662

663+
def append_prefetch_tasks(self, tasks: List[PendingPrefetch]):
664+
if self.storage_prefetch_tasks is None:
665+
self.storage_prefetch_tasks = []
666+
self.storage_prefetch_tasks.extend(tasks)
667+
662668
def __repr__(self):
663669
requests_repr = repr(self.requests)
664-
return f"BatchRequest(requests={requests_repr}, swap_metadata={self.cache_swap_metadata}, evict_metadata={self.cache_evict_metadata})"
670+
return (
671+
f"BatchRequest(requests={requests_repr}, swap_metadata={self.cache_swap_metadata}, "
672+
f"evict_metadata={self.cache_evict_metadata}, prefetch_tasks={self.storage_prefetch_tasks})"
673+
)
665674

666675
def __getstate__(self):
667676
state = self.__dict__.copy()
@@ -688,14 +697,19 @@ def __getitem__(self, index):
688697
return self.requests[index]
689698

690699
def __len__(self):
691-
return len(self.requests)
700+
count = len(self.requests)
701+
if self.storage_prefetch_tasks:
702+
count += len(self.storage_prefetch_tasks)
703+
return count
692704

693705
def append(self, batch_request: "BatchRequest"):
694706
self.requests.extend(batch_request.requests)
695707
if batch_request.cache_swap_metadata:
696708
self.append_swap_metadata([batch_request.cache_swap_metadata])
697709
if batch_request.cache_evict_metadata:
698710
self.append_evict_metadata([batch_request.cache_evict_metadata])
711+
if batch_request.storage_prefetch_tasks:
712+
self.append_prefetch_tasks(batch_request.storage_prefetch_tasks)
699713

700714
def extend(self, batch_requests: list["BatchRequest"]):
701715
for br in batch_requests:

0 commit comments

Comments
 (0)