Skip to content

feat(dist_reuse): KV cache sharing across TP/PP/CP + single-node radi…#169

Open
feiqiangs wants to merge 2 commits into
feat/layerwise_rebasefrom
kvcache_reuse_debug
Open

feat(dist_reuse): KV cache sharing across TP/PP/CP + single-node radi…#169
feiqiangs wants to merge 2 commits into
feat/layerwise_rebasefrom
kvcache_reuse_debug

Conversation

@feiqiangs
Copy link
Copy Markdown
Collaborator

…x match

Squashes 5 commits (20a65d5 + d97db4d + 72632ec + b9230c6 + 7cd04ad) into a single landed feature. This is the full dist_reuse stack on top of PR #165 (RankInfo refactor), validated end-to-end on a 2-machine GPU setup (gpu-146.56.224.46 master / gpu-129.211.162.213 peer):

S1 (single-node TP=1) cached_ratio 99.65% PASS
S2 (single-node TP=2) cached_ratio 99.65% PASS
S3 (cross-node TP=1) master cold->warm 99.63%, peer
crosshit 99.63% storage=272 backend=
FlexKVConnector (PEERH2H @ 6.22 GB/s
via mooncake/RDMA) PASS x3
357/357 unit tests on both nodes PASS

== Original commits (in chronological order) ==

[20a65d5] feat(dist_reuse): KV cache sharing across TP/PP/CP + single- node radix match
Initial dist_reuse stack: master coordinator, sharing-domain key,
aggregate radix, redis-meta namespace, multi-node policy, P2P
transfer types (PEERH2H/H2PEERH/PEERSSD2H/H2PEERSSD), failure
detector, four S{1..4} sglang+FlexKV e2e benchmark scripts.

[d97db4d] fix(dist_reuse): unblock cross-instance KV cache sharing on s3_cross_node_tp1
Three runtime bugs blocked the s3 (master prime / peer crosshit) flow:
1) GPUCPUTransferWorker._transfer_impl had positional-arg drift
on the transfer_kv_blocks pybind: C++ added 'start_layer_id'
between 'chunk_size_in_bytes' and 'num_layers' (transfer.cu
2025-07-10), which silently mapped is_h2d=False onto
transfer_num_cta and launched D2H kernels with gridDim(0)
-> cudaErrorInvalidConfiguration on every D2H.
Fix: bind every value to the C++ pybind name with kwargs and
add 'start_layer_id=0' explicitly.
2) GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops carried a
dead 'layer_num' parameter which the only caller in
_get_impl_local passed undefined -> NameError on first
cross-instance reuse hit. Fix: drop the dead parameter and 6
call sites in tests/test_d3_filter_and_get_clones.py.
3) merge_to_batch_graph raised NotImplementedError on PEERH2H /
H2PEERH / PEERSSD2H / H2PEERSSD as soon as a real cross-instance
hit produced a P2P op. Fix: whitelist the four types as P2P
passthrough (preserves per-block src_block_node_ids and
per-op target_node_ids from D-3 multi-SD broadcast clones),
wire dependencies on merged_h2d_op (GET) / merged_d2h_op (PUT).

[72632ec] fix(memory_handle): propagate _import_tensor_handle exceptions
Previously _import_tensor_handle logged the error and returned
torch.empty(0) on import failure, which silently dropped the wrapper
into a 0-element tensor and surfaced as an unrelated IndexError later
in worker.py::_get_layer_ptrs (layer_blocks[lay_id][0] out of range).
Now always re-raise, keeping the original traceback so cross-node
CUDA IPC handle device-id mismatches surface at their source.
Consistent with _import_cuda_ipc_handle which never swallowed.

