diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index 1a81cfd652f..a0409e4e725 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -26,7 +26,7 @@ KVCacheStorage, logger, ) -from fastdeploy.cache_manager.transfer_factory.utils import get_rdma_nics +from fastdeploy.cache_manager.v1.cache_utils import get_rdma_nics from fastdeploy.platforms import current_platform from fastdeploy.utils import get_host_ip diff --git a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py index 121d8d3d51c..4835549227e 100644 --- a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py @@ -16,7 +16,7 @@ import traceback -from fastdeploy.cache_manager.transfer_factory.utils import get_rdma_nics +from fastdeploy.cache_manager.v1.cache_utils import get_rdma_nics from fastdeploy.utils import get_logger logger = get_logger("cache_messager", "cache_messager.log") diff --git a/fastdeploy/cache_manager/transfer_factory/utils.py b/fastdeploy/cache_manager/transfer_factory/utils.py index 61ae72cab7c..fbfce6ca5c2 100644 --- a/fastdeploy/cache_manager/transfer_factory/utils.py +++ b/fastdeploy/cache_manager/transfer_factory/utils.py @@ -14,36 +14,6 @@ # limitations under the License. """ -import importlib -import subprocess +from fastdeploy.cache_manager.v1.cache_utils import get_rdma_nics -from fastdeploy.platforms import current_platform -from fastdeploy.utils import get_logger - -logger = get_logger("cache_messager", "cache_messager.log") - - -def get_rdma_nics(): - res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh" - with importlib.resources.as_file(res) as path: - file_path = str(path) - - nic_type = current_platform.device_name - command = ["bash", file_path, nic_type] - result = subprocess.run( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=False, - ) - logger.info(f"get_rdma_nics command: {command}") - logger.info(f"get_rdma_nics output: {result.stdout}") - if result.returncode != 0: - raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}") - - env_name, env_value = result.stdout.strip().split("=") - if env_name != "KVCACHE_RDMA_NICS": - raise ValueError(f"Unexpected variable name: {env_name}, expected 'KVCACHE_RDMA_NICS'") - - return env_value +__all__ = ["get_rdma_nics"] diff --git a/fastdeploy/cache_manager/v1/__init__.py b/fastdeploy/cache_manager/v1/__init__.py index ca9380f8528..6e71c0a3ead 100644 --- a/fastdeploy/cache_manager/v1/__init__.py +++ b/fastdeploy/cache_manager/v1/__init__.py @@ -17,7 +17,7 @@ from .base import KVCacheBase from .cache_controller import CacheController from .cache_manager import CacheManager -from .cache_utils import LayerDoneCounter, LayerSwapTimeoutError +from .cache_utils import LayerDoneCounter, LayerSwapTimeoutError, get_rdma_nics from .metadata import ( AsyncTaskHandler, BlockNode, @@ -25,6 +25,7 @@ CacheStatus, MatchResult, PDTransferMetadata, + PendingPrefetch, StorageConfig, StorageMetadata, StorageType, @@ -49,6 +50,7 @@ "LayerSwapTimeoutError", # Utils "LayerDoneCounter", + "get_rdma_nics", # Metadata "CacheBlockMetadata", "BlockNode", @@ -60,6 +62,7 @@ "AsyncTaskHandler", "MatchResult", "StorageMetadata", + "PendingPrefetch", "PDTransferMetadata", "StorageConfig", "StorageType", diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 53b7292179f..2278961c2d1 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -18,6 +18,7 @@ import os import threading import time +import traceback from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -111,6 +112,11 @@ def write_policy(self) -> Optional[str]: return self.cache_config.write_policy return None + @property + def storage_enabled(self) -> bool: + """Whether a storage connector is available for Host↔Storage transfers.""" + return getattr(self._transfer_manager, "_storage_connector", None) is not None + def _should_wait_for_swap_out(self) -> bool: """ Determine if swap-out operations should wait synchronously. @@ -147,7 +153,17 @@ def submit_swap_tasks( # Note: evict returns LayerDoneCounter but we don't wait on it layer-by-layer # (except in write_back mode where we wait synchronously via wait_all) if evict_metadata is not None: - evict_counter = self.evict_device_to_host(evict_metadata) + # Build StorageMetadata when storage is enabled and hash_values are available, + # so that D2H eviction is automatically followed by Host→Storage backup. + storage_metadata = None + if self.storage_enabled and evict_metadata.hash_values: + storage_metadata = StorageMetadata( + hash_values=evict_metadata.hash_values, + block_ids=evict_metadata.dst_block_ids, + direction="evict", + ) + + evict_counter = self.evict_device_to_host(evict_metadata, storage_metadata) self._pending_evict_counters.append(evict_counter) # Step 3: For write_back, wait for evict to complete before submitting swap-in @@ -622,6 +638,15 @@ def initialize_host_cache( # Share host_cache_kvs_map with transfer manager self._transfer_manager.set_host_cache_kvs_map(self.host_cache_kvs_map) + # Propagate block shape so transfer manager can compute per-block byte offsets + # for prefetch_from_storage / backup_to_storage. + self._transfer_manager.set_host_block_shape( + key_shape=self._host_key_cache_shape, + value_shape=self._host_value_cache_shape, + scale_shape=self._host_cache_scale_shape, + cache_item_bytes=cache_item_bytes, + ) + def get_host_cache_kvs_map(self) -> Dict[str, Any]: """ Get the Host KV Cache pointer dictionary. @@ -641,6 +666,7 @@ def _submit_swap_task( transfer_fn_all: callable, transfer_fn_layer: callable, force_all_layers: bool = False, + on_success: callable = None, ) -> LayerDoneCounter: """ Submit a single swap transfer task (internal method). @@ -658,6 +684,8 @@ def _submit_swap_task( transfer_fn_all: All-layer transfer function, signature (src_ids, dst_ids) -> bool. transfer_fn_layer: Layer-by-layer transfer function, signature (layer_indices, on_layer_complete, src_ids, dst_ids) -> bool. force_all_layers: If True, always use all-layers mode (used for D2H evict). + on_success: Optional callback invoked after a successful transfer, + signature () -> None. Runs in the same worker thread. Returns: LayerDoneCounter instance for tracking layer completion. @@ -755,9 +783,14 @@ def _do_transfer(): meta.success = result.success meta.error_message = result.error_message - except Exception as e: - import traceback + # Chain next task on success + if result.success and on_success is not None: + try: + on_success() + except Exception as cb_err: + logger.error(f"[SwapTask] on_success callback failed: {cb_err}\n" f"{traceback.format_exc()}") + except Exception as e: traceback.print_exc() logger.error( f"[SwapTask] {src_location.value}->{dst_location.value} " @@ -807,6 +840,7 @@ def load_host_to_device( def evict_device_to_host( self, swap_metadata: CacheSwapMetadata, + storage_metadata: Optional[StorageMetadata] = None, ) -> LayerDoneCounter: """ Evict device cache to host (async). @@ -818,10 +852,24 @@ def evict_device_to_host( swap_metadata: CacheSwapMetadata containing: - src_block_ids: Source device block IDs - dst_block_ids: Destination host block IDs + storage_metadata: Optional StorageMetadata. If provided, a + backup_host_to_storage task is automatically submitted + after the D2H transfer succeeds (chained in the same + worker thread). Returns: LayerDoneCounter for tracking layer completion. """ + on_success = None + if storage_metadata is not None: + host_block_ids = swap_metadata.dst_block_ids + + def _on_success_backup(): + logger.debug(f"[EvictAndBackup] D2H done, chaining backup to storage " f"host_blocks={host_block_ids}") + self.backup_host_to_storage(host_block_ids, storage_metadata) + + on_success = _on_success_backup + layer_counter = self._submit_swap_task( meta=swap_metadata, src_location=CacheLevel.DEVICE, @@ -829,6 +877,7 @@ def evict_device_to_host( transfer_fn_all=lambda src_ids, dst_ids: self._transfer_manager.evict_to_host_async(src_ids, dst_ids), transfer_fn_layer=None, force_all_layers=True, # Eviction always uses output_stream for all-layers async transfer + on_success=on_success, ) return layer_counter @@ -854,35 +903,43 @@ def prefetch_from_storage( handler = AsyncTaskHandler() - # TODO: Implement storage prefetch logic - handler.set_error("Storage prefetch not implemented yet") - - return handler - - def backup_device_to_storage( - self, - device_block_ids: List[int], - metadata: StorageMetadata, - ) -> AsyncTaskHandler: - """ - Backup device cache to storage (async). + hash_values = metadata.hash_values + block_ids = metadata.block_ids - Backup KV cache from device memory to external storage - for reuse by subsequent requests. + if not hash_values or not block_ids: + logger.info(f"[StoragePrefetch] skip: empty hash_values={hash_values}, block_ids={block_ids}") + handler.set_error("Empty hash_values or block_ids in StorageMetadata") + return handler - Args: - device_block_ids: Device block IDs to backup. - metadata: Storage transfer metadata. - - Returns: - AsyncTaskHandler for tracking the async transfer task. - """ - - handler = AsyncTaskHandler() + def _do_prefetch(): + try: + start_time = time.time() + results = self._transfer_manager.prefetch_from_storage( + hash_list=hash_values, + cpu_block_list=block_ids, + ) + elapsed = time.time() - start_time - # TODO: Implement storage backup logic - handler.set_error("Storage backup not implemented yet") + success = all(results) + if success: + logger.debug( + f"[StoragePrefetch] success hash_values={hash_values} " + f"block_ids={block_ids} elapsed={elapsed*1000:.3f}ms" + ) + handler.set_result(results) + else: + failed_indices = [i for i, ok in enumerate(results) if not ok] + logger.warning( + f"[StoragePrefetch] partial failure " + f"failed_indices={failed_indices} elapsed={elapsed*1000:.3f}ms" + ) + handler.set_error(f"Storage prefetch failed for blocks at indices {failed_indices}") + except Exception as e: + traceback.print_exc() + logger.error(f"[StoragePrefetch] EXCEPTION: {e}\n{traceback.format_exc()}") + handler.set_error(str(e)) + self._executor.submit(_do_prefetch) return handler def backup_host_to_storage( @@ -905,9 +962,42 @@ def backup_host_to_storage( handler = AsyncTaskHandler() - # TODO: Implement storage backup logic - handler.set_error("Storage backup not implemented yet") + hash_values = metadata.hash_values + + if not host_block_ids or not hash_values: + logger.info(f"[StorageBackup] skip: empty host_block_ids={host_block_ids}, " f"hash_values={hash_values}") + handler.set_error("Empty host_block_ids or hash_values in StorageMetadata") + return handler + + def _do_backup(): + try: + start_time = time.time() + results = self._transfer_manager.backup_to_storage( + cpu_block_list=host_block_ids, + hash_list=hash_values, + ) + elapsed = time.time() - start_time + + success = all(results) + if success: + logger.debug( + f"[StorageBackup] success host_block_ids={host_block_ids} " + f"hash_values={hash_values} elapsed={elapsed*1000:.3f}ms" + ) + handler.set_result(results) + else: + failed_indices = [i for i, ok in enumerate(results) if not ok] + logger.warning( + f"[StorageBackup] partial failure " + f"failed_indices={failed_indices} elapsed={elapsed*1000:.3f}ms" + ) + handler.set_error(f"Storage backup failed for blocks at indices {failed_indices}") + except Exception as e: + traceback.print_exc() + logger.error(f"[StorageBackup] EXCEPTION: {e}\n{traceback.format_exc()}") + handler.set_error(str(e)) + self._executor.submit(_do_backup) return handler def send_to_node( diff --git a/fastdeploy/cache_manager/v1/cache_manager.py b/fastdeploy/cache_manager/v1/cache_manager.py index 6e7a0b47869..ee175b8dd4f 100644 --- a/fastdeploy/cache_manager/v1/cache_manager.py +++ b/fastdeploy/cache_manager/v1/cache_manager.py @@ -29,7 +29,16 @@ from .base import KVCacheBase from .block_pool import DeviceBlockPool, HostBlockPool -from .metadata import BlockNode, CacheLevel, CacheStatus, CacheSwapMetadata, MatchResult +from .cache_utils import storage_key_for_block +from .metadata import ( + BlockNode, + CacheLevel, + CacheStatus, + CacheSwapMetadata, + MatchResult, + PendingPrefetch, + StorageMetadata, +) from .radix_tree import RadixTree from .storage import create_storage_scheduler @@ -106,6 +115,14 @@ def __init__( self._pending_backup: List[Tuple[List[BlockNode], List[int]]] = [] self._pending_block_ids: List[int] = [] + # Mapping from host_block_id -> BlockNode for LOADING_FROM_STORAGE blocks, + # used to quickly update status to HOST once prefetch completes. + self._prefetch_node_map: Dict[int, BlockNode] = {} + + # Pending prefetch queue: tasks waiting to be dispatched by scheduler + self._pending_prefetch_list: List[PendingPrefetch] = [] + self._pending_prefetch_lock = threading.Lock() + # Storage scheduler (create using factory method if backend is configured) self._storage_scheduler = create_storage_scheduler(self.cache_config) @@ -259,7 +276,14 @@ def allocate_device_blocks( if self.enable_host_cache and match_result.matched_host_nums > 0: device_blocks = allocated[: match_result.matched_host_nums] + host_refs_before = [(n.block_id, n.ref_count) for n in match_result.host_nodes] free_host_block_ids = self._radix_tree.swap_to_device(match_result.host_nodes, device_blocks) + host_refs_after = [(n.block_id, n.ref_count) for n in match_result.host_nodes] + logger.info( + f"[Debug][allocate_device_blocks] request_id={request.request_id} " + f"swap host->device: host_refs: {host_refs_before} -> {host_refs_after}, " + f"host_blocks_released={free_host_block_ids}, device_blocks={device_blocks}" + ) logger.debug( f"[allocate_device_blocks] request_id={request.request_id} " f"swap host->device: host_block_ids={free_host_block_ids} -> device_block_ids={device_blocks}" @@ -294,6 +318,11 @@ def allocate_device_blocks( device_nodes, wasted_block_ids = self._radix_tree.insert(blocks=blocks, start_node=start_node) match_result.device_nodes.extend(device_nodes) + logger.info( + f"[Debug][allocate_device_blocks] request_id={request.request_id} " + f"insert uncached: nodes={[(n.block_id, n.cache_status.name, n.ref_count) for n in device_nodes]}, " + f"wasted={wasted_block_ids}" + ) inserted_block_ids = [n.block_id for n in device_nodes] logger.debug( @@ -405,18 +434,6 @@ def resize_device_pool(self, new_num_blocks: int) -> bool: # These methods provide backward compatibility with PrefixCacheManager interface # for resource_manager.py - def write_cache_to_storage(self, req: Any) -> None: - """ - Write request cache to storage if storage is enabled. - - Args: - req: The request object containing cache data to write - """ - if self._storage_scheduler is None: - return - # TODO: Implement storage write logic when storage is enabled - pass - @property def gpu_free_block_list(self) -> List[int]: """ @@ -481,7 +498,7 @@ def update_cache_config(self, new_cfg) -> None: def match_prefix( self, request: Request, - skip_storage: bool = True, + skip_storage: bool = False, ) -> None: """ Execute three-level cache matching (Device -> Host -> Storage). @@ -508,6 +525,11 @@ def match_prefix( # Step 1: Match Device and Host cache via RadixTree matched_nodes = self._radix_tree.find_prefix(block_hashes) + logger.info( + f"[Debug][match_prefix] request_id={request.request_id} skip_storage={skip_storage} " + f"find_prefix matched={len(matched_nodes)} nodes, " + f"refs={[(n.block_id, n.cache_status.name, n.ref_count) for n in matched_nodes]}" + ) # Split matched_nodes into device blocks and host blocks if self.enable_host_cache: @@ -526,11 +548,16 @@ def match_prefix( # Step 2: Match Storage (if enabled and not skipped) if not skip_storage and self._storage_scheduler and remaining_hashes: storage_matches = self._match_storage(remaining_hashes) - result.storage_nodes = self.prepare_prefetch_metadata(storage_matches) + start_node = matched_nodes[-1] if matched_nodes else None + result.storage_nodes = self.prepare_prefetch_metadata(storage_matches, start_node=start_node) - # Step 3: Increment ref count for matched blocks(only first match node) - if not (self._storage_scheduler and skip_storage): + # Step 3: Increment ref count for matched blocks(only scheduling phase) + if skip_storage: self._radix_tree.increment_ref_nodes(matched_nodes) + logger.info( + f"[Debug][match_prefix] request_id={request.request_id} " + f"after increment_ref: refs={[(n.block_id, n.cache_status.name, n.ref_count) for n in matched_nodes]}" + ) logger.info( f"match_prefix for request_id: {request.request_id} total_hashes: {len(block_hashes)}, " @@ -551,24 +578,63 @@ def match_prefix( def _match_storage(self, hash_values: List[str]) -> List[str]: """ - Match hash values against storage. + Match hash values against storage using per-layer keys. + + Checks each hash for existence in storage and returns the longest + consecutive prefix of hashes that are all present (prefix semantics + are required because a cache miss in the middle breaks prefetch continuity). + + Probes rank=0 per-layer keys: a block is considered present only when + ALL of its ``2 * num_layers`` per-layer keys exist. This avoids false + positives from partial writes where only some layers were stored. + + Storage key format (see cache_utils.storage_key_for_block): + "{hash_value}_0_key_{layer_idx}" / "{hash_value}_0_value_{layer_idx}" Args: - hash_values: List of hash values to check + hash_values: List of block hash values to check, in prefix order. Returns: - List of hashes that exist in storage + The leading sub-list of hash_values whose blocks all exist in storage. """ if not self._storage_scheduler: return [] try: if not self._storage_scheduler.is_connected(): - self._storage_scheduler.connect() - - existence_map = self._storage_scheduler.query(hash_values) - return [h for h, exists in existence_map.items() if exists] + logger.warning("_match_storage: storage scheduler disconnected, skipping storage match") + return [] + + num_layers = self.model_config.num_hidden_layers + per_block_keys = 2 * num_layers # key + value per layer + + # Probe all per-layer keys for rank=0. + # Flat layout: [h0_key_0, h0_value_0, ..., h0_key_L-1, h0_value_L-1, + # h1_key_0, h1_value_0, ...] + probe_keys = [] + for h in hash_values: + for layer_idx in range(num_layers): + probe_keys.append(storage_key_for_block(h, 0, "key", layer_idx)) + probe_keys.append(storage_key_for_block(h, 0, "value", layer_idx)) + + exist_flags = self._storage_scheduler.batch_exists(probe_keys) + + # A block is present only when all per-layer keys exist. + matched = [] + for block_idx, h in enumerate(hash_values): + start = block_idx * per_block_keys + ok = all(exist_flags[start : start + per_block_keys]) + if not ok: + break + matched.append(h) + + logger.debug( + f"[CacheManager] _match_storage: probing {len(hash_values)} blocks " + f"({len(probe_keys)} per-layer keys), matched={len(matched)}" + ) + return matched except Exception: + logger.warning("_match_storage failed", exc_info=True) return [] # ============ Eviction Methods ============ @@ -701,8 +767,17 @@ def request_finish( uncached_blocks.extend(request.block_tables[match_result.matched_device_nums :]) # Decrement ref count - blocks become evictable if ref_count reaches 0 + refs_before = [(n.block_id, n.cache_status.name, n.ref_count) for n in match_result.device_nodes] self._radix_tree.decrement_ref_nodes(match_result.device_nodes) + refs_after = [(n.block_id, n.cache_status.name, n.ref_count) for n in match_result.device_nodes] self._device_pool.release(uncached_blocks) + logger.info( + f"[Debug][request_finish] request_id={request.request_id} " + f"decrement device_nodes: before={refs_before}, after={refs_after}, " + f"evictable_device={len(self._radix_tree._evictable_device)}, " + f"evictable_host={len(self._radix_tree._evictable_host)}, " + f"available_device={self._device_pool.available_blocks()}" + ) cached_block_ids = [n.block_id for n in match_result.device_nodes] logger.debug( @@ -768,6 +843,7 @@ def issue_pending_backup_to_batch_request( all_device_block_ids = [] all_host_block_ids = [] + all_hash_values = [] freed_host_ids = [] for nodes, host_block_ids in self._pending_backup: @@ -792,9 +868,10 @@ def issue_pending_backup_to_batch_request( # Mark nodes as backed up self._radix_tree.backup_blocks(valid_nodes, valid_host_ids) - # Collect device block IDs + # Collect device block IDs and hash values all_device_block_ids.extend([node.block_id for node in valid_nodes]) all_host_block_ids.extend(valid_host_ids) + all_hash_values.extend([node.hash_value for node in valid_nodes]) # Release invalid host block allocations if freed_host_ids: @@ -811,6 +888,7 @@ def issue_pending_backup_to_batch_request( dst_block_ids=all_host_block_ids, src_type=CacheLevel.DEVICE, dst_type=CacheLevel.HOST, + hash_values=all_hash_values, ) return evict_metadata @@ -873,45 +951,6 @@ def check_and_add_pending_backup( except Exception as e: logger.error(f"check_and_add_pending_backup error: {e}, {str(traceback.format_exc())}") - # ============ Host/Device Transfer Coordination ============ - - def offload_to_host(self, block_indices: List[int]) -> bool: - """ - Offload blocks from device to host memory. - - This is a coordination method. Actual data transfer happens in Worker. - - Args: - block_indices: List of block indices to offload - - Returns: - True if successful, False otherwise - """ - try: - with self._lock: - # Allocate host blocks - host_indices = self._host_pool.allocate(len(block_indices)) - if host_indices is None or len(host_indices) != len(block_indices): - # Not enough host memory, release what we allocated - if host_indices: - self._host_pool.release(host_indices) - return False - - # Perform the offload (actual data transfer would happen in Worker) - for i, dev_idx in enumerate(block_indices): - host_idx = host_indices[i] - metadata = self._device_pool.get_metadata(dev_idx) - if metadata: - self._host_pool.set_metadata(host_idx, metadata) - - # Release device blocks - self._device_pool.release(block_indices) - - return True - except Exception as e: - logger.error(f"offload_to_host error: {e}, {str(traceback.format_exc())}") - return False - def load_from_host(self, block_indices: List[int]) -> bool: """ Load blocks from host to device memory. @@ -945,9 +984,90 @@ def load_from_host(self, block_indices: List[int]) -> bool: # ============ Prefetch Methods ============ + def prefetch_storage(self, request: "Request") -> bool: + """ + Execute storage matching and enqueue prefetch info for later dispatch. + + Called from the preprocess thread. Does match_prefix(skip_storage=False) + to probe storage, allocate host blocks, and enqueue PendingPrefetch + into the pending list. The scheduler will drain and dispatch later. + + Args: + request: The request to prefetch cache for. + + Returns: + True if storage blocks were matched and enqueued, False otherwise. + """ + if not self.enable_prefix_caching: + return False + + self.match_prefix(request, skip_storage=False) + match_result = request.match_result + request.match_result = None + + if match_result is None or match_result.matched_storage_nums == 0: + return False + + storage_nodes = match_result.storage_nodes + host_block_ids = [node.block_id for node in storage_nodes] + hash_values = [node.hash_value for node in storage_nodes] + + metadata = StorageMetadata( + hash_values=hash_values, + block_ids=host_block_ids, + direction="load", + ) + + pending = PendingPrefetch( + request_id=request.request_id, + metadata=metadata, + host_block_ids=host_block_ids, + ) + with self._pending_prefetch_lock: + self._pending_prefetch_list.append(pending) + + logger.info( + f"[Debug][StoragePrefetch] request_id={request.request_id} " + f"storage_matched={match_result.matched_storage_nums} blocks, enqueued for dispatch" + ) + return True + + def drain_pending_prefetches(self) -> List[PendingPrefetch]: + """Atomically drain all pending prefetch tasks for scheduler dispatch.""" + with self._pending_prefetch_lock: + items = self._pending_prefetch_list + self._pending_prefetch_list = [] + return items + + def cancel_pending_prefetch(self, request_id: str) -> List[int]: + """ + Cancel a pending prefetch by request ID. + + Removes the matching entry from the pending list (not yet dispatched to + workers) and returns its pre-allocated host block IDs for cleanup. + + Args: + request_id: The request ID to cancel. + + Returns: + List of host block IDs that were pre-allocated, or empty list if + no matching pending prefetch was found. + """ + with self._pending_prefetch_lock: + host_block_ids: List[int] = [] + remaining = [] + for item in self._pending_prefetch_list: + if item.request_id == request_id: + host_block_ids = item.host_block_ids + else: + remaining.append(item) + self._pending_prefetch_list = remaining + return host_block_ids + def prepare_prefetch_metadata( self, storage_hashes: List[str], + start_node: Optional["BlockNode"] = None, ) -> Optional[List["BlockNode"]]: """ Prepare metadata for storage prefetch operation. @@ -957,6 +1077,10 @@ def prepare_prefetch_metadata( Args: storage_hashes: List of storage hash values to prefetch + start_node: Node to start insertion from in the radix tree. + Must be the last matched node from find_prefix so that + the new LOADING_FROM_STORAGE nodes are attached as proper + extensions of the existing prefix chain. Returns: List of BlockNode objects if successful, None or empty list otherwise. @@ -964,7 +1088,7 @@ def prepare_prefetch_metadata( (may differ from originally allocated if node was reused). """ if not storage_hashes: - return None + return [] try: with self._lock: @@ -972,24 +1096,112 @@ def prepare_prefetch_metadata( if not self.can_allocate_host_blocks(len(storage_hashes)): return [] - # Allocate host blocks for prefetch - host_block_ids = self._host_pool.allocate(len(storage_hashes)) - if host_block_ids is None or len(host_block_ids) == 0: + # Allocate host blocks for prefetch (evicts evictable host blocks if needed) + host_block_ids = self.allocate_host_blocks(len(storage_hashes)) + if not host_block_ids: return [] blocks = list(zip(storage_hashes, host_block_ids)) prefetch_nodes, wasted_block_ids = self._radix_tree.insert( - blocks=blocks, cache_status=CacheStatus.LOADING_FROM_STORAGE + blocks=blocks, cache_status=CacheStatus.LOADING_FROM_STORAGE, start_node=start_node + ) + logger.info( + f"[Debug][prepare_prefetch_metadata] after insert: " + f"prefetch_nodes={[(n.block_id, n.cache_status.name, n.ref_count) for n in prefetch_nodes]}, " + f"wasted={wasted_block_ids}" ) # Release any blocks that were wasted due to node reuse if wasted_block_ids: self._host_pool.release(wasted_block_ids) - return prefetch_nodes + # Register only truly new LOADING_FROM_STORAGE nodes. + # insert() reuses existing nodes without updating their status, so nodes + # that were already HOST/DEVICE must be excluded — they don't need a + # storage transfer and would trigger a spurious "unexpected status" warning + # in update_storage_blocks_to_host. + actual_prefetch_nodes = [] + for node in prefetch_nodes: + if node.cache_status == CacheStatus.LOADING_FROM_STORAGE: + self._prefetch_node_map[node.block_id] = node + actual_prefetch_nodes.append(node) + + return actual_prefetch_nodes except Exception as e: logger.error(f"prepare_prefetch_metadata error: {e}, {str(traceback.format_exc())}") return [] + def update_storage_blocks_to_host(self, host_block_ids: List[int]) -> None: + """ + Mark storage-prefetched blocks as HOST after data transfer completes. + + Called by Scheduler when all TP workers report prefetch done for a batch + of blocks. Transitions block status LOADING_FROM_STORAGE → HOST so that + these blocks become eligible for swap-in scheduling. + + Args: + host_block_ids: List of host block IDs that finished loading. + """ + if not host_block_ids: + return + try: + with self._lock: + updated = 0 + for block_id in host_block_ids: + node = self._prefetch_node_map.pop(block_id, None) + if node is None: + logger.warning( + f"[StoragePrefetch] update_storage_blocks_to_host: " + f"block_id={block_id} not found in prefetch_node_map" + ) + continue + if node.cache_status == CacheStatus.LOADING_FROM_STORAGE: + old_ref = node.ref_count + node.cache_status = CacheStatus.HOST + updated += 1 + # Balance the ref_count from insert() in prepare_prefetch_metadata. + # The prefetch is complete, the node should be in idle-cached state (ref=0). + self._radix_tree.decrement_ref_nodes([node]) + logger.info( + f"[Debug][update_storage_blocks_to_host] block_id={block_id} " + f"LFS->HOST, ref: {old_ref}->{node.ref_count}" + ) + else: + logger.warning( + f"[StoragePrefetch] update_storage_blocks_to_host: " + f"block_id={block_id} unexpected status={node.cache_status}" + ) + logger.info( + f"[StoragePrefetch] update_storage_blocks_to_host: " + f"requested={len(host_block_ids)}, updated={updated}" + ) + except Exception as e: + logger.error(f"update_storage_blocks_to_host error: {e}, {str(traceback.format_exc())}") + + def abort_prefetch_blocks(self, host_block_ids: List[int]) -> None: + """ + Abort in-flight prefetch blocks on failure. + + Removes nodes from the prefetch_node_map, deletes them from the RadixTree, + and releases their host pool blocks. Called when the storage→CPU transfer + fails so that LOADING_FROM_STORAGE blocks do not leak. + + Args: + host_block_ids: List of host block IDs whose prefetch should be aborted. + """ + if not host_block_ids: + return + try: + with self._lock: + for block_id in host_block_ids: + node = self._prefetch_node_map.pop(block_id, None) + if node is None: + continue + self._radix_tree._remove_node_from_tree(node) + self._host_pool.release(host_block_ids) + logger.warning(f"[StoragePrefetch] abort_prefetch_blocks: released {len(host_block_ids)} host blocks") + except Exception as e: + logger.error(f"abort_prefetch_blocks error: {e}, {str(traceback.format_exc())}") + # ============ Reset Methods ============ def reset_cache(self) -> bool: diff --git a/fastdeploy/cache_manager/v1/cache_utils.py b/fastdeploy/cache_manager/v1/cache_utils.py index 589d2c46e7a..4f5cac2b625 100644 --- a/fastdeploy/cache_manager/v1/cache_utils.py +++ b/fastdeploy/cache_manager/v1/cache_utils.py @@ -15,7 +15,9 @@ """ import hashlib +import importlib import pickle +import subprocess import threading import time from typing import Any, Callable, Dict, List, Optional, Sequence, Set @@ -23,6 +25,34 @@ from paddleformers.utils.log import logger +def get_rdma_nics() -> str: + from fastdeploy.platforms import current_platform + + res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh" + with importlib.resources.as_file(res) as path: + file_path = str(path) + + nic_type = current_platform.device_name + command = ["bash", file_path, nic_type] + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + logger.info(f"get_rdma_nics command: {command}") + logger.info(f"get_rdma_nics output: {result.stdout}") + if result.returncode != 0: + raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}") + + env_name, env_value = result.stdout.strip().split("=") + if env_name != "KVCACHE_RDMA_NICS": + raise ValueError(f"Unexpected variable name: {env_name}, expected 'KVCACHE_RDMA_NICS'") + + return env_value + + class LayerDoneCounter: """ Independent synchronization primitive for tracking layer completion of a single transfer. @@ -422,6 +452,30 @@ class LayerSwapTimeoutError(Exception): pass +# ============ Storage Key Computation ============ + + +def storage_key_for_block(hash_value: str, local_rank: int, kind: str, layer_idx: int | None = None) -> str: + """Build a storage key for a single block / kind, optionally per-layer. + + Key format (per-block): ``{hash_value}_{local_rank}_{kind}`` + Key format (per-layer): ``{hash_value}_{local_rank}_{kind}_{layer_idx}`` + + Args: + hash_value: Block hash value (from Scheduler). + local_rank: Local rank index of the current process. + kind: One of "key", "value", "key_scale", "value_scale". + layer_idx: Optional layer index for per-layer keys. + + Returns: + Storage key string. + """ + base = f"{hash_value}_{local_rank}_{kind}" + if layer_idx is not None: + return f"{base}_{layer_idx}" + return base + + # ============ Block Hash Computation ============ diff --git a/fastdeploy/cache_manager/v1/metadata.py b/fastdeploy/cache_manager/v1/metadata.py index 5337eeb5458..0996aaf3345 100644 --- a/fastdeploy/cache_manager/v1/metadata.py +++ b/fastdeploy/cache_manager/v1/metadata.py @@ -409,6 +409,23 @@ class StorageMetadata: extra_params: Dict[str, Any] = field(default_factory=dict) +@dataclass +class PendingPrefetch: + """ + Represents a pending storage prefetch task enqueued by CacheManager, + waiting to be dispatched to workers by the scheduler. + + Attributes: + request_id: The request that triggered this prefetch. + metadata: StorageMetadata with hash_values and block_ids for the transfer. + host_block_ids: Pre-allocated host block IDs (for cleanup on failure). + """ + + request_id: str = "" + metadata: "StorageMetadata" = field(default_factory=lambda: StorageMetadata()) + host_block_ids: List[int] = field(default_factory=list) + + @dataclass class PDTransferMetadata: """ diff --git a/fastdeploy/cache_manager/v1/radix_tree.py b/fastdeploy/cache_manager/v1/radix_tree.py index aea19835878..337c51fd508 100644 --- a/fastdeploy/cache_manager/v1/radix_tree.py +++ b/fastdeploy/cache_manager/v1/radix_tree.py @@ -252,7 +252,11 @@ def find_prefix( break node = node.children[block_hash] - if node.cache_status in (CacheStatus.DELETING, CacheStatus.SWAP_TO_HOST): + if node.cache_status in ( + CacheStatus.DELETING, + CacheStatus.SWAP_TO_HOST, + CacheStatus.LOADING_FROM_STORAGE, + ): break node.touch() diff --git a/fastdeploy/cache_manager/v1/storage/__init__.py b/fastdeploy/cache_manager/v1/storage/__init__.py index b1c986b9a4e..45054697076 100644 --- a/fastdeploy/cache_manager/v1/storage/__init__.py +++ b/fastdeploy/cache_manager/v1/storage/__init__.py @@ -64,55 +64,43 @@ def create_storage_scheduler( scheduler = MooncakeStorageScheduler(config) - elif config.kvcache_storage_backend == "attention_store": - from .attnstore.connector import AttnStoreScheduler - - scheduler = AttnStoreScheduler(config) - else: - raise ValueError( - f"Unsupported storage type: {config.kvcache_storage_backend}. " - f"Supported types: mooncake, attention_store, local" - ) + raise ValueError(f"Unsupported storage type: {config.kvcache_storage_backend}. " "Supported types: mooncake") # Attempt connection if scheduler is not None: if not scheduler.connect(): - # Log warning but still return the scheduler - pass + raise RuntimeError( + f"Failed to connect to storage backend '{config.kvcache_storage_backend}'. " + "Check server address, credentials, and network connectivity." + ) return scheduler def create_storage_connector( config: Any, + tp_rank: Optional[int] = None, ) -> Optional[StorageConnector]: """ Create a StorageConnector instance based on configuration. This is a factory function that creates the appropriate StorageConnector based on the storage backend type specified in the configuration. + The caller is responsible for calling ``connector.connect()`` when ready. Args: config: Configuration object, can be: - CacheConfig: FastDeploy configuration object - Dict: Dictionary with 'storage_type' and backend-specific settings - StorageConfig: StorageConfig dataclass instance + tp_rank: Tensor-parallel rank, passed to the connector for RDMA NIC + selection (Mooncake only). When None, the connector uses the + default device selection strategy. Returns: - StorageConnector instance if successful, None otherwise - - Example: - # Using CacheConfig - connector = create_storage_connector(fd_config) - - # Using dict config - config = { - 'storage_type': 'mooncake', - 'server_addr': 'localhost:8080', - 'buffer_size': 1024 * 1024, - } - connector = create_storage_connector(config) + StorageConnector instance (not yet connected), or None if no backend + is configured. """ if config.kvcache_storage_backend is None: return None @@ -123,24 +111,10 @@ def create_storage_connector( if config.kvcache_storage_backend == "mooncake": from .mooncake.connector import MooncakeStorageConnector - connector = MooncakeStorageConnector(config) - - elif config.kvcache_storage_backend == "attention_store": - from .attnstore.connector import AttnStoreConnector - - connector = AttnStoreConnector(config) + connector = MooncakeStorageConnector(config, tp_rank=tp_rank) else: - raise ValueError( - f"Unsupported storage type: {config.kvcache_storage_backend}. " - f"Supported types: mooncake, attention_store, local" - ) - - # Attempt connection - if connector is not None: - if not connector.connect(): - # Log warning but still return the connector - pass + raise ValueError(f"Unsupported storage type: {config.kvcache_storage_backend}. " "Supported types: mooncake") return connector diff --git a/fastdeploy/cache_manager/v1/storage/attnstore/connector.py b/fastdeploy/cache_manager/v1/storage/attnstore/connector.py index c63f1b74d68..f2265ca1ea3 100644 --- a/fastdeploy/cache_manager/v1/storage/attnstore/connector.py +++ b/fastdeploy/cache_manager/v1/storage/attnstore/connector.py @@ -51,33 +51,12 @@ def disconnect(self) -> None: """Disconnect from AttnStore.""" self._connected = False - def exists(self, key: str) -> bool: - """Check if key exists in AttnStore.""" + def batch_exists(self, keys: List[str]) -> List[bool]: + """Batch check existence of multiple keys.""" if not self._connected: - return False - # Placeholder implementation - return False - - def query(self, keys: List[str]) -> Dict[str, bool]: - """Query multiple keys for existence.""" - if not self._connected: - return {k: False for k in keys} - # Placeholder implementation - return {k: False for k in keys} - - def get_metadata(self, key: str) -> Optional[Dict[str, Any]]: - """Get metadata for a key.""" - if not self._connected: - return None - # Placeholder implementation - return None - - def list_keys(self, prefix: str = "") -> List[str]: - """List keys with a given prefix.""" - if not self._connected: - return [] + return [False] * len(keys) # Placeholder implementation - return [] + return [False] * len(keys) class AttnStoreConnector(StorageConnector): @@ -111,30 +90,26 @@ def disconnect(self) -> None: """Disconnect from AttnStore.""" self._connected = False - def get(self, key: str, dst_buffer: Any) -> bool: - """Get data from AttnStore.""" + def batch_get( + self, + keys: List[str], + dst_ptrs: List[int], + sizes: List[int], + ) -> List[bool]: + """Batch get multiple objects from storage via zero-copy.""" if not self._connected: - return False - # Placeholder implementation - return False - - def set(self, key: str, src_buffer: Any, size: int) -> bool: - """Set data in AttnStore.""" - if not self._connected: - return False + return [False] * len(keys) # Placeholder implementation - return False - - def delete(self, key: str) -> bool: - """Delete data from AttnStore.""" - if not self._connected: - return False - # Placeholder implementation - return False - - def clear(self, prefix: str = "") -> int: - """Clear data from AttnStore.""" + return [False] * len(keys) + + def batch_set( + self, + keys: List[str], + src_ptrs: List[int], + sizes: List[int], + ) -> List[bool]: + """Batch set multiple objects into storage via zero-copy.""" if not self._connected: - return 0 + return [False] * len(keys) # Placeholder implementation - return 0 + return [False] * len(keys) diff --git a/fastdeploy/cache_manager/v1/storage/base.py b/fastdeploy/cache_manager/v1/storage/base.py index 3ad64480e9d..ea8b248025b 100644 --- a/fastdeploy/cache_manager/v1/storage/base.py +++ b/fastdeploy/cache_manager/v1/storage/base.py @@ -24,25 +24,34 @@ class StorageScheduler(ABC): Abstract base class for storage scheduler operations. Used by CacheManager (Scheduler process) to query storage - existence and metadata without performing actual data transfer. + existence without performing actual data transfer. + + Minimal interface for backend implementations: + - ``connect`` / ``disconnect`` — lifecycle + - ``batch_exists`` — the only query method required by CacheManager """ def __init__(self, config: Optional[Dict[str, Any]] = None): - """ - Initialize the storage scheduler. + from fastdeploy.utils import get_logger - Args: - config: Storage configuration - """ self.config = config or {} self._lock = threading.RLock() self._connected = False + self.logger = get_logger("mooncake_storage", "cache_manager.log") + + # ------------------------------------------------------------------ + # Abstract methods — must be implemented by every backend + # ------------------------------------------------------------------ @abstractmethod def connect(self) -> bool: """ Connect to the storage backend. + Implementations must be idempotent: if already connected + (``self._connected is True``) return ``True`` immediately without + re-initialising the underlying client. + Returns: True if connection was successful """ @@ -54,56 +63,21 @@ def disconnect(self) -> None: pass @abstractmethod - def exists(self, key: str) -> bool: - """ - Check if a key exists in storage. - - Args: - key: Storage key to check - - Returns: - True if key exists - """ - pass - - @abstractmethod - def query(self, keys: List[str]) -> Dict[str, bool]: + def batch_exists(self, keys: List[str]) -> List[bool]: """ - Query multiple keys for existence. + Batch check existence of multiple keys. Args: - keys: List of keys to query + keys: List of storage keys to check Returns: - Dictionary mapping keys to existence status + List of booleans corresponding to each key's existence """ pass - @abstractmethod - def get_metadata(self, key: str) -> Optional[Dict[str, Any]]: - """ - Get metadata for a key. - - Args: - key: Storage key - - Returns: - Metadata dictionary or None if not found - """ - pass - - @abstractmethod - def list_keys(self, prefix: str = "") -> List[str]: - """ - List keys with a given prefix. - - Args: - prefix: Key prefix to filter - - Returns: - List of matching keys - """ - pass + # ------------------------------------------------------------------ + # Concrete methods + # ------------------------------------------------------------------ def is_connected(self) -> bool: """Check if connected to storage.""" @@ -123,24 +97,38 @@ class StorageConnector(ABC): Used by CacheController (Worker process) to perform actual data transfer operations with the storage backend. + + All get/set operations use zero-copy semantics: callers pass raw memory + pointers (int) and sizes (int, bytes) so the backend can perform direct + RDMA transfers without an intermediate copy. + + Minimal interface for backend implementations: + - ``connect`` / ``disconnect`` — lifecycle + - ``batch_get`` — prefetch from storage + - ``batch_set`` — backup to storage """ def __init__(self, config: Optional[Dict[str, Any]] = None): - """ - Initialize the storage connector. + from paddleformers.utils.log import logger - Args: - config: Storage configuration - """ self.config = config or {} self._lock = threading.RLock() self._connected = False + self.logger = logger + + # ------------------------------------------------------------------ + # Abstract methods — must be implemented by every backend + # ------------------------------------------------------------------ @abstractmethod def connect(self) -> bool: """ Connect to the storage backend. + Implementations must be idempotent: if already connected + (``self._connected is True``) return ``True`` immediately without + re-initialising the underlying client. + Returns: True if connection was successful """ @@ -152,59 +140,86 @@ def disconnect(self) -> None: pass @abstractmethod - def get(self, key: str, dst_buffer: Any) -> bool: + def batch_get( + self, + keys: List[str], + dst_ptrs: List[int], + sizes: List[int], + ) -> List[bool]: """ - Get data from storage. + Batch get multiple objects from storage into pre-allocated zero-copy buffers. Args: - key: Storage key - dst_buffer: Destination buffer to write data + keys: List of storage keys + dst_ptrs: List of destination memory pointers (must be registered if RDMA) + sizes: List of expected sizes in bytes Returns: - True if get was successful + List of booleans indicating success for each key """ pass @abstractmethod - def set(self, key: str, src_buffer: Any, size: int) -> bool: + def batch_set( + self, + keys: List[str], + src_ptrs: List[int], + sizes: List[int], + ) -> List[bool]: """ - Set data in storage. + Batch set multiple objects into storage from zero-copy source buffers. Args: - key: Storage key - src_buffer: Source buffer to read data from - size: Size of data in bytes + keys: List of storage keys + src_ptrs: List of source memory pointers (must be registered if RDMA) + sizes: List of data sizes in bytes Returns: - True if set was successful + List of booleans indicating success for each key """ pass - @abstractmethod - def delete(self, key: str) -> bool: + # ------------------------------------------------------------------ + # Concrete methods — backends may override for efficiency + # ------------------------------------------------------------------ + + def register_buffer(self, buffer_ptr: int, buffer_size: int) -> None: """ - Delete data from storage. + Register a memory buffer with the storage backend for zero-copy transfer. + + This must be called before using the buffer pointer in get/set operations + when the backend requires RDMA memory registration (e.g., Mooncake). + Backends that do not need registration can leave this as a no-op. Args: - key: Storage key to delete + buffer_ptr: Raw pointer (int) to the start of the memory region + buffer_size: Size of the memory region in bytes - Returns: - True if deletion was successful + Raises: + RuntimeError: If registration fails """ pass - @abstractmethod - def clear(self, prefix: str = "") -> int: + def batch_exists(self, keys: List[str]) -> List[bool]: + """ + Batch check key existence. Backends that support it should override. + Default returns False for all keys (conservative: assume missing). """ - Clear data from storage. + return [False] * len(keys) - Args: - prefix: Key prefix to clear (empty for all) + def batch_delete(self, keys: List[str]) -> List[bool]: + """ + Delete multiple keys. Backends can override for efficiency. + Default returns False for all keys. + """ + return [False] * len(keys) - Returns: - Number of keys cleared + def clear(self) -> int: """ - pass + Clear all data from storage. Optional — backends that support it + should override. Default is a no-op returning 0. + """ + return 0 def is_connected(self) -> bool: """Check if connected to storage.""" diff --git a/fastdeploy/cache_manager/v1/storage/mooncake/connector.py b/fastdeploy/cache_manager/v1/storage/mooncake/connector.py index a8e0d01010d..d344e5f35f0 100644 --- a/fastdeploy/cache_manager/v1/storage/mooncake/connector.py +++ b/fastdeploy/cache_manager/v1/storage/mooncake/connector.py @@ -14,155 +14,728 @@ # limitations under the License. """ +import json +import os +import time +import traceback +import uuid +from dataclasses import dataclass from typing import Any, Dict, List, Optional +from fastdeploy.utils import get_host_ip + from ..base import StorageConnector, StorageScheduler +DEFAULT_GLOBAL_SEGMENT_SIZE = 1024 * 1024 * 1024 # 1 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1024 * 1024 # 1 MB +DEFAULT_MC_MAX_MR_SIZE = 4 * 1024 * 1024 * 1024 # 4 GB +MIN_MC_MAX_MR_SIZE = 1024 * 1024 * 1024 # 1 GB +MAX_MC_MAX_MR_SIZE = 6 * 1024 * 1024 * 1024 # 6 GB -class MooncakeStorageScheduler(StorageScheduler): + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class MooncakeStorageConfig: """ - Mooncake storage scheduler for Scheduler process. + Configuration for Mooncake distributed store. - Provides query operations for Mooncake distributed storage. + Loaded with the following priority (highest first): + 1. Explicit keyword arguments passed to ``from_config`` + 2. JSON config file at ``MOONCAKE_CONFIG_PATH`` + 3. Individual environment variables """ - def __init__(self, config: Optional[Dict[str, Any]] = None): + local_hostname: str + metadata_server: str + master_server_addr: str + global_segment_size: int + local_buffer_size: int + protocol: str + rdma_devices: str + + # --------------------------------------------------------------------------- + + @staticmethod + def create(extra: Optional[Dict[str, Any]] = None) -> "MooncakeStorageConfig": """ - Initialize Mooncake storage scheduler. + Load config from (in priority order): + 1. ``extra`` dict (e.g. from CacheConfig.kvcache_storage_config) + 2. JSON file at ``MOONCAKE_CONFIG_PATH`` + 3. Environment variables Args: - config: Configuration with keys: - - server_addr: Mooncake server address - - namespace: Storage namespace - - timeout: Connection timeout + extra: Optional dict of override values (takes highest priority). + + Returns: + Populated ``MooncakeStorageConfig`` instance. + """ + extra = extra or {} + + # --- base from env / file --- + file_path = os.getenv("MOONCAKE_CONFIG_PATH") + host_ip = get_host_ip() + file_cfg: Dict[str, Any] = {} + + if file_path: + if not os.path.exists(file_path): + raise FileNotFoundError(f"MOONCAKE_CONFIG_PATH points to non-existent file: {file_path}") + with open(file_path) as f: + file_cfg = json.load(f) + + def _get(key: str, default: Any = None) -> Any: + """extra > file > env > default""" + if key in extra: + return extra[key] + if key in file_cfg: + return file_cfg[key] + env_map = { + "local_hostname": "MOONCAKE_LOCAL_HOSTNAME", + "metadata_server": "MOONCAKE_METADATA_SERVER", + "master_server_addr": "MOONCAKE_MASTER_SERVER_ADDR", + "global_segment_size": "MOONCAKE_GLOBAL_SEGMENT_SIZE", + "local_buffer_size": "MOONCAKE_LOCAL_BUFFER_SIZE", + "protocol": "MOONCAKE_PROTOCOL", + "rdma_devices": "MOONCAKE_RDMA_DEVICES", + } + if key in env_map: + return os.environ.get(env_map[key], default) + return default + + local_hostname = _get("local_hostname", host_ip) + metadata_server = _get("metadata_server") + master_server_addr = _get("master_server_addr") + global_segment_size = int(_get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)) + local_buffer_size = int(_get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)) + protocol = _get("protocol", "rdma") + rdma_devices = _get("rdma_devices", "") + + if metadata_server is None or master_server_addr is None: + raise ValueError( + "Both MOONCAKE_METADATA_SERVER and MOONCAKE_MASTER_SERVER_ADDR must be provided " + "(via extra config, config file, or environment variables)." + ) + if local_hostname == "localhost": + raise ValueError("local_hostname must not be 'localhost'; Mooncake requires a real IP or hostname.") + + # Auto-detect RDMA NICs if not provided + if rdma_devices == "" and protocol == "rdma": + try: + from fastdeploy.cache_manager.v1.cache_utils import get_rdma_nics + + rdma_devices = get_rdma_nics() + except Exception: + pass + + return MooncakeStorageConfig( + local_hostname=local_hostname, + metadata_server=metadata_server, + master_server_addr=master_server_addr, + global_segment_size=global_segment_size, + local_buffer_size=local_buffer_size, + protocol=protocol, + rdma_devices=rdma_devices, + ) + + def select_rdma_device(self, tp_rank: int) -> None: + """Select a single RDMA device from a comma-separated list by TP rank.""" + devices = [d.strip() for d in self.rdma_devices.split(",") if d.strip()] + if devices: + self.rdma_devices = devices[tp_rank % len(devices)] + + +# --------------------------------------------------------------------------- +# Shared helper — wraps the raw MooncakeDistributedStore +# --------------------------------------------------------------------------- + + +class _MooncakeStoreBase: + """ + Thin wrapper around ``mooncake.store.MooncakeDistributedStore`` shared by + both the Scheduler and Connector implementations. + """ + + def __init__(self, logger) -> None: + self._store = None # MooncakeDistributedStore instance + self.logger = logger + self.mc_max_mr_size = DEFAULT_MC_MAX_MR_SIZE + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + # Minimal segment size for the Scheduler process. + # The Scheduler only calls batch_is_exist (metadata query, no RDMA transfer), + # so there is no need to allocate a large global segment. + _SCHEDULER_SEGMENT_SIZE = 64 * 1024 * 1024 # 64 MB + + def _setup_store( + self, + cfg: MooncakeStorageConfig, + tp_rank: Optional[int] = None, + scheduler_mode: bool = False, + ) -> None: + """ + Import the SDK and call ``store.setup()``. + + Args: + cfg: Populated ``MooncakeStorageConfig``. + tp_rank: When provided, selects a single RDMA device from the + comma-separated ``cfg.rdma_devices`` list by rank modulo. + Must be provided for Connector (worker) instances. + scheduler_mode: When True, selects the first RDMA device from the + list (scheduler has no tp_rank) and uses a small + ``global_segment_size`` since no actual data transfer happens. + """ + try: + from mooncake.store import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "mooncake package not found. Install it by following " + "https://kvcache-ai.github.io/Mooncake/python-api-reference/mooncake-store.html" + ) from e + + if tp_rank is not None: + # Worker path: pick one device per TP rank + cfg.select_rdma_device(tp_rank) + elif scheduler_mode: + # Scheduler path: Mooncake setup() expects a single device name, + # not a comma-separated list. Pick the first available device. + cfg.select_rdma_device(0) + + # Scheduler does not transfer data — avoid allocating a large segment. + if scheduler_mode: + cfg.global_segment_size = self._SCHEDULER_SEGMENT_SIZE + + host_ip = get_host_ip() + os.environ.setdefault("MC_TCP_BIND_ADDRESS", host_ip) + + # Configure MC_MAX_MR_SIZE for buffer registration + raw_mr_size = int(os.environ.get("MC_MAX_MR_SIZE", 0)) + if raw_mr_size == 0: + self.mc_max_mr_size = DEFAULT_MC_MAX_MR_SIZE + elif raw_mr_size < MIN_MC_MAX_MR_SIZE: + self.mc_max_mr_size = MIN_MC_MAX_MR_SIZE + elif raw_mr_size > MAX_MC_MAX_MR_SIZE: + self.mc_max_mr_size = MAX_MC_MAX_MR_SIZE + else: + self.mc_max_mr_size = raw_mr_size + os.environ["MC_MAX_MR_SIZE"] = str(self.mc_max_mr_size) + + self._store = MooncakeDistributedStore() + ret = self._store.setup( + local_hostname=cfg.local_hostname, + metadata_server=cfg.metadata_server, + global_segment_size=cfg.global_segment_size, + local_buffer_size=cfg.local_buffer_size, + protocol=cfg.protocol, + rdma_devices=cfg.rdma_devices, + master_server_addr=cfg.master_server_addr, + ) + if ret != 0: + raise RuntimeError(f"MooncakeDistributedStore.setup() returned error code {ret}") + self.logger.info("MooncakeDistributedStore connected successfully.") + + def _teardown_store(self) -> None: + """Release the store (destructor handles cleanup).""" + self._store = None + + # ------------------------------------------------------------------ + # Warmup + # ------------------------------------------------------------------ + + def _warmup(self, prefix: str = "fd") -> None: + """Send a small test key to verify connectivity.""" + key = f"{prefix}_mooncake_warmup_{uuid.uuid4().hex}" + value = bytes(1 * 1024 * 1024) # 1 MB + rc = self._store.put(key, value) + if rc != 0: + raise RuntimeError(f"Warmup put failed for key={key}, rc={rc}") + rc = self._store.is_exist(key) + if rc != 1: + raise RuntimeError(f"Warmup exists check failed for key={key}, rc={rc}") + self._store.get(key) + self._store.remove(key) + + # ------------------------------------------------------------------ + # Low-level zero-copy primitives + # ------------------------------------------------------------------ + + def _batch_put( + self, + keys: List[str], + src_ptrs: List[int], + sizes: List[int], + ) -> List[int]: + """ + Call ``store.batch_put_from``. + + Returns: + List of ints: 0 = success, negative = error. + """ + tic = time.perf_counter() + results: List[int] = self._store.batch_put_from(keys, src_ptrs, sizes) + elapsed = time.perf_counter() - tic + success = results.count(0) + total = len(keys) + if success == total: + self.logger.debug(f"batch_put {total} keys in {elapsed:.4f}s") + else: + self.logger.error(f"batch_put: {total - success}/{total} keys failed, elapsed={elapsed:.4f}s") + if success > 0: + total_bytes = sum(s for r, s in zip(results, sizes) if r == 0) + speed_gbs = total_bytes / (elapsed * 1024**3) if elapsed > 0 else float("inf") + self.logger.debug(f"batch_put throughput: {total_bytes / 1024**3:.4f} GB @ {speed_gbs:.4f} GB/s") + return results + + def _batch_get( + self, + keys: List[str], + dst_ptrs: List[int], + sizes: List[int], + ) -> List[int]: + """ + Call ``store.batch_get_into``. + + Returns: + List of ints: bytes_read (> 0) = success, negative = error. + """ + tic = time.perf_counter() + results: List[int] = self._store.batch_get_into(keys, dst_ptrs, sizes) + elapsed = time.perf_counter() - tic + success = sum(1 for r in results if r > 0) + total = len(keys) + if success == total: + self.logger.debug(f"batch_get {total} keys in {elapsed:.4f}s") + else: + self.logger.error(f"batch_get: {total - success}/{total} keys failed, elapsed={elapsed:.4f}s") + if success > 0: + total_bytes = sum(s for r, s in zip(results, sizes) if r > 0) + speed_gbs = total_bytes / (elapsed * 1024**3) if elapsed > 0 else float("inf") + self.logger.debug(f"batch_get throughput: {total_bytes / 1024**3:.4f} GB @ {speed_gbs:.4f} GB/s") + return results + + def _batch_exists(self, keys: List[str]) -> tuple: + """ + Call ``store.batch_is_exist``. + + Returns: + Tuple of (results, elapsed_ms): + results: List of ints, 1 = exists, 0 = not found. + elapsed_ms: Time taken in milliseconds. + """ + tic = time.perf_counter() + results: List[int] = self._store.batch_is_exist(keys) + elapsed_exists_ms = (time.perf_counter() - tic) * 1000 + return results, elapsed_exists_ms + + +# --------------------------------------------------------------------------- +# StorageScheduler implementation — Scheduler process +# --------------------------------------------------------------------------- + + +class MooncakeStorageScheduler(StorageScheduler): + """ + Mooncake storage scheduler for the Scheduler (controller) process. + + Only performs existence queries and metadata lookups — never transfers data. + Uses the same underlying ``MooncakeDistributedStore`` so it can call + ``batch_is_exist`` efficiently via RDMA. + """ + + def __init__(self, config: Any = None): + """ + Args: + config: Either a ``CacheConfig``-style object (with + ``kvcache_storage_config`` attribute) or a plain dict. """ super().__init__(config) - self._client = None + self._base = _MooncakeStoreBase(self.logger) + self._mc_config: Optional[MooncakeStorageConfig] = None + + # ------------------------------------------------------------------ + # StorageScheduler interface + # ------------------------------------------------------------------ def connect(self) -> bool: - """Connect to Mooncake storage.""" + """Connect to Mooncake store.""" + if self._connected: + return True try: - # Initialize Mooncake client - # This would be implemented with actual Mooncake SDK - # import mooncake - # self._client = mooncake.Client(**self.config) + extra = self._extract_extra_config(self.config) + self._mc_config = MooncakeStorageConfig.create(extra) + self._base._setup_store(self._mc_config, scheduler_mode=True) + self._base._warmup("fd_scheduler") self._connected = True + self.logger.info("MooncakeStorageScheduler connected.") return True - except Exception: + except Exception as e: + self.logger.error(f"MooncakeStorageScheduler connect failed: {e}\n{traceback.format_exc()}") self._connected = False return False def disconnect(self) -> None: - """Disconnect from Mooncake storage.""" - self._client = None + """Disconnect from Mooncake store.""" + self._base._teardown_store() self._connected = False def exists(self, key: str) -> bool: - """Check if key exists in Mooncake storage.""" - if not self._connected or self._client is None: + """Check if a single key exists.""" + if not self._connected or self._base._store is None: return False + results, _ = self._base._batch_exists([key]) + return results[0] == 1 + + def batch_exists(self, keys: List[str]) -> List[bool]: + """Batch check key existence.""" + if not self._connected or self._base._store is None: + return [False] * len(keys) + results, _ = self._base._batch_exists(keys) + return [r == 1 for r in results] + + def query_prefix_count( + self, + k_keys: List[str], + v_keys: List[str], + k_scale_keys: Optional[List[str]] = None, + v_scale_keys: Optional[List[str]] = None, + ) -> int: + """ + Return the number of consecutive valid KV cache blocks from the start. - # Placeholder implementation - # return self._client.exists(key) - return False + Mirrors the logic of ``MooncakeStore.query()`` in the v1 transfer_factory. + """ + if not self._connected or self._base._store is None: + return 0 - def query(self, keys: List[str]) -> Dict[str, bool]: - """Query multiple keys for existence.""" - if not self._connected or self._client is None: - return {k: False for k in keys} + assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length" - # Placeholder implementation - # return self._client.batch_exists(keys) - return {k: False for k in keys} + has_scale = k_scale_keys is not None and v_scale_keys is not None + all_keys = k_keys + v_keys + if has_scale: + assert ( + len(k_scale_keys) == len(v_scale_keys) == len(k_keys) + ), "scale key lists must have the same length as k/v key lists" + all_keys = all_keys + k_scale_keys + v_scale_keys - def get_metadata(self, key: str) -> Optional[Dict[str, Any]]: - """Get metadata for a key.""" - if not self._connected or self._client is None: - return None + exist_map = dict(zip(all_keys, self._base._batch_exists(all_keys)[0])) - # Placeholder implementation - # return self._client.get_metadata(key) - return None + count = 0 + if has_scale: + for k, v, ks, vs in zip(k_keys, v_keys, k_scale_keys, v_scale_keys): + if not (exist_map[k] and exist_map[v] and exist_map[ks] and exist_map[vs]): + break + count += 1 + else: + for k, v in zip(k_keys, v_keys): + if not (exist_map[k] and exist_map[v]): + break + count += 1 - def list_keys(self, prefix: str = "") -> List[str]: - """List keys with a given prefix.""" - if not self._connected or self._client is None: - return [] + return count + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ - # Placeholder implementation - # return self._client.list_keys(prefix) - return [] + @staticmethod + def _extract_extra_config(config: Any) -> Dict[str, Any]: + """Extract the mooncake-specific sub-config from a CacheConfig or dict.""" + if config is None: + return {} + if isinstance(config, dict): + return config.get("kvcache_storage_config", config) + # CacheConfig-style object + return getattr(config, "kvcache_storage_config", None) or {} + + +# --------------------------------------------------------------------------- +# StorageConnector implementation — Worker process +# --------------------------------------------------------------------------- class MooncakeStorageConnector(StorageConnector): """ - Mooncake storage connector for Worker process. + Mooncake storage connector for Worker processes. + + Performs zero-copy data transfer using ``batch_put_from`` / ``batch_get_into`` + from ``MooncakeDistributedStore``. - Provides data transfer operations for Mooncake distributed storage. + Memory model + ------------ + Data flows between Mooncake distributed store and the **CPU cache** (pinned + host memory), never directly to/from GPU blocks. The typical lifecycle is: + + 1. The CacheController allocates a contiguous pinned-CPU memory pool. + 2. It calls ``register_buffer(pool_ptr, pool_size)`` once to register the + entire pool with Mooncake for zero-copy RDMA access. + 3. For each eviction / prefetch it calls ``batch_set`` / ``batch_get`` + with raw pointers into that pool. + + ``global_segment_size`` must be at least as large as the registered buffer. + Pass the actual per-rank CPU cache size via ``cpu_cache_size`` so the value + is set correctly at setup time. """ - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__( + self, + config: Any = None, + tp_rank: Optional[int] = None, + cpu_cache_size: Optional[int] = None, + ): """ - Initialize Mooncake storage connector. - Args: - config: Configuration with keys: - - server_addr: Mooncake server address - - namespace: Storage namespace - - transfer_timeout: Transfer timeout - - buffer_size: Transfer buffer size + config: Either a ``CacheConfig``-style object or a plain dict. + tp_rank: Tensor-parallel rank used for RDMA NIC selection. + cpu_cache_size: Size in bytes of the pinned CPU memory pool that + will be registered via ``register_buffer``. When provided, + overrides ``global_segment_size`` from config so that the + Mooncake segment exactly covers the registered buffer. + If omitted, the value from config / env is used as-is. """ super().__init__(config) - self._client = None + self._base = _MooncakeStoreBase(self.logger) + self._mc_config: Optional[MooncakeStorageConfig] = None + self._tp_rank = tp_rank + self._cpu_cache_size = cpu_cache_size + + # ------------------------------------------------------------------ + # StorageConnector interface + # ------------------------------------------------------------------ def connect(self) -> bool: - """Connect to Mooncake storage.""" + """Connect to Mooncake store.""" + if self._connected: + return True try: - # Initialize Mooncake client - # This would be implemented with actual Mooncake SDK + extra = self._extract_extra_config(self.config) + self._mc_config = MooncakeStorageConfig.create(extra) + + # Override global_segment_size with the actual CPU cache size when + # provided. This ensures the Mooncake segment covers the buffer + # that will be registered via register_buffer(). + if self._cpu_cache_size is not None: + self.logger.info( + f"Overriding global_segment_size with cpu_cache_size=" + f"{self._cpu_cache_size / 1024**3:.3f} GB (tp_rank={self._tp_rank})" + ) + self._mc_config.global_segment_size = self._cpu_cache_size + + self._base._setup_store(self._mc_config, tp_rank=self._tp_rank) + self._base._warmup("fd_worker") self._connected = True + self.logger.info(f"MooncakeStorageConnector connected (tp_rank={self._tp_rank}).") return True - except Exception: + except Exception as e: + self.logger.error(f"MooncakeStorageConnector connect failed: {e}\n{traceback.format_exc()}") self._connected = False return False def disconnect(self) -> None: - """Disconnect from Mooncake storage.""" - self._client = None + """Disconnect from Mooncake store.""" + self._base._teardown_store() self._connected = False - def get(self, key: str, dst_buffer: Any) -> bool: - """Get data from Mooncake storage.""" - if not self._connected or self._client is None: - return False + def register_buffer(self, buffer_ptr: int, buffer_size: int) -> None: + """ + Register a memory buffer with the Mooncake store for zero-copy RDMA. - # Placeholder implementation - # return self._client.get(key, dst_buffer) - return False + Must be called before using ``buffer_ptr`` in any get/set operation. + If buffer_size exceeds ``mc_max_mr_size`` the buffer is split into + multiple chunks, each registered separately. - def set(self, key: str, src_buffer: Any, size: int) -> bool: - """Set data in Mooncake storage.""" - if not self._connected or self._client is None: - return False + Args: + buffer_ptr: Raw pointer (int) to the memory region start. + buffer_size: Size in bytes. - # Placeholder implementation - # return self._client.set(key, src_buffer, size) - return False + Raises: + RuntimeError: If the store is not connected or registration fails. + """ + if self._base._store is None: + raise RuntimeError("MooncakeStorageConnector is not connected; call connect() first.") + + max_mr_size = self._base.mc_max_mr_size + if buffer_size <= max_mr_size: + ret = self._base._store.register_buffer(buffer_ptr, buffer_size) + if ret != 0: + raise RuntimeError(f"MooncakeDistributedStore.register_buffer() failed with error code {ret}") + self.logger.debug(f"Registered buffer ptr=0x{buffer_ptr:x} size={buffer_size} bytes.") + else: + num_chunks = (buffer_size + max_mr_size - 1) // max_mr_size + self.logger.info( + f"Registering buffer of {buffer_size / 1024**3:.2f} GB in {num_chunks} chunks " + f"(max_mr_size={max_mr_size / 1024**3:.2f} GB per chunk)" + ) + for i in range(num_chunks): + chunk_ptr = buffer_ptr + i * max_mr_size + chunk_size = min(max_mr_size, buffer_size - i * max_mr_size) + ret = self._base._store.register_buffer(chunk_ptr, chunk_size) + if ret != 0: + raise RuntimeError( + f"MooncakeDistributedStore.register_buffer() chunk {i}/{num_chunks} failed " + f"with error code {ret}" + ) + + # ------------------------------------------------------------------ + # Single-key operations (delegates to batch for consistency) + # ------------------------------------------------------------------ + + def get(self, key: str, dst_ptr: int, size: int) -> bool: + """Get a single object via zero-copy into ``dst_ptr``.""" + if not self._connected or self._base._store is None: + return False + results = self._base._batch_get([key], [dst_ptr], [size]) + return results[0] > 0 - def delete(self, key: str) -> bool: - """Delete data from Mooncake storage.""" - if not self._connected or self._client is None: + def set(self, key: str, src_ptr: int, size: int) -> bool: + """Set a single object via zero-copy from ``src_ptr``.""" + if not self._connected or self._base._store is None: return False + results = self._base._batch_put([key], [src_ptr], [size]) + return results[0] == 0 + + # ------------------------------------------------------------------ + # Batch operations + # ------------------------------------------------------------------ + + def batch_get( + self, + keys: List[str], + dst_ptrs: List[int], + sizes: List[int], + ) -> List[bool]: + """ + Batch get multiple objects via zero-copy. + + Args: + keys: Storage keys to retrieve. + dst_ptrs: Destination memory pointers (must be registered for RDMA). + sizes: Expected sizes in bytes for each key. + + Returns: + List of booleans: True if the corresponding key was retrieved successfully. + """ + if not self._connected or self._base._store is None: + return [False] * len(keys) + if not keys: + return [] + if not (len(keys) == len(dst_ptrs) == len(sizes)): + raise ValueError("keys, dst_ptrs, and sizes must have the same length") + + results = self._base._batch_get(keys, dst_ptrs, sizes) + return [r > 0 for r in results] + + def batch_set( + self, + keys: List[str], + src_ptrs: List[int], + sizes: List[int], + ) -> List[bool]: + """ + Batch set multiple objects via zero-copy. + + Skips keys that already exist in the store to avoid redundant writes. + + Args: + keys: Storage keys. + src_ptrs: Source memory pointers (must be registered for RDMA). + sizes: Data sizes in bytes. + + Returns: + List of booleans: True if the corresponding key was stored successfully. + """ + if not self._connected or self._base._store is None: + return [False] * len(keys) + if not keys: + return [] + if not (len(keys) == len(src_ptrs) == len(sizes)): + raise ValueError("keys, src_ptrs, and sizes must have the same length") + + put_results = self._base._batch_put(keys, src_ptrs, sizes) + final_results = [r == 0 for r in put_results] + success = put_results.count(0) + total_bytes = sum(s for r, s in zip(put_results, sizes) if r == 0) + self.logger.debug( + f"batch_set {len(keys)} keys: " f"written={success}/{len(keys)}, " f"data={total_bytes / 1024**3:.4f} GB" + ) + + return final_results + + def batch_exists(self, keys: List[str]) -> List[bool]: + """Batch check key existence.""" + if not self._connected or self._base._store is None: + return [False] * len(keys) + if not keys: + return [] + results, _ = self._base._batch_exists(keys) + return [r == 1 for r in results] + + # ------------------------------------------------------------------ + # Delete / clear + # ------------------------------------------------------------------ + + def delete(self, key: str, timeout: int = 5) -> bool: + """ + Delete a key from the store, retrying up to ``timeout`` seconds. - # Placeholder implementation - # return self._client.delete(key) + Args: + key: Key to delete. + timeout: Retry window in seconds. + + Returns: + True if deletion succeeded within the timeout. + """ + if not self._connected or self._base._store is None: + return False + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + rc = self._base._store.remove(key) + if rc == 0: + return True + time.sleep(1) + self.logger.error(f"delete({key!r}) timed out after {timeout}s") return False - def clear(self, prefix: str = "") -> int: - """Clear data from Mooncake storage.""" - if not self._connected or self._client is None: - return 0 + def batch_delete(self, keys: List[str]) -> List[bool]: + """ + Delete multiple keys from the store (single attempt, no retry). - # Placeholder implementation - # return self._client.clear(prefix) - return 0 + Used for cleaning up partial writes where some kinds succeeded + and others failed. Returns per-key success flags. + """ + if not self._connected or self._base._store is None: + return [False] * len(keys) + results = [] + for key in keys: + rc = self._base._store.remove(key) + results.append(rc == 0) + return results + + def clear(self) -> int: + """ + Remove all objects from the store. + + Returns: + Number of objects removed (as reported by the store). + """ + if not self._connected or self._base._store is None: + return 0 + count: int = self._base._store.remove_all() + self.logger.info(f"Cleared {count} objects from Mooncake store.") + return count + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _extract_extra_config(config: Any) -> Dict[str, Any]: + if config is None: + return {} + if isinstance(config, dict): + return config.get("kvcache_storage_config", config) + return getattr(config, "kvcache_storage_config", None) or {} diff --git a/fastdeploy/cache_manager/v1/transfer_manager.py b/fastdeploy/cache_manager/v1/transfer_manager.py index f4ed0bb6539..d6190dc1210 100644 --- a/fastdeploy/cache_manager/v1/transfer_manager.py +++ b/fastdeploy/cache_manager/v1/transfer_manager.py @@ -37,6 +37,7 @@ swap_cache_per_layer_async, # async per-layer op (no cudaStreamSynchronize) ) from fastdeploy.cache_manager.ops import swap_cache_all_layers +from fastdeploy.cache_manager.v1.cache_utils import storage_key_for_block from fastdeploy.cache_manager.v1.storage import create_storage_connector from fastdeploy.cache_manager.v1.transfer import create_transfer_connector @@ -127,9 +128,20 @@ def __init__( self._host_value_scales_ptrs: List[int] = [] # value scale pointers (fp8) # ============ Connectors (for future use) ============ - self._storage_connector = create_storage_connector(self.cache_config) + # connect() is deferred to set_host_block_shape() so that cpu_cache_size + # can be computed from the actual block shape before connecting. + self._storage_connector = create_storage_connector( + self.cache_config, + tp_rank=self._local_rank, + ) self._transfer_connector = create_transfer_connector(self.cache_config) + # ============ Host block stride (bytes per block per layer) ============ + # Set by set_host_block_shape() after host cache is allocated. + self._host_key_block_stride_bytes: int = 0 + self._host_value_block_stride_bytes: int = 0 + self._host_scale_block_stride_bytes: int = 0 + # ============ Cache Map Setters ============ @property @@ -157,10 +169,6 @@ def set_cache_kvs_map(self, cache_kvs_map: Dict[str, Any]) -> None: def _build_device_layer_indices(self) -> None: """Build layer-indexed Device cache lists from _cache_kvs_map.""" if not self._cache_kvs_map: - self._device_key_caches = [] - self._device_value_caches = [] - self._device_key_scales = [] - self._device_value_scales = [] return self._device_key_caches = [] @@ -199,6 +207,35 @@ def set_host_cache_kvs_map(self, host_cache_kvs_map: Dict[str, Any]) -> None: with self._lock: self._host_cache_kvs_map = host_cache_kvs_map self._build_host_layer_indices() + self._register_host_buffers() + + def _register_host_buffers(self) -> None: + """Register all per-layer host buffers with the storage connector for zero-copy RDMA.""" + if self._storage_connector is None: + return + if self._num_host_blocks <= 0 or not self._host_key_ptrs: + return + if self._host_key_block_stride_bytes <= 0: + return + + layer_total_bytes = self._num_host_blocks * self._host_key_block_stride_bytes + for layer_idx in range(len(self._host_key_ptrs)): + key_ptr = self._host_key_ptrs[layer_idx] + if key_ptr: + try: + self._storage_connector.register_buffer(key_ptr, layer_total_bytes) + except Exception as e: + logger.warning(f"[TransferManager] register_buffer key layer {layer_idx} failed: {e}") + if self._is_fp8_quantization(): + val_ptr = self._host_value_ptrs[layer_idx] if layer_idx < len(self._host_value_ptrs) else 0 + else: + val_ptr = self._host_value_ptrs[layer_idx] if layer_idx < len(self._host_value_ptrs) else 0 + if val_ptr: + val_total = self._num_host_blocks * self._host_value_block_stride_bytes + try: + self._storage_connector.register_buffer(val_ptr, val_total) + except Exception as e: + logger.warning(f"[TransferManager] register_buffer value layer {layer_idx} failed: {e}") def _build_host_layer_indices(self) -> None: """Build layer-indexed Host pointer lists from _host_cache_kvs_map.""" @@ -227,6 +264,75 @@ def _build_host_layer_indices(self) -> None: self._host_key_scales_ptrs.append(self._host_cache_kvs_map.get(key_scale_name, 0)) self._host_value_scales_ptrs.append(self._host_cache_kvs_map.get(val_scale_name, 0)) + # ============ Host Block Shape ============ + + def set_host_block_shape( + self, + key_shape: List[int], + value_shape: Optional[List[int]], + scale_shape: Optional[List[int]], + cache_item_bytes: int, + scale_item_bytes: int = 4, + ) -> None: + """ + Set per-layer host block shape for stride calculation. + + Must be called after host cache is allocated (initialize_swap_space) + so that prefetch_from_storage / backup_to_storage can compute the + correct byte offset for each block_id. + + Args: + key_shape: [num_host_blocks, dim1, dim2, dim3] per-layer key cache shape. + value_shape: [num_host_blocks, dim1, dim2, dim3] per-layer value cache shape, or None. + scale_shape: [num_host_blocks, dim1, dim2] per-layer scale shape (fp8), or None. + cache_item_bytes: Bytes per cache element (e.g. 2 for float16). + scale_item_bytes: Bytes per scale element (default 4 for float32). + """ + with self._lock: + # stride = elements per block per layer * bytes per element + # key_shape = [num_blocks, d1, d2, d3] → per-block stride = d1*d2*d3 * bytes + self._host_key_block_stride_bytes = ( + int(key_shape[1]) * int(key_shape[2]) * int(key_shape[3]) * cache_item_bytes + ) + if value_shape: + self._host_value_block_stride_bytes = ( + int(value_shape[1]) * int(value_shape[2]) * int(value_shape[3]) * cache_item_bytes + ) + else: + self._host_value_block_stride_bytes = self._host_key_block_stride_bytes + if scale_shape: + self._host_scale_block_stride_bytes = int(scale_shape[1]) * int(scale_shape[2]) * scale_item_bytes + else: + self._host_scale_block_stride_bytes = 0 + + # Connect storage connector now that block strides are known. + # cpu_cache_size = total pinned CPU memory across all layers + # (key + value, plus fp8 scales when present). + if self._storage_connector is not None and not self._storage_connector.is_connected(): + cpu_cache_size = ( + self._num_host_blocks + * self._num_layers + * (self._host_key_block_stride_bytes + self._host_value_block_stride_bytes) + ) + if self._is_fp8_quantization() and self._host_scale_block_stride_bytes > 0: + cpu_cache_size += ( + self._num_host_blocks + * self._num_layers + * self._host_scale_block_stride_bytes + * 2 # key scale + value scale + ) + + self._storage_connector._cpu_cache_size = cpu_cache_size + logger.info( + f"[TransferManager] Connecting storage connector: " + f"tp_rank={self._local_rank}, cpu_cache_size={cpu_cache_size / 1024**3:.3f} GB" + ) + self._storage_connector.connect() + # connect() completes RDMA initialization; now all three conditions for + # _register_host_buffers are satisfied (_host_key_ptrs set, strides > 0, + # connector connected), so register host pinned memory as RDMA MR. + self._register_host_buffers() + # ============ Metadata Properties ============ def _get_kv_cache_quant_type(self) -> Optional[str]: @@ -664,3 +770,241 @@ def get_stats(self) -> Dict[str, Any]: "has_host_cache": len(self._host_key_ptrs) > 0, "is_fp8": self._is_fp8_quantization(), } + + # ============ Storage Transfer API ============ + # + # Key format (per-layer): + # K cache: "{hash_value}_{local_rank}_key_{layer_idx}" + # V cache: "{hash_value}_{local_rank}_value_{layer_idx}" + # K scale: "{hash_value}_{local_rank}_key_scale_{layer_idx}" (fp8 only) + # V scale: "{hash_value}_{local_rank}_value_scale_{layer_idx}" (fp8 only) + # + # Each layer's data is stored independently with its own key, allowing + # direct zero-copy RDMA between per-layer host memory and remote storage + # via connector.batch_get / batch_set (backed by batch_get_into / batch_put_from). + + def _compute_layer_block_ptr(self, kind: str, layer_idx: int, cpu_block_id: int) -> int: + """Compute raw pointer for a specific layer+block of a given kind.""" + if kind == "key": + return self._host_key_ptrs[layer_idx] + cpu_block_id * self._host_key_block_stride_bytes + elif kind == "value": + return self._host_value_ptrs[layer_idx] + cpu_block_id * self._host_value_block_stride_bytes + elif kind == "key_scale": + return self._host_key_scales_ptrs[layer_idx] + cpu_block_id * self._host_scale_block_stride_bytes + elif kind == "value_scale": + return self._host_value_scales_ptrs[layer_idx] + cpu_block_id * self._host_scale_block_stride_bytes + return 0 + + def _get_layer_block_size(self, kind: str) -> int: + """Return byte size for a single layer's block of a given kind.""" + if kind == "key": + return self._host_key_block_stride_bytes + elif kind == "value": + return self._host_value_block_stride_bytes + elif kind in ("key_scale", "value_scale"): + return self._host_scale_block_stride_bytes + return 0 + + def _build_per_layer_io_args( + self, + hash_list: List[str], + cpu_block_list: List[int], + ) -> tuple: + """Build flat per-layer keys, pointers, and sizes for direct connector calls. + + Returns: + (flat_keys, flat_ptrs, flat_sizes, flat_block_indices, flat_kinds) + All lists have length = len(hash_list) * num_layers * num_kinds. + flat_block_indices maps flat_idx -> position in hash_list/cpu_block_list. + flat_kinds maps flat_idx -> kind string (for rollback). + """ + is_fp8 = self._is_fp8_quantization() + kinds = ["key", "value"] + if is_fp8 and self._host_scale_block_stride_bytes > 0: + kinds.extend(["key_scale", "value_scale"]) + + flat_keys: List[str] = [] + flat_ptrs: List[int] = [] + flat_sizes: List[int] = [] + flat_block_indices: List[int] = [] + flat_kinds: List[str] = [] + + for bi, (hash_val, cpu_block_id) in enumerate(zip(hash_list, cpu_block_list)): + for layer_idx in range(self._num_layers): + for kind in kinds: + flat_keys.append(storage_key_for_block(hash_val, self._local_rank, kind, layer_idx)) + flat_ptrs.append(self._compute_layer_block_ptr(kind, layer_idx, cpu_block_id)) + flat_sizes.append(self._get_layer_block_size(kind)) + flat_block_indices.append(bi) + flat_kinds.append(kind) + + return flat_keys, flat_ptrs, flat_sizes, flat_block_indices, flat_kinds + + def prefetch_from_storage( + self, + hash_list: List[str], + cpu_block_list: List[int], + ) -> List[bool]: + """ + Batch-prefetch KV cache blocks from remote storage into CPU host memory. + + Uses per-layer storage keys. Each layer's data is read directly from + storage into per-layer host buffers via zero-copy RDMA (batch_get_into). + + Storage key per block, per layer: + ``"{hash}_{rank}_key_{layer_idx}"`` / ``"{hash}_{rank}_value_{layer_idx}"`` + + Args: + hash_list: List of block hash values (one per block). + cpu_block_list: List of target CPU block IDs (same length as hash_list). + + Returns: + List[bool]: True for each block that was fully retrieved successfully. + """ + if self._storage_connector is None or not self._storage_connector.is_connected(): + logger.warning("[TransferManager] prefetch_from_storage: connector not ready") + return [False] * len(hash_list) + + if len(hash_list) != len(cpu_block_list): + raise ValueError("hash_list and cpu_block_list must have the same length") + + if not hash_list: + return [] + + if not self._host_key_ptrs or self._host_key_block_stride_bytes <= 0: + logger.warning( + "[TransferManager] prefetch_from_storage: host cache not ready " + "(call set_host_block_shape after initialize_swap_space)" + ) + return [False] * len(hash_list) + + flat_keys, flat_ptrs, flat_sizes, flat_block_indices, flat_kinds = self._build_per_layer_io_args( + hash_list, cpu_block_list + ) + results = self._storage_connector.batch_get(flat_keys, flat_ptrs, flat_sizes) + + num_blocks = len(hash_list) + block_success = [True] * num_blocks + for flat_idx, ok in enumerate(results): + if not ok: + block_success[flat_block_indices[flat_idx]] = False + + # Diagnostic: probe which per-layer keys are missing for failed blocks + failed_indices = [i for i, ok in enumerate(block_success) if not ok] + if failed_indices and self._storage_connector is not None: + probe_keys = [] + probe_labels = [] + for i in failed_indices: + for layer_idx in range(self._num_layers): + for kind in ["key", "value"]: + probe_keys.append(storage_key_for_block(hash_list[i], self._local_rank, kind, layer_idx)) + probe_labels.append((i, cpu_block_list[i], hash_list[i], layer_idx, kind)) + + try: + exist_flags = self._storage_connector.batch_exists(probe_keys) + block_diag: Dict[int, Dict] = {} + for (bi, cpu_bid, h, layer_idx, kind), ok in zip(probe_labels, exist_flags): + if bi not in block_diag: + block_diag[bi] = {"cpu_bid": cpu_bid, "hash": h, "missing": [], "existing": []} + label = f"{kind}_l{layer_idx}" + if ok: + block_diag[bi]["existing"].append(label) + else: + block_diag[bi]["missing"].append(label) + + partial_missing = {bi: v for bi, v in block_diag.items() if v["missing"]} + pure_transfer_err = {bi: v for bi, v in block_diag.items() if not v["missing"]} + + if partial_missing: + detail = [ + f"cpu_block={v['cpu_bid']} hash={v['hash'][:16]}.. " + f"missing={v['missing'][:4]}{'...' if len(v['missing']) > 4 else ''}" + for v in partial_missing.values() + ] + logger.warning( + f"[TransferManager] prefetch_from_storage: {len(partial_missing)} block(s) have missing per-layer keys — " + + "; ".join(detail) + ) + if pure_transfer_err: + detail = [f"cpu_block={v['cpu_bid']} hash={v['hash'][:16]}.." for v in pure_transfer_err.values()] + logger.warning( + f"[TransferManager] prefetch_from_storage: {len(pure_transfer_err)} block(s) keys exist but transfer failed — " + + ", ".join(detail) + ) + except Exception as e: + logger.warning(f"[TransferManager] prefetch_from_storage: failed to probe missing keys: {e}") + + return block_success + + def backup_to_storage( + self, + cpu_block_list: List[int], + hash_list: List[str], + ) -> List[bool]: + """ + Batch-backup KV cache blocks from CPU host memory to remote storage. + + Uses per-layer storage keys. Each layer's data is written directly + from per-layer host buffers to storage via zero-copy RDMA (batch_put_from). + + Storage key per block, per layer: + ``"{hash}_{rank}_key_{layer_idx}"`` / ``"{hash}_{rank}_value_{layer_idx}"`` + + Args: + cpu_block_list: List of source CPU block IDs. + hash_list: List of block hash values (same length as cpu_block_list). + + Returns: + List[bool]: True for each block that was fully stored successfully. + """ + if self._storage_connector is None or not self._storage_connector.is_connected(): + logger.warning("[TransferManager] backup_to_storage: connector not ready") + return [False] * len(cpu_block_list) + + if len(cpu_block_list) != len(hash_list): + raise ValueError("cpu_block_list and hash_list must have the same length") + + if not cpu_block_list: + return [] + + if not self._host_key_ptrs or self._host_key_block_stride_bytes <= 0: + logger.warning( + "[TransferManager] backup_to_storage: host cache not ready " + "(call set_host_block_shape after initialize_swap_space)" + ) + return [False] * len(cpu_block_list) + + flat_keys, flat_ptrs, flat_sizes, flat_block_indices, flat_kinds = self._build_per_layer_io_args( + hash_list, cpu_block_list + ) + results = self._storage_connector.batch_set(flat_keys, flat_ptrs, flat_sizes) + + num_blocks = len(cpu_block_list) + block_success = [True] * num_blocks + block_written_keys: Dict[int, List[str]] = {} + + for flat_idx, ok in enumerate(results): + bi = flat_block_indices[flat_idx] + if ok: + block_written_keys.setdefault(bi, []).append(flat_keys[flat_idx]) + else: + block_success[bi] = False + + # Rollback: delete successfully-written per-layer keys of partially-failed blocks. + # If a block failed but some of its layer-kind keys were written, delete those + # keys so the block appears fully absent in storage. + keys_to_rollback = [key for bi, keys in block_written_keys.items() if not block_success[bi] for key in keys] + if keys_to_rollback: + logger.warning( + f"[TransferManager] backup_to_storage: partial write on {len(keys_to_rollback)} key(s), rolling back" + ) + self._storage_connector.batch_delete(keys_to_rollback) + + failed = [(cpu_block_list[i], hash_list[i]) for i, ok in enumerate(block_success) if not ok] + if failed: + logger.warning( + f"[TransferManager] backup_to_storage: {len(failed)}/{len(cpu_block_list)} block(s) failed — " + + ", ".join(f"cpu_block={cb} hash={h[:16]}.." for cb, h in failed) + ) + + return block_success diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 1d931ece5d2..ee2c29722b0 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1121,7 +1121,7 @@ def _fetch_request(): batch_request, error_tasks = self.resource_manager.schedule() # 3. Send to engine - if len(batch_request) > 0: + if batch_request.has_pending_work: if self.cfg.scheduler_config.splitwise_role == "decode": for task in batch_request: if task.task_type == RequestType.PREEMPTED: @@ -1177,6 +1177,12 @@ def _fetch_request(): task.metrics.decode_inference_start_time = time.time() elif not task.has_been_preempted_before: task.metrics.inference_start_time = time.time() + if batch_request.storage_prefetch_tasks: + self.llm_logger.info( + f"[Debug][StoragePrefetch][Dispatch] put_tasks with " + f"{len(batch_request.storage_prefetch_tasks)} prefetch tasks, " + f"{len(batch_request.requests)} inference requests" + ) self.engine_worker_queue.put_tasks((batch_request, self.resource_manager.real_bsz)) else: # When there are no actual tasks to schedule, send an empty task batch to EP workers. @@ -1194,7 +1200,7 @@ def _fetch_request(): continue self._send_error_response(request_id, failed) - if len(batch_request) <= 0 and not error_tasks: + if not batch_request.has_pending_work and not error_tasks: time.sleep(0.005) except RuntimeError as e: @@ -1321,7 +1327,6 @@ def _insert_zmq_task_to_scheduler(self): err_msg = None try: request = Request.from_dict(data) - request.metrics.scheduler_recv_req_time = time.time() main_process_metrics.requests_number.inc() trace_carrier = data.get("trace_carrier") diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index c17b8821ce2..e04efb2440e 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -34,7 +34,11 @@ from typing_extensions import TypeVar from fastdeploy import envs -from fastdeploy.cache_manager.v1.metadata import CacheLevel, CacheSwapMetadata +from fastdeploy.cache_manager.v1.metadata import ( + CacheLevel, + CacheSwapMetadata, + PendingPrefetch, +) from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.openai.protocol import ( @@ -257,6 +261,10 @@ def prompt_hashes(self) -> list[str]: def match_result(self) -> Optional[MatchResult]: return self._match_result + @match_result.setter + def match_result(self, value: Optional[MatchResult]) -> None: + self._match_result = value + def set_block_hasher(self, block_hasher: callable): """Set the block hasher for dynamic hash computation.""" self._block_hasher = block_hasher @@ -619,6 +627,7 @@ def __init__(self): self.cache_swap_metadata: Optional[CacheSwapMetadata] = None self.cache_evict_metadata: Optional[CacheSwapMetadata] = None + self.storage_prefetch_tasks: Optional[List[PendingPrefetch]] = None def add_request(self, request): if hasattr(request, "cache_swap_metadata") and request.cache_swap_metadata: @@ -660,9 +669,17 @@ def append_evict_metadata(self, metadata: List[CacheSwapMetadata]): hash_values=meta.hash_values, ) + def append_prefetch_tasks(self, tasks: List[PendingPrefetch]): + if self.storage_prefetch_tasks is None: + self.storage_prefetch_tasks = [] + self.storage_prefetch_tasks.extend(tasks) + def __repr__(self): requests_repr = repr(self.requests) - return f"BatchRequest(requests={requests_repr}, swap_metadata={self.cache_swap_metadata}, evict_metadata={self.cache_evict_metadata})" + return ( + f"BatchRequest(requests={requests_repr}, swap_metadata={self.cache_swap_metadata}, " + f"evict_metadata={self.cache_evict_metadata}, prefetch_tasks={self.storage_prefetch_tasks})" + ) def __getstate__(self): state = self.__dict__.copy() @@ -691,12 +708,24 @@ def __getitem__(self, index): def __len__(self): return len(self.requests) + @property + def has_pending_work(self) -> bool: + """Whether there is any pending work (inference requests, prefetch/swap/evict tasks).""" + return ( + len(self.requests) > 0 + or bool(self.storage_prefetch_tasks) + or bool(self.cache_swap_metadata) + or bool(self.cache_evict_metadata) + ) + def append(self, batch_request: "BatchRequest"): self.requests.extend(batch_request.requests) if batch_request.cache_swap_metadata: self.append_swap_metadata([batch_request.cache_swap_metadata]) if batch_request.cache_evict_metadata: self.append_evict_metadata([batch_request.cache_evict_metadata]) + if batch_request.storage_prefetch_tasks: + self.append_prefetch_tasks(batch_request.storage_prefetch_tasks) def extend(self, batch_requests: list["BatchRequest"]): for br in batch_requests: diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e3d20cc7d02..2df2c80e65b 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -22,10 +22,11 @@ from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import List, Union +from typing import Dict, List, Set, Union import numpy as np import paddle +import zmq from fastdeploy import envs from fastdeploy.cache_manager.multimodal_cache_manager import ( @@ -44,6 +45,7 @@ from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.input.utils import IDS_TYPE_FLAG from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.inter_communicator.zmq_server import ZmqIpcServer from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.platforms import current_platform @@ -112,6 +114,28 @@ class ScheduledAbortTask(ScheduledTaskBase): task_type: RequestType = RequestType.ABORT +@dataclass +class PrefetchResult: + """Result of a completed storage prefetch, populated by the receiver thread.""" + + request_id: str = "" + host_block_ids: List[int] = field(default_factory=list) + success: bool = True + error: str = "" + + +@dataclass +class InflightPrefetch: + """Tracks an in-flight prefetch dispatched to TP workers.""" + + request_id: str = "" + host_block_ids: List[int] = field(default_factory=list) + expected_count: int = 0 + done_ranks: Set[int] = field(default_factory=set) + failed_ranks: Set[int] = field(default_factory=set) + dispatch_time: float = 0.0 + + class SignalConsumer: """ A class that consumes a signal value up to a specified limit. @@ -252,6 +276,30 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 + # ---- Storage Prefetch ZMQ channels (Scheduler side) ---- + # Initialized only when storage backend is configured. + # One PULL done socket per worker local_rank for receiving completion notifications. + # Prefetch commands are dispatched via batch_request (EngineWorkerQueue). + # local_rank = dp_rank * tp_size + tp_rank + self._prefetch_done_servers: Dict[int, ZmqIpcServer] = {} + + if self.config.cache_config.kvcache_storage_backend and self.enable_cache_manager_v1: + self._init_prefetch_zmq_servers() + + # ---- Storage Prefetch tracking ---- + self._inflight_prefetches: Dict[str, InflightPrefetch] = {} + self._inflight_lock = threading.Lock() + self._prefetch_results: Dict[str, PrefetchResult] = {} + self._prefetch_results_lock = threading.Lock() + + if self._prefetch_done_servers: + self._prefetch_receiver_thread = threading.Thread( + target=self._prefetch_receiver_loop, + daemon=True, + name="StoragePrefetchReceiver", + ) + self._prefetch_receiver_thread.start() + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -434,13 +482,13 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, batch_reques del self.requests[preempted_req.request_id] if preempted_req.request_id in self.req_dict: del self.req_dict[preempted_req.request_id] - if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST: + if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST and not self.enable_cache_manager_v1: if self.config.cache_config.kvcache_storage_backend: self.cache_manager.write_cache_to_storage_decode(preempted_req) self._free_blocks(preempted_req) llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}") else: - if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST: + if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST and not self.enable_cache_manager_v1: if self.config.cache_config.kvcache_storage_backend: self.cache_manager.write_cache_to_storage(preempted_req) self._free_blocks(preempted_req) @@ -1229,13 +1277,16 @@ def _allocate_decode_and_extend(): # Issue pending backup tasks to batch_request # This handles write_through_selective policy by attaching backup tasks # to the batch request, which will be processed by the worker - if self.enable_cache_manager_v1 and len(batch_request) > 0: + if self.enable_cache_manager_v1: + self.cache_manager.check_and_add_pending_backup() + evict_metadata = self.cache_manager.issue_pending_backup_to_batch_request() if evict_metadata: batch_request.append_evict_metadata([evict_metadata]) - if self.enable_cache_manager_v1: - self.cache_manager.check_and_add_pending_backup() + # Dispatch any pending storage prefetch tasks via batch_request + if self.config.cache_config.kvcache_storage_backend: + self._dispatch_pending_prefetches(batch_request) return batch_request, error_reqs @@ -1260,6 +1311,256 @@ def waiting_async_process(self, request: Request) -> None: def apply_async_preprocess(self, request: Request) -> None: request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request)) + if self.config.cache_config.kvcache_storage_backend: + request.async_process_futures.append( + self.async_preprocess_pool.submit(self._prefetch_storage_cache, request) + ) + + def _init_prefetch_zmq_servers(self) -> None: + """ + Initialize per-worker-rank ZMQ PULL sockets for storage prefetch done notification. + + Called once during __init__ when storage backend is enabled. + Creates: + - prefetch_done_server[local_rank]: PULL ← Worker (receive done notification) + + Prefetch commands are sent via batch_request (EngineWorkerQueue), not ZMQ. + local_rank = dp_rank * tp_size + tp_rank, covers all workers in this DP group. + """ + tp_size = self.config.parallel_config.tensor_parallel_size + dp_rank = self.config.parallel_config.local_data_parallel_id + port = self.config.parallel_config.local_engine_worker_queue_port + + for tp_rank in range(tp_size): + local_rank = dp_rank * tp_size + tp_rank + done_name = f"prefetch_done_rank{local_rank}_{port}" + self._prefetch_done_servers[local_rank] = ZmqIpcServer(done_name, zmq.PULL) + llm_logger.info(f"[StoragePrefetch] init ZMQ done server: {done_name}") + + def _prefetch_storage_cache(self, request: Request) -> None: + """ + Asynchronously prefetch KV cache blocks from storage to host memory. + + Phase 1: Calls cache_manager.prefetch_storage() to probe storage, + allocate host blocks, and enqueue prefetch info for dispatch. + Phase 2: Polls _prefetch_results dict until the receiver thread reports + that all TP workers have completed the transfer. + + The actual ZMQ dispatch happens in schedule() via _dispatch_pending_prefetches(), + and the ZMQ receive happens in the dedicated _prefetch_receiver_loop thread. + + Args: + request: The request to prefetch cache for. + """ + host_block_ids: List[int] = [] + try: + llm_logger.info( + f"[Debug][StoragePrefetch][Phase1] start prefetch_storage for request_id={request.request_id}" + ) + + has_prefetch = self.cache_manager.prefetch_storage(request) + if not has_prefetch: + llm_logger.info( + f"[Debug][StoragePrefetch][Phase1] no storage match for request_id={request.request_id}, skip" + ) + return + + llm_logger.info( + f"[Debug][StoragePrefetch][Phase1] enqueued pending, now polling results for request_id={request.request_id}" + ) + + if not self._prefetch_done_servers: + llm_logger.warning( + f"[Debug][StoragePrefetch][Phase1] no done servers, skip polling for request_id={request.request_id}" + ) + return + + # Poll _prefetch_results until receiver thread populates it + timeout = 60.0 + start = time.time() + poll_count = 0 + while True: + with self._prefetch_results_lock: + result = self._prefetch_results.pop(request.request_id, None) + if result is not None: + elapsed = time.time() - start + host_block_ids = result.host_block_ids + if result.success: + self.cache_manager.update_storage_blocks_to_host(host_block_ids) + llm_logger.info( + f"[Debug][StoragePrefetch][Phase1] request_id={request.request_id} done, " + f"updated {len(host_block_ids)} blocks to HOST, " + f"waited {elapsed:.3f}s, polled {poll_count} times" + ) + else: + llm_logger.warning( + f"[Debug][StoragePrefetch][Phase1] request_id={request.request_id} failed: {result.error}, " + f"waited {elapsed:.3f}s" + ) + self.cache_manager.abort_prefetch_blocks(host_block_ids) + return + + poll_count += 1 + if time.time() - start > timeout: + llm_logger.error( + f"[Debug][StoragePrefetch][Phase1] request_id={request.request_id} timeout after {timeout}s, " + f"polled {poll_count} times" + ) + self._cleanup_prefetch_on_timeout(request.request_id) + return + time.sleep(0.005) + + except Exception as e: + llm_logger.error(f"[StoragePrefetch] request_id={request.request_id} error: {e}") + if host_block_ids: + self.cache_manager.abort_prefetch_blocks(host_block_ids) + + def _dispatch_pending_prefetches(self, batch_request) -> None: + """ + Drain pending prefetch tasks from CacheManager and attach to batch_request. + + Called from schedule() under self.lock. Prefetch tasks are dispatched to + workers via the normal EngineWorkerQueue path (same as swap/evict metadata). + Completion notifications are still received via ZMQ in the receiver thread. + """ + pending_items = self.cache_manager.drain_pending_prefetches() + if not pending_items: + return + + llm_logger.info( + f"[Debug][StoragePrefetch][Phase2] drained {len(pending_items)} pending prefetch tasks, " + f"request_ids={[item.request_id for item in pending_items]}, " + f"attaching to batch_request (existing requests={len(batch_request.requests)})" + ) + + batch_request.append_prefetch_tasks(pending_items) + + expected_count = len(self._prefetch_done_servers) + for item in pending_items: + inflight = InflightPrefetch( + request_id=item.request_id, + host_block_ids=item.host_block_ids, + expected_count=expected_count, + done_ranks=set(), + failed_ranks=set(), + dispatch_time=time.time(), + ) + with self._inflight_lock: + self._inflight_prefetches[item.request_id] = inflight + + llm_logger.info( + f"[Debug][StoragePrefetch][Phase2] registered inflight: request_id={item.request_id}, " + f"host_block_ids={item.host_block_ids}, expected_workers={expected_count}" + ) + + def _prefetch_receiver_loop(self) -> None: + """ + Dedicated daemon thread that receives prefetch done messages from all TP workers. + + Uses zmq.Poller for efficient multiplexed receive. When all workers for a + given request_id have reported, stores the result in _prefetch_results dict + for the preprocess thread to pick up. + """ + poller = zmq.Poller() + rank_by_socket = {} + for local_rank, done_server in self._prefetch_done_servers.items(): + done_server._ensure_socket() + poller.register(done_server.socket, zmq.POLLIN) + rank_by_socket[done_server.socket] = local_rank + + llm_logger.info("[StoragePrefetch] receiver thread started") + + while True: + try: + events = dict(poller.poll(timeout=50)) + except Exception: + continue + + for socket, _ in events.items(): + local_rank = rank_by_socket[socket] + done_server = self._prefetch_done_servers[local_rank] + err, msg = done_server.receive_pyobj_once(block=False) + if err is not None or msg is None: + continue + + request_id = msg.get("request_id", "") + status = msg.get("status", "") + + llm_logger.info( + f"[Debug][StoragePrefetch][Phase3] received msg from rank={local_rank}: " + f"request_id={request_id}, status={status}" + ) + + with self._inflight_lock: + inflight = self._inflight_prefetches.get(request_id) + if inflight is None: + llm_logger.warning( + f"[Debug][StoragePrefetch][Phase3] receiver got stale msg for request_id={request_id}, discarding" + ) + continue + + if status == "ok": + inflight.done_ranks.add(local_rank) + else: + inflight.failed_ranks.add(local_rank) + llm_logger.warning( + f"[Debug][StoragePrefetch][Phase3] rank={local_rank} reported failure for " + f"request_id={request_id}: {msg.get('error', '')}" + ) + + total = len(inflight.done_ranks) + len(inflight.failed_ranks) + llm_logger.info( + f"[Debug][StoragePrefetch][Phase3] request_id={request_id} " + f"progress: {total}/{inflight.expected_count} " + f"(done_ranks={inflight.done_ranks}, failed_ranks={inflight.failed_ranks})" + ) + if total < inflight.expected_count: + continue + + # All workers reported -- produce result + self._inflight_prefetches.pop(request_id) + + success = len(inflight.failed_ranks) == 0 + result = PrefetchResult( + request_id=request_id, + host_block_ids=inflight.host_block_ids, + success=success, + error=f"failed_ranks={inflight.failed_ranks}" if not success else "", + ) + with self._prefetch_results_lock: + self._prefetch_results[request_id] = result + + elapsed = time.time() - inflight.dispatch_time + llm_logger.info( + f"[Debug][StoragePrefetch][Phase3] request_id={request_id} all workers done, " + f"success={success}, elapsed={elapsed:.3f}s, result stored" + ) + + def _cleanup_prefetch_on_timeout(self, request_id: str) -> None: + """ + Clean up prefetch state when a timeout occurs. + + Removes from inflight tracking (if dispatched) or from CacheManager's + pending list (if not yet dispatched), and aborts host blocks. + """ + host_block_ids: List[int] = [] + + # Check inflight first (already dispatched to workers) + with self._inflight_lock: + inflight = self._inflight_prefetches.pop(request_id, None) + if inflight is not None: + host_block_ids = inflight.host_block_ids + + # Also check pending list (not yet dispatched) + if not host_block_ids: + host_block_ids = self.cache_manager.cancel_pending_prefetch(request_id) + + if host_block_ids: + self.cache_manager.abort_prefetch_blocks(host_block_ids) + llm_logger.warning( + f"[StoragePrefetch] timeout cleanup: aborted {len(host_block_ids)} " + f"host blocks for request_id={request_id}" + ) def _has_features_info(self, task): inputs = task.multimodal_inputs @@ -1364,43 +1665,39 @@ def get_real_bsz(self) -> int: return self.real_bsz def _allocate_gpu_blocks(self, request: Request, num_blocks: int) -> List[int]: - llm_logger.debug(f"[allocate_gpu_blocks] request_id={request.request_id}, num_blocks={num_blocks}") + llm_logger.info(f"[DEBUG allocate_gpu_blocks] request_id={request.request_id}, num_blocks={num_blocks}") if self.enable_cache_manager_v1: return self.cache_manager.allocate_gpu_blocks(request, num_blocks) else: return self.cache_manager.allocate_gpu_blocks(num_blocks, request.request_id) - def _request_match_blocks(self, request: Request, skip_storage: bool = True): + def _request_match_blocks(self, request: Request): """ Prefixed cache manager v1 will match blocks for request and return common_block_ids. """ if self.enable_cache_manager_v1: - self.cache_manager.match_prefix(request, skip_storage) + self.cache_manager.match_prefix(request, skip_storage=True) match_result = request.match_result - if skip_storage: - common_block_ids = match_result.device_block_ids - matched_token_num = match_result.total_matched_blocks * self.config.cache_config.block_size - metrics = { - "gpu_match_token_num": match_result.matched_device_nums * self.config.cache_config.block_size, - "cpu_match_token_num": match_result.matched_host_nums * self.config.cache_config.block_size, - "storage_match_token_num": match_result.matched_storage_nums * self.config.cache_config.block_size, - "match_gpu_block_ids": common_block_ids, - "gpu_recv_block_ids": [], - "match_storage_block_ids": [], - "cpu_cache_prepare_time": 0, - "storage_cache_prepare_time": 0, - } - - no_cache_block_num = ( - request.need_prefill_tokens - matched_token_num + self.config.cache_config.block_size - 1 - ) // self.config.cache_config.block_size - request.cache_info = [len(common_block_ids), no_cache_block_num] - - return (common_block_ids, matched_token_num, metrics) - else: - # Prefetch cache from storage - pass + common_block_ids = match_result.device_block_ids + matched_token_num = match_result.total_matched_blocks * self.config.cache_config.block_size + metrics = { + "gpu_match_token_num": match_result.matched_device_nums * self.config.cache_config.block_size, + "cpu_match_token_num": match_result.matched_host_nums * self.config.cache_config.block_size, + "storage_match_token_num": match_result.matched_storage_nums * self.config.cache_config.block_size, + "match_gpu_block_ids": common_block_ids, + "gpu_recv_block_ids": [], + "match_storage_block_ids": [], + "cpu_cache_prepare_time": 0, + "storage_cache_prepare_time": 0, + } + + no_cache_block_num = ( + request.need_prefill_tokens - matched_token_num + self.config.cache_config.block_size - 1 + ) // self.config.cache_config.block_size + request.cache_info = [len(common_block_ids), no_cache_block_num] + + return (common_block_ids, matched_token_num, metrics) else: (common_block_ids, matched_token_num, metrics) = self.cache_manager.request_match_blocks( request, self.config.cache_config.block_size @@ -1707,13 +2004,16 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): # Do not block the main thread here # Write cache to storage if kvcache_storage_backend is enabled - for req in need_postprocess_reqs: - if self.config.scheduler_config.splitwise_role == "decode": - # D instance uses simplified write method (does not rely on Radix Tree) - self.cache_manager.write_cache_to_storage_decode(req) - else: - # P instance / Mixed instance uses standard write method (relies on Radix Tree) - self.cache_manager.write_cache_to_storage(req) + # v1 CacheManager handles storage write-back inside request_finish() via RadixTree, + # so skip this block when enable_cache_manager_v1 is True. + if not self.enable_cache_manager_v1: + for req in need_postprocess_reqs: + if self.config.scheduler_config.splitwise_role == "decode": + # D instance uses simplified write method (does not rely on Radix Tree) + self.cache_manager.write_cache_to_storage_decode(req) + else: + # P instance / Mixed instance uses standard write method (relies on Radix Tree) + self.cache_manager.write_cache_to_storage(req) with self.lock: for req in need_postprocess_reqs: diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 9c7495fba72..738f185ffe0 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -143,6 +143,28 @@ def recv_pyobj(self, flags: int = 0): return envelope["data"] return envelope + def receive_pyobj_once(self, block=False): + """ + Receive a single Pickle-serializable message from the socket. + + Args: + block: If True, block until a message arrives. If False, return immediately. + + Returns: + Tuple of (error, data). error is None on success, data is None if no message. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmq socket has closed", None + try: + flags = 0 if block else zmq.NOBLOCK + return None, self.recv_pyobj(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + llm_logger.warning(f"[ZmqClient] receive_pyobj_once error: {e}") + return str(e), None + @abstractmethod def close(self): pass diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1f9b1902517..a136d0173bf 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -806,13 +806,7 @@ def insert_tasks_v1(self, req_dicts: BatchRequest, num_running_requests: int = N } if self.enable_mm: # Sort by idx to ensure attention mask offsets are filled in order during mm prefill - req_dicts.requests.sort(key=lambda r: r.idx) - if self.enable_cache_manager_v1: - # submit_swap_tasks handles: - # 1. Waiting for pending evict handlers before submitting new evict - # 2. write_back policy: waiting for evict to complete before submitting swap-in - # 3. Adding handlers to pending lists appropriately - self.cache_controller.submit_swap_tasks(req_dicts.cache_evict_metadata, req_dicts.cache_swap_metadata) + req_dicts = sorted(req_dicts, key=lambda r: r.idx) for i in range(req_len): request = req_dicts[i] diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 28a943cf9d4..9db64b2033a 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -20,9 +20,11 @@ import os import time import traceback +from concurrent.futures import ThreadPoolExecutor from typing import Tuple import numpy as np +import zmq from fastdeploy.logger.logger import intercept_paddle_loggers @@ -70,6 +72,7 @@ RearrangeExpertStatus, ) from fastdeploy.inter_communicator.fmq import FMQ +from fastdeploy.inter_communicator.zmq_client import ZmqIpcClient from fastdeploy.model_executor.layers.quantization import parse_quant_config from fastdeploy.model_executor.utils import v1_loader_support from fastdeploy.platforms import current_platform @@ -185,6 +188,111 @@ def init_control(self): logger.info(f"Init Control Output Queue: {queue_name}(producer)") self._ctrl_output = FMQ().queue(queue_name, "producer") + def init_prefetch_zmq_clients(self) -> None: + """ + Initialize ZMQ PUSH client for storage prefetch done notification. + + The prefetch commands are now received via batch_request (EngineWorkerQueue), + but completion notifications are still sent back to Scheduler via ZMQ PUSH. + + Only initialized when storage backend is configured and cache_manager_v1 is enabled. + """ + if not self.fd_config.cache_config.kvcache_storage_backend or not envs.ENABLE_V1_KVCACHE_MANAGER: + return + + port = self.parallel_config.local_engine_worker_queue_port + local_rank = self.local_rank + + done_name = f"prefetch_done_rank{local_rank}_{port}" + + self._prefetch_done_client = ZmqIpcClient(name=done_name, mode=zmq.PUSH) + self._prefetch_done_client.connect() + self._prefetch_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="StoragePrefetch") + + logger.info(f"[StoragePrefetch] rank={local_rank} ZMQ done client connected: done={done_name}") + + def _handle_prefetch_tasks(self, prefetch_tasks) -> None: + """ + Handle storage prefetch tasks received from batch_request. + + Submits each prefetch task to thread pool for async execution. + On completion, sends done notification back to scheduler via ZMQ PUSH. + + Args: + prefetch_tasks: List of PendingPrefetch from batch_request. + """ + for task in prefetch_tasks: + self._prefetch_executor.submit(self._execute_single_prefetch, task) + + def _execute_single_prefetch(self, task) -> None: + """ + Execute a single storage prefetch task and send done notification via ZMQ. + + Args: + task: PendingPrefetch object with request_id, metadata, host_block_ids. + """ + local_rank = self.local_rank + request_id = task.request_id + metadata = task.metadata + + logger.info( + f"[Debug][StoragePrefetch][Worker] rank={local_rank} executing prefetch for " + f"request_id={request_id}, block_ids={metadata.block_ids}, timeout={metadata.timeout}" + ) + + try: + cache_controller = self.worker.model_runner.cache_controller + handler = cache_controller.prefetch_from_storage(metadata) + + start_time = time.time() + completed = handler.wait(timeout=metadata.timeout) + elapsed = time.time() - start_time + + if completed and handler.error is None: + logger.info( + f"[Debug][StoragePrefetch][Worker] rank={local_rank} prefetch OK for " + f"request_id={request_id}, elapsed={elapsed:.3f}s" + ) + done_msg = { + "request_id": request_id, + "host_block_ids": metadata.block_ids, + "status": "ok", + } + else: + error_str = handler.error or "timeout" + logger.warning( + f"[Debug][StoragePrefetch][Worker] rank={local_rank} prefetch failed for " + f"request_id={request_id}: {error_str}, elapsed={elapsed:.3f}s" + ) + done_msg = { + "request_id": request_id, + "host_block_ids": metadata.block_ids, + "status": "error", + "error": error_str, + } + + self._prefetch_done_client.send_pyobj(done_msg) + logger.info( + f"[Debug][StoragePrefetch][Worker] rank={local_rank} sent done msg for " + f"request_id={request_id}, status={done_msg['status']}" + ) + + except Exception as e: + logger.error( + f"[StoragePrefetch] rank={local_rank} execute_prefetch exception for " + f"request_id={request_id}: {e}\n{traceback.format_exc()}" + ) + try: + done_msg = { + "request_id": request_id, + "host_block_ids": metadata.block_ids, + "status": "error", + "error": str(e), + } + self._prefetch_done_client.send_pyobj(done_msg) + except Exception: + pass + def init_health_status(self) -> None: """ Initialize the health status of the worker. @@ -570,6 +678,24 @@ def event_loop_normal(self) -> None: batch_request, control_reqs, max_occupied_batch_index = BatchRequest.from_tasks(tasks) + # Handle storage prefetch tasks from batch_request (async, non-blocking) + if batch_request.storage_prefetch_tasks: + logger.info( + f"[Debug][StoragePrefetch][Worker] rank={self.local_rank} received " + f"{len(batch_request.storage_prefetch_tasks)} prefetch tasks: " + f"request_ids={[t.request_id for t in batch_request.storage_prefetch_tasks]}" + ) + self._handle_prefetch_tasks(batch_request.storage_prefetch_tasks) + batch_request.storage_prefetch_tasks = None + + # Handle swap/evict tasks from batch_request + if batch_request.cache_evict_metadata or batch_request.cache_swap_metadata: + self.worker.model_runner.cache_controller.submit_swap_tasks( + batch_request.cache_evict_metadata, batch_request.cache_swap_metadata + ) + batch_request.cache_evict_metadata = None + batch_request.cache_swap_metadata = None + if len(control_reqs) > 0: logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") for control_req in control_reqs: @@ -1341,6 +1467,10 @@ def run_worker_proc() -> None: # Initialize health status worker_proc.init_health_status() + # Initialize storage prefetch ZMQ clients and start prefetch_loop thread. + # Must be called after model/cache initialization so cache_controller is ready. + worker_proc.init_prefetch_zmq_clients() + worker_proc.start_task_queue_service() # Start event loop diff --git a/tests/cache_manager/v1/test_cache_controller.py b/tests/cache_manager/v1/test_cache_controller.py index 858dbf69b56..55974f24ced 100644 --- a/tests/cache_manager/v1/test_cache_controller.py +++ b/tests/cache_manager/v1/test_cache_controller.py @@ -572,29 +572,18 @@ def setUp(self): def test_prefetch_from_storage_returns_error_handler(self): """Test prefetch_from_storage returns error handler (not implemented).""" - from fastdeploy.cache_manager.v1.metadata import StorageMetadata - - mock_metadata = MagicMock(spec=StorageMetadata) + mock_metadata = MagicMock() + mock_metadata.hash_values = [] + mock_metadata.block_ids = [] handler = self.controller.prefetch_from_storage(mock_metadata) self.assertIsNotNone(handler) self.assertIsNotNone(handler.error) - def test_backup_device_to_storage_returns_error_handler(self): - """Test backup_device_to_storage returns error handler (not implemented).""" - from fastdeploy.cache_manager.v1.metadata import StorageMetadata - - mock_metadata = MagicMock(spec=StorageMetadata) - handler = self.controller.backup_device_to_storage([0, 1], mock_metadata) - - self.assertIsNotNone(handler) - self.assertIsNotNone(handler.error) - def test_backup_host_to_storage_returns_error_handler(self): - """Test backup_host_to_storage returns error handler (not implemented).""" - from fastdeploy.cache_manager.v1.metadata import StorageMetadata - - mock_metadata = MagicMock(spec=StorageMetadata) + """Test backup_host_to_storage returns error handler.""" + mock_metadata = MagicMock() + mock_metadata.hash_values = [] handler = self.controller.backup_host_to_storage([0, 1], mock_metadata) self.assertIsNotNone(handler) diff --git a/tests/cache_manager/v1/test_cache_manager.py b/tests/cache_manager/v1/test_cache_manager.py index 5ae7b4f3658..2bbd7cfe2d7 100644 --- a/tests/cache_manager/v1/test_cache_manager.py +++ b/tests/cache_manager/v1/test_cache_manager.py @@ -435,9 +435,9 @@ def test_match_prefix_updates_ref_count(self): ) cache_manager.match_prefix(req2) - # Ref count should be incremented, nodes not evictable + # Ref count not incremented in non-scheduling match_prefix (skip_storage=False by default) stats2 = cache_manager.radix_tree.get_stats() - self.assertEqual(stats2.evictable_device_count, 0) + self.assertEqual(stats2.evictable_device_count, 2) def test_insert_and_find_prefix(self): """Test inserting blocks and finding prefix.""" @@ -709,68 +709,83 @@ def test_storage_scheduler_none_by_default(self): _ = cache_manager.storage_scheduler -# --------------------------------------------------------------------------- -# offload_to_host -# --------------------------------------------------------------------------- +class TestUpdateStorageBlocksToHost(unittest.TestCase): + """Tests for CacheManager.update_storage_blocks_to_host().""" + def _make_node_with_status(self, cache_manager, status): + """Allocate a host block and return a BlockNode with the given CacheStatus.""" + from fastdeploy.cache_manager.v1.metadata import BlockNode -class TestCacheManagerOffloadToHost(unittest.TestCase): - """Tests for CacheManager.offload_to_host.""" + block_ids = cache_manager._host_pool.allocate(1) + self.assertIsNotNone(block_ids, "host pool exhausted") + block_id = block_ids[0] + node = BlockNode(block_id=block_id, hash_value="test_hash", cache_status=status) + return node, block_id - def test_offload_frees_device_blocks(self): - """After offload, device blocks should be released.""" - cm = create_cache_manager(total_block_num=20, num_cpu_blocks=20) - device_blocks = cm._device_pool.allocate(4) - self.assertIsNotNone(device_blocks) - free_before = cm.num_free_device_blocks + def test_update_transitions_loading_to_host(self): + """update_storage_blocks_to_host transitions LOADING_FROM_STORAGE → HOST.""" + from fastdeploy.cache_manager.v1.metadata import CacheStatus - success = cm.offload_to_host(device_blocks) + cache_manager = create_cache_manager(num_cpu_blocks=10) + node, block_id = self._make_node_with_status(cache_manager, CacheStatus.LOADING_FROM_STORAGE) + cache_manager._prefetch_node_map[block_id] = node - self.assertTrue(success) - self.assertEqual(cm.num_free_device_blocks, free_before + 4) + cache_manager.update_storage_blocks_to_host([block_id]) - def test_offload_allocates_host_blocks(self): - """After offload, host blocks should be consumed.""" - cm = create_cache_manager(total_block_num=20, num_cpu_blocks=20) - device_blocks = cm._device_pool.allocate(3) - free_host_before = cm.num_free_host_blocks + self.assertEqual(node.cache_status, CacheStatus.HOST) + self.assertNotIn(block_id, cache_manager._prefetch_node_map) - cm.offload_to_host(device_blocks) + def test_update_multiple_blocks(self): + """update_storage_blocks_to_host handles multiple blocks in one call.""" + from fastdeploy.cache_manager.v1.metadata import CacheStatus - self.assertEqual(cm.num_free_host_blocks, free_host_before - 3) + cache_manager = create_cache_manager(num_cpu_blocks=20) + nodes = [] + block_ids = [] + for i in range(5): + node, block_id = self._make_node_with_status(cache_manager, CacheStatus.LOADING_FROM_STORAGE) + cache_manager._prefetch_node_map[block_id] = node + nodes.append(node) + block_ids.append(block_id) - def test_offload_fails_when_no_host_blocks(self): - """Offload should return False when host pool is exhausted.""" - cm = create_cache_manager(total_block_num=20, num_cpu_blocks=0) - device_blocks = cm._device_pool.allocate(2) + cache_manager.update_storage_blocks_to_host(block_ids) - success = cm.offload_to_host(device_blocks) - self.assertFalse(success) + for node in nodes: + self.assertEqual(node.cache_status, CacheStatus.HOST) + for block_id in block_ids: + self.assertNotIn(block_id, cache_manager._prefetch_node_map) - def test_offload_copies_device_metadata_to_host(self): - """Metadata on device blocks should be copied to host blocks.""" - from fastdeploy.cache_manager.v1.metadata import CacheBlockMetadata + def test_update_unknown_block_id_is_ignored(self): + """Unknown block_id (not in prefetch_node_map) logs warning and continues.""" + cache_manager = create_cache_manager(num_cpu_blocks=10) + # Should not raise even if the block_id is not in the map + cache_manager.update_storage_blocks_to_host([9999]) - cm = create_cache_manager(total_block_num=20, num_cpu_blocks=20) - device_blocks = cm._device_pool.allocate(1) - block_id = device_blocks[0] - meta = CacheBlockMetadata(block_id=block_id, device_id=0, block_size=64, ref_count=5) - cm._device_pool.set_metadata(block_id, meta) - - cm.offload_to_host(device_blocks) - - # Find the newly used host block (last used) - used_host = list(cm._host_pool._used_blocks) - self.assertEqual(len(used_host), 1) - host_meta = cm._host_pool.get_metadata(used_host[0]) - self.assertIsNotNone(host_meta) - self.assertEqual(host_meta.ref_count, 5) - - def test_offload_empty_list_returns_true(self): - """Offloading empty list succeeds.""" - cm = create_cache_manager() - success = cm.offload_to_host([]) - self.assertTrue(success) + def test_update_empty_list_is_noop(self): + """Empty block_ids list is a no-op.""" + cache_manager = create_cache_manager(num_cpu_blocks=10) + # Should not raise or modify anything + cache_manager.update_storage_blocks_to_host([]) + + def test_update_wrong_status_does_not_change(self): + """Block already in HOST status is not re-processed.""" + from fastdeploy.cache_manager.v1.metadata import CacheStatus + + cache_manager = create_cache_manager(num_cpu_blocks=10) + node, block_id = self._make_node_with_status(cache_manager, CacheStatus.HOST) + cache_manager._prefetch_node_map[block_id] = node + + cache_manager.update_storage_blocks_to_host([block_id]) + + # Status must remain HOST (not changed to something else) + self.assertEqual(node.cache_status, CacheStatus.HOST) + # The node is still removed from the map + self.assertNotIn(block_id, cache_manager._prefetch_node_map) + + def test_prefetch_node_map_initially_empty(self): + """_prefetch_node_map is empty on a fresh CacheManager.""" + cache_manager = create_cache_manager() + self.assertEqual(len(cache_manager._prefetch_node_map), 0) # --------------------------------------------------------------------------- @@ -903,31 +918,64 @@ def test_issue_returns_none_when_host_cache_disabled(self): self.assertEqual(cm.get_pending_backup_count(), 0) -# --------------------------------------------------------------------------- -# prepare_prefetch_metadata -# --------------------------------------------------------------------------- +class TestPreparePrefixtMetadataStartNode(unittest.TestCase): + """Regression test for the start_node bug in prepare_prefetch_metadata. + Before the fix, prepare_prefetch_metadata called radix_tree.insert without + start_node, which inserted LOADING_FROM_STORAGE nodes as children of root + (using storage hashes h22..h29 at depth 1) instead of as extensions of the + existing device prefix chain (at depth 22..29). As a result, a subsequent + find_prefix on the full hash list would traverse root → h0 → ... → h21, + then fail to find h22 as a child of node(21), and stop at 22 nodes — never + reaching the HOST nodes even after update_storage_blocks_to_host. + """ -class TestCacheManagerPreparePrefetchMetadata(unittest.TestCase): - """Tests for CacheManager.prepare_prefetch_metadata.""" + def test_find_prefix_finds_host_blocks_after_prefetch(self): + """After prepare_prefetch_metadata + update_storage_blocks_to_host, + find_prefix must return all 30 nodes (22 DEVICE + 8 HOST).""" + from fastdeploy.cache_manager.v1.metadata import CacheStatus - def test_empty_hashes_returns_none(self): - cm = create_cache_manager() - result = cm.prepare_prefetch_metadata([]) - self.assertIsNone(result) + cm = create_cache_manager(total_block_num=50, num_cpu_blocks=20) + rt = cm._radix_tree - def test_returns_nodes_when_host_blocks_available(self): - cm = create_cache_manager(num_cpu_blocks=20) - hashes = ["hash_a", "hash_b"] - result = cm.prepare_prefetch_metadata(hashes) - # Should return a list (possibly empty if no host blocks or tree reuse) - self.assertIsInstance(result, list) - - def test_returns_empty_when_insufficient_host_blocks(self): - cm = create_cache_manager(total_block_num=20, num_cpu_blocks=0) - result = cm.prepare_prefetch_metadata(["h1", "h2"]) - # With no host blocks, should return empty or None - self.assertFalse(result) # None or [] + # Build 30 hashes: 22 for device, 8 for storage + all_hashes = [f"h{i}" for i in range(30)] + device_hashes = all_hashes[:22] + storage_hashes = all_hashes[22:] + + # Insert 22 device blocks into the radix tree + device_block_ids = cm._device_pool.allocate(22) + self.assertIsNotNone(device_block_ids) + device_nodes, _ = rt.insert( + blocks=list(zip(device_hashes, device_block_ids)), + cache_status=CacheStatus.DEVICE, + ) + self.assertEqual(len(device_nodes), 22) + + # The last device node is the correct start_node for the storage insertion + last_device_node = device_nodes[-1] + + # prepare_prefetch_metadata should attach storage nodes AFTER the last device node + storage_nodes = cm.prepare_prefetch_metadata(storage_hashes, start_node=last_device_node) + self.assertEqual(len(storage_nodes), 8) + for node in storage_nodes: + self.assertEqual(node.cache_status, CacheStatus.LOADING_FROM_STORAGE) + + # Simulate prefetch completion: transition LOADING_FROM_STORAGE → HOST + storage_block_ids = [n.block_id for n in storage_nodes] + for node in storage_nodes: + cm._prefetch_node_map[node.block_id] = node + cm.update_storage_blocks_to_host(storage_block_ids) + for node in storage_nodes: + self.assertEqual(node.cache_status, CacheStatus.HOST) + + # Now find_prefix on all 30 hashes must return 30 nodes + found = rt.find_prefix(all_hashes) + self.assertEqual(len(found), 30, f"Expected 30 nodes, got {len(found)}") + device_found = [n for n in found if n.is_on_device()] + host_found = [n for n in found if n.is_on_host()] + self.assertEqual(len(device_found), 22) + self.assertEqual(len(host_found), 8) if __name__ == "__main__": diff --git a/tests/cache_manager/v1/test_cache_utils.py b/tests/cache_manager/v1/test_cache_utils.py index aafd8206ef4..3a5356caab3 100644 --- a/tests/cache_manager/v1/test_cache_utils.py +++ b/tests/cache_manager/v1/test_cache_utils.py @@ -386,11 +386,6 @@ def test_item_end_equals_end_idx_fully_contained(self): self.assertIn("h-exact-end", keys) -# --------------------------------------------------------------------------- -# hash_block_tokens -# --------------------------------------------------------------------------- - - class TestHashBlockTokens(unittest.TestCase): """Direct tests for hash_block_tokens.""" diff --git a/tests/cache_manager/v1/test_transfer_manager.py b/tests/cache_manager/v1/test_transfer_manager.py index a5880b1be24..cf173f3684b 100644 --- a/tests/cache_manager/v1/test_transfer_manager.py +++ b/tests/cache_manager/v1/test_transfer_manager.py @@ -647,137 +647,105 @@ def test_get_stats_includes_expected_keys(self): self.assertTrue(stats["has_host_cache"]) -# --------------------------------------------------------------------------- -# _swap_single_layer – validation paths (no real GPU transfer needed) -# --------------------------------------------------------------------------- - - -class TestSwapSingleLayer(unittest.TestCase): - """Tests for CacheTransferManager._swap_single_layer validation paths.""" - - def setUp(self): - self.tm = create_transfer_manager(enable_prefix_caching=True, num_host_blocks=0) - - def test_returns_false_when_no_host_blocks(self): - """_swap_single_layer returns False when _num_host_blocks <= 0.""" - self.assertEqual(self.tm._num_host_blocks, 0) - result = self.tm._swap_single_layer( - layer_idx=0, - device_block_ids=[0, 1], - host_block_ids=[10, 11], - mode=0, - ) - self.assertFalse(result) - - def test_returns_false_when_empty_device_ids(self): - """_swap_single_layer returns False when device_block_ids is empty.""" - tm = create_transfer_manager(num_host_blocks=50) - result = tm._swap_single_layer( - layer_idx=0, - device_block_ids=[], - host_block_ids=[10], - mode=0, - ) - self.assertFalse(result) - - def test_returns_false_when_empty_host_ids(self): - """_swap_single_layer returns False when host_block_ids is empty.""" - tm = create_transfer_manager(num_host_blocks=50) - result = tm._swap_single_layer( - layer_idx=0, - device_block_ids=[0], - host_block_ids=[], - mode=0, - ) - self.assertFalse(result) +# ============================================================================ +# Storage Key Format Tests +# ============================================================================ - def test_returns_false_when_length_mismatch(self): - """_swap_single_layer returns False when lists have different lengths.""" - tm = create_transfer_manager(num_host_blocks=50) - result = tm._swap_single_layer( - layer_idx=0, - device_block_ids=[0, 1], - host_block_ids=[10], - mode=0, - ) - self.assertFalse(result) - def test_returns_false_when_no_device_cache(self): - """_swap_single_layer returns False when device cache map not set.""" - tm = create_transfer_manager(num_host_blocks=50) - # No cache map set → get_device_key_cache returns None - result = tm._swap_single_layer( - layer_idx=0, - device_block_ids=[0], - host_block_ids=[10], - mode=0, - ) - self.assertFalse(result) +class TestStorageKeyFormat(unittest.TestCase): + """Test storage_key_for_block produces per-layer keys.""" + def test_key_format_per_layer(self): + """Per-layer key: '{hash}_{rank}_key_{layer_idx}'.""" + from fastdeploy.cache_manager.v1.cache_utils import storage_key_for_block -# --------------------------------------------------------------------------- -# sync_input_stream / sync_output_stream -# --------------------------------------------------------------------------- + key = storage_key_for_block("abc123", 0, "key", 5) + self.assertEqual(key, "abc123_0_key_5") + def test_key_format_no_layer(self): + """Backward-compat: without layer_idx, key is '{hash}_{rank}_key'.""" + from fastdeploy.cache_manager.v1.cache_utils import storage_key_for_block -class TestSyncStreams(unittest.TestCase): - """Tests for sync_input_stream and sync_output_stream.""" + key = storage_key_for_block("abc123", 0, "key") + self.assertEqual(key, "abc123_0_key") - def test_sync_input_stream_no_stream_does_not_raise(self): - """When _input_stream is None, sync_input_stream should not raise.""" - tm = create_transfer_manager() - tm._input_stream = None - tm.sync_input_stream() # should not raise + def test_value_format_per_layer(self): + from fastdeploy.cache_manager.v1.cache_utils import storage_key_for_block - def test_sync_output_stream_no_stream_does_not_raise(self): - """When _output_stream is None, sync_output_stream should not raise.""" - tm = create_transfer_manager() - tm._output_stream = None - tm.sync_output_stream() # should not raise + key = storage_key_for_block("abc123", 1, "value", 3) + self.assertEqual(key, "abc123_1_value_3") - def test_sync_input_stream_with_mock_stream(self): - """sync_input_stream calls synchronize() on the stream.""" - from unittest.mock import MagicMock + def test_scale_format_per_layer(self): + from fastdeploy.cache_manager.v1.cache_utils import storage_key_for_block - tm = create_transfer_manager() - mock_stream = MagicMock() - tm._input_stream = mock_stream - tm.sync_input_stream() - mock_stream.synchronize.assert_called_once() + key = storage_key_for_block("abc123", 0, "key_scale", 0) + self.assertEqual(key, "abc123_0_key_scale_0") - def test_sync_output_stream_with_mock_stream(self): - """sync_output_stream calls synchronize() on the stream.""" - from unittest.mock import MagicMock - tm = create_transfer_manager() - mock_stream = MagicMock() - tm._output_stream = mock_stream - tm.sync_output_stream() - mock_stream.synchronize.assert_called_once() +# ============================================================================ +# Build Per-Layer IO Args Tests +# ============================================================================ -# --------------------------------------------------------------------------- -# record_input_stream_event -# --------------------------------------------------------------------------- +class TestBuildPerLayerIOArgs(unittest.TestCase): + """Test _build_per_layer_io_args helper.""" + def setUp(self): + self.manager = create_transfer_manager() + num_layers = self.manager._num_layers + host_cache = create_mock_host_cache_kvs_map(num_layers=num_layers) + self.manager.set_host_cache_kvs_map(host_cache) + self.manager._host_key_block_stride_bytes = 1024 + self.manager._host_value_block_stride_bytes = 1024 + + def test_basic_keys(self): + hash_list = ["h1", "h2"] + cpu_block_list = [10, 20] + flat_keys, flat_ptrs, flat_sizes, flat_indices, flat_kinds = self.manager._build_per_layer_io_args( + hash_list, cpu_block_list + ) -class TestRecordInputStreamEvent(unittest.TestCase): - """Tests for record_input_stream_event.""" + num_kinds = 2 # key, value + expected_len = len(hash_list) * self.manager._num_layers * num_kinds + self.assertEqual(len(flat_keys), expected_len) + self.assertEqual(len(flat_ptrs), expected_len) + self.assertEqual(len(flat_sizes), expected_len) + self.assertEqual(len(flat_indices), expected_len) + self.assertEqual(len(flat_kinds), expected_len) + + # First entry should be block 0, layer 0, kind "key" + self.assertEqual(flat_keys[0], "h1_0_key_0") + self.assertEqual(flat_sizes[0], 1024) + self.assertEqual(flat_indices[0], 0) + self.assertEqual(flat_kinds[0], "key") + + def test_block_indices(self): + hash_list = ["h1", "h2"] + cpu_block_list = [10, 20] + flat_keys, flat_ptrs, flat_sizes, flat_indices, flat_kinds = self.manager._build_per_layer_io_args( + hash_list, cpu_block_list + ) - def test_returns_none_when_no_cupy(self): - """When cupy unavailable (_input_stream is None), returns None.""" - tm = create_transfer_manager() - tm._input_stream = None - result = tm.record_input_stream_event() - self.assertIsNone(result) + per_block = self.manager._num_layers * 2 + # First per_block entries belong to block 0 + for i in range(per_block): + self.assertEqual(flat_indices[i], 0) + # Next per_block entries belong to block 1 + for i in range(per_block, 2 * per_block): + self.assertEqual(flat_indices[i], 1) + + def test_sizes_match_strides(self): + hash_list = ["h1"] + cpu_block_list = [0] + flat_keys, flat_ptrs, flat_sizes, flat_indices, flat_kinds = self.manager._build_per_layer_io_args( + hash_list, cpu_block_list + ) - def test_returns_none_when_input_stream_none(self): - """Explicitly set _input_stream to None → returns None.""" - tm = create_transfer_manager() - # Patch _HAS_CUPY via the module, or just verify None path works - tm._input_stream = None - result = tm.record_input_stream_event() - self.assertIsNone(result) + for i, kind in enumerate(flat_kinds): + if kind in ("key", "value"): + self.assertEqual(flat_sizes[i], 1024) + elif kind in ("key_scale", "value_scale"): + self.assertEqual(flat_sizes[i], 0) if __name__ == "__main__": diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 799212e1351..197473a9459 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -41,6 +41,7 @@ _read_latest_worker_traceback, ) from fastdeploy.engine.request import ( + BatchRequest, ControlRequest, ControlResponse, Request, @@ -325,7 +326,11 @@ def available_batch(self): def schedule(self): eng.running = False - return schedule_result + tasks, error_tasks = schedule_result + br = BatchRequest() + for t in tasks: + br.add_request(t) + return br, error_tasks def get_real_bsz(self): return self.real_bsz diff --git a/tests/multimodal/test_mm_warmup.py b/tests/multimodal/test_mm_warmup.py new file mode 100644 index 00000000000..2eee4a962da --- /dev/null +++ b/tests/multimodal/test_mm_warmup.py @@ -0,0 +1,526 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +测试多模态 warmup 相关逻辑: + - ErnieMM45DataProcessor.prepare_mm_split_fuse_fields + - Engine._build_mm_warmup_data +""" +import sys +import types +import unittest +from unittest.mock import MagicMock + +import numpy as np + +# --------------------------------------------------------------------------- +# 构造最小 Mock 模块,避免 import 重型依赖 +# --------------------------------------------------------------------------- + + +def _make_paddle_mock(): + """返回一个能满足 prepare_mm_split_fuse_fields 中所有调用的 paddle mock。""" + paddle = MagicMock(name="paddle") + + class _T: + """统一的轻量 Tensor stub,支持 cast/cpu/cumsum/concat/squeeze/repeat_interleave/numpy/tolist。""" + + def __init__(self, data): + self._data = np.array(data, dtype=np.float32) + + def cast(self, dtype): + return self + + def cpu(self): + return self + + def squeeze(self, dims=None): + return _T(self._data.squeeze()) + + def repeat_interleave(self, n, dim=None): + return _T(np.repeat(self._data.flatten(), n)) + + def numpy(self): + return self._data + + def tolist(self): + return self._data.tolist() + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + # 按元素比较,返回 _T,模拟 paddle.Tensor == scalar + val = other._data if isinstance(other, _T) else other + return _T((self._data == val).astype(np.float32)) + + def to_tensor(data, dtype=None): + return _T(data) + + def zeros(shape, dtype=None): + return _T(np.zeros(shape)) + + def cumsum(x): + return _T(np.cumsum(x._data)) + + def concat(tensors): + arrays = [np.atleast_1d(t._data) for t in tensors] + return _T(np.concatenate(arrays)) + + def where(cond, x, y): + # cond 是 _T(比较结果),直接返回:0/1 值就是 is_image_token 的内容 + if isinstance(cond, _T): + return cond + # fallback:标量 bool + return _T(np.array([x if cond else y])) + + paddle.to_tensor = to_tensor + paddle.zeros = zeros + paddle.cumsum = cumsum + paddle.concat = concat + paddle.where = where + return paddle + + +def _setup_sys_mocks(): + """注入所有需要 Mock 的模块到 sys.modules。""" + # paddle + paddle_mock = _make_paddle_mock() + sys.modules.setdefault("paddle", paddle_mock) + + # server.engine.config + config_mod = types.ModuleType("server.engine.config") + + class VitMode: + VIT_INCOMPLETE = MagicMock(name="VIT_INCOMPLETE") + VIT_INCOMPLETE.name = "VIT_INCOMPLETE" + VIT_COMPLETED = MagicMock(name="VIT_COMPLETED") + VIT_COMPLETED.name = "VIT_COMPLETED" + + env_cfg = MagicMock() + env_cfg.image_patch_id = 151859 + env_cfg.split_fuse_size_image = 2048 + env_cfg.split_fuse_size = 1024 + env_cfg.ellm_dynamic_mode = False + env_cfg.enable_vpd_split = False + env_cfg.multi_modal_model_v45_turbo = True + + config_mod.VitMode = VitMode + config_mod.get_config = lambda: env_cfg + sys.modules["server"] = types.ModuleType("server") + sys.modules["server.engine"] = types.ModuleType("server.engine") + sys.modules["server.engine.config"] = config_mod + + # server.utils + utils_mod = types.ModuleType("server.utils") + utils_mod.data_processor_logger = MagicMock() + utils_mod.model_server_logger = MagicMock() + sys.modules["server.utils"] = utils_mod + + # server.data.base_processor + base_proc_mod = types.ModuleType("server.data.base_processor") + base_proc_mod.BaseDataProcessor = object + sys.modules["server.data"] = types.ModuleType("server.data") + sys.modules["server.data.base_processor"] = base_proc_mod + + # server.data.ernie_tokenizer + tok_mod = types.ModuleType("server.data.ernie_tokenizer") + tok_mod.ErnieBotTokenizer = MagicMock() + sys.modules["server.data.ernie_tokenizer"] = tok_mod + + # toolkit + toolkit_mod = types.ModuleType("toolkit") + toolkit_mod.ProcessedDataLoader = MagicMock() + sys.modules["toolkit"] = toolkit_mod + + # custom_setup_ops (get_mm_split_fuse 通过这里导入) + ops_mod = types.ModuleType("custom_setup_ops") + sys.modules["custom_setup_ops"] = ops_mod + + # server.data.data_processor.* + for submod in [ + "server.data.data_processor", + "server.data.data_processor.data_processor", + "server.data.data_processor.data_processor.utils", + "server.data.data_processor.data_processor.utils.argparser", + "server.data.data_processor.data_processor.steps", + "server.data.data_processor.data_processor.steps.end2end_processing", + ]: + sys.modules.setdefault(submod, types.ModuleType(submod)) + + argparser_mod = sys.modules["server.data.data_processor.data_processor.utils.argparser"] + argparser_mod.PdArgumentParser = MagicMock() + argparser_mod.get_config = MagicMock() + + e2e_mod = sys.modules["server.data.data_processor.data_processor.steps.end2end_processing"] + e2e_mod.End2EndProcessor = MagicMock() + e2e_mod.End2EndProcessorArguments = MagicMock() + + return env_cfg, VitMode + + +# --------------------------------------------------------------------------- +# 构造一个轻量的 ErnieMM45DataProcessor stub,仅含被测方法 +# --------------------------------------------------------------------------- + + +def _make_processor_stub(env_cfg, get_mm_split_fuse_fn): + """ + 返回一个只有 prepare_mm_split_fuse_fields 的最小 processor 实例。 + image_preprocess、patch_size、temporal_patch_size 全部手动注入。 + """ + + class _ImagePreprocess: + rescale_factor = 0.00392156862745098 # 1/255 + image_mean = [0.485, 0.456, 0.406] + image_std = [0.229, 0.224, 0.225] + image_mean_tensor = np.array(image_mean, dtype="float32").reshape(1, 3, 1, 1) + image_std_tensor = np.array(image_std, dtype="float32").reshape(1, 3, 1, 1) + + class _Processor: + patch_size = 14 + temporal_patch_size = 1 + image_preprocess = _ImagePreprocess() + + def prepare_mm_split_fuse_fields(self, data): + # 直接从 ernie_45mm_processor 中复制,但注入 mock 的 get_mm_split_fuse + import paddle as _paddle + + input_ids = _paddle.to_tensor(data["input_ids"]).cast("int64") + is_image_token = _paddle.where(input_ids == env_cfg.image_patch_id, 1, 0) + image_token_sum = _paddle.cumsum(is_image_token) + image_token_sum = _paddle.concat([_paddle.zeros([1], dtype="int64"), image_token_sum]) + grid_thw = _paddle.to_tensor(data.get("grid_thw_list", []), dtype="int64") + image_type_ids_tensor = _paddle.to_tensor(list(data["image_type_ids"])).cast("int32") + + image_chunk_selections_task, split_fuse_cur_seq_lens_task = get_mm_split_fuse_fn( + input_ids.cpu(), + image_type_ids_tensor.cpu(), + image_token_sum.cast("int32").cpu(), + grid_thw.cpu(), + env_cfg.image_patch_id, + len(data.get("grid_thw_list", [])), + 0, + len(data["input_ids"]), + env_cfg.split_fuse_size_image, + env_cfg.split_fuse_size, + 2048, + ) + data["image_chunk_selections_task"] = image_chunk_selections_task.numpy().tolist() + data["split_fuse_cur_seq_lens_task"] = split_fuse_cur_seq_lens_task.numpy().tolist() + data["split_fuse_chunk_num"] = len(split_fuse_cur_seq_lens_task) + + data["rescale_factor"] = self.image_preprocess.rescale_factor + if env_cfg.multi_modal_model_v45_turbo: + data["image_mean_tensor"] = ( + _paddle.to_tensor(self.image_preprocess.image_mean_tensor) + .squeeze([-2, -1]) + .repeat_interleave(self.patch_size**2 * self.temporal_patch_size, -1) + .numpy() + .tolist() + ) + data["image_std_tensor"] = ( + _paddle.to_tensor(self.image_preprocess.image_std_tensor) + .squeeze([-2, -1]) + .repeat_interleave(self.patch_size**2 * self.temporal_patch_size, -1) + .numpy() + .tolist() + ) + else: + data["image_mean_tensor"] = self.image_preprocess.image_mean_tensor.numpy().tolist() + data["image_std_tensor"] = self.image_preprocess.image_std_tensor.numpy().tolist() + data["image_batch"] = len(data["image_type_ids"]) + return data + + return _Processor() + + +# --------------------------------------------------------------------------- +# 辅助:构造合成 warmup data(模仿 _build_mm_warmup_data 的前半部分) +# --------------------------------------------------------------------------- + + +def _build_synthetic_warmup_data(image_patch_id): + T, H, W = 1, 4, 4 + merge_size = 2 + H_eff, W_eff = H // merge_size, W // merge_size + num_img_tokens = T * H_eff * W_eff # 4 + + prefix_ids = [5, 5, 5] + img_ids = [image_patch_id] * num_img_tokens + suffix_ids = [5, 5, 5] + input_ids = prefix_ids + img_ids + suffix_ids + + t = len(prefix_ids) + position_ids = [[i, i, i] for i in range(t)] + for h in range(H_eff): + for w in range(W_eff): + position_ids.append([t, t + h, t + w]) + next_pos = t + W_eff + for k in range(len(suffix_ids)): + position_ids.append([next_pos + k] * 3) + + return ( + { + "input_ids": input_ids, + "grid_thw": [[T, H, W]], + "grid_thw_list": [[T, H, W]], + "image_type_ids": [0], + "position_ids": position_ids, + "image_dict": {}, + "media_info": {}, + }, + T, + H, + W, + H_eff, + W_eff, + num_img_tokens, + ) + + +# --------------------------------------------------------------------------- +# 测试类 +# --------------------------------------------------------------------------- + + +class TestPrepareMmSplitFuseFields(unittest.TestCase): + """单元测试 prepare_mm_split_fuse_fields。""" + + def setUp(self): + self.env_cfg, self.VitMode = _setup_sys_mocks() + self.image_patch_id = self.env_cfg.image_patch_id + + # mock get_mm_split_fuse 返回值:1 个 image chunk(crop_num=1) + # 用 MagicMock 包装,避免直接覆盖 np.ndarray 的只读属性 + self._chunk_sel_ret = MagicMock() + self._chunk_sel_ret.numpy.return_value = MagicMock() + self._chunk_sel_ret.numpy.return_value.tolist.return_value = [1] + + self._seq_lens_ret = MagicMock() + self._seq_lens_ret.numpy.return_value = MagicMock() + self._seq_lens_ret.numpy.return_value.tolist.return_value = [10] + self._seq_lens_ret.__len__ = lambda self: 1 + + def fake_get_mm_split_fuse(*args, **kwargs): + return self._chunk_sel_ret, self._seq_lens_ret + + self.processor = _make_processor_stub(self.env_cfg, fake_get_mm_split_fuse) + self.data, self.T, self.H, self.W, self.H_eff, self.W_eff, self.num_img_tokens = _build_synthetic_warmup_data( + self.image_patch_id + ) + + def test_split_fuse_fields_populated(self): + """调用后 image_chunk_selections_task / split_fuse_cur_seq_lens_task 必须存在且为列表。""" + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertIn("image_chunk_selections_task", result) + self.assertIn("split_fuse_cur_seq_lens_task", result) + self.assertIn("split_fuse_chunk_num", result) + self.assertIsInstance(result["image_chunk_selections_task"], list) + self.assertIsInstance(result["split_fuse_cur_seq_lens_task"], list) + + def test_chunk_num_consistent(self): + """split_fuse_chunk_num 应等于 split_fuse_cur_seq_lens_task 的长度。""" + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertEqual(result["split_fuse_chunk_num"], len(result["split_fuse_cur_seq_lens_task"])) + + def test_rescale_factor_populated(self): + """rescale_factor 应为非 None 的浮点数。""" + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertIsNotNone(result["rescale_factor"]) + self.assertIsInstance(result["rescale_factor"], float) + + def test_image_mean_std_tensor_populated(self): + """v45_turbo 模式下 image_mean_tensor / image_std_tensor 应为非空列表。""" + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertIsNotNone(result["image_mean_tensor"]) + self.assertIsNotNone(result["image_std_tensor"]) + self.assertIsInstance(result["image_mean_tensor"], list) + self.assertIsInstance(result["image_std_tensor"], list) + self.assertGreater(len(result["image_mean_tensor"]), 0) + + def test_image_mean_std_length_matches_patch(self): + """ + v45_turbo: mean/std 经 repeat_interleave(patch_size^2 * temporal_patch_size) + 展开后长度应为 3 * patch_size^2 * temporal_patch_size。 + """ + patch_size = self.processor.patch_size # 14 + temporal = self.processor.temporal_patch_size # 1 + expected_len = 3 * patch_size**2 * temporal # 3*196*1 = 588 + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertEqual(len(result["image_mean_tensor"]), expected_len) + self.assertEqual(len(result["image_std_tensor"]), expected_len) + + def test_image_batch_equals_image_type_ids_len(self): + """image_batch 应等于 image_type_ids 长度。""" + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertEqual(result["image_batch"], len(self.data["image_type_ids"])) + + def test_returns_same_dict(self): + """方法应原地修改并返回同一个 dict 对象。""" + result = self.processor.prepare_mm_split_fuse_fields(self.data) + self.assertIs(result, self.data) + + +class TestBuildMmWarmupData(unittest.TestCase): + """ + 测试 _build_mm_warmup_data 生成的数据结构。 + + Engine 依赖太多重型模块,这里直接测试 warmup data 的构造逻辑, + 使用独立函数模拟 Engine._build_mm_warmup_data 的行为。 + """ + + def setUp(self): + self.env_cfg, self.VitMode = _setup_sys_mocks() + self.image_patch_id = self.env_cfg.image_patch_id + + # 构造固定返回的 mock prepare_mm_split_fuse_fields + def fake_prepare(data): + data["image_chunk_selections_task"] = [1] + data["split_fuse_cur_seq_lens_task"] = [len(data["input_ids"])] + data["split_fuse_chunk_num"] = 1 + data["rescale_factor"] = 1 / 255.0 + data["image_mean_tensor"] = [0.0] * 588 + data["image_std_tensor"] = [1.0] * 588 + data["image_batch"] = 1 + return data + + mock_dp = MagicMock() + mock_dp.prepare_mm_split_fuse_fields.side_effect = fake_prepare + self.mock_dp = mock_dp + + # 构造 base_data(Engine.warmup 传入的纯文基础字段) + self.base_data = { + "req_id": "warmup_req", + "input_ids": [1, 2, 3], + } + + def _run_build(self): + """执行 _build_mm_warmup_data 的逻辑(不依赖真实 Engine 实例)。""" + image_patch_id = self.image_patch_id + T, H, W = 1, 4, 4 + merge_size = 2 + H_eff = H // merge_size + W_eff = W // merge_size + num_img_tokens = T * H_eff * W_eff + + prefix_ids = [5, 5, 5] + img_ids = [image_patch_id] * num_img_tokens + suffix_ids = [5, 5, 5] + input_ids = prefix_ids + img_ids + suffix_ids + + t = len(prefix_ids) + position_ids = [[i, i, i] for i in range(t)] + for h in range(H_eff): + for w in range(W_eff): + position_ids.append([t, t + h, t + w]) + next_pos = t + W_eff + for k in range(len(suffix_ids)): + position_ids.append([next_pos + k] * 3) + + data = dict(self.base_data) + data["input_ids"] = input_ids + data["grid_thw"] = [[T, H, W]] + data["image_type_ids"] = [0] + data["position_ids"] = position_ids + data["enable_thinking"] = False + data["max_think_len"] = -1 + data["max_content_len"] = -1 + + # v45_turbo 分支 + data["grid_thw_list"] = [[T, H, W]] + data["vit_mode"] = "VIT_INCOMPLETE" + data["use_vpd_split"] = False + data["image_dict"] = {} + data["media_info"] = {} + self.mock_dp.prepare_mm_split_fuse_fields(data) + + return data, T, H, W, H_eff, W_eff, num_img_tokens + + def test_input_ids_structure(self): + """input_ids = prefix(3) + image_tokens(4) + suffix(3) = 10。""" + data, T, H, W, H_eff, W_eff, num_img_tokens = self._run_build() + self.assertEqual(len(data["input_ids"]), 3 + num_img_tokens + 3) + # image patch id 在中间 + img_slice = data["input_ids"][3 : 3 + num_img_tokens] + self.assertTrue(all(x == self.image_patch_id for x in img_slice)) + + def test_position_ids_length(self): + """position_ids 长度应与 input_ids 一致。""" + data, *_ = self._run_build() + self.assertEqual(len(data["position_ids"]), len(data["input_ids"])) + + def test_position_ids_text_tokens_are_1d(self): + """文本 token 的 position_ids 三个维度相等(1D 编码)。""" + data, T, H, W, H_eff, W_eff, num_img_tokens = self._run_build() + prefix_pos = data["position_ids"][:3] + suffix_pos = data["position_ids"][3 + num_img_tokens :] + for pos in prefix_pos + suffix_pos: + self.assertEqual(pos[0], pos[1], msg=f"text token pos not 1D: {pos}") + self.assertEqual(pos[1], pos[2], msg=f"text token pos not 1D: {pos}") + + def test_position_ids_image_tokens_3d(self): + """ + 图片 token 的 position_ids: + - pos[0](t 维)全部相等 + - 覆盖 H_eff × W_eff 不同的 (h, w) 坐标 + """ + data, T, H, W, H_eff, W_eff, num_img_tokens = self._run_build() + img_pos = data["position_ids"][3 : 3 + num_img_tokens] + t_val = img_pos[0][0] + for pos in img_pos: + self.assertEqual(pos[0], t_val, msg=f"image t dim varies: {pos}") + + hw_pairs = {(pos[1], pos[2]) for pos in img_pos} + self.assertEqual( + len(hw_pairs), H_eff * W_eff, msg=f"expected {H_eff*W_eff} unique (h,w) pairs, got {hw_pairs}" + ) + + def test_grid_thw_and_grid_thw_list(self): + """grid_thw 和 grid_thw_list 内容一致。""" + data, T, H, W, *_ = self._run_build() + self.assertEqual(data["grid_thw"], [[T, H, W]]) + self.assertEqual(data["grid_thw_list"], [[T, H, W]]) + + def test_prepare_mm_split_fuse_fields_called(self): + """_build_mm_warmup_data 必须调用 data_processor.prepare_mm_split_fuse_fields。""" + self._run_build() + self.mock_dp.prepare_mm_split_fuse_fields.assert_called_once() + + def test_split_fuse_fields_not_none(self): + """prepare_mm_split_fuse_fields 填充的字段不能为 None。""" + data, *_ = self._run_build() + for key in [ + "image_chunk_selections_task", + "split_fuse_cur_seq_lens_task", + "rescale_factor", + "image_mean_tensor", + "image_std_tensor", + ]: + self.assertIsNotNone(data[key], msg=f"{key} should not be None") + + def test_vit_mode_and_use_vpd_split(self): + """vit_mode 应为 VIT_INCOMPLETE,use_vpd_split 应为 False。""" + data, *_ = self._run_build() + self.assertEqual(data["vit_mode"], "VIT_INCOMPLETE") + self.assertFalse(data["use_vpd_split"]) + + def test_image_dict_and_media_info_empty(self): + """image_dict 和 media_info 应为空 dict。""" + data, *_ = self._run_build() + self.assertEqual(data["image_dict"], {}) + self.assertEqual(data["media_info"], {}) + + +if __name__ == "__main__": + unittest.main(verbosity=2)