Skip to content

Commit e0e523e

Browse files
author
zittozhang
committed
feat: support sglang dp attention and drop sglang-specific attn_* fields
1 parent 00cc828 commit e0e523e

6 files changed

Lines changed: 311 additions & 250 deletions

File tree

flexkv/cache/redis_meta.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,6 @@ def register_node(self) -> Optional[int]:
283283
"uuid": self.uuid,
284284
"status": "active",
285285
"timestamp": str(int(time.time())),
286-
"pp_rank": str(getattr(self, 'pp_rank', 0)),
287-
"pp_size": str(getattr(self, 'pp_size', 1)),
288286
})
289287

290288
# Set TTL so the key auto-expires if the process crashes

flexkv/common/config.py

Lines changed: 64 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,8 @@ class ModelConfig:
4040
# ------------------------------------------------------------------
4141
# Attention-level parallel configs
4242
# ------------------------------------------------------------------
43-
# enable_dp_attention: whether DP-attention is enabled (sglang
44-
# ``--enable-dp-attention`` or TRT-LLM ``enable_attention_dp``).
45-
# When True, the physical TP group is split into
46-
# attn_tp × attn_cp × attn_dp.
47-
enable_dp_attention: bool = False
48-
49-
# attn_cp_size: context-parallel size (global).
50-
attn_cp_size: int = 1
43+
# cp_size: context-parallel size (global), default 1.
44+
cp_size: int = 1
5145

