4646except ImportError :
4747 from sglang .srt .observability .metrics_collector import StorageMetrics
4848
49+ # ---------------------------------------------------------------------------
50+ # Constants
51+ # ---------------------------------------------------------------------------
52+
53+ MODE_LOCAL = "local"
54+ MODE_DISTRIBUTED = "distributed"
55+ _VALID_MODES = (MODE_LOCAL , MODE_DISTRIBUTED )
56+
57+ # Error operation labels (used in metrics recording)
58+ _OP_GET = "get"
59+ _OP_SET = "set"
60+ _OP_EXISTS = "exists"
61+
4962
5063# ---------------------------------------------------------------------------
5164# Helper: extract token_ids from extra_info
@@ -162,17 +175,17 @@ def __init__(
162175 self ._cache_config = CacheConfig (** cache_kwargs )
163176
164177 # Extract distributed mode configuration
165- self ._mode : str = extra .get ("mode" , "local" )
166- if self ._mode not in ( "local" , "distributed" ) :
167- raise ValueError (f"Invalid mode: { self ._mode } . Must be 'local' or 'distributed' " )
178+ self ._mode : str = extra .get ("mode" , MODE_LOCAL )
179+ if self ._mode not in _VALID_MODES :
180+ raise ValueError (f"Invalid mode: { self ._mode } . Must be one of { _VALID_MODES } " )
168181
169182 self ._redis_host : str = extra .get ("redis_host" , "127.0.0.1" )
170183 self ._redis_port : int = extra .get ("redis_port" , 6379 )
171184 self ._redis_password : Optional [str ] = extra .get ("redis_password" , None )
172185
173186 self ._prefetch_timeout : float = float (extra .get ("prefetch_timeout" , 5.0 ))
174187
175- if self ._mode == "distributed" and not self ._redis_host :
188+ if self ._mode == MODE_DISTRIBUTED and not self ._redis_host :
176189 raise ValueError ("redis_host is required when mode='distributed'" )
177190
178191 self ._should_backup : bool = (
@@ -184,9 +197,9 @@ def __init__(
184197 self ._kv_manager = None
185198 self ._cpu_cache_tensor = None # direct access to CPU cache (thread mode)
186199 self ._elements_per_block : int = 0
200+ self ._block_shape : tuple = () # set after KVManager init
187201 self ._page_size : int = self ._cache_config .tokens_per_block
188202 self ._mem_pool_host = mem_pool_host
189- self ._started = False
190203 self ._bytes_per_page : int = 0
191204 self ._gb_per_page : float = 0.0
192205
@@ -231,8 +244,7 @@ def _init_kv_manager(self):
231244 }
232245
233246 # Branch on mode: local vs distributed
234- if self ._mode == "distributed" :
235- # Distributed mode: enable cross-node KV Cache sharing
247+ if self ._mode == MODE_DISTRIBUTED :
236248 cache_config_kwargs .update ({
237249 "enable_remote" : True ,
238250 "enable_kv_sharing" : True ,
@@ -273,13 +285,18 @@ def _init_kv_manager(self):
273285 # Get direct access to CPU cache tensor (thread mode only)
274286 self ._cpu_cache_tensor = self ._kv_manager .get_cpu_cache_tensor ()
275287
276- # Compute elements per block for indexing
288+ # Compute elements per block and cache the block shape for indexing
277289 kv_dim = 1 if self ._model_config .use_mla else 2
278290 self ._elements_per_block = (
279291 self ._model_config .num_layers * kv_dim *
280292 self ._page_size * self ._model_config .num_kv_heads *
281293 self ._model_config .head_size
282294 )
295+ self ._block_shape = (
296+ self ._model_config .num_layers , kv_dim ,
297+ self ._page_size , self ._model_config .num_kv_heads ,
298+ self ._model_config .head_size
299+ )
283300
284301 # Compute bytes_per_page for bandwidth reporting
285302 dtype = getattr (self ._model_config , 'dtype' , None ) or torch .float16
@@ -346,12 +363,7 @@ def _get_block_view(self, block_id: int) -> torch.Tensor:
346363
347364 def _get_block_shaped (self , block_id : int ) -> torch .Tensor :
348365 """Get a CPU cache block reshaped to BLOCKFIRST: [L, kv_dim, T, H, D]."""
349- kv_dim = 1 if self ._model_config .use_mla else 2
350- return self ._get_block_view (block_id ).view (
351- self ._model_config .num_layers , kv_dim ,
352- self ._page_size , self ._model_config .num_kv_heads ,
353- self ._model_config .head_size
354- )
366+ return self ._get_block_view (block_id ).view (self ._block_shape )
355367
356368 def _fetch_remote_blocks (self , token_ids : np .ndarray ) -> bool :
357369 """Fetch remote blocks into local CPU cache via prefetch_async.
@@ -411,7 +423,6 @@ def register_mem_pool_host(self, mem_pool_host: Any) -> None:
411423 self ._page_size = sglang_page_size
412424
413425 self ._init_kv_manager ()
414- self ._started = True
415426
416427 # KVManager has started — global collector is now initialized
417428 try :
@@ -499,7 +510,7 @@ def batch_exists(
499510 tokens_per_block = page_size )
500511
501512 # In distributed mode, query both local and remote trees
502- if (self ._mode == "distributed"
513+ if (self ._mode == MODE_DISTRIBUTED
503514 and hasattr (cache_engine .cpu_cache_engine , 'match_all' )):
504515 match_result = cache_engine .cpu_cache_engine .match_all (seq_meta )
505516 else :
@@ -510,7 +521,7 @@ def batch_exists(
510521 except Exception :
511522 logger .exception ("batch_exists failed" )
512523 if self ._metrics :
513- self ._metrics .record_sglang_error ("exists" )
524+ self ._metrics .record_sglang_error (_OP_EXISTS )
514525 return 0
515526
516527 def batch_get_v1 (
@@ -550,7 +561,7 @@ def batch_get_v1(
550561 tokens_per_block = page_size )
551562
552563 # In distributed mode, query both local and remote trees
553- if (self ._mode == "distributed"
564+ if (self ._mode == MODE_DISTRIBUTED
554565 and hasattr (cache_engine .cpu_cache_engine , 'match_all' )):
555566 match_result = cache_engine .cpu_cache_engine .match_all (seq_meta )
556567
@@ -610,7 +621,7 @@ def batch_get_v1(
610621 except Exception :
611622 logger .exception ("batch_get_v1 failed" )
612623 if self ._metrics :
613- self ._metrics .record_sglang_error ("get" )
624+ self ._metrics .record_sglang_error (_OP_GET )
614625 return [False ] * len (keys )
615626
616627 def batch_set_v1 (
@@ -694,7 +705,7 @@ def batch_set_v1(
694705 except Exception :
695706 logger .exception ("batch_set_v1 failed" )
696707 if self ._metrics :
697- self ._metrics .record_sglang_error ("set" )
708+ self ._metrics .record_sglang_error (_OP_SET )
698709 return [False ] * len (keys )
699710
700711 # Legacy abstract methods
0 commit comments