[b9230c6] fix(config): move tp_node_idx from ModelConfig to RankInfo
PR #165 removed tp_rank from ModelConfig but
ModelConfig.tp_node_idx still referenced self.tp_rank, raising
AttributeError. Two pre-existing test_cache_config_batch_b.py
cases failed because of this.
Fix: remove ModelConfig.tp_node_idx (replaced with a migration
comment); add RankInfo.tp_node_idx (tp_rank // tp_size_per_node)
to complement RankInfo.tp_rank_per_node (tp_rank %
tp_size_per_node); update the two TP-node-count tests to
construct a RankInfo for tp_node_idx assertions.

[7cd04ad] docs(monitoring): document the new flexkv_py_dist_reuse_* metrics
Added user-facing documentation for the 5 cross-instance reuse
metrics in docs/monitoring/README_{en,zh}.md (kept in sync):
* \xa72.3 'Cross-instance Reuse Metrics' table with type, labels, severity and KNOWN_ISSUE-derived alert thresholds.
* 'Instrumentation status' subtable that flags the two metrics (lease_meta_nullptr_total / about_to_evict_total) whose Python collector hooks are ready but whose C++ master-side trigger has not yet landed, with a callout that '0' on these two does NOT mean 'system healthy'.
* \xa71.1 environment variable table now documents PROMETHEUS_MULTIPROC_DIR (the sample dir used by prometheus_client across sglang TP/PP workers, KVManager subprocess and transfer workers).
* \xa73.5 'Multiprocess Scrape Notes' explaining the MultiProcessCollector aggregation path and the recommended /dev/shm/flexkv_prom tmpfs override.
* \xa73.6 'Recommended PromQL alerts' section with 4 ready-to-paste Prometheus alert rules:
- FlexKVDistReuseLeaseMetaNullptr (critical, any positive)
- FlexKVDistReusePeerReadFailureRate (critical, > 0.1%) - FlexKVDistReusePeerReadP99High (warning, > 500ms) - FlexKVDistReuseEvictPressure (warning, ratio > 10)
* The /metrics curl verification snippet now also greps flexkv_py_dist_reuse_.

Comment thread flexkv/integration/multinode_policy.py Outdated


@dataclass(frozen=True)
class RankTopology:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该用现有的RankInfo

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multinode_policy.py 是一个之前为了单测和测试实现的一个老代码,现在没用了,直接删掉了。

Comment thread flexkv/common/config.py
return self.tp_rank % self.model_config.tp_size_per_node

@property
def tp_node_idx(self) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的这些属性很多冗余的,有现成属性可以用

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已采纳。ModelConfig 新增的几个 per-rank 字段 (pp_rank / tp_node_idx 等) 与 RankInfo 完全重复,控制面阶段又不可能拿到真实 rank。

我改了下:

控制面所需信息全部走 ModelConfig(per-cluster)+ RankInfo(per-rank)的现有划分;

ModelConfig 上仅保留 pp_size tp_size cp_size nnodes is_nsa 这些 per-cluster 字段;

is_nsa 直接来自 model_config.is_nsa,不再用 enable_nsa_prefill_context_parallel 反推

"pp_size": int(model_config.pp_size),
"tp_node_idx": _tp_node_idx,
"tp_node_count": int(getattr(model_config, "tp_node_count", 1)),
"is_nsa": bool(_resolve_is_nsa(model_config)),
Copy link
Copy Markdown
Contributor

@zhjc1124 zhjc1124 May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SharingDomainKey.from_model_config的操作每个实例下KVManager做的,这里应该属于控制面,拿不到真实的rank信息,所有rank都是0,感觉还是应该隔离开
KVManager不会持有rank信息,构造SharingDomainKey没意义,可以把pp_size/tp_size/cp_size等融入进model_id,然后再在enumerate_peers遍历各个rank来生成SharingDomainKey即可

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已采纳。
控制面只构造 self-SD:KVManager 在控制面阶段不再 采用rank。SharingDomainKey.from_model_config(model_config, rank_info=...) 在数据面(每个 worker 持有真实 RankInfo)才被调用,控制面如果没有 rank_info 就走 default() fallback。

跨 SD 拓扑通过 enumerate_peers() 派生:Master 构造完自身的 self-SD 后,调用 self_sd.enumerate_peers() 笛卡尔展开 pp_node_count × tp_node_count 个 SD,不再依赖 per-rank 信息。这正是你说的"在 enumerate_peers 遍历各 rank 生成"的语义。

正如上面说的,pp_size/tp_size/cp_size 已融入 model_id:derive_model_id(...) 会把模型架构 + dtype + page_size + 并行度等写入指纹,不同并行度自然落到不同 SD 命名空间,避免误配对。

控制面端只需要 total_sd_count = pp_node_count × tp_node_count,无需任何 rank。

Comment thread flexkv/integration/config.py Outdated
# extra indexer K cache buffer. Sourced from sglang's
# ``enable_nsa_prefill_context_parallel`` server arg, but in dist_reuse
# context the flag represents the *layout*, not whether CP is on.
is_nsa = getattr(server_args, 'enable_nsa_prefill_context_parallel', False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不应该用enable_nsa_prefill_context_parallel判断是否是nsa,nsa也可以不开cp
不过这里is_nsa看上去没有实际使用,应该删除

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已采纳。enable_nsa_prefill_context_parallel 是 CP 的开关,跟 NSA 模型架构正交,用它推 is_nsa 是错的。
is_nsa 直接来自 model_config.is_nsa(模型架构属性),与 enable_nsa_prefill_context_parallel 解耦;

flexkv/integration/config.py 中以前那条根据 CP 开关推 NSA 的旁路已经删除;

is_nsa 实际上是 SharingDomainKey 的一个字段,文档 §3.2 里明确"NSA 与非 NSA 的 block 物理 layout 不同必须隔离",所以这个字段是有用的,不能删,只是来源要纠正。

# 备注:
# * 如果该 instance 只用 CP 跨节点(CP 跨机、TP/PP 不跨机),脚本会自动
# 把 node_rank>0 的机器放到 CP_PEER_REGISTRATION_ONLY 路径上(不启
# TransferManagerOnRemote,sglang connector 侧按 multinode_policy
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没必要改现有TransferManagerOnRemote创建逻辑吧,写的这个逻辑兼容性太差
TransferManagerOnRemote应该是per-node的,自动兼容

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已采纳。TransferManagerOnRemote 本身就是 per-node 的进程,每个节点起 1 份即可,不需要按 SD 数量额外拉起。

Comment thread flexkv/integration/multinode_policy.py Outdated
no, rank 0 box False False NO_REMOTE
no, rank 0 box any any MASTER
no, off-master box is_multinode_tp any SD_REMOTE_FULL
no, off-master box False is_multinode_cp CP_PEER_REGISTRATION_ONLY
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断is_multinode_tp is_multinode_cp这些应该是没必要的,这里有点问题

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已采纳,整个 multinode_policy.py 已经删除。

@feiqiangs feiqiangs force-pushed the kvcache_reuse_debug branch from 6cdc3d8 to b22a969 Compare May 20, 2026 12:07
…x match

Squashes 5 commits (20a65d5 + d97db4d + 72632ec + b9230c6 + 7cd04ad)
into a single landed feature.  This is the full dist_reuse stack on
top of PR #165 (RankInfo refactor), validated end-to-end on a 2-machine
GPU setup (gpu-146.56.224.46 master / gpu-129.211.162.213 peer):

  S1 (single-node TP=1)        cached_ratio 99.65%   PASS
  S2 (single-node TP=2)        cached_ratio 99.65%   PASS
  S3 (cross-node TP=1)         master cold->warm 99.63%, peer
                               crosshit 99.63% storage=272 backend=
                               FlexKVConnector  (PEERH2H @ 6.22 GB/s
                               via mooncake/RDMA)   PASS x3
  357/357 unit tests on both nodes                  PASS

== Original commits (in chronological order) ==

[20a65d5] feat(dist_reuse): KV cache sharing across TP/PP/CP + single-
node radix match
  Initial dist_reuse stack: master coordinator, sharing-domain key,
  aggregate radix, redis-meta namespace, multi-node policy, P2P
  transfer types (PEERH2H/H2PEERH/PEERSSD2H/H2PEERSSD), failure
  detector, four S{1..4} sglang+FlexKV e2e benchmark scripts.

[d97db4d] fix(dist_reuse): unblock cross-instance KV cache sharing on
s3_cross_node_tp1
  Three runtime bugs blocked the s3 (master prime / peer crosshit) flow:
    1) GPUCPUTransferWorker._transfer_impl had positional-arg drift
       on the transfer_kv_blocks pybind: C++ added 'start_layer_id'
       between 'chunk_size_in_bytes' and 'num_layers' (transfer.cu
       2025-07-10), which silently mapped is_h2d=False onto
       transfer_num_cta and launched D2H kernels with gridDim(0)
       -> cudaErrorInvalidConfiguration on every D2H.
       Fix: bind every value to the C++ pybind name with kwargs and
       add 'start_layer_id=0' explicitly.
    2) GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops carried a
       dead 'layer_num' parameter which the only caller in
       _get_impl_local passed undefined -> NameError on first
       cross-instance reuse hit.  Fix: drop the dead parameter and 6
       call sites in tests/test_d3_filter_and_get_clones.py.
    3) merge_to_batch_graph raised NotImplementedError on PEERH2H /
       H2PEERH / PEERSSD2H / H2PEERSSD as soon as a real cross-instance
       hit produced a P2P op.  Fix: whitelist the four types as P2P
       passthrough (preserves per-block src_block_node_ids and
       per-op target_node_ids from D-3 multi-SD broadcast clones),
       wire dependencies on merged_h2d_op (GET) / merged_d2h_op (PUT).

