Skip to content

Commit 52d12ec

Browse files
authored
[kvoffload] feat: make LMCache connecter work (#1589)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 5c39abb commit 52d12ec

22 files changed

Lines changed: 1308 additions & 21 deletions

aphrodite/config/cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ class CacheConfig:
107107
num_cpu_blocks: int | None = field(default=None, init=False)
108108
"""The number of blocks to allocate for CPU memory."""
109109

110+
# Will be set after model loading.
111+
kv_bytes_per_block: int | None = field(default=None, init=False)
112+
"""The number of KV bytes per block, across all workers."""
113+
110114
kv_sharing_fast_prefill: bool = False
111115
"""This feature is work in progress and no prefill optimization takes place
112116
with this flag enabled currently.

aphrodite/distributed/kv_transfer/kv_connector/v1/lmcache_integration/aphrodite_v1_adapter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
from lmcache import utils
1111
from lmcache.config import LMCacheEngineMetadata
12-
from lmcache.logging import init_logger
1312
from lmcache.observability import LMCStatsMonitor
1413
from lmcache.utils import _lmcache_nvtx_annotate
1514
from lmcache.v1.cache_engine import LMCacheEngine, LMCacheEngineBuilder
@@ -20,10 +19,6 @@
2019
VLLMPagedMemGPUConnectorV2,
2120
VLLMPagedMemLayerwiseGPUConnector)
2221
from lmcache.v1.internal_api_server.api_server import InternalAPIServer
23-
from lmcache.v1.lookup_client import LookupClientFactory
24-
from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
25-
LMCacheAsyncLookupServer)
26-
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
2722
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
2823

2924
from aphrodite.common.sampling_params import SamplingParams
@@ -35,11 +30,16 @@
3530
lmcache_get_or_create_config, mla_enabled)
3631
from aphrodite.distributed.parallel_state import (
3732
get_tensor_model_parallel_rank, get_tp_group)
33+
from aphrodite.logger import init_logger
3834
from aphrodite.utils.math_utils import cdiv
3935
from aphrodite.utils.torch_utils import get_kv_cache_torch_dtype
4036
from aphrodite.v1.core.sched.output import SchedulerOutput
4137
from aphrodite.version import __version__ as APHRODITE_VERSION
4238

39+
from .lookup_client import LookupClientFactory
40+
from .lookup_client.lmcache_async_lookup_client import LMCacheAsyncLookupServer
41+
from .offload_server.zmq_server import ZMQOffloadServer
42+
4343
if TYPE_CHECKING:
4444
from aphrodite.attention.backends.abstract import AttentionMetadata
4545
from aphrodite.forward_context import ForwardContext
@@ -819,7 +819,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
819819
slot_mapping = request.slot_mapping.cuda()
820820
assert len(tokens) == len(slot_mapping)
821821

822-
self._stats_monitor.update_interval_aphrodite_hit_tokens(
822+
self._stats_monitor.update_interval_vllm_hit_tokens(
823823
request.load_spec.aphrodite_cached_tokens
824824
)
825825
token_mask = torch.ones(len(tokens), dtype=torch.bool)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from .abstract_client import LookupClientInterface
3+
from .factory import LookupClientFactory
4+
from .lmcache_lookup_client import LMCacheLookupClient, LMCacheLookupServer
5+
from .mooncake_lookup_client import MooncakeLookupClient
6+
7+
__all__ = [
8+
"LookupClientInterface",
9+
"LookupClientFactory",
10+
"MooncakeLookupClient",
11+
"LMCacheLookupClient",
12+
"LMCacheLookupServer",
13+
]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Standard
3+
import abc
4+
from typing import TYPE_CHECKING, Optional, Union
5+
6+
import torch
7+
8+
if TYPE_CHECKING:
9+
pass
10+
11+
12+
class LookupClientInterface(metaclass=abc.ABCMeta):
13+
"""Abstract interface for lookup clients."""
14+
15+
@abc.abstractmethod
16+
def lookup(
17+
self,
18+
token_ids: Union[torch.Tensor, list[int]],
19+
lookup_id: str,
20+
request_configs: Optional[dict] = None,
21+
) -> Optional[int]:
22+
"""
23+
Perform lookup for the given token IDs.
24+
25+
Args:
26+
token_ids: The token IDs to lookup
27+
28+
lookup_id: The lookup ID to associate with the lookup
29+
30+
request_configs: The configs of the request,
31+
includes tags and the other configs
32+
33+
Returns:
34+
The number of tokens that can be loaded from cache.
35+
None indicates the lookup/prefetch is in progress.
36+
"""
37+
raise NotImplementedError
38+
39+
@abc.abstractmethod
40+
def close(self) -> None:
41+
"""Close the lookup client and clean up resources."""
42+
raise NotImplementedError
43+
44+
def supports_producer_reuse(self) -> bool:
45+
"""
46+
Return whether this lookup client supports producer KV cache reuse.
47+
48+
Returns:
49+
True if producer reuse is supported, False otherwise
50+
"""
51+
return False
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Standard
3+
from typing import TYPE_CHECKING, Optional, Union
4+
5+
from lmcache.v1.cache_engine import LMCacheEngine
6+
from lmcache.v1.config import LMCacheEngineConfig
7+
8+
from aphrodite.logger import init_logger
9+
10+
from .abstract_client import LookupClientInterface
11+
from .hit_limit_lookup_client import HitLimitLookupClient
12+
from .mooncake_lookup_client import MooncakeLookupClient
13+
14+
if TYPE_CHECKING:
15+
from aphrodite.config import AphroditeConfig
16+
17+
from .lmcache_async_lookup_client import LMCacheAsyncLookupServer
18+
from .lmcache_lookup_client import LMCacheLookupServer
19+
20+
logger = init_logger(__name__)
21+
22+
23+
class LookupClientFactory:
24+
"""Factory for creating lookup clients and servers based on configuration."""
25+
26+
@staticmethod
27+
def create_lookup_client(
28+
aphrodite_config: "AphroditeConfig",
29+
config: LMCacheEngineConfig,
30+
) -> LookupClientInterface:
31+
"""
32+
Create a lookup client based on the configuration.
33+
34+
Args:
35+
aphrodite_config: The Aphrodite configuration
36+
config: The LMCache engine configuration
37+
38+
Returns:
39+
A lookup client instance
40+
"""
41+
42+
# Check if external_lookup_client is configured
43+
if config.external_lookup_client is not None:
44+
if config.enable_async_loading:
45+
raise ValueError(
46+
"Asynchronous loading is not supported for external lookup clients."
47+
)
48+
client = LookupClientFactory._create_external_lookup_client(
49+
config.external_lookup_client, aphrodite_config
50+
)
51+
else:
52+
from .lmcache_async_lookup_client import LMCacheAsyncLookupClient
53+
from .lmcache_lookup_client import LMCacheLookupClient
54+
55+
if config.enable_async_loading:
56+
client = LMCacheAsyncLookupClient(aphrodite_config)
57+
else:
58+
client = LMCacheLookupClient(aphrodite_config)
59+
60+
if config.hit_miss_ratio is not None and 0 <= config.hit_miss_ratio <= 1:
61+
return HitLimitLookupClient(client, config)
62+
return client
63+
64+
@staticmethod
65+
def create_lookup_server(
66+
lmcache_engine: LMCacheEngine,
67+
aphrodite_config: "AphroditeConfig",
68+
) -> Optional[Union["LMCacheLookupServer", "LMCacheAsyncLookupServer"]]:
69+
"""
70+
Create a lookup server based on the configuration.
71+
72+
Args:
73+
lmcache_engine: The LMCache engine instance
74+
aphrodite_config: The Aphrodite configuration
75+
76+
Returns:
77+
A lookup server instance, or None if no server should be created
78+
"""
79+
config = lmcache_engine.config
80+
assert isinstance(config, LMCacheEngineConfig), (
81+
"LMCache v1 config is expected for lookup server and client"
82+
)
83+
84+
# Only create the KV lookup API server on worker rank 0
85+
# when there are multiple workers and when not using external lookup client
86+
create_lookup_server_only_on_worker_0_for_mla = config.get_extra_config_value(
87+
"create_lookup_server_only_on_worker_0_for_mla",
88+
lmcache_engine.metadata.use_mla,
89+
)
90+
91+
if config.external_lookup_client is None and (
92+
not create_lookup_server_only_on_worker_0_for_mla
93+
or lmcache_engine.metadata.worker_id == 0
94+
):
95+
from .lmcache_async_lookup_client import LMCacheAsyncLookupServer
96+
from .lmcache_lookup_client import LMCacheLookupServer
97+
98+
if config.enable_async_loading:
99+
return LMCacheAsyncLookupServer(lmcache_engine, aphrodite_config)
100+
else:
101+
return LMCacheLookupServer(lmcache_engine, aphrodite_config)
102+
103+
return None
104+
105+
@staticmethod
106+
def _create_external_lookup_client(
107+
external_lookup_uri: str,
108+
aphrodite_config: "AphroditeConfig",
109+
) -> LookupClientInterface:
110+
"""
111+
Create an external lookup client based on the URI format.
112+
113+
Args:
114+
external_lookup_uri: URI in format <scheme>://<address>
115+
aphrodite_config: The Aphrodite configuration
116+
117+
Returns:
118+
A lookup client instance
119+
120+
Raises:
121+
ValueError: If the URI format is unsupported
122+
"""
123+
# Parse URI scheme and address
124+
if "://" not in external_lookup_uri:
125+
raise ValueError(
126+
f"Invalid external lookup client URI format: {external_lookup_uri}. "
127+
"Expected format: <scheme>://<address>"
128+
)
129+
130+
scheme, address = external_lookup_uri.split("://", 1)
131+
132+
# Route to appropriate client based on scheme
133+
if scheme == "mooncakestore":
134+
return LookupClientFactory._create_mooncake_lookup_client(
135+
address, aphrodite_config
136+
)
137+
else:
138+
raise ValueError(
139+
f"Unsupported external lookup client scheme: {scheme}. "
140+
"Supported schemes: mooncakestore"
141+
)
142+
143+
@staticmethod
144+
def _create_mooncake_lookup_client(
145+
master_address: str,
146+
aphrodite_config: "AphroditeConfig",
147+
) -> "MooncakeLookupClient":
148+
"""Create a MooncakeLookupClient instance."""
149+
from .mooncake_lookup_client import MooncakeLookupClient
150+
151+
return MooncakeLookupClient(aphrodite_config, master_address)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Standard
3+
from typing import Optional, Union
4+
5+
# Third Party
6+
import torch
7+
from lmcache.v1.config import LMCacheEngineConfig
8+
9+
# First Party
10+
from aphrodite.logger import init_logger
11+
12+
from .abstract_client import LookupClientInterface
13+
14+
logger = init_logger(__name__)
15+
16+
17+
"""
18+
HitLimitLookupClient now is used for test, when lookup is called, cal the cache hit,
19+
- if the cache hit <= (1 - hit_miss_ratio), direct return the result
20+
- if the cache hit > (1 - hit_miss_ratio), re-compute the result by hit_miss_ratio
21+
"""
22+
23+
24+
class HitLimitLookupClient(LookupClientInterface):
25+
def __init__(
26+
self, actual_lookup_client: LookupClientInterface, config: LMCacheEngineConfig
27+
):
28+
assert config.hit_miss_ratio is not None and 0 <= config.hit_miss_ratio <= 1
29+
self.actual_lookup_client = actual_lookup_client
30+
self.hit_ratio_upper = 1 - config.hit_miss_ratio
31+
self.chunk_size = config.chunk_size
32+
logger.info(
33+
f"create HitLimitLookupClient succeed, the hit ratio upper"
34+
f"is {self.hit_ratio_upper}, chunk size is {self.chunk_size}"
35+
)
36+
37+
def lookup(
38+
self,
39+
token_ids: Union[torch.Tensor, list[int]],
40+
lookup_id: str,
41+
request_configs: Optional[dict] = None,
42+
) -> Optional[int]:
43+
# get real hit tokens
44+
result = self.actual_lookup_client.lookup(token_ids, lookup_id, request_configs)
45+
if result is not None:
46+
total_tokens_length = len(token_ids)
47+
assert result <= total_tokens_length
48+
current_hit_ratio = 0.0
49+
if total_tokens_length > 0:
50+
current_hit_ratio = result / total_tokens_length
51+
# limit the hit tokens
52+
if current_hit_ratio > self.hit_ratio_upper:
53+
origin_result = result
54+
# align to chunk size
55+
new_result = (
56+
int(total_tokens_length * self.hit_ratio_upper)
57+
// self.chunk_size
58+
* self.chunk_size
59+
)
60+
# check again
61+
result = min(result, new_result)
62+
logger.debug(
63+
f"hit ratio upper: {self.hit_ratio_upper} is smaller than "
64+
f"the real hit ratio {current_hit_ratio}, "
65+
f"the origin result is {origin_result}, "
66+
f"the new result is {new_result}, the final result is {result}"
67+
)
68+
return result
69+
70+
def supports_producer_reuse(self) -> bool:
71+
return self.actual_lookup_client.supports_producer_reuse()
72+
73+
def clear_lookup_status(self, lookup_id: str) -> None:
74+
"""Clear lookup status for the given lookup_id.
75+
76+
Delegates to the wrapped lookup client.
77+
"""
78+
if hasattr(self.actual_lookup_client, 'clear_lookup_status'):
79+
self.actual_lookup_client.clear_lookup_status(lookup_id)
80+
81+
def close(self) -> None:
82+
self.actual_lookup_client.close()

0 commit comments

Comments
 (0)