Skip to content

Commit d7ec982

Browse files
staryxchenclaude
andcommitted
refactor(sglang): replace magic strings with constants and cache block shape
- Extract mode literals ("local"/"distributed") into module-level constants MODE_LOCAL, MODE_DISTRIBUTED, _VALID_MODES to prevent typo-induced bugs - Extract error operation labels ("get"/"set"/"exists") into _OP_GET, _OP_SET, _OP_EXISTS constants for consistent Prometheus label usage - Cache block shape tuple as self._block_shape at init time instead of recomputing kv_dim and constructing the shape on every _get_block_shaped() call (hot path in batch_get_v1 per-block loop) - Remove dead field _started (set but never read) - Update tests to import and use the new constants Signed-off-by: staryxchen <staryxchen@tencent.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3caf32e commit d7ec982

2 files changed

Lines changed: 35 additions & 25 deletions

File tree

flexkv/integration/sglang/hicache_storage_adapter.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@
4646
except ImportError:
4747
from sglang.srt.observability.metrics_collector import StorageMetrics
4848

49+
# ---------------------------------------------------------------------------
50+
# Constants
51+
# ---------------------------------------------------------------------------
52+
53+
MODE_LOCAL = "local"
54+
MODE_DISTRIBUTED = "distributed"
55+
_VALID_MODES = (MODE_LOCAL, MODE_DISTRIBUTED)
56+
57+
# Error operation labels (used in metrics recording)
58+
_OP_GET = "get"
59+
_OP_SET = "set"
60+
_OP_EXISTS = "exists"
61+
4962