[72632ec] fix(memory_handle): propagate _import_tensor_handle exceptions
  Previously _import_tensor_handle logged the error and returned
  torch.empty(0) on import failure, which silently dropped the wrapper
  into a 0-element tensor and surfaced as an unrelated IndexError later
  in worker.py::_get_layer_ptrs (layer_blocks[lay_id][0] out of range).
  Now always re-raise, keeping the original traceback so cross-node
  CUDA IPC handle device-id mismatches surface at their source.
  Consistent with _import_cuda_ipc_handle which never swallowed.

[b9230c6] fix(config): move tp_node_idx from ModelConfig to RankInfo
  PR #165 removed tp_rank from ModelConfig but
  ModelConfig.tp_node_idx still referenced self.tp_rank, raising
  AttributeError.  Two pre-existing test_cache_config_batch_b.py
  cases failed because of this.
  Fix: remove ModelConfig.tp_node_idx (replaced with a migration
  comment); add RankInfo.tp_node_idx (tp_rank // tp_size_per_node)
  to complement RankInfo.tp_rank_per_node (tp_rank %
  tp_size_per_node); update the two TP-node-count tests to
  construct a RankInfo for tp_node_idx assertions.

[7cd04ad] docs(monitoring): document the new flexkv_py_dist_reuse_*
metrics
  Added user-facing documentation for the 5 cross-instance reuse
  metrics in docs/monitoring/README_{en,zh}.md (kept in sync):
    * \xa72.3 'Cross-instance Reuse Metrics' table with type, labels,
      severity and KNOWN_ISSUE-derived alert thresholds.
    * 'Instrumentation status' subtable that flags the two metrics
      (lease_meta_nullptr_total / about_to_evict_total) whose Python
      collector hooks are ready but whose C++ master-side trigger
      has not yet landed, with a callout that '0' on these two does
      NOT mean 'system healthy'.
    * \xa71.1 environment variable table now documents
      PROMETHEUS_MULTIPROC_DIR (the sample dir used by
      prometheus_client across sglang TP/PP workers, KVManager
      subprocess and transfer workers).
    * \xa73.5 'Multiprocess Scrape Notes' explaining the
      MultiProcessCollector aggregation path and the recommended
      /dev/shm/flexkv_prom tmpfs override.
    * \xa73.6 'Recommended PromQL alerts' section with 4 ready-to-paste
      Prometheus alert rules:
        - FlexKVDistReuseLeaseMetaNullptr (critical, any positive)
        - FlexKVDistReusePeerReadFailureRate (critical, > 0.1%)
        - FlexKVDistReusePeerReadP99High (warning, > 500ms)
        - FlexKVDistReuseEvictPressure (warning, ratio > 10)
    * The /metrics curl verification snippet now also greps
      flexkv_py_dist_reuse_.
@feiqiangs feiqiangs force-pushed the kvcache_reuse_debug branch 3 times, most recently from 3822898 to bca6ccc Compare May 20, 2026 14:00
…cleanup

Squashes four prior commits (bca6ccc, c8a5a2a, e8da1b8, 0974b1c) addressing
reviewer comments on the dist_reuse feature commit (22bc183), plus the
follow-up dead-code/docs/test cleanups discovered during review.

F1. is_nsa source: read directly from model_config.is_nsa instead of
    reverse-deriving from enable_nsa_prefill_context_parallel (the latter
    is a CP toggle, orthogonal to NSA architecture).

F2. Control-plane / rank-plane separation: SharingDomainKey.from_model_config
    now takes an optional rank_info=RankInfo argument and reads
    pp_rank / tp_node_idx from it; the control plane (KVManager) only
    constructs self-SD via default() and enumerates peers via
    enumerate_peers() — no fake rank fabrication.

F3. RankTopology factory dropped; reuse the existing RankInfo end-to-end.
    Integration adapters (vLLM v1 / TRT-LLM / SGLang) plumb the real
    rank_info through.

F4. Shell / TransferManagerOnRemote decoupling: revert
    start_dist_reuse_serving.sh changes; TransferManagerOnRemote stays
    per-node and is tagged via set_target_sd_key on each handle.

F5. Delete unused flexkv.integration.multinode_policy module and its
    is_multinode_tp / is_multinode_cp / is_multinode_pp helpers; CP never
    participates in sd_key (attention all-gather makes per-cp_rank pools
    bit-wise identical), and TP-cross-node is encoded in
    SharingDomainKey.tp_node_count directly. Verified no external
    references in the SGLang FlexKVConnector codebase.

Dead-code sweep across dist_reuse layer (formerly c8a5a2a):
    - drop unused helpers / accessors that no production path references
    - remove the matching dead unit tests (test_iter_dunder, ...)

Docs (formerly e8da1b8):
    - drop stale phase tag in coordination_protocol module docstring.

Tests: full 18-suite dist_reuse subset (332/332) passes on both GPU
executors (gpu-146.56.224.46 and gpu-129.211.162.213).
@feiqiangs feiqiangs force-pushed the kvcache_reuse_debug branch from 0974b1c to 909e658 Compare May 20, 2026 15:05
Comment thread flexkv/server/server.py
gpu_register_port,
redis_meta=self.redis_meta_client,
rank_info=rank_info,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不对,KVServer是global的,拿不到rank_info信息

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVServer全局只有一个

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants