@@ -40,14 +40,8 @@ class ModelConfig:
4040 # ------------------------------------------------------------------
4141 # Attention-level parallel configs
4242 # ------------------------------------------------------------------
43- # enable_dp_attention: whether DP-attention is enabled (sglang
44- # ``--enable-dp-attention`` or TRT-LLM ``enable_attention_dp``).
45- # When True, the physical TP group is split into
46- # attn_tp × attn_cp × attn_dp.
47- enable_dp_attention : bool = False
48-
49- # attn_cp_size: context-parallel size (global).
50- attn_cp_size : int = 1
43+ # cp_size: context-parallel size (global), default 1.
44+ cp_size : int = 1
5145
5246 # ------------------------------------------------------------------
5347 # Topology configs (global)
@@ -94,17 +88,12 @@ def freeze(self) -> None:
9488 f"[ModelConfig] cannot derive gpus_per_node: "
9589 f"total_gpus={ self .total_gpus } not divisible by nnodes={ self .nnodes } "
9690 )
97- if self .nnodes_per_tp_group > 2 :
91+ if self .nnodes_per_pp_rank > 2 :
9892 raise ValueError (
9993 f"[ModelConfig] only support 2-nodes TP for now, but got "
100- f"nnodes_per_tp_group ={ self .nnodes_per_tp_group } "
94+ f"nnodes_per_pp_rank ={ self .nnodes_per_pp_rank } "
10195 f"(tp_size={ self .tp_size } , gpus_per_node={ self .gpus_per_node } )"
10296 )
103- if self .tp_size % self .nnodes_per_tp_group != 0 :
104- raise ValueError (
105- f"[ModelConfig] tp_size={ self .tp_size } not divisible by "
106- f"nnodes_per_tp_group={ self .nnodes_per_tp_group } "
107- )
10897 if self .instance_num < 1 :
10998 raise ValueError (
11099 f"[ModelConfig] instance_num must be >= 1, got { self .instance_num } "
@@ -119,8 +108,8 @@ def __setattr__(self, name: str, value) -> None:
119108 raise AttributeError (
120109 f"ModelConfig is frozen — cannot set '{ name } '. "
121110 f"All primitive fields must be set during post_init_from_*(), "
122- f"after which freeze() is called. Derived fields (attn_tp_size , "
123- f"tp_size_per_node) are @property "
111+ f"after which freeze() is called. Derived fields (effective_tp_size , "
112+ f"tp_size_per_node, cp_size_per_node, nnodes_per_pp_rank ) are @property "
124113 f"and cannot be set at all."
125114 )
126115 object .__setattr__ (self , name , value )
@@ -130,8 +119,10 @@ def __setattr__(self, name: str, value) -> None:
130119 # ------------------------------------------------------------------
131120 @property
132121 def total_gpus (self ) -> int :
133- """Total GPUs across all nodes for one FlexKV instance."""
134- return self .dp_size * self .tp_size * self .pp_size
122+ """Total GPU worker registration slots across all nodes for one FlexKV instance.
123+
124+ Unified formula: dp_size × tp_size × cp_size × pp_size."""
125+ return self .dp_size * self .tp_size * self .cp_size * self .pp_size
135126
136127 @property
137128 def total_clients (self ) -> int :
@@ -140,62 +131,43 @@ def total_clients(self) -> int:
140131
141132 @property
142133 def gpus_per_node (self ) -> int :
143- """Total GPUs on this node (across all DP, PP stages and TP groups)."""
134+ """GPU worker registration slots on this node (across all DP shards , PP stages and TP groups)."""
144135 return self .total_gpus // self .nnodes
145136
146137 @property
147138 def nnodes_per_pp_rank (self ) -> int :
148139 """Number of nodes spanned by one PP stage."""
149140 return max (self .nnodes // self .pp_size , 1 )
150141
151- @property
152- def nnodes_per_tp_group (self ) -> int :
153- """Number of nodes spanned by one TP group."""
154- return self .nnodes_per_pp_rank
155-
156142 @property
157143 def tp_size_per_node (self ) -> int :
158144 """Number of TP ranks on this node within one TP group."""
159- return self .tp_size // self .nnodes_per_tp_group
160-
161- @property
162- def attn_dp_size (self ) -> int :
163- """Attention-level DP size (= dp_size when enable_dp_attention else 1)."""
164- return max (1 , self .dp_size ) if self .enable_dp_attention else 1
145+ return max (1 , self .tp_size // self .nnodes_per_pp_rank )
165146
166147 @property
167- def attn_tp_size (self ) -> int :
168- """Attention-level TP size derived from tp / attn_dp / attn_cp."""
169- attn_dp = self .attn_dp_size
170- cp = max (1 , self .attn_cp_size )
171- return max (1 , max (1 , self .tp_size ) // (attn_dp * cp ))
148+ def cp_size_per_node (self ) -> int :
149+ """CP size on this node for a single PP stage.
172150
173- @property
174- def attn_tp_size_per_node (self ) -> int :
175- """Attention-level TP size per node."""
176- return self .attn_tp_size // self .nnodes_per_tp_group
177-
178- @property
179- def attn_cp_size_per_node (self ) -> int :
180- """Attention-level CP size on this node for a single pp stage. """
181- return max (1 , self .attn_cp_size // self .nnodes_per_pp_rank )
151+ Used for multi-node scenarios where the CP group spans multiple nodes.
152+ """
153+ return max (1 , self .cp_size // self .nnodes_per_pp_rank )
182154
183155 @property
184156 def effective_tp_size (self ) -> int :
185- """Effective tp-group size used for *data-plane* CPU slicing ."""
186- return max (1 , self .attn_tp_size ) * max (1 , self .attn_cp_size )
157+ """Number of CPU block slices = tp_size × cp_size ."""
158+ return max (1 , self .tp_size ) * max (1 , self .cp_size )
187159
188160 @property
189161 def effective_tp_size_per_node (self ) -> int :
190162 """Per-node counterpart of :pyattr:`effective_tp_size`."""
191- return self .attn_tp_size_per_node * self .attn_cp_size_per_node
163+ return self .tp_size_per_node * self .cp_size_per_node
192164
193165 @property
194166 def num_kv_heads_per_node (self ) -> int :
195167 """Number of KV heads visible to a single node."""
196168 if self .use_mla :
197169 return self .num_kv_heads
198- return self .num_kv_heads * self .tp_size_per_node // max (1 , self .attn_tp_size )
170+ return self .num_kv_heads * self .tp_size_per_node // max (1 , self .tp_size )
199171
200172 @property
201173 def kv_dim (self ) -> int :
@@ -218,7 +190,8 @@ def __str__(self) -> str:
218190 f", head_size={ self .head_size } , use_mla={ self .use_mla } "
219191 f", dtype={ self .dtype } "
220192 f", tp_size={ self .tp_size } , pp_size={ self .pp_size } , dp_size={ self .dp_size } "
221- f", attn_cp_size={ self .attn_cp_size } "
193+ f", cp_size={ self .cp_size } "
194+ f", total_gpus={ self .total_gpus } "
222195 f", nnodes={ self .nnodes } , master_host={ self .master_host !r} "
223196 f", instance_num={ self .instance_num } "
224197 )
@@ -230,11 +203,12 @@ class RankInfo:
230203 tp_rank : int = 0
231204 pp_rank : int = 0
232205 dp_rank : int = 0
233- attn_cp_rank : int = 0
206+ cp_rank : int = 0
234207 node_rank : int = 0
235208 instance_id : int = 0
236209 pp_start_layer : int = 0
237210 pp_end_layer : int = - 1
211+ local_rank : int = - 1
238212 @property
239213 def tp_rank_per_node (self ) -> int :
240214 """TP rank index within the local node (within one TP group)."""
@@ -252,18 +226,10 @@ def dp_client_id(self) -> int:
252226 """
253227 return self .instance_id * self .model_config .dp_size + self .dp_rank
254228
255- @property
256- def attn_tp_rank (self ) -> int :
257- """Attention-level TP rank derived from tp_rank / attn_tp_size."""
258- return self .tp_rank % max (1 , self .model_config .attn_tp_size )
259-
260229 @property
261230 def effective_tp_rank (self ) -> int :
262231 """Effective tp-rank in the *data-plane* segmentation space."""
263- if self .model_config .use_mla :
264- return self .attn_tp_rank
265- attn_tp_size = max (1 , self .model_config .attn_tp_size )
266- return self .attn_cp_rank * attn_tp_size + self .attn_tp_rank
232+ return self .cp_rank * max (1 , self .model_config .tp_size ) + self .tp_rank
267233
268234 @property
269235 def pp_size_per_node (self ) -> int :
@@ -276,25 +242,6 @@ def pp_rank_per_node(self) -> int:
276242 """This rank's PP index *within* its node."""
277243 return self .pp_rank % self .pp_size_per_node
278244
279- @property
280- def dp_size_per_node (self ) -> int :
281- """Number of DP replicas co-located on a single node."""
282- model_config = self .model_config
283- return model_config .gpus_per_node // (self .pp_size_per_node * model_config .tp_size_per_node )
284-
285- @property
286- def dp_rank_per_node (self ) -> int :
287- """This rank's DP index *within* its node (non-DP-attention layout)."""
288- return self .dp_rank % self .dp_size_per_node
289-
290- @property
291- def local_rank (self ) -> int :
292- model_config = self .model_config
293- if model_config .enable_dp_attention :
294- return self .pp_rank_per_node * model_config .tp_size_per_node + self .tp_rank_per_node
295- return (self .dp_rank_per_node * self .pp_size_per_node + self .pp_rank_per_node ) \
296- * model_config .tp_size_per_node + self .tp_rank_per_node
297-
298245 @property
299246 def num_layers_per_pp_stage (self ) -> int :
300247 """Number of layers managed by this PP stage."""
@@ -315,8 +262,9 @@ def __str__(self) -> str:
315262 """
316263 return (
317264 f"RankInfo(tp_rank={ self .tp_rank } , pp_rank={ self .pp_rank } "
318- f", dp_rank={ self .dp_rank } , attn_cp_rank ={ self .attn_cp_rank } "
265+ f", dp_rank={ self .dp_rank } , cp_rank ={ self .cp_rank } "
319266 f", node_rank={ self .node_rank } , instance_id={ self .instance_id } "
267+ f", local_rank={ self .local_rank } , effective_tp_rank={ self .effective_tp_rank } "
320268 )
321269
322270
@@ -540,14 +488,49 @@ def convert_to_block_num(size_in_GB: float, block_size_in_bytes: int) -> int:
540488def update_default_config_from_user_config (rank_info : RankInfo ,
541489 cache_config : CacheConfig ,
542490 user_config : UserConfig ) -> None :
543- block_size_in_bytes = rank_info .token_size_in_bytes_per_pp_stage * cache_config .tokens_per_block
491+ main_block_size_in_bytes = (
492+ rank_info .token_size_in_bytes_per_pp_stage * cache_config .tokens_per_block
493+ )
494+ indexer_block_size_in_bytes = 0
495+ if cache_config .indexer is not None :
496+ indexer_cfg = cache_config .indexer
497+ # Indexer is MLA-style (single shared head set, no TP head split),
498+ # so per-token bytes = num_kv_heads * head_size * dtype.itemsize.
499+ indexer_bytes_per_token_per_layer = (
500+ indexer_cfg .num_kv_heads
501+ * indexer_cfg .head_size
502+ * indexer_cfg .dtype .itemsize
503+ )
504+ indexer_block_size_in_bytes = (
505+ rank_info .num_layers_per_pp_stage
506+ * indexer_bytes_per_token_per_layer
507+ * 1
508+ )
509+ block_size_in_bytes = main_block_size_in_bytes + indexer_block_size_in_bytes
544510
545511 assert user_config .cpu_cache_gb > 0
546512 assert user_config .ssd_cache_gb >= 0
547513
548514 cache_config .num_cpu_blocks = convert_to_block_num (user_config .cpu_cache_gb , block_size_in_bytes )
549515 cache_config .num_ssd_blocks = convert_to_block_num (user_config .ssd_cache_gb , block_size_in_bytes )
550516
517+ if cache_config .indexer is not None :
518+ flexkv_logger .info (
519+ f"[CacheConfig] GB->blocks conversion (with indexer): "
520+ f"main_block_size={ main_block_size_in_bytes } B, "
521+ f"indexer_block_size={ indexer_block_size_in_bytes } B, "
522+ f"total_block_size={ block_size_in_bytes } B; "
523+ f"cpu_cache_gb={ user_config .cpu_cache_gb } -> num_cpu_blocks={ cache_config .num_cpu_blocks } , "
524+ f"ssd_cache_gb={ user_config .ssd_cache_gb } -> num_ssd_blocks={ cache_config .num_ssd_blocks } "
525+ )
526+ else :
527+ flexkv_logger .info (
528+ f"[CacheConfig] GB->blocks conversion: "
529+ f"block_size={ block_size_in_bytes } B; "
530+ f"cpu_cache_gb={ user_config .cpu_cache_gb } -> num_cpu_blocks={ cache_config .num_cpu_blocks } , "
531+ f"ssd_cache_gb={ user_config .ssd_cache_gb } -> num_ssd_blocks={ cache_config .num_ssd_blocks } "
532+ )
533+
551534 cache_config .ssd_cache_dir = user_config .ssd_cache_dir
552535 cache_config .enable_ssd = user_config .ssd_cache_gb > 0
553536 cache_config .enable_gds = user_config .enable_gds
0 commit comments