5246
# ------------------------------------------------------------------
5347
# Topology configs (global)
@@ -94,17 +88,12 @@ def freeze(self) -> None:
9488
f"[ModelConfig] cannot derive gpus_per_node: "
9589
f"total_gpus={self.total_gpus} not divisible by nnodes={self.nnodes}"
9690
)
97-
if self.nnodes_per_tp_group > 2:
91+
if self.nnodes_per_pp_rank > 2:
9892
raise ValueError(
9993
f"[ModelConfig] only support 2-nodes TP for now, but got "
100-
f"nnodes_per_tp_group={self.nnodes_per_tp_group} "
94+
f"nnodes_per_pp_rank={self.nnodes_per_pp_rank} "
10195
f"(tp_size={self.tp_size}, gpus_per_node={self.gpus_per_node})"
10296
)
103-
if self.tp_size % self.nnodes_per_tp_group != 0:
104-
raise ValueError(
105-
f"[ModelConfig] tp_size={self.tp_size} not divisible by "
106-
f"nnodes_per_tp_group={self.nnodes_per_tp_group}"
107-
)
10897
if self.instance_num < 1:
10998
raise ValueError(
11099
f"[ModelConfig] instance_num must be >= 1, got {self.instance_num}"
@@ -119,8 +108,8 @@ def __setattr__(self, name: str, value) -> None:
119108
raise AttributeError(
120109
f"ModelConfig is frozen — cannot set '{name}'. "
121110
f"All primitive fields must be set during post_init_from_*(), "
122-
f"after which freeze() is called. Derived fields (attn_tp_size, "
123-
f"tp_size_per_node) are @property "
111+
f"after which freeze() is called. Derived fields (effective_tp_size, "
112+
f"tp_size_per_node, cp_size_per_node, nnodes_per_pp_rank) are @property "
124113
f"and cannot be set at all."
125114
)
126115
object.__setattr__(self, name, value)
@@ -130,8 +119,10 @@ def __setattr__(self, name: str, value) -> None:
130119
# ------------------------------------------------------------------
131120
@property
132121
def total_gpus(self) -> int:
133-
"""Total GPUs across all nodes for one FlexKV instance."""
134-
return self.dp_size * self.tp_size * self.pp_size
122+
"""Total GPU worker registration slots across all nodes for one FlexKV instance.
123+
124+
Unified formula: dp_size × tp_size × cp_size × pp_size."""
125+
return self.dp_size * self.tp_size * self.cp_size * self.pp_size
135126

136127
@property
137128
def total_clients(self) -> int:
@@ -140,62 +131,43 @@ def total_clients(self) -> int:
140131

141132
@property
142133
def gpus_per_node(self) -> int:
143-
"""Total GPUs on this node (across all DP, PP stages and TP groups)."""
134+
"""GPU worker registration slots on this node (across all DP shards, PP stages and TP groups)."""
144135
return self.total_gpus // self.nnodes
145136

146137
@property
147138
def nnodes_per_pp_rank(self) -> int:
148139
"""Number of nodes spanned by one PP stage."""
149140
return max(self.nnodes // self.pp_size, 1)
150141

151-
@property
152-
def nnodes_per_tp_group(self) -> int:
153-
"""Number of nodes spanned by one TP group."""
154-
return self.nnodes_per_pp_rank
155-
156142
@property
157143
def tp_size_per_node(self) -> int:
158144
"""Number of TP ranks on this node within one TP group."""
159-
return self.tp_size // self.nnodes_per_tp_group
160-
161-
@property
162-
def attn_dp_size(self) -> int:
163-
"""Attention-level DP size (= dp_size when enable_dp_attention else 1)."""
164-
return max(1, self.dp_size) if self.enable_dp_attention else 1
145+
return max(1, self.tp_size // self.nnodes_per_pp_rank)
165146

166147
@property
167-
def attn_tp_size(self) -> int:
168-
"""Attention-level TP size derived from tp / attn_dp / attn_cp."""
169-
attn_dp = self.attn_dp_size
170-
cp = max(1, self.attn_cp_size)
171-
return max(1, max(1, self.tp_size) // (attn_dp * cp))
148+
def cp_size_per_node(self) -> int:
149+
"""CP size on this node for a single PP stage.
172150
173-
@property
174-
def attn_tp_size_per_node(self) -> int:
175-
"""Attention-level TP size per node."""
176-
return self.attn_tp_size // self.nnodes_per_tp_group
177-
178-
@property
179-
def attn_cp_size_per_node(self) -> int:
180-
"""Attention-level CP size on this node for a single pp stage. """
181-
return max(1, self.attn_cp_size // self.nnodes_per_pp_rank)
151+
Used for multi-node scenarios where the CP group spans multiple nodes.
152+
"""
153+
return max(1, self.cp_size // self.nnodes_per_pp_rank)
182154

183155
@property
184156
def effective_tp_size(self) -> int:
185-
"""Effective tp-group size used for *data-plane* CPU slicing."""
186-
return max(1, self.attn_tp_size) * max(1, self.attn_cp_size)
157+
"""Number of CPU block slices = tp_size × cp_size."""
158+
return max(1, self.tp_size) * max(1, self.cp_size)
187159

188160
@property
189161
def effective_tp_size_per_node(self) -> int:
190162
"""Per-node counterpart of :pyattr:`effective_tp_size`."""
191-
return self.attn_tp_size_per_node * self.attn_cp_size_per_node
163+
return self.tp_size_per_node * self.cp_size_per_node
192164

193165
@property
194166
def num_kv_heads_per_node(self) -> int:
195167
"""Number of KV heads visible to a single node."""
196168
if self.use_mla:
197169
return self.num_kv_heads
198-
return self.num_kv_heads * self.tp_size_per_node // max(1, self.attn_tp_size)
170+
return self.num_kv_heads * self.tp_size_per_node // max(1, self.tp_size)
199171

200172
@property
201173
def kv_dim(self) -> int:
@@ -218,7 +190,8 @@ def __str__(self) -> str:
218190
f", head_size={self.head_size}, use_mla={self.use_mla}"
219191
f", dtype={self.dtype}"
220192
f", tp_size={self.tp_size}, pp_size={self.pp_size}, dp_size={self.dp_size}"
221-
f", attn_cp_size={self.attn_cp_size}"
193+
f", cp_size={self.cp_size}"
194+
f", total_gpus={self.total_gpus}"
222195
f", nnodes={self.nnodes}, master_host={self.master_host!r}"
223196
f", instance_num={self.instance_num}"
224197
)
@@ -230,11 +203,12 @@ class RankInfo:
230203
tp_rank: int = 0
231204
pp_rank: int = 0
232205
dp_rank: int = 0
233-
attn_cp_rank: int = 0
206+
cp_rank: int = 0
234207
node_rank: int = 0
235208
instance_id: int = 0
236209
pp_start_layer: int = 0
237210
pp_end_layer: int = -1
211+
local_rank: int = -1
238212
@property
239213
def tp_rank_per_node(self) -> int:
240214
"""TP rank index within the local node (within one TP group)."""
@@ -252,18 +226,10 @@ def dp_client_id(self) -> int:
252226
"""
253227
return self.instance_id * self.model_config.dp_size + self.dp_rank
254228

255-
@property
256-
def attn_tp_rank(self) -> int:
257-
"""Attention-level TP rank derived from tp_rank / attn_tp_size."""
258-
return self.tp_rank % max(1, self.model_config.attn_tp_size)
259-
260229
@property
261230
def effective_tp_rank(self) -> int:
262231
"""Effective tp-rank in the *data-plane* segmentation space."""
263-
if self.model_config.use_mla:
264-
return self.attn_tp_rank
265-
attn_tp_size = max(1, self.model_config.attn_tp_size)
266-
return self.attn_cp_rank * attn_tp_size + self.attn_tp_rank
232+
return self.cp_rank * max(1, self.model_config.tp_size) + self.tp_rank
267233

268234
@property
269235
def pp_size_per_node(self) -> int:
@@ -276,25 +242,6 @@ def pp_rank_per_node(self) -> int:
276242
"""This rank's PP index *within* its node."""
277243
return self.pp_rank % self.pp_size_per_node
278244

279-
@property
280-
def dp_size_per_node(self) -> int:
281-
"""Number of DP replicas co-located on a single node."""
282-
model_config = self.model_config
283-
return model_config.gpus_per_node // (self.pp_size_per_node * model_config.tp_size_per_node)
284-
285-
@property
286-
def dp_rank_per_node(self) -> int:
287-
"""This rank's DP index *within* its node (non-DP-attention layout)."""
288-
return self.dp_rank % self.dp_size_per_node
289-
290-
@property
291-
def local_rank(self) -> int:
292-
model_config = self.model_config
293-
if model_config.enable_dp_attention:
294-
return self.pp_rank_per_node * model_config.tp_size_per_node + self.tp_rank_per_node
295-
return (self.dp_rank_per_node * self.pp_size_per_node + self.pp_rank_per_node) \
296-
* model_config.tp_size_per_node + self.tp_rank_per_node
297-
298245
@property
299246
def num_layers_per_pp_stage(self) -> int:
300247
"""Number of layers managed by this PP stage."""
@@ -315,8 +262,9 @@ def __str__(self) -> str:
315262
"""
316263
return (
317264
f"RankInfo(tp_rank={self.tp_rank}, pp_rank={self.pp_rank}"
318-
f", dp_rank={self.dp_rank}, attn_cp_rank={self.attn_cp_rank}"
265+
f", dp_rank={self.dp_rank}, cp_rank={self.cp_rank}"
319266
f", node_rank={self.node_rank}, instance_id={self.instance_id}"
267+
f", local_rank={self.local_rank}, effective_tp_rank={self.effective_tp_rank}"
320268
)
321269

322270

@@ -540,14 +488,49 @@ def convert_to_block_num(size_in_GB: float, block_size_in_bytes: int) -> int:
540488
def update_default_config_from_user_config(rank_info: RankInfo,
541489
cache_config: CacheConfig,
542490
user_config: UserConfig) -> None:
543-
block_size_in_bytes = rank_info.token_size_in_bytes_per_pp_stage * cache_config.tokens_per_block
491+
main_block_size_in_bytes = (
492+
rank_info.token_size_in_bytes_per_pp_stage * cache_config.tokens_per_block
493+
)
494+
indexer_block_size_in_bytes = 0
495+
if cache_config.indexer is not None:
496+
indexer_cfg = cache_config.indexer
497+
# Indexer is MLA-style (single shared head set, no TP head split),
498+
# so per-token bytes = num_kv_heads * head_size * dtype.itemsize.
499+
indexer_bytes_per_token_per_layer = (
500+
indexer_cfg.num_kv_heads
501+
* indexer_cfg.head_size
502+
* indexer_cfg.dtype.itemsize
503+
)
504+
indexer_block_size_in_bytes = (
505+
rank_info.num_layers_per_pp_stage
506+
* indexer_bytes_per_token_per_layer
507+
* 1
508+
)
509+
block_size_in_bytes = main_block_size_in_bytes + indexer_block_size_in_bytes
544510

545511
assert user_config.cpu_cache_gb > 0
546512
assert user_config.ssd_cache_gb >= 0
547513

548514
cache_config.num_cpu_blocks = convert_to_block_num(user_config.cpu_cache_gb, block_size_in_bytes)
549515
cache_config.num_ssd_blocks = convert_to_block_num(user_config.ssd_cache_gb, block_size_in_bytes)
550516

517+
if cache_config.indexer is not None:
518+
flexkv_logger.info(
519+
f"[CacheConfig] GB->blocks conversion (with indexer): "
520+
f"main_block_size={main_block_size_in_bytes} B, "
521+
f"indexer_block_size={indexer_block_size_in_bytes} B, "
522+
f"total_block_size={block_size_in_bytes} B; "
523+
f"cpu_cache_gb={user_config.cpu_cache_gb} -> num_cpu_blocks={cache_config.num_cpu_blocks}, "
524+
f"ssd_cache_gb={user_config.ssd_cache_gb} -> num_ssd_blocks={cache_config.num_ssd_blocks}"
525+
)
526+
else:
527+
flexkv_logger.info(
528+
f"[CacheConfig] GB->blocks conversion: "
529+
f"block_size={block_size_in_bytes} B; "
530+
f"cpu_cache_gb={user_config.cpu_cache_gb} -> num_cpu_blocks={cache_config.num_cpu_blocks}, "
531+
f"ssd_cache_gb={user_config.ssd_cache_gb} -> num_ssd_blocks={cache_config.num_ssd_blocks}"
532+
)
533+
551534
cache_config.ssd_cache_dir = user_config.ssd_cache_dir
552535
cache_config.enable_ssd = user_config.ssd_cache_gb > 0
553536
cache_config.enable_gds = user_config.enable_gds

0 commit comments

Comments
 (0)