From 2ccf5438a914e5a43b7296cf486ef233c4b35c79 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 4 Nov 2025 09:53:31 +0000 Subject: [PATCH] [kvoffload] feat: make LMCache connecter work Signed-off-by: AlpinDale --- aphrodite/config/cache.py | 4 + .../aphrodite_v1_adapter.py | 12 +- .../lookup_client/__init__.py | 13 + .../lookup_client/abstract_client.py | 51 +++ .../lookup_client/factory.py | 151 ++++++++ .../lookup_client/hit_limit_lookup_client.py | 82 +++++ .../lmcache_async_lookup_client.py | 328 ++++++++++++++++++ .../lookup_client/lmcache_lookup_client.py | 249 +++++++++++++ .../lookup_client/mooncake_lookup_client.py | 90 +++++ .../offload_server/__init__.py | 14 + .../offload_server/abstract_server.py | 36 ++ .../offload_server/message.py | 29 ++ .../offload_server/zmq_server.py | 73 ++++ .../v1/lmcache_integration/rpc_utils.py | 120 +++++++ aphrodite/v1/core/kv_cache_utils.py | 16 +- aphrodite/v1/engine/core.py | 2 + aphrodite/v1/kv_cache_interface.py | 2 + aphrodite/v1/kv_offload/cpu.py | 39 ++- aphrodite/v1/worker/gpu_model_runner.py | 3 + aphrodite/v1/worker/tpu_model_runner.py | 2 + benchmarks/benchmark_prefix_caching.py | 11 + requirements/kv_connectors.txt | 2 +- 22 files changed, 1308 insertions(+), 21 deletions(-) create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/__init__.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/abstract_client.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/factory.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/hit_limit_lookup_client.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_async_lookup_client.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_lookup_client.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/mooncake_lookup_client.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/__init__.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/abstract_server.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/message.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/zmq_server.py create mode 100644 aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/rpc_utils.py diff --git a/aphrodite/config/cache.py b/aphrodite/config/cache.py index 15906e5f89..2ccda4264f 100644 --- a/aphrodite/config/cache.py +++ b/aphrodite/config/cache.py @@ -107,6 +107,10 @@ class CacheConfig: num_cpu_blocks: int | None = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" + # Will be set after model loading. + kv_bytes_per_block: int | None = field(default=None, init=False) + """The number of KV bytes per block, across all workers.""" + kv_sharing_fast_prefill: bool = False """This feature is work in progress and no prefill optimization takes place with this flag enabled currently. diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/aphrodite_v1_adapter.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/aphrodite_v1_adapter.py index f911dd0a5a..e3be3add1a 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/aphrodite_v1_adapter.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/aphrodite_v1_adapter.py @@ -9,7 +9,6 @@ import torch from lmcache import utils from lmcache.config import LMCacheEngineMetadata -from lmcache.logging import init_logger from lmcache.observability import LMCStatsMonitor from lmcache.utils import _lmcache_nvtx_annotate from lmcache.v1.cache_engine import LMCacheEngine, LMCacheEngineBuilder @@ -20,10 +19,6 @@ VLLMPagedMemGPUConnectorV2, VLLMPagedMemLayerwiseGPUConnector) from lmcache.v1.internal_api_server.api_server import InternalAPIServer -from lmcache.v1.lookup_client import LookupClientFactory -from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( - LMCacheAsyncLookupServer) -from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.plugin.plugin_launcher import PluginLauncher from aphrodite.common.sampling_params import SamplingParams @@ -35,11 +30,16 @@ lmcache_get_or_create_config, mla_enabled) from aphrodite.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tp_group) +from aphrodite.logger import init_logger from aphrodite.utils.math_utils import cdiv from aphrodite.utils.torch_utils import get_kv_cache_torch_dtype from aphrodite.v1.core.sched.output import SchedulerOutput from aphrodite.version import __version__ as APHRODITE_VERSION +from .lookup_client import LookupClientFactory +from .lookup_client.lmcache_async_lookup_client import LMCacheAsyncLookupServer +from .offload_server.zmq_server import ZMQOffloadServer + if TYPE_CHECKING: from aphrodite.attention.backends.abstract import AttentionMetadata from aphrodite.forward_context import ForwardContext @@ -819,7 +819,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: slot_mapping = request.slot_mapping.cuda() assert len(tokens) == len(slot_mapping) - self._stats_monitor.update_interval_aphrodite_hit_tokens( + self._stats_monitor.update_interval_vllm_hit_tokens( request.load_spec.aphrodite_cached_tokens ) token_mask = torch.ones(len(tokens), dtype=torch.bool) diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/__init__.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/__init__.py new file mode 100644 index 0000000000..8105b3b2f0 --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +from .abstract_client import LookupClientInterface +from .factory import LookupClientFactory +from .lmcache_lookup_client import LMCacheLookupClient, LMCacheLookupServer +from .mooncake_lookup_client import MooncakeLookupClient + +__all__ = [ + "LookupClientInterface", + "LookupClientFactory", + "MooncakeLookupClient", + "LMCacheLookupClient", + "LMCacheLookupServer", +] diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/abstract_client.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/abstract_client.py new file mode 100644 index 0000000000..241b1121ca --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/abstract_client.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import abc +from typing import TYPE_CHECKING, Optional, Union + +import torch + +if TYPE_CHECKING: + pass + + +class LookupClientInterface(metaclass=abc.ABCMeta): + """Abstract interface for lookup clients.""" + + @abc.abstractmethod + def lookup( + self, + token_ids: Union[torch.Tensor, list[int]], + lookup_id: str, + request_configs: Optional[dict] = None, + ) -> Optional[int]: + """ + Perform lookup for the given token IDs. + + Args: + token_ids: The token IDs to lookup + + lookup_id: The lookup ID to associate with the lookup + + request_configs: The configs of the request, + includes tags and the other configs + + Returns: + The number of tokens that can be loaded from cache. + None indicates the lookup/prefetch is in progress. + """ + raise NotImplementedError + + @abc.abstractmethod + def close(self) -> None: + """Close the lookup client and clean up resources.""" + raise NotImplementedError + + def supports_producer_reuse(self) -> bool: + """ + Return whether this lookup client supports producer KV cache reuse. + + Returns: + True if producer reuse is supported, False otherwise + """ + return False diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/factory.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/factory.py new file mode 100644 index 0000000000..feb6e352ea --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/factory.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import TYPE_CHECKING, Optional, Union + +from lmcache.v1.cache_engine import LMCacheEngine +from lmcache.v1.config import LMCacheEngineConfig + +from aphrodite.logger import init_logger + +from .abstract_client import LookupClientInterface +from .hit_limit_lookup_client import HitLimitLookupClient +from .mooncake_lookup_client import MooncakeLookupClient + +if TYPE_CHECKING: + from aphrodite.config import AphroditeConfig + + from .lmcache_async_lookup_client import LMCacheAsyncLookupServer + from .lmcache_lookup_client import LMCacheLookupServer + +logger = init_logger(__name__) + + +class LookupClientFactory: + """Factory for creating lookup clients and servers based on configuration.""" + + @staticmethod + def create_lookup_client( + aphrodite_config: "AphroditeConfig", + config: LMCacheEngineConfig, + ) -> LookupClientInterface: + """ + Create a lookup client based on the configuration. + + Args: + aphrodite_config: The Aphrodite configuration + config: The LMCache engine configuration + + Returns: + A lookup client instance + """ + + # Check if external_lookup_client is configured + if config.external_lookup_client is not None: + if config.enable_async_loading: + raise ValueError( + "Asynchronous loading is not supported for external lookup clients." + ) + client = LookupClientFactory._create_external_lookup_client( + config.external_lookup_client, aphrodite_config + ) + else: + from .lmcache_async_lookup_client import LMCacheAsyncLookupClient + from .lmcache_lookup_client import LMCacheLookupClient + + if config.enable_async_loading: + client = LMCacheAsyncLookupClient(aphrodite_config) + else: + client = LMCacheLookupClient(aphrodite_config) + + if config.hit_miss_ratio is not None and 0 <= config.hit_miss_ratio <= 1: + return HitLimitLookupClient(client, config) + return client + + @staticmethod + def create_lookup_server( + lmcache_engine: LMCacheEngine, + aphrodite_config: "AphroditeConfig", + ) -> Optional[Union["LMCacheLookupServer", "LMCacheAsyncLookupServer"]]: + """ + Create a lookup server based on the configuration. + + Args: + lmcache_engine: The LMCache engine instance + aphrodite_config: The Aphrodite configuration + + Returns: + A lookup server instance, or None if no server should be created + """ + config = lmcache_engine.config + assert isinstance(config, LMCacheEngineConfig), ( + "LMCache v1 config is expected for lookup server and client" + ) + + # Only create the KV lookup API server on worker rank 0 + # when there are multiple workers and when not using external lookup client + create_lookup_server_only_on_worker_0_for_mla = config.get_extra_config_value( + "create_lookup_server_only_on_worker_0_for_mla", + lmcache_engine.metadata.use_mla, + ) + + if config.external_lookup_client is None and ( + not create_lookup_server_only_on_worker_0_for_mla + or lmcache_engine.metadata.worker_id == 0 + ): + from .lmcache_async_lookup_client import LMCacheAsyncLookupServer + from .lmcache_lookup_client import LMCacheLookupServer + + if config.enable_async_loading: + return LMCacheAsyncLookupServer(lmcache_engine, aphrodite_config) + else: + return LMCacheLookupServer(lmcache_engine, aphrodite_config) + + return None + + @staticmethod + def _create_external_lookup_client( + external_lookup_uri: str, + aphrodite_config: "AphroditeConfig", + ) -> LookupClientInterface: + """ + Create an external lookup client based on the URI format. + + Args: + external_lookup_uri: URI in format ://
+ aphrodite_config: The Aphrodite configuration + + Returns: + A lookup client instance + + Raises: + ValueError: If the URI format is unsupported + """ + # Parse URI scheme and address + if "://" not in external_lookup_uri: + raise ValueError( + f"Invalid external lookup client URI format: {external_lookup_uri}. " + "Expected format: ://
" + ) + + scheme, address = external_lookup_uri.split("://", 1) + + # Route to appropriate client based on scheme + if scheme == "mooncakestore": + return LookupClientFactory._create_mooncake_lookup_client( + address, aphrodite_config + ) + else: + raise ValueError( + f"Unsupported external lookup client scheme: {scheme}. " + "Supported schemes: mooncakestore" + ) + + @staticmethod + def _create_mooncake_lookup_client( + master_address: str, + aphrodite_config: "AphroditeConfig", + ) -> "MooncakeLookupClient": + """Create a MooncakeLookupClient instance.""" + from .mooncake_lookup_client import MooncakeLookupClient + + return MooncakeLookupClient(aphrodite_config, master_address) diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/hit_limit_lookup_client.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/hit_limit_lookup_client.py new file mode 100644 index 0000000000..fe0b494c3a --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/hit_limit_lookup_client.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import Optional, Union + +# Third Party +import torch +from lmcache.v1.config import LMCacheEngineConfig + +# First Party +from aphrodite.logger import init_logger + +from .abstract_client import LookupClientInterface + +logger = init_logger(__name__) + + +""" +HitLimitLookupClient now is used for test, when lookup is called, cal the cache hit, +- if the cache hit <= (1 - hit_miss_ratio), direct return the result +- if the cache hit > (1 - hit_miss_ratio), re-compute the result by hit_miss_ratio +""" + + +class HitLimitLookupClient(LookupClientInterface): + def __init__( + self, actual_lookup_client: LookupClientInterface, config: LMCacheEngineConfig + ): + assert config.hit_miss_ratio is not None and 0 <= config.hit_miss_ratio <= 1 + self.actual_lookup_client = actual_lookup_client + self.hit_ratio_upper = 1 - config.hit_miss_ratio + self.chunk_size = config.chunk_size + logger.info( + f"create HitLimitLookupClient succeed, the hit ratio upper" + f"is {self.hit_ratio_upper}, chunk size is {self.chunk_size}" + ) + + def lookup( + self, + token_ids: Union[torch.Tensor, list[int]], + lookup_id: str, + request_configs: Optional[dict] = None, + ) -> Optional[int]: + # get real hit tokens + result = self.actual_lookup_client.lookup(token_ids, lookup_id, request_configs) + if result is not None: + total_tokens_length = len(token_ids) + assert result <= total_tokens_length + current_hit_ratio = 0.0 + if total_tokens_length > 0: + current_hit_ratio = result / total_tokens_length + # limit the hit tokens + if current_hit_ratio > self.hit_ratio_upper: + origin_result = result + # align to chunk size + new_result = ( + int(total_tokens_length * self.hit_ratio_upper) + // self.chunk_size + * self.chunk_size + ) + # check again + result = min(result, new_result) + logger.debug( + f"hit ratio upper: {self.hit_ratio_upper} is smaller than " + f"the real hit ratio {current_hit_ratio}, " + f"the origin result is {origin_result}, " + f"the new result is {new_result}, the final result is {result}" + ) + return result + + def supports_producer_reuse(self) -> bool: + return self.actual_lookup_client.supports_producer_reuse() + + def clear_lookup_status(self, lookup_id: str) -> None: + """Clear lookup status for the given lookup_id. + + Delegates to the wrapped lookup client. + """ + if hasattr(self.actual_lookup_client, 'clear_lookup_status'): + self.actual_lookup_client.clear_lookup_status(lookup_id) + + def close(self) -> None: + self.actual_lookup_client.close() diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_async_lookup_client.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_async_lookup_client.py new file mode 100644 index 0000000000..ab49f4d7e0 --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_async_lookup_client.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import threading +import time +from typing import TYPE_CHECKING, Optional, Union + +import msgspec +import torch +import zmq +from lmcache.v1.cache_engine import LMCacheEngine + +from aphrodite.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( + create_lmcache_metadata, mla_enabled) +from aphrodite.logger import init_logger +# Third Party +from aphrodite.utils.network_utils import make_zmq_socket + +from ..rpc_utils import get_zmq_rpc_path_lmcache +from .abstract_client import LookupClientInterface + +if TYPE_CHECKING: + from aphrodite.config import AphroditeConfig + +logger = init_logger(__name__) + + +# NOTE(Jiayi): Prefetch could load extra redundant cache if multiple +# workers has different hit tokens. +class LMCacheAsyncLookupClient(LookupClientInterface): + """ + ZMQ-based lookup client that communicates with a lookup server. + + Related extra_config: + - create_lookup_server_only_on_worker_0_for_mla: + is a flag to control whether to create lookup server only on worker 0. + """ + + def __init__( + self, + aphrodite_config: "AphroditeConfig", + ): + metadata, config = create_lmcache_metadata(aphrodite_config) + + self.encoder = msgspec.msgpack.Encoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + rpc_port = aphrodite_config.kv_transfer_config.get_from_extra_config( + "lmcache_rpc_port", 0 + ) + self.tensor_parallel_size = aphrodite_config.parallel_config.tensor_parallel_size + use_mla = mla_enabled(aphrodite_config.model_config) + self.create_lookup_server_only_on_worker_0_for_mla = ( + config.get_extra_config_value( + "create_lookup_server_only_on_worker_0_for_mla", use_mla + ) + ) + ranks = self.tensor_parallel_size + self.push_sockets = [] + if self.create_lookup_server_only_on_worker_0_for_mla: + ranks = 1 + for tp_rank in range(ranks): + worker_socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "lookup_worker", rpc_port, tp_rank + ) + logger.info( + f"lmcache lookup client connect to tp_rank {tp_rank} " + f"with worker socket path {worker_socket_path}" + ) + + push_socket = make_zmq_socket( + self.ctx, + worker_socket_path, + zmq.PUSH, # type: ignore[attr-defined] + bind=False, + ) + + self.push_sockets.append(push_socket) + + scheduler_socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "lookup_scheduler", rpc_port, 0 + ) + self.pull_socket = make_zmq_socket( + self.ctx, + scheduler_socket_path, + zmq.PULL, # type: ignore[attr-defined] + bind=True, + ) + logger.info( + f"lmcache lookup client connect to scheduler " + f"with socket path {scheduler_socket_path}" + ) + + # First Party + from lmcache.v1.token_database import (ChunkedTokenDatabase, + SegmentTokenDatabase, + TokenDatabase) + + self.token_database: TokenDatabase + if config.enable_blending: + self.token_database = SegmentTokenDatabase(config, metadata) + else: + self.token_database = ChunkedTokenDatabase(config, metadata) + + # A lock is needed since we need another thread to pull + # responses from the lookup_and_prefetch server + # (e.g., worker process). + self.lock = threading.Lock() + + # map from lookup_id to req's status. + # None indicates ongoing. + # int indicates number of hit tokens. + self.reqs_status: dict[str, Optional[int]] = {} + + # map from lookup_id to number of hit tokens for each worker + self.res_for_each_worker: dict[str, list[int]] = {} + + # The two parts are [lookup_id, num_hit_tokens] + self.num_parts = 2 + + self.running = True + + self.thread = threading.Thread( + target=self.process_responses_from_workers, daemon=True + ) + self.thread.start() + + # default backoff time + self.lookup_backoff_time = 0.01 + if config.extra_config is not None: + self.lookup_backoff_time = float( + config.extra_config.get("lookup_backoff_time", self.lookup_backoff_time) + ) + + # TODO(Jiayi): Consider batching here + def lookup( + self, + token_ids: Union[torch.Tensor, list[int]], + lookup_id: str, + request_configs: Optional[dict] = None, + ) -> Optional[int]: + with self.lock: + # -1 indicates not found; None indicates ongoing. + req_status = self.reqs_status.get(lookup_id, -1) + if req_status is None: + time.sleep(self.lookup_backoff_time) + return None + elif req_status != -1: + self.reqs_status.pop(lookup_id) + return req_status + self.reqs_status[lookup_id] = None + hashes = [] + offsets = [] + for start, end, hash_val in self.token_database.process_tokens( + token_ids, make_key=False + ): + hashes.append(hash_val) + offsets.append(end - start) + hash_buf = self.encoder.encode(hashes) + offset_buf = self.encoder.encode(offsets) + + lookup_id_buf = lookup_id.encode("utf-8") + request_configs_str = "" + if request_configs is not None and len(request_configs) != 0: + request_configs_str = "@".join( + [f"{k}%{v}" for k, v in request_configs.items()] + ) + request_configs_buf = request_configs_str.encode("utf-8") + + msg_buf = [ + lookup_id_buf, + hash_buf, + offset_buf, + request_configs_buf, + ] + + ranks = self.tensor_parallel_size + if self.create_lookup_server_only_on_worker_0_for_mla: + ranks = 1 + for i in range(ranks): + self.push_sockets[i].send_multipart(msg_buf, copy=False) + time.sleep(self.lookup_backoff_time) + return None + + def process_responses_from_workers(self): + while self.running: + frames = self.pull_socket.recv_multipart(copy=False) + assert len(frames) == self.num_parts + lookup_id = frames[0].bytes.decode("utf-8") + res = int.from_bytes(frames[1], "big") + + with self.lock: + if lookup_id not in self.res_for_each_worker: + self.res_for_each_worker[lookup_id] = [res] + else: + self.res_for_each_worker[lookup_id].append(res) + all_res = self.res_for_each_worker[lookup_id] + + if len(all_res) == self.tensor_parallel_size or ( + self.create_lookup_server_only_on_worker_0_for_mla + and len(all_res) == 1 + ): + self.res_for_each_worker.pop(lookup_id) + + # NOTE: it is possible that the number of hit + # tokens is different across TP ranks, so we + # can use the minimum value as the number of + # hit tokens. + self.reqs_status[lookup_id] = min(all_res) + + def clear_lookup_status(self, lookup_id: str) -> None: + """Clear lookup status for the given lookup_id. + + This removes the lookup_id from both reqs_status and res_for_each_worker. + """ + with self.lock: + self.reqs_status.pop(lookup_id, None) + self.res_for_each_worker.pop(lookup_id, None) + + def supports_producer_reuse(self) -> bool: + """Return True as LMCacheLookupClient supports producer kvcache reuse""" + return True + + def close(self): + self.running = False + try: + if self.thread.is_alive(): + self.thread.join(timeout=1.0) + for s in self.push_sockets: + s.close(linger=0) # type: ignore[arg-type] + self.pull_socket.close(linger=0) # type: ignore[arg-type] + self.ctx.term() + except Exception as e: + logger.warning(f"Failed to join thread during close: {e}") + + +class LMCacheAsyncLookupServer: + """ZMQ-based async lookup server that handles lookup and prefetch + requests using LMCacheEngine.""" + + def __init__(self, lmcache_engine: LMCacheEngine, aphrodite_config: "AphroditeConfig"): + self.decoder = msgspec.msgpack.Decoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + rpc_port = aphrodite_config.kv_transfer_config.get_from_extra_config( + "lmcache_rpc_port", 0 + ) + worker_socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "lookup_worker", rpc_port, aphrodite_config.parallel_config.rank + ) + scheduler_socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "lookup_scheduler", rpc_port, 0 + ) + self.push_socket = make_zmq_socket( + self.ctx, + scheduler_socket_path, + zmq.PUSH, # type: ignore[attr-defined] + bind=False, + ) + self.pull_socket = make_zmq_socket( + self.ctx, + worker_socket_path, + zmq.PULL, # type: ignore[attr-defined] + bind=True, + ) + + self.lmcache_engine = lmcache_engine + self.running = True + + logger.info( + "lmcache lookup server start with" + f" scheduler socket path {scheduler_socket_path}, " + f"worker socket path {worker_socket_path}" + ) + self.thread = threading.Thread( + target=self.process_requests_from_scheduler, daemon=True + ) + self.thread.start() + + # The four parts are [hash, offset, lookup_id, request_configs] + self.num_parts = 4 + + def process_requests_from_scheduler(self): + while self.running: + frames = self.pull_socket.recv_multipart(copy=False) + num_frames = len(frames) + assert num_frames % self.num_parts == 0 + for i in range(0, num_frames, self.num_parts): + lookup_id = frames[i].bytes.decode("utf-8") + + hash_frame = frames[i + 1] + hashes = self.decoder.decode(hash_frame) + + offset_frame = frames[i + 2] + offsets = self.decoder.decode(offset_frame) + + request_configs_str = frames[i + 3].bytes.decode("utf-8") + request_configs = None + if request_configs_str != "": + request_configs = {} + request_configs_list = request_configs_str.split("@") + for kv in request_configs_list: + kvs = kv.split("%", 1) + if len(kvs) != 2: + raise ValueError(f"Unexpected tags_str: {kvs}") + request_configs[kvs[0]] = kvs[1] + + self.lmcache_engine.async_lookup_and_prefetch( + lookup_id=lookup_id, + hashes=hashes, + offsets=offsets, + pin=True, + request_configs=request_configs, + ) + + def send_response_to_scheduler(self, lookup_id: str, num_hit_tokens: int): + lookup_id_buf = lookup_id.encode("utf-8") + num_hit_tokens_buf = num_hit_tokens.to_bytes(4, "big") + self.push_socket.send_multipart([lookup_id_buf, num_hit_tokens_buf], copy=False) + + def close(self): + self.running = False + try: + if self.thread.is_alive(): + self.thread.join(timeout=1.0) + for s in self.push_sockets: + s.close(linger=0) # type: ignore[arg-type] + self.pull_socket.close(linger=0) # type: ignore[arg-type] + self.ctx.term() + except Exception as e: + logger.warning(f"Failed to join thread during close: {e}") diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_lookup_client.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_lookup_client.py new file mode 100644 index 0000000000..fbd26b8de7 --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/lmcache_lookup_client.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import json +import threading + +import msgspec +import torch +import zmq +from lmcache.v1.cache_engine import LMCacheEngine + +from aphrodite.config import AphroditeConfig +from aphrodite.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( + create_lmcache_metadata, mla_enabled) +from aphrodite.logger import init_logger +from aphrodite.utils.network_utils import make_zmq_socket + +from ..rpc_utils import get_zmq_rpc_path_lmcache +from .abstract_client import LookupClientInterface + +logger = init_logger(__name__) + + +class LMCacheLookupClient(LookupClientInterface): + """ + ZMQ-based lookup client that communicates with a lookup server. + + Related extra_config: + - create_lookup_server_only_on_worker_0_for_mla: + is a flag to control whether to create lookup server only on worker 0. + """ + + def __init__( + self, + aphrodite_config: "AphroditeConfig", + ): + metadata, config = create_lmcache_metadata(aphrodite_config) + + self.encoder = msgspec.msgpack.Encoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + self.config = config + rpc_port = aphrodite_config.kv_transfer_config.get_from_extra_config( + "lmcache_rpc_port", 0 + ) + self.tensor_parallel_size = aphrodite_config.parallel_config.tensor_parallel_size + use_mla = mla_enabled(aphrodite_config.model_config) + self.create_lookup_server_only_on_worker_0_for_mla = ( + config.get_extra_config_value( + "create_lookup_server_only_on_worker_0_for_mla", use_mla + ) + ) + ranks = self.tensor_parallel_size + self.sockets = [] + if self.create_lookup_server_only_on_worker_0_for_mla: + ranks = 1 + + # Set timeout values from config + timeout_ms = config.lookup_timeout_ms + + for tp_rank in range(ranks): + socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "lookup", rpc_port, tp_rank + ) + logger.info( + f"lmcache lookup client connect to tp_rank {tp_rank} " + f"with socket path {socket_path}" + ) + socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REQ, # type: ignore[attr-defined] + bind=False, + ) + + # Set socket timeout during initialization + socket.setsockopt(zmq.RCVTIMEO, timeout_ms) + socket.setsockopt(zmq.SNDTIMEO, timeout_ms) + + self.sockets.append(socket) + + # First Party + from lmcache.v1.token_database import (ChunkedTokenDatabase, + SegmentTokenDatabase, + TokenDatabase) + + self.enable_blending = config.enable_blending + self.token_database: TokenDatabase + if self.enable_blending: + self.token_database = SegmentTokenDatabase(config, metadata) + else: + self.token_database = ChunkedTokenDatabase(config, metadata) + + # FIXME(Jiayi): Cacheblend need token ids + def lookup( + self, + token_ids: torch.Tensor | list[int], + lookup_id: str, + request_configs: dict | None = None, + ) -> int | None: + lookup_id_buf = lookup_id.encode("utf-8") + request_configs_str = "" + if request_configs is not None and len(request_configs) != 0: + request_configs_str = json.dumps(request_configs) + request_configs_buf = request_configs_str.encode("utf-8") + ranks = self.tensor_parallel_size + if self.create_lookup_server_only_on_worker_0_for_mla: + ranks = 1 + + # NOTE(Jiayi): We cannot only send hashes when blending enabled + # because the blender need the input embedding. + if not self.enable_blending: + hashes = [] + offsets = [] + for start, end, key in self.token_database.process_tokens( + token_ids, make_key=False + ): + hashes.append(key) + offsets.append(end - start) + hash_buf = self.encoder.encode(hashes) + offset_buf = self.encoder.encode(offsets) + msg_buf = [ + hash_buf, + offset_buf, + lookup_id_buf, + request_configs_buf, + ] + else: + tokens_buf = self.encoder.encode(token_ids) + msg_buf = [ + tokens_buf, + lookup_id_buf, + request_configs_buf, + ] + + results = [] + try: + for i in range(ranks): + self.sockets[i].send_multipart(msg_buf, copy=False) + + # TODO(Jiayi): we can use zmq poll to optimize a bit + for i in range(ranks): + resp = self.sockets[i].recv() + result = int.from_bytes(resp, "big") + results.append(result) + except zmq.Again: + logger.error(f"Timeout occurred for rank {i}") + return 0 + except zmq.ZMQError as e: + logger.error(f"ZMQ error for rank {i}: {str(e)}") + return 0 + + assert len(results) == ranks + if len(set(results)) > 1: + logger.warning( + f"Lookup results (number of hit tokens) differ " + f"across tensor parallel ranks: {results}." + ) + # NOTE: it is possible that the number of hit tokens is different + # across TP ranks, so we can use the minimum value as the + # number of hit tokens. + return min(results) + + def supports_producer_reuse(self) -> bool: + """Return True as LMCacheLookupClient supports producer kvcache reuse""" + return True + + def clear_lookup_status(self, lookup_id: str) -> None: + """Clear lookup status for the given lookup_id. + + For synchronous lookup client, this is a no-op since no state is maintained. + """ + pass + + def close(self): + for socket in self.sockets: + try: + socket.close(linger=0) + except Exception as e: + logger.warning(f"Error closing socket: {e}") + + try: + if self.ctx: + self.ctx.term() + except Exception as e: + logger.warning(f"Error terminating ZMQ context: {e}") + + +class LMCacheLookupServer: + """ZMQ-based lookup server that handles lookup requests using LMCacheEngine.""" + + def __init__(self, lmcache_engine: LMCacheEngine, aphrodite_config: "AphroditeConfig"): + self.decoder = msgspec.msgpack.Decoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + rpc_port = aphrodite_config.kv_transfer_config.get_from_extra_config( + "lmcache_rpc_port", 0 + ) + socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "lookup", rpc_port, aphrodite_config.parallel_config.rank + ) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.lmcache_engine = lmcache_engine + self.running = True + + self.enable_blending = lmcache_engine.config.enable_blending + + def process_request(): + while self.running: + frames = self.socket.recv_multipart(copy=False) + lookup_id = frames[-2].bytes.decode("utf-8") + request_configs_str = frames[-1].bytes.decode("utf-8") + request_configs = None + if request_configs_str != "": + request_configs = json.loads(request_configs_str) + if not self.enable_blending: + hash_frames = frames[0] + offset_frames = frames[1] + hashes = self.decoder.decode(hash_frames) + offsets = self.decoder.decode(offset_frames) + result = self.lmcache_engine.lookup( + hashes=hashes, + offsets=offsets, + lookup_id=lookup_id, + pin=True, + request_configs=request_configs, + ) + else: + token_frames = frames[0] + tokens = self.decoder.decode(token_frames) + result = self.lmcache_engine.lookup( + tokens=tokens, + lookup_id=lookup_id, + pin=True, + request_configs=request_configs, + ) + response = result.to_bytes(4, "big") + self.socket.send(response) + + logger.info(f"lmcache lookup server start on {socket_path}") + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/mooncake_lookup_client.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/mooncake_lookup_client.py new file mode 100644 index 0000000000..62f402ac37 --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/lookup_client/mooncake_lookup_client.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import TYPE_CHECKING, Optional, Union + +import torch +from lmcache.utils import CacheEngineKey +from lmcache.v1.config import LMCacheEngineConfig + +from aphrodite.logger import init_logger + +from .abstract_client import LookupClientInterface + +if TYPE_CHECKING: + from aphrodite.config import AphroditeConfig + +logger = init_logger(__name__) + + +class MooncakeLookupClient(LookupClientInterface): + def __init__( + self, + aphrodite_config: "AphroditeConfig", + master_addr: str, + ): + # Third Party + from mooncake.store import MooncakeDistributedStore + + self.store = MooncakeDistributedStore() + self.store.setup( + "localhost", + "P2PHANDSHAKE", + 0, + 16 * 1024 * 1024, + "tcp", + "", + master_addr, + ) + + from aphrodite.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( + create_lmcache_metadata) + + metadata, config = create_lmcache_metadata(aphrodite_config) + + assert isinstance(config, LMCacheEngineConfig), ( + "LMCache v1 configuration is should be passed." + ) + + from lmcache.v1.token_database import ChunkedTokenDatabase + + assert not config.enable_blending, ( + "LMCache v1 blending is not supported in MooncakeLookupClient yet." + ) + self.token_database = ChunkedTokenDatabase(config, metadata) + + def lookup( + self, + token_ids: Union[torch.Tensor, list[int]], + lookup_id: Optional[str] = None, + request_configs: Optional[dict] = None, + ) -> Optional[int]: + # process token_ids to cacheengine keys + keys = [] + ends = [] + for start, end, key in self.token_database.process_tokens(token_ids): + assert isinstance(key, CacheEngineKey) + keys.append(key.to_string()) + ends.append(end) + + # Use batch_is_exist to check all keys at once + # rets is list of int: 1 = found, 0 = not found, -1 = error + rets = self.store.batch_is_exist(keys) + + # Find the first key that doesn't exist (ret != 1) + # This follows the same logic as cache engine's lookup method + for i, ret in enumerate(rets): + if ret != 1: # Not found or error + # Return the end position of the previous chunk + # If i == 0, no chunks were found, return 0 + return ends[i - 1] if i > 0 else 0 + + # All keys were found, return the last end position + return ends[-1] if ends else 0 + + def supports_producer_reuse(self) -> bool: + """Return True as MooncakeLookupClient supports producer kvcache reuse""" + return True + + def close(self): + # nothing here + pass diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/__init__.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/__init__.py new file mode 100644 index 0000000000..761bf90e6e --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +from ..lookup_client.abstract_client import LookupClientInterface +from ..lookup_client.factory import LookupClientFactory +from ..lookup_client.lmcache_lookup_client import (LMCacheLookupClient, + LMCacheLookupServer) +from ..lookup_client.mooncake_lookup_client import MooncakeLookupClient + +__all__ = [ + "LookupClientInterface", + "LookupClientFactory", + "MooncakeLookupClient", + "LMCacheLookupClient", + "LMCacheLookupServer", +] diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/abstract_server.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/abstract_server.py new file mode 100644 index 0000000000..a5858f620d --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/abstract_server.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import abc +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + pass + + +class OffloadServerInterface(metaclass=abc.ABCMeta): + """Abstract interface for offload server.""" + + @abc.abstractmethod + def offload( + self, + hashes: List[int], + slot_mapping: List[int], + offsets: List[int], + ) -> bool: + """ + Perform offload for the given hashes and block IDs. + + Args: + hashes: The hashes to offload. + slot_mapping: The slot ids to offload. + offsets: Number of tokens in each block. + + Returns: + Whether the offload was successful. + """ + raise NotImplementedError + + @abc.abstractmethod + def close(self) -> None: + """Close the offload server and clean up resources.""" + raise NotImplementedError diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/message.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/message.py new file mode 100644 index 0000000000..7dcecb03d0 --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/message.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +from typing import List + +import msgspec + + +class OffloadMsg(msgspec.Struct): + """Message for Offloading""" + + hashes: List[int] + slot_mapping: List[int] + offsets: List[int] + + def describe(self) -> str: + return ( + f"OffloadMsg(hashes={self.hashes}, " + f"slot_mapping={self.slot_mapping}, " + f"offsets={self.offsets})" + ) + + +class OffloadRetMsg(msgspec.Struct): + """Return message for Offloading""" + + success: bool + + def describe(self) -> str: + return f"OffloadRetMsg(success={self.success})" diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/zmq_server.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/zmq_server.py new file mode 100644 index 0000000000..c9a6c7ba44 --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/offload_server/zmq_server.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import os +import threading +from typing import TYPE_CHECKING, List + +import msgspec +import zmq +from lmcache.v1.cache_engine import LMCacheEngine + +from aphrodite.utils.network_utils import make_zmq_socket + +from ..rpc_utils import get_zmq_rpc_path_lmcache +from .abstract_server import OffloadServerInterface +from .message import OffloadMsg, OffloadRetMsg + +if TYPE_CHECKING: + from aphrodite.config import AphroditeConfig + + +class ZMQOffloadServer(OffloadServerInterface): + def __init__( + self, + lmcache_engine: LMCacheEngine, + aphrodite_config: "AphroditeConfig", + tp_rank: int, + ): + self.ctx = zmq.Context() # type: ignore[attr-defined] + offload_rpc_port = int(os.environ.get("LMCACHE_OFFLOAD_RPC_PORT", 100)) + socket_path = get_zmq_rpc_path_lmcache( + aphrodite_config, "offload", offload_rpc_port, tp_rank + ) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.lmcache_engine = lmcache_engine + self.running = True + + def process_request(): + while self.running: + frame = self.socket.recv(copy=False) + offload_msg = msgspec.msgpack.decode(frame, type=OffloadMsg) + result = self.offload( + offload_msg.hashes, + offload_msg.slot_mapping, + offload_msg.offsets, + ) + response = OffloadRetMsg(success=result) + response = msgspec.msgpack.encode(response) + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def offload( + self, + hashes: List[int], + slot_mapping: List[int], + offsets: List[int], + ) -> bool: + self.lmcache_engine.store( + hashes=hashes, slot_mapping=slot_mapping, offsets=offsets + ) + return True + + def close(self) -> None: + self.socket.close(linger=0) + self.running = False + self.thread.join() diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/rpc_utils.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/rpc_utils.py new file mode 100644 index 0000000000..23bc33bc0a --- /dev/null +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/rpc_utils.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# Standard +import socket +from typing import TYPE_CHECKING, Literal, Optional + +# Third Party +import zmq +import zmq.asyncio + +# First Party +from aphrodite.logger import init_logger + +if TYPE_CHECKING: + # Third Party + from aphrodite.config import AphroditeConfig + +logger = init_logger(__name__) + +ServiceKind = Literal["lookup", "offload", "lookup_worker", "lookup_scheduler"] + + +def get_zmq_context(use_asyncio: bool = True): + if use_asyncio: + return zmq.asyncio.Context.instance() + else: + return zmq.Context.instance() + + +def get_zmq_socket( + context, socket_path: str, protocol: str, role, bind_or_connect: str +): + """ + Create a ZeroMQ socket with the specified protocol and role. + """ + socket_addr = f"{protocol}://{socket_path}" + socket = context.socket(role) + if bind_or_connect == "bind": + socket.bind(socket_addr) + elif bind_or_connect == "connect": + socket.connect(socket_addr) + else: + raise ValueError(f"Invalid bind_or_connect: {bind_or_connect}") + + return socket + + +def close_zmq_socket(socket: zmq.asyncio.Socket, linger: int = 0) -> None: + """ + Close a ZeroMQ socket cleanly. + + :param socket: The zmq.Socket to be closed. + :param linger: LINGER period (in milliseconds). + Default is 0 (drop immediately). + """ + try: + socket.setsockopt(zmq.LINGER, linger) # type: ignore[attr-defined] + socket.close() + except Exception as e: + logger.error(f"Warning: Failed to close socket cleanly: {e}") + + +def get_ip(): + """ + Get the local IP address of the machine. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # "Connect" to a public IP — just to determine local IP + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception: + logger.warning( + "Failed to get local IP address. Falling back to loopback address." + ) + return "127.0.0.1" # Fallback to loopback + finally: + s.close() + + +def get_zmq_rpc_path_lmcache( + aphrodite_config: Optional["AphroditeConfig"] = None, + service_name: ServiceKind = "lookup", + rpc_port: int = 0, + tp_rank: int = 0, +) -> str: + """Get the ZMQ RPC path for LMCache lookup and offload communication.""" + # Third Party + import aphrodite.envs as envs + + if aphrodite_config is None or aphrodite_config.kv_transfer_config is None: + raise ValueError("A valid kv_transfer_config with engine_id is required.") + + if service_name not in {"lookup", "offload", "lookup_worker", "lookup_scheduler"}: + raise ValueError( + f"service_name must be 'lookup' or 'offload', got {service_name!r}" + ) + + base_url = envs.APHRODITE_RPC_BASE_PATH + + engine_id = aphrodite_config.kv_transfer_config.engine_id + + if isinstance(rpc_port, str): + rpc_port = rpc_port + str(tp_rank) + else: + rpc_port += tp_rank + + logger.debug( + "Base URL: %s, Engine: %s, Service Name: %s, RPC Port: %s", + base_url, + engine_id, + service_name, + rpc_port, + ) + + socket_path = ( + f"ipc://{base_url}/engine_{engine_id}_service_{service_name}_" + f"lmcache_rpc_port_{rpc_port}" + ) + + return socket_path diff --git a/aphrodite/v1/core/kv_cache_utils.py b/aphrodite/v1/core/kv_cache_utils.py index d0a7c23122..ea3c4adbbc 100644 --- a/aphrodite/v1/core/kv_cache_utils.py +++ b/aphrodite/v1/core/kv_cache_utils.py @@ -967,9 +967,11 @@ def get_kv_cache_config_from_groups( num_blocks=1, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups, + kv_bytes_per_block=0, ) # Determine how model runners should initialize the KV cache tensors. + kv_bytes_per_block = 0 if len(kv_cache_groups) == 1 and isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs): # Special case: all layers have the same type of KV cache but with # different hidden size. Allocate different amount of memory for each @@ -977,13 +979,11 @@ def get_kv_cache_config_from_groups( num_blocks = available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes num_blocks = may_override_num_blocks(aphrodite_config, num_blocks) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs - kv_cache_tensors = [ - KVCacheTensor( - size=per_layer_specs[layer_name].page_size_bytes * num_blocks, - shared_by=[layer_name], - ) - for layer_name in kv_cache_groups[0].layer_names - ] + kv_cache_tensors = [] + for layer_name in kv_cache_groups[0].layer_names: + page_size_bytes = per_layer_specs[layer_name].page_size_bytes + kv_bytes_per_block += page_size_bytes + kv_cache_tensors.append(KVCacheTensor(size=page_size_bytes * num_blocks, shared_by=[layer_name])) else: # General case: # We will have group_size memory pools, each is shared by one layer from @@ -1005,11 +1005,13 @@ def get_kv_cache_config_from_groups( if i < len(kv_cache_groups[j].layer_names): shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)) + kv_bytes_per_block += page_size * len(shared_by) return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups, + kv_bytes_per_block=kv_bytes_per_block, ) diff --git a/aphrodite/v1/engine/core.py b/aphrodite/v1/engine/core.py index 044841859f..c387f59f6d 100644 --- a/aphrodite/v1/engine/core.py +++ b/aphrodite/v1/engine/core.py @@ -231,6 +231,8 @@ def _initialize_kv_caches(self, aphrodite_config: AphroditeConfig) -> tuple[int, # Initialize kv cache and warmup the execution self.modeling.initialize_from_config(kv_cache_configs) + aphrodite_config.cache_config.kv_bytes_per_block = scheduler_kv_cache_config.kv_bytes_per_block + elapsed = time.time() - start logger.info_once( ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), diff --git a/aphrodite/v1/kv_cache_interface.py b/aphrodite/v1/kv_cache_interface.py index 4791e7bdeb..2a627dce5e 100644 --- a/aphrodite/v1/kv_cache_interface.py +++ b/aphrodite/v1/kv_cache_interface.py @@ -348,3 +348,5 @@ class KVCacheConfig: see `_get_kv_cache_config_uniform_page_size` for more details. """ kv_cache_groups: list[KVCacheGroupSpec] + """The number of KV bytes per block, across all workers""" + kv_bytes_per_block: int diff --git a/aphrodite/v1/kv_offload/cpu.py b/aphrodite/v1/kv_offload/cpu.py index b20312be9a..9defb3e8f7 100644 --- a/aphrodite/v1/kv_offload/cpu.py +++ b/aphrodite/v1/kv_offload/cpu.py @@ -18,10 +18,27 @@ class CPUOffloadingSpec(OffloadingSpec): def __init__(self, aphrodite_config: AphroditeConfig): super().__init__(aphrodite_config) - num_cpu_blocks = self.extra_config.get("num_cpu_blocks") - if not num_cpu_blocks: - raise Exception("num_cpu_blocks must be specified in kv_connector_extra_config") - self.num_cpu_blocks: int = num_cpu_blocks + swap_space_bytes = self.extra_config.get("swap_space_bytes") + if not swap_space_bytes: + # Try to auto-calculate from kv_bytes_per_rank if available + kv_bytes_per_rank = self.extra_config.get("kv_bytes_per_rank") + if kv_bytes_per_rank is not None: + swap_space_bytes = int(kv_bytes_per_rank) + else: + # Fallback: calculate from kv_offloading_size + kv_offloading_size = aphrodite_config.cache_config.kv_offloading_size + if kv_offloading_size is not None: + num_kv_ranks = ( + aphrodite_config.parallel_config.tensor_parallel_size + * aphrodite_config.parallel_config.pipeline_parallel_size + ) + swap_space_bytes = int(kv_offloading_size * (1 << 30) / num_kv_ranks) + else: + raise Exception( + "swap_space_bytes must be specified in kv_connector_extra_config, " + "or kv_offloading_size must be set in CacheConfig" + ) + self.swap_space_bytes: int = swap_space_bytes # scheduler-side self._manager: OffloadingManager | None = None @@ -31,11 +48,14 @@ def __init__(self, aphrodite_config: AphroditeConfig): def get_manager(self) -> OffloadingManager: if not self._manager: + kv_bytes_per_offloaded_block = self.aphrodite_config.cache_config.kv_bytes_per_block * ( + self.offloaded_block_size // self.gpu_block_size + ) + num_blocks = self.swap_space_bytes // kv_bytes_per_offloaded_block kv_events_config = self.aphrodite_config.kv_events_config enable_events = kv_events_config is not None and kv_events_config.enable_kv_cache_events self._manager = LRUOffloadingManager( - CPUBackend(block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks), - enable_events=enable_events, + CPUBackend(block_size=self.offloaded_block_size, num_blocks=num_blocks), enable_events=enable_events ) return self._manager @@ -50,11 +70,16 @@ def get_handlers( layers = get_layers_from_aphrodite_config(self.aphrodite_config, AttentionLayerBase, layer_names) attn_backends = {layer_name: layers[layer_name].get_attn_backend() for layer_name in layer_names} + kv_bytes_per_offloaded_block = self.aphrodite_config.cache_config.kv_bytes_per_block * ( + self.offloaded_block_size // self.gpu_block_size + ) + num_blocks = self.swap_space_bytes // kv_bytes_per_offloaded_block + self._handler = CpuGpuOffloadingHandler( attn_backends=attn_backends, gpu_block_size=self.gpu_block_size, cpu_block_size=self.offloaded_block_size, - num_cpu_blocks=self.num_cpu_blocks, + num_cpu_blocks=num_blocks, gpu_caches=kv_caches, ) diff --git a/aphrodite/v1/worker/gpu_model_runner.py b/aphrodite/v1/worker/gpu_model_runner.py index b773c5692b..3d9ebda994 100644 --- a/aphrodite/v1/worker/gpu_model_runner.py +++ b/aphrodite/v1/worker/gpu_model_runner.py @@ -4278,6 +4278,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + + self.cache_config.kv_bytes_per_block = kv_cache_config.kv_bytes_per_block + self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) diff --git a/aphrodite/v1/worker/tpu_model_runner.py b/aphrodite/v1/worker/tpu_model_runner.py index c66bb73c34..deaac95ced 100644 --- a/aphrodite/v1/worker/tpu_model_runner.py +++ b/aphrodite/v1/worker/tpu_model_runner.py @@ -1650,6 +1650,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: raise NotImplementedError + self.cache_config.kv_bytes_per_block = kv_cache_config.kv_bytes_per_block + # Set up cross-layer KV cache sharing if needed self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index eec7e88da8..489ff4d5db 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -24,6 +24,17 @@ --num-prompts 20 \ --repeat-count 5 \ --input-length-range 128:256 + +CPU Offloading example usage: + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 512:2048 \ + --kv-offloading-size 16 \ + --kv-offloading-backend native + --disable-hybrid-kv-cache-manager """ import dataclasses diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt index 45cae90601..9516d541dc 100644 --- a/requirements/kv_connectors.txt +++ b/requirements/kv_connectors.txt @@ -1,2 +1,2 @@ -lmcache == 0.3.9 +lmcache == 0.3.7 nixl >= 0.6.0 # Required for disaggregated prefill \ No newline at end of file