5063
# ---------------------------------------------------------------------------
5164
# Helper: extract token_ids from extra_info
@@ -162,17 +175,17 @@ def __init__(
162175
self._cache_config = CacheConfig(**cache_kwargs)
163176

164177
# Extract distributed mode configuration
165-
self._mode: str = extra.get("mode", "local")
166-
if self._mode not in ("local", "distributed"):
167-
raise ValueError(f"Invalid mode: {self._mode}. Must be 'local' or 'distributed'")
178+
self._mode: str = extra.get("mode", MODE_LOCAL)
179+
if self._mode not in _VALID_MODES:
180+
raise ValueError(f"Invalid mode: {self._mode}. Must be one of {_VALID_MODES}")
168181

169182
self._redis_host: str = extra.get("redis_host", "127.0.0.1")
170183
self._redis_port: int = extra.get("redis_port", 6379)
171184
self._redis_password: Optional[str] = extra.get("redis_password", None)
172185

173186
self._prefetch_timeout: float = float(extra.get("prefetch_timeout", 5.0))
174187

175-
if self._mode == "distributed" and not self._redis_host:
188+
if self._mode == MODE_DISTRIBUTED and not self._redis_host:
176189
raise ValueError("redis_host is required when mode='distributed'")
177190

178191
self._should_backup: bool = (
@@ -184,9 +197,9 @@ def __init__(
184197
self._kv_manager = None
185198
self._cpu_cache_tensor = None # direct access to CPU cache (thread mode)
186199
self._elements_per_block: int = 0
200+
self._block_shape: tuple = () # set after KVManager init
187201
self._page_size: int = self._cache_config.tokens_per_block
188202
self._mem_pool_host = mem_pool_host
189-
self._started = False
190203
self._bytes_per_page: int = 0
191204
self._gb_per_page: float = 0.0
192205

@@ -231,8 +244,7 @@ def _init_kv_manager(self):
231244
}
232245

233246
# Branch on mode: local vs distributed
234-
if self._mode == "distributed":
235-
# Distributed mode: enable cross-node KV Cache sharing
247+
if self._mode == MODE_DISTRIBUTED:
236248
cache_config_kwargs.update({
237249
"enable_remote": True,
238250
"enable_kv_sharing": True,
@@ -273,13 +285,18 @@ def _init_kv_manager(self):
273285
# Get direct access to CPU cache tensor (thread mode only)
274286
self._cpu_cache_tensor = self._kv_manager.get_cpu_cache_tensor()
275287

276-
# Compute elements per block for indexing
288+
# Compute elements per block and cache the block shape for indexing
277289
kv_dim = 1 if self._model_config.use_mla else 2
278290
self._elements_per_block = (
279291
self._model_config.num_layers * kv_dim *
280292
self._page_size * self._model_config.num_kv_heads *
281293
self._model_config.head_size
282294
)
295+
self._block_shape = (
296+
self._model_config.num_layers, kv_dim,
297+
self._page_size, self._model_config.num_kv_heads,
298+
self._model_config.head_size
299+
)
283300

284301
# Compute bytes_per_page for bandwidth reporting
285302
dtype = getattr(self._model_config, 'dtype', None) or torch.float16
@@ -346,12 +363,7 @@ def _get_block_view(self, block_id: int) -> torch.Tensor:
346363

347364
def _get_block_shaped(self, block_id: int) -> torch.Tensor:
348365
"""Get a CPU cache block reshaped to BLOCKFIRST: [L, kv_dim, T, H, D]."""
349-
kv_dim = 1 if self._model_config.use_mla else 2
350-
return self._get_block_view(block_id).view(
351-
self._model_config.num_layers, kv_dim,
352-
self._page_size, self._model_config.num_kv_heads,
353-
self._model_config.head_size
354-
)
366+
return self._get_block_view(block_id).view(self._block_shape)
355367

356368
def _fetch_remote_blocks(self, token_ids: np.ndarray) -> bool:
357369
"""Fetch remote blocks into local CPU cache via prefetch_async.
@@ -411,7 +423,6 @@ def register_mem_pool_host(self, mem_pool_host: Any) -> None:
411423
self._page_size = sglang_page_size
412424

413425
self._init_kv_manager()
414-
self._started = True
415426

416427
# KVManager has started — global collector is now initialized
417428
try:
@@ -499,7 +510,7 @@ def batch_exists(
499510
tokens_per_block=page_size)
500511

501512
# In distributed mode, query both local and remote trees
502-
if (self._mode == "distributed"
513+
if (self._mode == MODE_DISTRIBUTED
503514
and hasattr(cache_engine.cpu_cache_engine, 'match_all')):
504515
match_result = cache_engine.cpu_cache_engine.match_all(seq_meta)
505516
else:
@@ -510,7 +521,7 @@ def batch_exists(
510521
except Exception:
511522
logger.exception("batch_exists failed")
512523
if self._metrics:
513-
self._metrics.record_sglang_error("exists")
524+
self._metrics.record_sglang_error(_OP_EXISTS)
514525
return 0
515526

516527
def batch_get_v1(
@@ -550,7 +561,7 @@ def batch_get_v1(
550561
tokens_per_block=page_size)
551562

552563
# In distributed mode, query both local and remote trees
553-
if (self._mode == "distributed"
564+
if (self._mode == MODE_DISTRIBUTED
554565
and hasattr(cache_engine.cpu_cache_engine, 'match_all')):
555566
match_result = cache_engine.cpu_cache_engine.match_all(seq_meta)
556567

@@ -610,7 +621,7 @@ def batch_get_v1(
610621
except Exception:
611622
logger.exception("batch_get_v1 failed")
612623
if self._metrics:
613-
self._metrics.record_sglang_error("get")
624+
self._metrics.record_sglang_error(_OP_GET)
614625
return [False] * len(keys)
615626

616627
def batch_set_v1(
@@ -694,7 +705,7 @@ def batch_set_v1(
694705
except Exception:
695706
logger.exception("batch_set_v1 failed")
696707
if self._metrics:
697-
self._metrics.record_sglang_error("set")
708+
self._metrics.record_sglang_error(_OP_SET)
698709
return [False] * len(keys)
699710

700711
# Legacy abstract methods

flexkv/integration/sglang/test_hicache_storage_adapter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from unittest.mock import MagicMock
5050

5151
from flexkv.integration.sglang.hicache_storage_adapter import (
52-
FlexKVHiCacheStorage, _get_token_ids,
52+
FlexKVHiCacheStorage, _get_token_ids, MODE_LOCAL, MODE_DISTRIBUTED,
5353
)
5454
from sglang.srt.mem_cache.hicache_storage import (
5555
HiCacheStorageConfig, HiCacheStorageExtraInfo,
@@ -230,7 +230,7 @@ def test_no_token_ids_degradation():
230230
def test_default_local_mode():
231231
"""Default mode is 'local' when not specified."""
232232
backend = FlexKVHiCacheStorage(_make_config())
233-
assert backend._mode == "local", f"Expected 'local', got '{backend._mode}'"
233+
assert backend._mode == MODE_LOCAL, f"Expected '{MODE_LOCAL}', got '{backend._mode}'"
234234

235235

236236
def test_explicit_local_mode():
@@ -247,8 +247,7 @@ def test_explicit_local_mode():
247247
}
248248
)
249249
backend = FlexKVHiCacheStorage(config)
250-
assert backend._mode == "local"
251-
250+
assert backend._mode == MODE_LOCAL
252251

253252
def test_distributed_mode_config():
254253
"""Distributed mode stores redis config correctly."""
@@ -267,7 +266,7 @@ def test_distributed_mode_config():
267266
}
268267
)
269268
backend = FlexKVHiCacheStorage(config)
270-
assert backend._mode == "distributed"
269+
assert backend._mode == MODE_DISTRIBUTED
271270
assert backend._redis_host == "redis.example.com"
272271
assert backend._redis_port == 6380
273272
assert backend._redis_password == "test_password"

0 commit comments

Comments
 (0)