Skip to content

Commit bca6ccc

Browse files
author
phaedonsun
committed
review-fixes(dist_reuse): consolidated reviewer feedback
Squashes three prior review-fix commits into one. Addresses reviewer comments on the dist_reuse feature commit (22bc183): F1. is_nsa source: read directly from model_config.is_nsa instead of reverse-deriving from enable_nsa_prefill_context_parallel (the latter is a CP toggle, orthogonal to NSA architecture). F2. Control-plane / rank-plane separation: SharingDomainKey.from_model_config now takes an optional rank_info=RankInfo argument and reads pp_rank / tp_node_idx from it; the control plane (KVManager) only constructs self-SD via default() and enumerates peers via enumerate_peers() — no fake rank fabrication. F3. RankTopology factory dropped; reuse the existing RankInfo end-to-end. Integration adapters (vLLM v1 / TRT-LLM / SGLang) plumb the real rank_info through. F4. shell / TransferManagerOnRemote decoupling: revert start_dist_reuse_serving.sh changes; TransferManagerOnRemote stays per-node and is tagged via set_target_sd_key on each handle. F5. delete unused flexkv.integration.multinode_policy module and its is_multinode_tp / is_multinode_cp / is_multinode_pp helpers; CP never participates in sd_key (attention all-gather makes per-cp_rank pools bit-wise identical), and TP-cross-node is encoded in SharingDomainKey.tp_node_count directly. Verified no external references in the SGLang FlexKVConnector codebase. Tests: full dist_reuse suite (363/363) passes on both GPU executors.
1 parent 22bc183 commit bca6ccc

12 files changed

Lines changed: 243 additions & 450 deletions

File tree

flexkv/common/config.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,16 @@ def freeze(self) -> None:
104104
f"[ModelConfig] cannot derive gpus_per_node: "
105105
f"total_gpus={self.total_gpus} not divisible by nnodes={self.nnodes}"
106106
)
107-
if self.nnodes_per_tp_group > 2:
107+
if self.tp_node_count > 2:
108108
raise ValueError(
109109
f"[ModelConfig] only support 2-nodes TP for now, but got "
110-
f"nnodes_per_tp_group={self.nnodes_per_tp_group} "
110+
f"tp_node_count={self.tp_node_count} "
111111
f"(tp_size={self.tp_size}, gpus_per_node={self.gpus_per_node})"
112112
)
113-
if self.tp_size % self.nnodes_per_tp_group != 0:
113+
if self.tp_size % self.tp_node_count != 0:
114114
raise ValueError(
115115
f"[ModelConfig] tp_size={self.tp_size} not divisible by "
116-
f"nnodes_per_tp_group={self.nnodes_per_tp_group}"
116+
f"tp_node_count={self.tp_node_count}"
117117
)
118118
if self.instance_num < 1:
119119
raise ValueError(
@@ -160,13 +160,23 @@ def nnodes_per_pp_rank(self) -> int:
160160

161161
@property
162162
def nnodes_per_tp_group(self) -> int:
163-
"""Number of nodes spanned by one TP group."""
164-
return self.nnodes_per_pp_rank
163+
"""Number of nodes spanned by one TP group.
164+
165+
.. deprecated::
166+
Kept as a stable alias of :pyattr:`tp_node_count` for
167+
backwards compatibility with adapter code that pre-dates
168+
the SD-key naming convention. New code should read
169+
``tp_node_count`` directly — that property carries the
170+
authoritative semantic ("the TP-axis node-count entering
171+
``SharingDomainKey``") and is the value tracked in the
172+
redis schema (``docs/dist_reuse/redis_schema.md``).
173+
"""
174+
return self.tp_node_count
165175

166176
@property
167177
def tp_size_per_node(self) -> int:
168178
"""Number of TP ranks on this node within one TP group."""
169-
return self.tp_size // self.nnodes_per_tp_group
179+
return self.tp_size // self.tp_node_count
170180

171181
@property
172182
def attn_dp_size(self) -> int:
@@ -183,7 +193,7 @@ def attn_tp_size(self) -> int:
183193
@property
184194
def attn_tp_size_per_node(self) -> int:
185195
"""Attention-level TP size per node."""
186-
return self.attn_tp_size // self.nnodes_per_tp_group
196+
return self.attn_tp_size // self.tp_node_count
187197

188198
@property
189199
def attn_cp_size_per_node(self) -> int:
@@ -242,6 +252,43 @@ def is_multinode_tp(self) -> bool:
242252
"""
243253
return self.tp_node_count > 1
244254

255+
@property
256+
def is_multinode_pp(self) -> bool:
257+
"""PP is the dimension that makes *this instance* cross node boundaries.
258+
259+
Returns True iff:
260+
261+
* ``pp_size > 1`` — PP is actually deployed,
262+
* ``nnodes > 1`` — the instance occupies more than one node,
263+
* ``tp_node_count == 1`` — TP **does not** cross nodes
264+
(otherwise TP-multinode is the dominant axis and would
265+
already drive the SD-Remote decision; classifying the same
266+
deployment as "multinode-PP" too would double-count).
267+
268+
This is the missing third axis next to :pyattr:`is_multinode_tp`
269+
and :pyattr:`is_multinode_cp`. It exists so the connector's
270+
runtime launch logic can stop folding "PP-only crosses nodes"
271+
into the off-master fall-through branch.
272+
273+
Worked examples:
274+
275+
* ``pp=4, nnodes=2, tp=8, gpus_per_node=8`` → True
276+
(PP=4 stages × tp=8 = 32 GPUs across 2 nodes; each node
277+
owns 2 PP stages; TP stays inside one node).
278+
* ``pp=1, nnodes=2, tp=16`` → False
279+
(PP single-stage; TP is the multinode axis).
280+
* ``pp=2, nnodes=2, tp=16`` → False
281+
(TP already crosses; PP is *not* the dominant axis here \u2014
282+
we leave the multinode-TP branch to handle this).
283+
* ``pp=2, nnodes=1`` → False
284+
(single node; PP fits in-host).
285+
"""
286+
return (
287+
self.pp_size > 1
288+
and self.nnodes > 1
289+
and self.tp_node_count == 1
290+
)
291+
245292
@property
246293
def is_multinode_cp(self) -> bool:
247294
"""CP > 1 *and* the CP group spans more than one physical node.
@@ -300,9 +347,17 @@ def num_kv_heads_per_node(self) -> int:
300347
# ------------------------------------------------------------------
301348
@property
302349
def tp_node_count(self) -> int:
303-
"""Number of physical nodes one TP group spans (=
304-
``nnodes_per_tp_group``). ``1`` when TP fits on a single node."""
305-
return self.nnodes_per_tp_group
350+
"""Number of physical nodes one TP group spans.
351+
352+
Authoritative source for the TP-axis node-count used in
353+
:class:`SharingDomainKey` and ``docs/dist_reuse/redis_schema.md``.
354+
``1`` when TP fits on a single node. Deprecated alias:
355+
:pyattr:`nnodes_per_tp_group`.
356+
"""
357+
# PP and TP groups share the same per-rank node assignment in the
358+
# current topology (one TP group sits on the same set of nodes as
359+
# one PP stage), so ``nnodes_per_tp_group == nnodes_per_pp_rank``.
360+
return self.nnodes_per_pp_rank
306361

307362
# NOTE: ``tp_node_idx`` is a per-rank concept and was moved to
308363
# ``RankInfo`` in PR #165 (separate-per-rank-state-into-RankInfo).

flexkv/common/dist_reuse/sharing_domain.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,39 @@ def from_model_config(
360360
getattr(model_config, "tp_node_idx", 0))
361361
)
362362
else:
363+
# Legacy path: PR #165 moved ``pp_rank`` / ``tp_node_idx``
364+
# off ``ModelConfig`` onto :class:`RankInfo`. These
365+
# ``getattr(..., 0)`` reads therefore now return ``0``
366+
# for any post-#165 ``ModelConfig`` instance — i.e. the
367+
# caller silently gets the master-position SD even if it
368+
# is actually on (pp_rank>0, tp_node_idx>0). This is
369+
# only safe for unit-test fakes that explicitly set
370+
# ``pp_rank`` / ``tp_node_idx`` on the stub ModelConfig
371+
# (see ``tests/test_sharing_domain_key.py``); production
372+
# callers should pass ``rank_info=`` explicitly. We log
373+
# a one-time warning when the heuristic is exercised on
374+
# a multi-rank topology so the error surfaces during
375+
# bring-up instead of silently corrupting Redis keys.
363376
_pp_rank = int(getattr(model_config, "pp_rank", 0))
364377
_tp_node_idx = int(getattr(model_config, "tp_node_idx", 0))
378+
_has_pp_rank = hasattr(model_config, "pp_rank")
379+
_has_tp_node_idx = hasattr(model_config, "tp_node_idx")
380+
if (
381+
pp_size > 1 or int(getattr(model_config, "tp_node_count", 1)) > 1
382+
) and not (_has_pp_rank and _has_tp_node_idx):
383+
# Local import to avoid pulling logger into module
384+
# import-time graph (sharing_domain.py is imported
385+
# very early by config.py via ``derive_model_id``).
386+
from flexkv.common.debug import flexkv_logger
387+
flexkv_logger.warning(
388+
"SharingDomainKey.from_model_config: called without "
389+
"rank_info on a multi-rank topology (pp_size=%d, "
390+
"tp_node_count=%s); per-rank fields default to 0, "
391+
"which only matches the master SD. Pass rank_info="
392+
"<RankInfo> so the per-rank position is honoured.",
393+
pp_size,
394+
getattr(model_config, "tp_node_count", 1),
395+
)
365396

366397
_pp_node_idx = _pp_rank // pp_per_node
367398

flexkv/integration/config.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,11 @@ def post_init_from_sglang_config(
218218
sglang_config: sglang.srt.configs.model_config.ModelConfig-like object
219219
server_args: sglang ServerArgs — source of tp_size, dp_size,
220220
nnodes, node_rank, enable_dp_attention, attn_cp_size,
221-
is_nsa (read from server_args.enable_nsa_prefill_context_parallel),
222-
kv_cache_dtype, dist_init_addr
221+
kv_cache_dtype, dist_init_addr. ``is_nsa`` is **not**
222+
read from server_args: see the body below — it is
223+
derived from ``sglang_config.index_head_dim`` instead,
224+
because NSA is a model-layout property orthogonal to
225+
whether CP-prefill is enabled.
223226
page_size: KV block size (tokens per block) used by sglang
224227
tp_rank: physical tensor parallel rank (runtime, from process group)
225228
pp_rank: pipeline parallel rank (runtime, from process group)
@@ -234,11 +237,24 @@ def post_init_from_sglang_config(
234237
node_rank = server_args.node_rank
235238
enable_dp_attention = server_args.enable_dp_attention
236239
attn_cp_size = server_args.attn_cp_size
237-
# ``is_nsa`` (NSA model layout flag): True when the model has an
238-
# extra indexer K cache buffer. Sourced from sglang's
239-
# ``enable_nsa_prefill_context_parallel`` server arg, but in dist_reuse
240-
# context the flag represents the *layout*, not whether CP is on.
241-
is_nsa = getattr(server_args, 'enable_nsa_prefill_context_parallel', False)
240+
# ``is_nsa`` (NSA model layout flag): True when the model itself has
241+
# an extra indexer K cache buffer. This is a *layout* property of
242+
# the model architecture, **independent** of whether CP is enabled
243+
# at runtime — an NSA model with cp_size=1 still has the indexer K
244+
# cache and must therefore be isolated from non-NSA models in the
245+
# cross-instance reuse namespace (it lives in
246+
# ``SharingDomainKey.serialize`` as the ``nsa<0|1>`` segment).
247+
#
248+
# Detection rule: an NSA/DSA model exposes a positive
249+
# ``index_head_dim`` attribute on its sglang ModelConfig (the same
250+
# signal already consulted ~25 lines below to size the indexer
251+
# head buffer). Falling back to
252+
# ``server_args.enable_nsa_prefill_context_parallel`` was incorrect
253+
# because it conflates the *runtime CP toggle* with the *static
254+
# model layout* — a deployment can run an NSA model with CP=1
255+
# (no prefill-CP) and still need NSA-isolated namespaces.
256+
index_head_dim = getattr(sglang_config, "index_head_dim", None)
257+
is_nsa = bool(index_head_dim) and int(index_head_dim) > 0
242258
kv_cache_dtype = getattr(server_args, 'kv_cache_dtype', None)
243259
dp_rank = 0 if dp_rank is None else int(dp_rank)
244260

flexkv/integration/multinode_policy.py

Lines changed: 0 additions & 185 deletions
This file was deleted.

0 commit comments

Comments
 (0)