From 22bc183de9d79694dcdfd31c6d8d3ac854161ea2 Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Tue, 19 May 2026 15:19:43 +0800 Subject: [PATCH 1/2] feat(dist_reuse): KV cache sharing across TP/PP/CP + single-node radix match Squashes 5 commits (20a65d5 + d97db4d + 72632ec + b9230c6 + 7cd04ad) into a single landed feature. This is the full dist_reuse stack on top of PR #165 (RankInfo refactor), validated end-to-end on a 2-machine GPU setup (gpu-146.56.224.46 master / gpu-129.211.162.213 peer): S1 (single-node TP=1) cached_ratio 99.65% PASS S2 (single-node TP=2) cached_ratio 99.65% PASS S3 (cross-node TP=1) master cold->warm 99.63%, peer crosshit 99.63% storage=272 backend= FlexKVConnector (PEERH2H @ 6.22 GB/s via mooncake/RDMA) PASS x3 357/357 unit tests on both nodes PASS == Original commits (in chronological order) == [20a65d5] feat(dist_reuse): KV cache sharing across TP/PP/CP + single- node radix match Initial dist_reuse stack: master coordinator, sharing-domain key, aggregate radix, redis-meta namespace, multi-node policy, P2P transfer types (PEERH2H/H2PEERH/PEERSSD2H/H2PEERSSD), failure detector, four S{1..4} sglang+FlexKV e2e benchmark scripts. [d97db4d] fix(dist_reuse): unblock cross-instance KV cache sharing on s3_cross_node_tp1 Three runtime bugs blocked the s3 (master prime / peer crosshit) flow: 1) GPUCPUTransferWorker._transfer_impl had positional-arg drift on the transfer_kv_blocks pybind: C++ added 'start_layer_id' between 'chunk_size_in_bytes' and 'num_layers' (transfer.cu 2025-07-10), which silently mapped is_h2d=False onto transfer_num_cta and launched D2H kernels with gridDim(0) -> cudaErrorInvalidConfiguration on every D2H. Fix: bind every value to the C++ pybind name with kwargs and add 'start_layer_id=0' explicitly. 2) GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops carried a dead 'layer_num' parameter which the only caller in _get_impl_local passed undefined -> NameError on first cross-instance reuse hit. Fix: drop the dead parameter and 6 call sites in tests/test_d3_filter_and_get_clones.py. 3) merge_to_batch_graph raised NotImplementedError on PEERH2H / H2PEERH / PEERSSD2H / H2PEERSSD as soon as a real cross-instance hit produced a P2P op. Fix: whitelist the four types as P2P passthrough (preserves per-block src_block_node_ids and per-op target_node_ids from D-3 multi-SD broadcast clones), wire dependencies on merged_h2d_op (GET) / merged_d2h_op (PUT). [72632ec] fix(memory_handle): propagate _import_tensor_handle exceptions Previously _import_tensor_handle logged the error and returned torch.empty(0) on import failure, which silently dropped the wrapper into a 0-element tensor and surfaced as an unrelated IndexError later in worker.py::_get_layer_ptrs (layer_blocks[lay_id][0] out of range). Now always re-raise, keeping the original traceback so cross-node CUDA IPC handle device-id mismatches surface at their source. Consistent with _import_cuda_ipc_handle which never swallowed. [b9230c6] fix(config): move tp_node_idx from ModelConfig to RankInfo PR #165 removed tp_rank from ModelConfig but ModelConfig.tp_node_idx still referenced self.tp_rank, raising AttributeError. Two pre-existing test_cache_config_batch_b.py cases failed because of this. Fix: remove ModelConfig.tp_node_idx (replaced with a migration comment); add RankInfo.tp_node_idx (tp_rank // tp_size_per_node) to complement RankInfo.tp_rank_per_node (tp_rank % tp_size_per_node); update the two TP-node-count tests to construct a RankInfo for tp_node_idx assertions. [7cd04ad] docs(monitoring): document the new flexkv_py_dist_reuse_* metrics Added user-facing documentation for the 5 cross-instance reuse metrics in docs/monitoring/README_{en,zh}.md (kept in sync): * \xa72.3 'Cross-instance Reuse Metrics' table with type, labels, severity and KNOWN_ISSUE-derived alert thresholds. * 'Instrumentation status' subtable that flags the two metrics (lease_meta_nullptr_total / about_to_evict_total) whose Python collector hooks are ready but whose C++ master-side trigger has not yet landed, with a callout that '0' on these two does NOT mean 'system healthy'. * \xa71.1 environment variable table now documents PROMETHEUS_MULTIPROC_DIR (the sample dir used by prometheus_client across sglang TP/PP workers, KVManager subprocess and transfer workers). * \xa73.5 'Multiprocess Scrape Notes' explaining the MultiProcessCollector aggregation path and the recommended /dev/shm/flexkv_prom tmpfs override. * \xa73.6 'Recommended PromQL alerts' section with 4 ready-to-paste Prometheus alert rules: - FlexKVDistReuseLeaseMetaNullptr (critical, any positive) - FlexKVDistReusePeerReadFailureRate (critical, > 0.1%) - FlexKVDistReusePeerReadP99High (warning, > 500ms) - FlexKVDistReuseEvictPressure (warning, ratio > 10) * The /metrics curl verification snippet now also greps flexkv_py_dist_reuse_. --- .../dist_benchmark/TWONODE_DIRECT_README.md | 125 +++ .../dist_benchmark/benchmark_dist_direct.py | 8 + .../benchmark_dist_reuse_smoke.py | 370 +++++++ csrc/bindings.cpp | 63 +- csrc/dist/distributed_radix_tree.cpp | 49 +- csrc/dist/redis_meta_channel.cpp | 155 ++- csrc/dist/redis_meta_channel.h | 46 +- csrc/radix_tree.cpp | 57 +- csrc/radix_tree.h | 24 +- docs/dist_reuse/METRICS_dist_reuse.md | 270 +++++ docs/dist_reuse/OPS_QUICK_REF_p2p.md | 146 +++ ...euse_with_cp_pp_multinode_tp_simplified.md | 390 ++++++++ docs/dist_reuse/redis_schema.md | 391 ++++++++ docs/monitoring/README_en.md | 108 ++ docs/monitoring/README_zh.md | 104 ++ flexkv/cache/cache_engine.py | 937 +++++++++++++++++- flexkv/cache/hie_cache_engine.py | 45 +- flexkv/cache/radixtree.py | 43 +- flexkv/cache/redis_meta.py | 833 ++++++++++------ flexkv/common/config.py | 318 ++++++ flexkv/common/dist_reuse/__init__.py | 98 ++ flexkv/common/dist_reuse/aggregate_radix.py | 408 ++++++++ .../dist_reuse/coordination_protocol.py | 184 ++++ flexkv/common/dist_reuse/failure_detector.py | 388 ++++++++ .../common/dist_reuse/master_coordinator.py | 621 ++++++++++++ flexkv/common/dist_reuse/remote_init.py | 250 +++++ flexkv/common/dist_reuse/sharing_domain.py | 453 +++++++++ .../dist_reuse/sharing_domain_namespace.py | 185 ++++ flexkv/common/memory_handle.py | 9 +- flexkv/common/transfer.py | 77 +- flexkv/common/type.py | 6 +- flexkv/integration/config.py | 10 +- flexkv/integration/multinode_policy.py | 185 ++++ flexkv/kvmanager.py | 1 + flexkv/kvtask.py | 273 ++++- flexkv/metrics/collector.py | 231 ++++- flexkv/metrics/server.py | 35 +- flexkv/server/server.py | 1 + flexkv/transfer/worker.py | 141 ++- flexkv/transfer_manager.py | 475 ++++++++- .../multi-nodes/start_dist_reuse_serving.sh | 253 +++++ tests/_dist_reuse_fakes.py | 285 ++++++ .../test_cross_instance_reuse_e2e.py | 312 ++++++ tests/test_aggregate_radix.py | 306 ++++++ tests/test_cache_config_batch_b.py | 194 ++++ tests/test_cache_engine.py | 66 ++ tests/test_cache_engine_dist_reuse_gate.py | 490 +++++++++ tests/test_cext_evict_refcount_guard.py | 206 ++++ tests/test_coord_protocol.py | 128 +++ tests/test_d3_filter_and_get_clones.py | 622 ++++++++++++ tests/test_dist_reuse_launcher.py | 163 +++ tests/test_evict_refcount_guard.py | 162 +++ tests/test_failure_detector.py | 327 ++++++ tests/test_flexkv_redis_db.py | 247 +++++ tests/test_master_coordinator.py | 422 ++++++++ tests/test_metrics_dist_reuse.py | 153 +++ tests/test_multinode_flags.py | 185 ++++ tests/test_multinode_role_policy.py | 186 ++++ tests/test_phase2_combinations.py | 200 ++++ tests/test_redis_db_integration.py | 202 ++++ tests/test_redis_integration.py | 533 ++++++++++ tests/test_redis_meta_namespace.py | 438 ++++++++ tests/test_redis_metachannel_sd_prefix.py | 137 +++ tests/test_sd_enumerate_max.py | 249 +++++ tests/test_sharing_domain_key.py | 416 ++++++++ tests/test_sharing_domain_namespace.py | 136 +++ tests/test_single_node_match.py | 335 +++++++ 67 files changed, 15393 insertions(+), 473 deletions(-) create mode 100644 benchmarks/dist_benchmark/TWONODE_DIRECT_README.md create mode 100644 benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py create mode 100644 docs/dist_reuse/METRICS_dist_reuse.md create mode 100644 docs/dist_reuse/OPS_QUICK_REF_p2p.md create mode 100644 docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp_simplified.md create mode 100644 docs/dist_reuse/redis_schema.md create mode 100644 flexkv/common/dist_reuse/__init__.py create mode 100644 flexkv/common/dist_reuse/aggregate_radix.py create mode 100644 flexkv/common/dist_reuse/coordination_protocol.py create mode 100644 flexkv/common/dist_reuse/failure_detector.py create mode 100644 flexkv/common/dist_reuse/master_coordinator.py create mode 100644 flexkv/common/dist_reuse/remote_init.py create mode 100644 flexkv/common/dist_reuse/sharing_domain.py create mode 100644 flexkv/common/dist_reuse/sharing_domain_namespace.py create mode 100644 flexkv/integration/multinode_policy.py create mode 100644 scripts/multi-nodes/start_dist_reuse_serving.sh create mode 100644 tests/_dist_reuse_fakes.py create mode 100644 tests/multinode/test_cross_instance_reuse_e2e.py create mode 100644 tests/test_aggregate_radix.py create mode 100644 tests/test_cache_config_batch_b.py create mode 100644 tests/test_cache_engine_dist_reuse_gate.py create mode 100644 tests/test_cext_evict_refcount_guard.py create mode 100644 tests/test_coord_protocol.py create mode 100644 tests/test_d3_filter_and_get_clones.py create mode 100644 tests/test_dist_reuse_launcher.py create mode 100644 tests/test_evict_refcount_guard.py create mode 100644 tests/test_failure_detector.py create mode 100644 tests/test_flexkv_redis_db.py create mode 100644 tests/test_master_coordinator.py create mode 100644 tests/test_metrics_dist_reuse.py create mode 100644 tests/test_multinode_flags.py create mode 100644 tests/test_multinode_role_policy.py create mode 100644 tests/test_phase2_combinations.py create mode 100644 tests/test_redis_db_integration.py create mode 100644 tests/test_redis_integration.py create mode 100644 tests/test_redis_meta_namespace.py create mode 100644 tests/test_redis_metachannel_sd_prefix.py create mode 100644 tests/test_sd_enumerate_max.py create mode 100644 tests/test_sharing_domain_key.py create mode 100644 tests/test_sharing_domain_namespace.py create mode 100644 tests/test_single_node_match.py diff --git a/benchmarks/dist_benchmark/TWONODE_DIRECT_README.md b/benchmarks/dist_benchmark/TWONODE_DIRECT_README.md new file mode 100644 index 0000000000..c8438ecef7 --- /dev/null +++ b/benchmarks/dist_benchmark/TWONODE_DIRECT_README.md @@ -0,0 +1,125 @@ +# FlexKV Two-Node Direct-Mode e2e — How to Run + +This harness validates **distributed KV cache sharing** across two physical +hosts (146 ↔ 129) **without touching sglang** — it drives FlexKV directly +via ``KVManager`` + ``KVTPClient`` inside ``benchmark_dist_direct.py``. + +## Why direct mode (not ``benchmark_dist_kvcache.py``) + +* No ``KVServer`` subprocess → simpler lifecycle, no residual IPC socket at + ``/tmp/flexkv_server`` to clean up. +* Uses the same ``get_match`` / ``put_async`` main path that §2.1 wires + through to ``_sharing_domain_gate_get`` and ``_notify_sd_ready_on_put``, + so any breakage here is a real breakage. +* GPU footprint is tiny (~40 MB / GPU) — safe to run alongside a live + sglang serving that already owns most of the memory. + +## Conflict isolation with the running sglang process + +| Resource | sglang (GLM-5-FP8) | this benchmark | +|---|---|---| +| Mooncake engine TCP port | **5555** (sglang Transfer) | **5556 on 146** / **5557 on 129** | +| Redis logical DB | 0 (mooncake keys ``mooncake/*``) | 0 (mooncake) + **DB 1 for flexkv keys** | +| Redis key prefixes (DB 1) | — | ``sd:*``, ``instance:*``, ``node:*`` … | +| GPU VRAM | most of it | one GPU, ~40 MB | +| IPC sockets | ``/tmp/flexkv_server`` | **none** (direct mode) | + +So the only shared resources are the physical Redis server (different DB) +and the RDMA NICs (different QP). Neither overlaps in state. + +## Quick checklist before launching + +1. **Redis reachable + password OK**: + ```bash + redis-cli -h 10.206.0.9 -p 6379 -a 123456 PING # → PONG + ``` +2. **DB 1 is clean (first time only)**: + ```bash + redis-cli -h 10.206.0.9 -p 6379 -a 123456 -n 1 DBSIZE # → (integer) 0 + # If non-zero and you know it's our leftover state, wipe it: + # redis-cli -h 10.206.0.9 -p 6379 -a 123456 -n 1 FLUSHDB + # Do NOT touch DB 0 — mooncake + sglang live there. + ``` +3. **Mooncake ports 5556 / 5557 free** on the respective hosts: + ```bash + ss -ltn | grep -E ':5556|:5557' # should be empty + ``` +4. **FlexKV built with FLEXKV_ENABLE_P2P=1** on both hosts: + ```bash + python3 -c 'from flexkv.cache.redis_meta import dist_available; print(dist_available())' + # → True + ``` + +## Run + +### Step A — on host **146** (10.206.0.9), start **PUT-only** + +```bash +cd /data1/phaedonsun/flexkv/FlexKV + +export PYTHONPATH=/data1/phaedonsun/flexkv/FlexKV +export LD_LIBRARY_PATH=/data1/phaedonsun/flexkv/FlexKV/build/lib +export CUDA_VISIBLE_DEVICES=0 # pick any single free GPU + +python3 benchmarks/dist_benchmark/benchmark_dist_direct.py \ + --config benchmarks/dist_benchmark/twonode_direct_146.yml \ + --mode put-only \ + --batch-size 1 --sequence-length 256 \ + --seed 42 \ + --rebuild-interval-ms 20 +``` + +Expected final line before the process idles: +``` +Data published to Redis. Press Enter to shutdown (keep running for other nodes to GET)... +``` + +### Step B — on host **129** (10.206.0.13), start **GET-only with same seed** + +```bash +cd /data1/phaedonsun/flexkv/FlexKV + +export PYTHONPATH=/data1/phaedonsun/flexkv/FlexKV +export LD_LIBRARY_PATH=/data1/phaedonsun/flexkv/FlexKV/build/lib +export CUDA_VISIBLE_DEVICES=0 # different physical GPU, same index is fine + +python3 benchmarks/dist_benchmark/benchmark_dist_direct.py \ + --config benchmarks/dist_benchmark/twonode_direct_129.yml \ + --mode get-only \ + --batch-size 1 --sequence-length 256 \ + --seed 42 \ + --rebuild-interval-ms 20 +``` + +## Success criteria + +In the 129 log look for: + +``` +--- GET Phase --- + GET: 256/256 tokens, data_size: 0.000 GB, cache_ratio: 100.00% ... +``` + +A non-zero ``cache_ratio`` means the 129 instance: +1. Found the 146 instance via the shared Redis (``instance:*`` discovery) +2. Resolved the 146 peer SD's ``node_id`` from the aggregate radix +3. Issued a Mooncake RDMA read against 146's mooncake engine @ 5556 +4. Received KV data that matches byte-for-byte what 146 PUT + +If ``cache_ratio: 0.00%``: +* Check the 129 log for ``[DistReuse]`` lines — the §2.1 gate ruled it out. +* Check ``KEYS sd:*`` in Redis DB 1 — the 146 side should have published + ``sd:<…>:block::`` keys. +* Check Mooncake connectivity by running the ``transfer_engine_bench`` + binary between 146:5556 and 129:5557. + +## Teardown + +* 146: press Ctrl-C in the PUT-only terminal (the Ctrl-C handler calls + ``kvmanager.shutdown()`` which releases Mooncake + Redis state). +* 129: the GET-only run exits on its own; its ``atexit`` hook tears + KVManager down. +* Optionally wipe Redis DB 1 between runs: + ```bash + redis-cli -h 10.206.0.9 -p 6379 -a 123456 -n 1 FLUSHDB + ``` diff --git a/benchmarks/dist_benchmark/benchmark_dist_direct.py b/benchmarks/dist_benchmark/benchmark_dist_direct.py index a5e12048cb..e146db6852 100644 --- a/benchmarks/dist_benchmark/benchmark_dist_direct.py +++ b/benchmarks/dist_benchmark/benchmark_dist_direct.py @@ -132,6 +132,14 @@ def load_dist_direct_config(config_path: str): user_config.local_ip = config["local_ip"] if "redis_password" in config: user_config.redis_password = config["redis_password"] + # Optional: pick a non-default Redis logical DB so FlexKV keys don't collide + # with other tenants (e.g. Mooncake meta, or another running FlexKV / + # sglang instance on the same physical Redis). The flexkv keys (``sd:*``, + # ``instance:*``, ``node:*`` …) all live in the selected DB; the mooncake + # backend continues to use whatever DB its ``metadata_server`` URL + # implies (default 0). + if "flexkv_redis_db" in config: + user_config.flexkv_redis_db = int(config["flexkv_redis_db"]) # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): diff --git a/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py b/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py new file mode 100644 index 0000000000..082ad7d4a0 --- /dev/null +++ b/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py @@ -0,0 +1,370 @@ +"""Single-machine, two-instance dist_reuse smoke benchmark. + +Goal: exercise the dist_reuse control-plane end-to-end on ONE machine +(no RDMA, no Mooncake) so we can verify the §2.1 main-path wiring +lands correctly *before* we move to the two-GPU-box harness. + +Why this is valuable: + * Confirms two FlexKV instances sharing one Redis can register + their ``flexkv:instance::sd_nodes`` maps and discover each + other via ``RedisSessionClient`` heart-beats. + * Confirms ``MasterCoordinator`` spins up, its + ``AggregateRadixTree`` accepts self-SD acks from the PUT hook, + and ``_sharing_domain_gate_get`` short-circuits correctly for + the single-SD degenerate case. + * Confirms the new ``set_evict_guard`` refcount predicate keeps + in-flight coord-GET blocks pinned. + * Confirms the PUT transfer-callback calls ``_notify_sd_ready_on_put`` + and the self-SD ack lands in ``aggregate_radix`` under the right + prefix hash. + +What this script DOES NOT cover (→ requires 2-machine harness): + * Real Mooncake P2P read (cross-instance data plane) + * PP/TP cross-node SD barrier (``total_sd_count > 1``) + * Cross-instance peer-lost failure detection under network + partition + +Usage:: + + # Prereq: a reachable redis with password 123456 (env overridable) + export FLEXKV_SMOKE_REDIS_HOST=10.206.0.9 + export FLEXKV_SMOKE_REDIS_PORT=6379 + export FLEXKV_SMOKE_REDIS_PASSWORD=123456 + + python benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py \ + --num-instances 2 --num-blocks 32 +""" +from __future__ import annotations + +import argparse +import os +import sys +import time +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional + +# Allow running as a script from anywhere in the repo. +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from flexkv.common.dist_reuse.master_coordinator import MasterCoordinator # noqa: E402 +from flexkv.common.dist_reuse.sharing_domain import SharingDomainKey # noqa: E402 + + +# --------------------------------------------------------------------------- +# Redis client helpers +# --------------------------------------------------------------------------- +def _build_redis_client(host: str, port: int, password: Optional[str], db: int = 0): + import redis + return redis.Redis( + host=host, port=port, password=password, db=db, + socket_connect_timeout=3.0, decode_responses=True, + ) + + +def _ensure_redis_reachable(host: str, port: int, password: Optional[str]) -> None: + client = _build_redis_client(host, port, password) + if not client.ping(): + raise RuntimeError(f"Redis {host}:{port} unreachable") + + +# --------------------------------------------------------------------------- +# Minimal RedisMeta shim — we only need ``_client()`` + ``register_instance_sd_nodes`` +# for this smoke test. Importing the real ``flexkv.cache.redis_meta`` requires +# a full FlexKV install (c_ext loaded, etc.), which IS available in the +# flexkv_distreuse container but we want the script to also run against a +# Python-only install. +# --------------------------------------------------------------------------- +class _LightRedisMeta: + def __init__(self, host: str, port: int, password: Optional[str]): + self._client_ = _build_redis_client(host, port, password) + + def _client(self): + return self._client_ + + def register_instance_sd_nodes(self, instance_id: str, sd_to_nid: Dict[str, int]): + """Mirror RedisMeta's method signature — write a hash at the key the + design doc specifies so other instances can discover us.""" + key = f"flexkv:instance:{instance_id}:sd_nodes" + pipe = self._client_.pipeline() + pipe.delete(key) + if sd_to_nid: + pipe.hset(key, mapping={str(k): str(v) for k, v in sd_to_nid.items()}) + pipe.expire(key, 300) # 5-min TTL — smoke test teardown will del anyway + pipe.execute() + + +# --------------------------------------------------------------------------- +# Smoke harness +# --------------------------------------------------------------------------- +def _mk_instance( + *, + instance_id: str, + model_id: str, + self_node_id: int, + redis_host: str, + redis_port: int, + redis_password: Optional[str], + master_zmq_addr: str, + ttl_seconds: int, +) -> MasterCoordinator: + sd = SharingDomainKey( + model_id=model_id, + pp_rank=0, pp_size=1, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + mc = MasterCoordinator( + self_sd=sd, + instance_id=instance_id, + ) + # Single-SD instance — no remote ACKs to expect. + mc.expect_remotes(0) + # Mark "all remotes ready" by explicitly skipping (single-SD case + # has no remotes). This lets register_instance_discoverables run. + # For the smoke harness we don't call register_instance_discoverables + # directly (it requires a real RedisMeta type); we mimic its side + # effect manually: + meta = _LightRedisMeta(redis_host, redis_port, redis_password) + sd_to_nid = {sd.serialize(): self_node_id} + meta.register_instance_sd_nodes(instance_id, sd_to_nid) + + # Hook the session client so we show up in the Layer-1 failure + # detector's view on the OTHER instance. We're skipping the full + # MasterCoordinator.register_instance_discoverables plumbing + # (which needs a real RedisMeta); the script below just manually + # maintains a heartbeat on flexkv:session:. + return mc + + +def _register_heartbeat( + redis_client, + instance_id: str, + epoch: str, + ttl_seconds: int, + master_zmq_addr: str, +) -> None: + key = f"flexkv:session:{instance_id}" + payload = { + "epoch": epoch, + "master_zmq_addr": master_zmq_addr, + "ts": str(int(time.time())), + } + pipe = redis_client.pipeline() + pipe.delete(key) + pipe.hset(key, mapping=payload) + pipe.expire(key, ttl_seconds) + pipe.execute() + + +def _read_heartbeat(redis_client, instance_id: str) -> Optional[Dict[str, str]]: + key = f"flexkv:session:{instance_id}" + val = redis_client.hgetall(key) + return val or None + + +def _cleanup_instance(redis_client, instance_id: str) -> None: + keys = [] + keys.append(f"flexkv:instance:{instance_id}:sd_nodes") + keys.append(f"flexkv:session:{instance_id}") + redis_client.delete(*keys) + + +# --------------------------------------------------------------------------- +# Scenarios +# --------------------------------------------------------------------------- +def scenario_peer_discovery(args) -> Dict[str, Any]: + """Two single-SD instances register their sd_nodes, ensure each + can see the other's entry via Redis (simulates what + DistributedRadixTree.remote_tree_refresh does).""" + print("\n[scenario] peer_discovery") + + redis_client = _build_redis_client(args.redis_host, args.redis_port, args.redis_password) + # Create two instances with distinct node_ids. + inst_a_id = f"smoke-a-{uuid.uuid4().hex[:8]}" + inst_b_id = f"smoke-b-{uuid.uuid4().hex[:8]}" + mc_a = _mk_instance( + instance_id=inst_a_id, model_id=args.model_id, self_node_id=1000, + redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, + master_zmq_addr="127.0.0.1:30001", ttl_seconds=args.session_ttl, + ) + mc_b = _mk_instance( + instance_id=inst_b_id, model_id=args.model_id, self_node_id=2000, + redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, + master_zmq_addr="127.0.0.1:30002", ttl_seconds=args.session_ttl, + ) + try: + # Heartbeat both. + _register_heartbeat(redis_client, inst_a_id, mc_a.session_epoch, + args.session_ttl, "127.0.0.1:30001") + _register_heartbeat(redis_client, inst_b_id, mc_b.session_epoch, + args.session_ttl, "127.0.0.1:30002") + + # A reads B's sd_nodes + session. + b_sd_nodes = redis_client.hgetall(f"flexkv:instance:{inst_b_id}:sd_nodes") + b_session = _read_heartbeat(redis_client, inst_b_id) + assert len(b_sd_nodes) == 1, f"A cannot see B's sd_nodes: {b_sd_nodes}" + assert b_session is not None, "A cannot see B's session heartbeat" + print(f" [A] sees [B] sd_nodes = {b_sd_nodes}") + print(f" [A] sees [B] session = {b_session}") + + a_sd_nodes = redis_client.hgetall(f"flexkv:instance:{inst_a_id}:sd_nodes") + a_session = _read_heartbeat(redis_client, inst_a_id) + assert len(a_sd_nodes) == 1 + assert a_session is not None + print(f" [B] sees [A] sd_nodes = {a_sd_nodes}") + print(f" [B] sees [A] session = {a_session}") + + # Cross-verify node_ids don't collide. + a_nid = int(list(a_sd_nodes.values())[0]) + b_nid = int(list(b_sd_nodes.values())[0]) + assert a_nid != b_nid, "two instances got the same node_id — collision!" + print(f" distinct node_ids: A={a_nid} B={b_nid}") + return {"status": "ok", "a_nid": a_nid, "b_nid": b_nid} + finally: + _cleanup_instance(redis_client, inst_a_id) + _cleanup_instance(redis_client, inst_b_id) + + +def scenario_aggregate_radix_put_hook(args) -> Dict[str, Any]: + """Simulate PUTs on instance A: drive the AggregateRadixTree via + ``_notify_sd_ready_on_put`` semantics and assert that match_fully_ready + returns the entry (single-SD degenerate case).""" + print("\n[scenario] aggregate_radix_put_hook") + + mc = _mk_instance( + instance_id=f"smoke-agg-{uuid.uuid4().hex[:8]}", model_id=args.model_id, + self_node_id=3000, + redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, + master_zmq_addr="127.0.0.1:30003", ttl_seconds=args.session_ttl, + ) + # Simulate N put-then-get cycles. + results = {"ok": 0, "miss": 0} + for i in range(args.num_blocks): + prefix_hash = hash((args.model_id, i)) & 0xFFFFFFFFFFFFFFFF + block_ids = [10 + 3 * i, 11 + 3 * i, 12 + 3 * i] + # PUT hook → self-SD ack. + mc.mark_sd_ready( + prefix_hash=prefix_hash, + sd_key_str=mc.self_sd.serialize(), + block_ids=block_ids, + ) + # GET hook → fully_ready check (total_sd=1 so a single ack is enough). + entry = mc.match_fully_ready(prefix_hash) + if entry is not None: + results["ok"] += 1 + # Refcount protection: the aggregate pins blocks when acquired. + mc.pin_blocks_for_coord_get(block_ids) + for b in block_ids: + assert not mc.is_evictable(b), \ + f"block {b} evictable while pinned — refcount guard broken" + mc.unpin_blocks_for_coord_get(block_ids) + for b in block_ids: + assert mc.is_evictable(b), \ + f"block {b} NOT evictable after unpin — refcount stuck" + else: + results["miss"] += 1 + + print(f" put→get cycles: ok={results['ok']} miss={results['miss']}") + assert results["miss"] == 0, "single-SD instance must have every prefix fully_ready" + return results + + +def scenario_cross_instance_reuse_readiness(args) -> Dict[str, Any]: + """Both instances PUT overlapping block hashes. Verify the aggregate + radix on instance A still reports fully_ready for its own prefix even + when instance B is also active (isolation by instance_id).""" + print("\n[scenario] cross_instance_reuse_readiness") + + inst_a_id = f"smoke-a-{uuid.uuid4().hex[:8]}" + inst_b_id = f"smoke-b-{uuid.uuid4().hex[:8]}" + mc_a = _mk_instance( + instance_id=inst_a_id, model_id=args.model_id, self_node_id=4001, + redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, + master_zmq_addr="127.0.0.1:30010", ttl_seconds=args.session_ttl, + ) + mc_b = _mk_instance( + instance_id=inst_b_id, model_id=args.model_id, self_node_id=4002, + redis_host=args.redis_host, redis_port=args.redis_port, + redis_password=args.redis_password, + master_zmq_addr="127.0.0.1:30011", ttl_seconds=args.session_ttl, + ) + try: + prefix_hash = 0xC0FFEE + block_ids = [1, 2, 3] + # Both instances mark the prefix ready on their own SD. + mc_a.mark_sd_ready( + prefix_hash=prefix_hash, + sd_key_str=mc_a.self_sd.serialize(), + block_ids=block_ids, + ) + mc_b.mark_sd_ready( + prefix_hash=prefix_hash, + sd_key_str=mc_b.self_sd.serialize(), + block_ids=block_ids, + ) + # Each aggregate is instance-local, so both should see fully_ready. + assert mc_a.match_fully_ready(prefix_hash) is not None + assert mc_b.match_fully_ready(prefix_hash) is not None + print(" both instances independently report fully_ready — ok") + + # Simulate peer_lost on A (e.g., B's session TTL expired). + n_invalidated = mc_a.invalidate_by_peer_instance(inst_b_id) + # mc_a's aggregate was only acked by mc_a itself, never by + # ``contributing_peer=inst_b_id``, so invalidate_by_peer_instance + # should be a no-op. + assert n_invalidated == 0 + assert mc_a.match_fully_ready(prefix_hash) is not None, \ + "A's prefix wrongly invalidated by B's peer-lost signal" + print(" A's prefix survives B's peer-lost signal — isolation ok") + + return {"status": "ok", "invalidated": n_invalidated} + finally: + _cleanup_instance(_build_redis_client(args.redis_host, args.redis_port, args.redis_password), + inst_a_id) + _cleanup_instance(_build_redis_client(args.redis_host, args.redis_port, args.redis_password), + inst_b_id) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--redis-host", default=os.environ.get("FLEXKV_SMOKE_REDIS_HOST", "127.0.0.1")) + parser.add_argument("--redis-port", type=int, default=int(os.environ.get("FLEXKV_SMOKE_REDIS_PORT", "6379"))) + parser.add_argument("--redis-password", default=os.environ.get("FLEXKV_SMOKE_REDIS_PASSWORD", None)) + parser.add_argument("--model-id", default="dist-reuse-smoke-model") + parser.add_argument("--num-blocks", type=int, default=8, + help="How many put→get cycles to run in the aggregate scenario") + parser.add_argument("--session-ttl", type=int, default=30) + parser.add_argument("--scenario", default="all", + choices=["all", "peer_discovery", "aggregate", "cross_instance"]) + args = parser.parse_args() + + _ensure_redis_reachable(args.redis_host, args.redis_port, args.redis_password) + print(f"Redis reachable at {args.redis_host}:{args.redis_port}") + + results: Dict[str, Any] = {} + if args.scenario in ("all", "peer_discovery"): + results["peer_discovery"] = scenario_peer_discovery(args) + if args.scenario in ("all", "aggregate"): + results["aggregate"] = scenario_aggregate_radix_put_hook(args) + if args.scenario in ("all", "cross_instance"): + results["cross_instance"] = scenario_cross_instance_reuse_readiness(args) + + print("\n=== SMOKE RESULTS ===") + for name, res in results.items(): + print(f" {name}: {res}") + print("\nALL SCENARIOS PASSED ✅") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 659802cdd8..778a7da337 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -640,6 +641,19 @@ PYBIND11_MODULE(c_ext, m) { &flexkv::CRadixTreeIndex::evict), py::arg("evicted_blocks"), py::arg("evicted_block_hashes"), py::arg("num_evicted"), py::call_guard()) + // 4-arg overload: dist-reuse refcount guard (§2.2). The callback + // is invoked once per candidate physical block id; if it returns + // False the block id is dropped from the returned eviction set. + // NOTE: we do *not* release the GIL here because the predicate + // runs Python code — pybind11 would have to re-acquire on every + // call. The evict loop is O(#evicted blocks) so the GIL impact + // is negligible compared to the Python callback cost itself. + .def("evict", + py::overload_cast>( + &flexkv::CRadixTreeIndex::evict), + py::arg("evicted_blocks"), py::arg("evicted_block_hashes"), + py::arg("num_evicted"), py::arg("is_evictable_fn")) .def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks) .def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks) @@ -660,11 +674,18 @@ PYBIND11_MODULE(c_ext, m) { py::class_>( m, "CMatchResult") .def(py::init()) + torch::Tensor, int32_t>(), + py::arg("num_ready_matched_blocks"), + py::arg("num_matched_blocks"), + py::arg("last_node_matched_length"), + py::arg("last_ready_node"), + py::arg("last_node"), + py::arg("physical_blocks"), + py::arg("matched_node_id") = -1) .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) .def_readonly("last_node", &flexkv::CMatchResult::last_node) .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) - .def_readonly("block_node_ids", &flexkv::CMatchResult::block_node_ids) + .def_readonly("matched_node_id", &flexkv::CMatchResult::matched_node_id) .def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks) .def_readonly("num_matched_blocks", @@ -748,11 +769,13 @@ PYBIND11_MODULE(c_ext, m) { py::class_(m, "RedisMetaChannel") .def(py::init(), + const std::string &, const std::string &, int>(), py::arg("host"), py::arg("port"), py::arg("node_id"), py::arg("local_ip"), py::arg("blocks_key") = std::string("blocks"), - py::arg("password") = std::string("")) + py::arg("password") = std::string(""), + py::arg("db") = 0) .def("connect", &flexkv::RedisMetaChannel::connect) + .def("get_db", &flexkv::RedisMetaChannel::get_db) .def("get_node_id", &flexkv::RedisMetaChannel::get_node_id) .def("get_local_ip", &flexkv::RedisMetaChannel::get_local_ip) .def("make_block_key", &flexkv::RedisMetaChannel::make_block_key, @@ -831,12 +854,38 @@ PYBIND11_MODULE(c_ext, m) { .def( "load_metas_by_keys", [](flexkv::RedisMetaChannel &ch, - const std::vector &keys) { + const std::vector &keys, + size_t batch_size) { std::vector out; - ch.load_metas_by_keys(keys, out); + if (batch_size == 0) { + ch.load_metas_by_keys(keys, out); + } else { + ch.load_metas_by_keys(keys, out, batch_size); + } return out; }, - py::arg("keys")) + py::arg("keys"), py::arg("batch_size") = 0) + .def( + "list_all_block_keys", + [](flexkv::RedisMetaChannel &ch) { + std::vector keys; + ch.list_all_block_keys(keys); + return keys; + }) + .def( + "load_instance_sd_nodes", + [](flexkv::RedisMetaChannel &ch, const std::string &instance_id) { + std::unordered_map out; + ch.load_instance_sd_nodes(instance_id, out); + // Convert to py::dict explicitly — default caster maps uint32_t + // fine but we want a stable ordering for tests. + py::dict d; + for (auto &kv : out) { + d[py::cast(kv.first)] = py::cast(kv.second); + } + return d; + }, + py::arg("instance_id")) .def( "update_block_state_batch", [](flexkv::RedisMetaChannel &ch, uint32_t node_id, diff --git a/csrc/dist/distributed_radix_tree.cpp b/csrc/dist/distributed_radix_tree.cpp index 1831896861..27cd012ca0 100644 --- a/csrc/dist/distributed_radix_tree.cpp +++ b/csrc/dist/distributed_radix_tree.cpp @@ -268,10 +268,25 @@ RefRadixTree* DistributedRadixTree::remote_tree_refresh() { const std::string &k = node_keys[i]; if (k.size() <= 5) continue; if (ips.size() <= i) continue; - // parse nid + // Parse node id from the tail of the key. + // + // Two layouts are in use today: + // * legacy: "node:" + // * sharing-domain: "sd::node:" + // + // The old code hard-coded ``substr(5)`` which worked for the legacy + // layout but yielded ``"default__:..."`` under SD → ``stoul`` threw → + // every node was silently dropped → the aggregate radix ended up empty + // → cross-instance GET always missed (cache_ratio=0%). + // + // We now take the suffix after the final ':'. This matches both the + // legacy and SD layouts in one line and is bug-compatible with the + // single-SD degenerate namespace. + size_t pos = k.rfind(':'); + if (pos == std::string::npos || pos + 1 >= k.size()) continue; uint32_t nid = 0; try { - nid = (uint32_t)std::stoul(k.substr(5)); + nid = (uint32_t)std::stoul(k.substr(pos + 1)); } catch (const std::exception&) { continue; } @@ -354,8 +369,7 @@ std::shared_ptr DistributedRadixTree::match_prefix( if (idx == nullptr) { // Remote index not yet built - this is normal at startup auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); - auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); - return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64); } // Safely increment reference count while holding the lock @@ -563,8 +577,7 @@ std::shared_ptr RefRadixTree::match_prefix( if (root == nullptr) { auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); - auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); - return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64); } auto current_node = root; @@ -578,10 +591,10 @@ std::shared_ptr RefRadixTree::match_prefix( auto block_hashes_ptr = block_hashes.data_ptr(); HashType child_hash; - // node ids stored as int32 tensor (PyTorch lacks uint32 dtype) - auto node_ids_tensor = torch::empty({num_blocks}, torch::dtype(torch::kInt32)); - auto *ni_out = node_ids_tensor.data_ptr(); - int32_t ni_write = 0; + // Single-node matching constraint: all matched blocks must come from the + // same peer node_id. We lock the node_id on the first valid block and + // stop matching when a different node_id is encountered. + int32_t matched_node_id = -1; // -1 = not yet determined // now in ms struct timeval now_tv; gettimeofday(&now_tv, nullptr); @@ -638,9 +651,20 @@ std::shared_ptr RefRadixTree::match_prefix( if (bnis == nullptr || bnis->size() != pbs.size()) break; + // Single-node constraint: stop at the first block whose node_id + // differs from the already-locked matched_node_id. + int actually_copied = 0; for (int i = 0; i < matched; ++i) { + int32_t block_nid = static_cast((*bnis)[i]); + if (matched_node_id == -1) { + matched_node_id = block_nid; // lock the first node_id + } else if (block_nid != matched_node_id) { + // Different node_id encountered - stop matching here + matched = actually_copied; + break; + } pb_out[pb_write++] = pbs[i]; - ni_out[ni_write++] = (*bnis)[i]; + actually_copied++; } if (current_node->is_ready()) { @@ -672,10 +696,9 @@ std::shared_ptr RefRadixTree::match_prefix( } auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); - auto node_ids = node_ids_tensor.narrow(0, 0, ni_write); return std::make_shared(prefix_blocks_num, prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks, node_ids); + last_ready_node, current_node, physical_blocks, matched_node_id); } // Helper function to clean up an orphan tree (not attached to main tree) diff --git a/csrc/dist/redis_meta_channel.cpp b/csrc/dist/redis_meta_channel.cpp index c0ff69bf58..7128da07b5 100644 --- a/csrc/dist/redis_meta_channel.cpp +++ b/csrc/dist/redis_meta_channel.cpp @@ -13,17 +13,19 @@ namespace flexkv { -RedisHiredisClient::RedisHiredisClient() : context_(nullptr), port_(0), timeout_ms_(3000), password_("") {} +RedisHiredisClient::RedisHiredisClient() : context_(nullptr), port_(0), timeout_ms_(3000), password_(""), db_(0) {} RedisHiredisClient::~RedisHiredisClient() { close(); } -bool RedisHiredisClient::connect(const std::string &host, int port, int timeout_ms, const std::string &password) { +bool RedisHiredisClient::connect(const std::string &host, int port, int timeout_ms, + const std::string &password, int db) { host_ = host; port_ = port; timeout_ms_ = timeout_ms; password_ = password; + db_ = db; // Create connection with timeout struct timeval timeout = { timeout_ms / 1000, (timeout_ms % 1000) * 1000 }; @@ -56,7 +58,26 @@ bool RedisHiredisClient::connect(const std::string &host, int port, int timeout_ return false; } } - + + // Switch to the configured logical db. db == 0 is the Redis default so + // we skip the SELECT round-trip in that case. + if (db_ != 0) { + redisReply* reply = (redisReply*)redisCommand(context_, "SELECT %d", db_); + if (!reply) { + redisFree(context_); + context_ = nullptr; + return false; + } + bool select_ok = (reply->type == REDIS_REPLY_STATUS && + strcmp(reply->str, "OK") == 0); + freeReplyObject(reply); + if (!select_ok) { + redisFree(context_); + context_ = nullptr; + return false; + } + } + return true; } @@ -187,12 +208,14 @@ bool RedisHiredisClient::parse_reply(redisReply* reply, std::vector RedisMetaChannel::RedisMetaChannel(const std::string &h, int p, uint32_t node_id, const std::string &lip, const std::string &bk, - const std::string &pwd) - : host(h), port(p), node_id(node_id), blocks_key(bk), local_ip(lip), password(pwd) { + const std::string &pwd, + int db_) + : host(h), port(p), node_id(node_id), blocks_key(bk), local_ip(lip), + password(pwd), db(db_) { } bool RedisMetaChannel::connect() { - return client.connect(host, port, 3000, password); + return client.connect(host, port, 3000, password, db); } std::string RedisMetaChannel::make_block_key(uint32_t node_id, uint64_t hash) const { @@ -436,7 +459,32 @@ bool RedisMetaChannel::list_keys(const std::string &pattern, std::vector &keys) { - return list_keys("node:*", keys); + // Per-SD node pattern. ``blocks_key`` carries the full SD prefix + // (plus an optional trailing ``:`` component) — see + // ``_channel_blocks_key`` on the Python side. We strip the device + // suffix (if any) and append ``:node:*``. + // + // Layout examples (simplified design — CP not in sd_key): + // blocks_key = "sd::CPUB" → scan "sd::node:*" + // blocks_key = "sd:" → scan "sd::node:*" + // blocks_key = "blocks" (legacy) → scan "node:*" (backward compat) + if (blocks_key.compare(0, 3, "sd:") != 0) { + return list_keys("node:*", keys); + } + // Count the ':' separators to distinguish + // sd::pp<>:tpn<>:nsa<> (4 colons) — SD only + // sd::pp<>:tpn<>:nsa<>: (5 colons) — SD + device + // Strip the last ':' part only when we see > 4 colons. + size_t colons = 0; + for (char c : blocks_key) if (c == ':') ++colons; + std::string sd_prefix; + if (colons > 4) { + size_t pos = blocks_key.find_last_of(':'); + sd_prefix = blocks_key.substr(0, pos); + } else { + sd_prefix = blocks_key; + } + return list_keys(sd_prefix + ":node:*", keys); } bool RedisMetaChannel::list_block_keys(uint32_t node_id, std::vector &keys) { @@ -444,6 +492,13 @@ bool RedisMetaChannel::list_block_keys(uint32_t node_id, std::vector &keys) { + // Global SCAN over every block in this SD/device namespace. Used by the + // optimized ``remote_tree_refresh`` in Phase 1-F (design doc §4.7.1.2). + std::string pattern = blocks_key + ":block:*"; + return list_keys(pattern, keys); +} + bool RedisMetaChannel::hmget_field_for_keys(const std::vector &keys, const std::string &field, std::vector &values) { @@ -507,47 +562,83 @@ bool RedisMetaChannel::hmget_two_fields_for_keys(const std::vector size_t RedisMetaChannel::load_metas_by_keys(const std::vector &keys, std::vector &out) { + // Preserve original single-shot behaviour for backward compatibility. + return load_metas_by_keys(keys, out, keys.size()); +} + +size_t RedisMetaChannel::load_metas_by_keys(const std::vector &keys, + std::vector &out, + size_t batch_size) { out.clear(); if (keys.empty()) return 0; - - // Batch HMGET for all fields - std::vector> batch; - batch.reserve(keys.size()); - - for (const auto& key : keys) { - batch.push_back({"HMGET", key, "ph", "pb", "nid", "hash", "lt", "state"}); - } - - std::vector> replies; - if (!client.pipeline(batch, replies)) return 0; - - // Parse replies into BlockMeta objects - for (size_t i = 0; i < replies.size() && i < keys.size(); ++i) { - const auto& reply = replies[i]; - if (reply.size() == 6) { + if (batch_size == 0) batch_size = 500; + + out.reserve(keys.size()); + + size_t idx = 0; + const size_t total = keys.size(); + while (idx < total) { + const size_t end = std::min(idx + batch_size, total); + + std::vector> batch; + batch.reserve(end - idx); + for (size_t i = idx; i < end; ++i) { + batch.push_back({"HMGET", keys[i], "ph", "pb", "nid", "hash", "lt", "state"}); + } + + std::vector> replies; + if (!client.pipeline(batch, replies)) { + out.clear(); + return 0; + } + + for (size_t i = 0; i < replies.size(); ++i) { + const auto& reply = replies[i]; BlockMeta meta; - if (reply[0].empty() || reply[1].empty() || reply[2].empty() - || reply[3].empty() || reply[4].empty() || reply[5].empty()) { - meta.state = NODE_STATE_EVICTED; - } else { + if (reply.size() == 6 && + !reply[0].empty() && !reply[1].empty() && !reply[2].empty() && + !reply[3].empty() && !reply[4].empty() && !reply[5].empty()) { meta.ph = std::stoll(reply[0]); meta.pb = std::stoll(reply[1]); meta.nid = std::stoul(reply[2]); meta.hash = std::stoll(reply[3]); meta.lt = std::stoul(reply[4]); meta.state = std::stoi(reply[5]); + } else { + meta.state = NODE_STATE_EVICTED; } out.push_back(meta); - } else { - BlockMeta meta; - meta.state = NODE_STATE_EVICTED; - out.push_back(meta); } + idx = end; } - return out.size(); } +bool RedisMetaChannel::load_instance_sd_nodes(const std::string &instance_id, + std::unordered_map &out) { + out.clear(); + if (instance_id.empty()) return false; + std::vector resp; + const std::string key = "flexkv:instance:" + instance_id + ":sd_nodes"; + if (!client.command({"HGETALL", key}, resp)) return false; + // HGETALL replies are a flat [field0, value0, field1, value1, ...] array. + if (resp.size() % 2 != 0) return false; + for (size_t i = 0; i + 1 < resp.size(); i += 2) { + const std::string &sd_key = resp[i]; + const std::string &nid_str = resp[i + 1]; + if (sd_key.empty() || nid_str.empty()) continue; + try { + // Intentionally use stoul — node_id is stored as an unsigned int on + // the Python side. stoi would silently truncate overflow values. + out[sd_key] = static_cast(std::stoul(nid_str)); + } catch (const std::exception &) { + // Skip malformed entries but continue collecting the rest. + continue; + } + } + return true; +} + static std::string key_for_block(RedisMetaChannel* ch, uint32_t node_id, int64_t hash) { return ch->make_block_key(node_id, (uint64_t)hash); } diff --git a/csrc/dist/redis_meta_channel.h b/csrc/dist/redis_meta_channel.h index 1bd960541f..46675c6e98 100644 --- a/csrc/dist/redis_meta_channel.h +++ b/csrc/dist/redis_meta_channel.h @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include "block_meta.h" @@ -29,12 +31,17 @@ class RedisHiredisClient { int port_; int timeout_ms_; std::string password_; + // Logical Redis database number (0..15 by default). Populated by + // ``connect()`` and re-applied after any implicit reconnect inside + // ``command()`` / ``pipeline()`` via a ``SELECT `` command. + int db_; public: RedisHiredisClient(); ~RedisHiredisClient(); - bool connect(const std::string &host, int port, int timeout_ms = 3000, const std::string &password = ""); + bool connect(const std::string &host, int port, int timeout_ms = 3000, + const std::string &password = "", int db = 0); void close(); // Sends a RESP array command and parses a single reply into raw components. @@ -61,15 +68,24 @@ class RedisMetaChannel { std::string host; int port; uint32_t node_id; - std::string blocks_key; // legacy, unused for list storage + // Full key-prefix namespace. Format: ``sd:[:]`` or + // legacy bare ``blocks`` / ``CPUB`` / ... . Used verbatim by + // ``make_block_key`` and inspected by ``list_node_keys`` to derive the + // per-SD node-scan pattern. See design doc §4.7 / §4.7.1. + std::string blocks_key; std::string local_ip; std::string password; + // Logical Redis db. Matches ``CacheConfig.flexkv_redis_db`` and is + // forwarded to ``RedisHiredisClient::connect`` so every command this + // channel issues lands on the configured db. + int db; public: RedisMetaChannel(const std::string &host, int port, uint32_t node_id, const std::string &local_ip, const std::string &blocks_key = "blocks", - const std::string &password = ""); + const std::string &password = "", + int db = 0); bool connect(); // Build Redis block key: :block:: @@ -88,6 +104,8 @@ class RedisMetaChannel { // Returns the global node id assigned to this process, or UINT32_MAX if uninitialized. uint32_t get_node_id() const; const std::string &get_local_ip() const { return local_ip; } + const std::string &get_blocks_key() const { return blocks_key; } + int get_db() const { return db; } // Batch update state for given hashes belonging to node_id bool update_block_state_batch(uint32_t node_id, @@ -103,10 +121,16 @@ class RedisMetaChannel { // Generic helpers for metadata queries bool list_keys(const std::string &pattern, std::vector &keys); - // List node keys: SCAN node:* + // List node keys **scoped to this channel's SD**. The SD prefix is + // derived from ``blocks_key`` by stripping the trailing device-prefix + // component (if any) — see implementation. Legacy bare ``blocks_key`` + // values collapse to the bare ``node:*`` pattern for backward-compat. bool list_node_keys(std::vector &keys); // List block keys for a specific node: SCAN :block::* bool list_block_keys(uint32_t node_id, std::vector &keys); + // Global SCAN over *every* block key in this SD (design doc §4.7.1.2). + // Produces ``:block:*``. + bool list_all_block_keys(std::vector &keys); // Pipeline HMGET for a single field over many keys. values.size()==keys.size() on success bool hmget_field_for_keys(const std::vector &keys, @@ -119,9 +143,21 @@ class RedisMetaChannel { const std::string &field2, std::vector> &out); - // Load BlockMeta for provided keys via HMGET ph pb nid hash lt state + // Load BlockMeta for provided keys via HMGET ph pb nid hash lt state. + // ``batch_size`` controls the pipeline size (design doc §4.7.1.2); pass 0 + // for the default (500). The original no-batch-size overload is kept for + // backward compatibility. size_t load_metas_by_keys(const std::vector &keys, std::vector &out); + size_t load_metas_by_keys(const std::vector &keys, + std::vector &out, + size_t batch_size); + + // Read the ``flexkv:instance::sd_nodes`` Hash + // (design doc §4.7.1.5). Returns false when the key does not exist or + // a Redis error occurs; ``out`` is cleared on failure. + bool load_instance_sd_nodes(const std::string &instance_id, + std::unordered_map &out); }; } // namespace flexkv diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp index 04d3429f6b..ddd685ca30 100644 --- a/csrc/radix_tree.cpp +++ b/csrc/radix_tree.cpp @@ -322,10 +322,33 @@ int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, int num_evicted) { } int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, torch::Tensor &evicted_block_hashes, int num_evicted) { + // Delegate to the 4-arg overload with an empty (null) predicate — the + // "evict everything that's LRU-picked" legacy behaviour. + return evict(evicted_blocks, evicted_block_hashes, num_evicted, + std::function{}); +} + +int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, + torch::Tensor &evicted_block_hashes, + int num_evicted, + std::function is_evictable_fn) { int64_t *evicted_blocks_ptr = evicted_blocks.data_ptr(); int64_t *evicted_block_hashes_ptr = evicted_block_hashes.data_ptr(); int has_evicted = 0; + // Defensive: a buggy predicate must not wedge the allocator. Any + // exception is swallowed and the block is treated as evictable + // (same semantics as the Python path in + // ``flexkv/cache/radixtree.py::RadixTreeIndex.evict``). + auto block_ok = [&is_evictable_fn](int64_t bid) -> bool { + if (!is_evictable_fn) return true; + try { + return is_evictable_fn(bid); + } catch (...) { + return true; + } + }; + // Optimization: Batch build the priority queue to reduce overhead from O(N log N) to O(N) std::vector candidates; candidates.reserve(leaf_list.size()); @@ -345,12 +368,18 @@ int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, torch::Tensor &evicted if (node->size() > num_evicted - has_evicted) { auto [blocks, block_hashes] = node->shrink(num_evicted - has_evicted); - auto _has_evicted(has_evicted); // Shadow index - for (auto it = blocks->begin(); it != blocks->end(); it++, _has_evicted++) { - evicted_blocks_ptr[_has_evicted] = *it; - } - for (auto it = block_hashes->begin(); it != block_hashes->end(); it++, has_evicted++) { - evicted_block_hashes_ptr[has_evicted] = *it; + // Dist-reuse refcount guard (§2.2): drop any block id the + // predicate says is pinned. The block has already been + // physically detached via ``shrink`` — we just omit it from + // the returned eviction set so the allocator doesn't recycle + // the slot while a coord GET is in flight. + auto b_it = blocks->begin(); + auto h_it = block_hashes->begin(); + for (; b_it != blocks->end() && h_it != block_hashes->end(); ++b_it, ++h_it) { + if (!block_ok(*b_it)) continue; + evicted_blocks_ptr[has_evicted] = *b_it; + evicted_block_hashes_ptr[has_evicted] = *h_it; + has_evicted++; } delete blocks; delete block_hashes; @@ -362,12 +391,13 @@ int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, torch::Tensor &evicted assert(parent != nullptr); parent->remove_child(node->get_head_hash()); - auto _has_evicted(has_evicted); // Shadow index - for (auto it = blocks.begin(); it != blocks.end(); it++, _has_evicted++) { - evicted_blocks_ptr[_has_evicted] = *it; - } - for (auto it = block_hashes.begin(); it != block_hashes.end(); it++, has_evicted++) { - evicted_block_hashes_ptr[has_evicted] = *it; + auto b_it = blocks.begin(); + auto h_it = block_hashes.begin(); + for (; b_it != blocks.end() && h_it != block_hashes.end(); ++b_it, ++h_it) { + if (!block_ok(*b_it)) continue; + evicted_blocks_ptr[has_evicted] = *b_it; + evicted_block_hashes_ptr[has_evicted] = *h_it; + has_evicted++; } if (parent->is_leaf() && !is_root(parent)) { @@ -520,9 +550,8 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks, } auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); - auto empty_uint32 = torch::Tensor(); return std::make_shared(ready_prefix_blocks_num, prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks, empty_uint32); + last_ready_node, current_node, physical_blocks); } } // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h index 65bad5dc12..59ea018aa9 100644 --- a/csrc/radix_tree.h +++ b/csrc/radix_tree.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include #include #include @@ -227,17 +228,18 @@ class CMatchResult { CRadixNode *last_ready_node; CRadixNode *last_node; torch::Tensor physical_blocks; - torch::Tensor block_node_ids; + int32_t matched_node_id; // single node_id for all matched blocks (-1 = no match) CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, CRadixNode *_last_ready_node, CRadixNode *_last_node, torch::Tensor blocks, - torch::Tensor block_node_ids = torch::Tensor()) + int32_t matched_node_id = -1) : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), last_node(_last_node), - physical_blocks(blocks), block_node_ids(block_node_ids) {} + physical_blocks(blocks), + matched_node_id(matched_node_id) {} ~CMatchResult() {} }; @@ -407,6 +409,22 @@ class CRadixTreeIndex { virtual int evict(torch::Tensor &evicted_blocks, int num_evicted); virtual int evict(torch::Tensor &evicted_blocks, torch::Tensor &evicted_block_hashes, int num_evicted); + // 4-arg overload for dist-reuse refcount guard (§2.2 of + // docs/dist_reuse/implementation_gap_*.md). ``is_evictable_fn`` is a + // per-block_id predicate -- blocks for which it returns ``false`` are + // treated as refcount-pinned: they are *still* detached from the tree + // (so the LRU keeps making progress), but their block_ids are dropped + // from the returned ``evicted_blocks`` array so the upstream allocator + // never recycles them. Behaviour mirrors the Python + // ``RadixTreeIndex.evict(..., is_evictable_fn=...)`` path in + // flexkv/cache/radixtree.py. + // + // When ``is_evictable_fn`` is a null std::function (default), this + // overload is byte-identical to the 3-arg version above -- no + // regression for callers that don't wire dist-reuse. + virtual int evict(torch::Tensor &evicted_blocks, + torch::Tensor &evicted_block_hashes, int num_evicted, + std::function is_evictable_fn); virtual std::shared_ptr match_prefix(torch::Tensor &block_hashes, int num_blocks, bool update_cache_info = true); diff --git a/docs/dist_reuse/METRICS_dist_reuse.md b/docs/dist_reuse/METRICS_dist_reuse.md new file mode 100644 index 0000000000..11d49dd45b --- /dev/null +++ b/docs/dist_reuse/METRICS_dist_reuse.md @@ -0,0 +1,270 @@ +# Dist-Reuse 监控指标接入清单 + +> 配套 `KNOWN_ISSUE_p2p_refcount_2026-05-14.md` §4 的具体落地。 +> 本文档列出**已经在代码里埋点**的指标 vs **需要 C++ build 才能加**的指标。 + +## 启用方式 + +```bash +export FLEXKV_ENABLE_METRICS=1 +export FLEXKV_PY_METRICS_PORT=8080 # 默认 8080 +# PROMETHEUS_MULTIPROC_DIR 会自动创建为 /tmp/flexkv_prom_ +# 如需自定义可显式 export +``` + +启动 FlexKV 后访问 `http://:8080/metrics` 即可看到所有指标。 + +## ⚠️ 多进程聚合(已自动启用) + +FlexKV 的数据面 worker(`PEER2CPUTransferWorker` 等)跑在 `mp.Process` +子进程里。**朴素的 `prometheus_client` 不会跨进程聚合**——子进程递增的 +counter 不会出现在父进程的 `/metrics` 端点。 + +本项目已通过 prometheus_client 官方的 `PROMETHEUS_MULTIPROC_DIR` + +`MultiProcessCollector` 机制解决: + +- `flexkv/metrics/collector.py::_bootstrap_multiproc_dir` 在 import + `prometheus_client` **之前**自动创建一个 multiproc dir,并设置环境变量 +- `flexkv/metrics/server.py::start_metrics_server` 检测到 multiproc dir 后 + 使用 `MultiProcessCollector(registry, path=multiproc_dir)` 包装 registry +- 所有子进程通过继承父进程的 env,自动写入同一 multiproc dir +- HTTP `/metrics` 端点每次被 scrape 时实时聚合所有进程的样本 + +**集成测试**:父进程 + spawn 子进程各自递增计数器,从 `/metrics` 端点 +读出聚合值正确(5/5 断言通过,见 `_test_metrics_multiproc.py`)。 + +**运维注意**: +- multiproc dir 默认在 `/tmp/flexkv_prom_`,pid 不复用就不 + 会冲突;docker 容器重启自动清掉 +- 如运行很久(pid 漂移),mmap 文件可能累积(每个 worker pid 一个文件, + ~几十 KB),可定期清理过时 pid 的文件 +- 子进程崩溃不会自动清理它的 mmap 文件,但 `MultiProcessCollector` + 会继续读这些样本作为"最后已知值"——不影响监控正确性,只占少量磁盘 + +--- + +## 已就绪指标(Python 侧) + +### 1. `flexkv_py_dist_reuse_peer_mooncake_read_seconds` (Histogram) + +**含义**:peer 端 mooncake `transfer_sync_read` 调用延迟(含失败路径)。 + +**埋点位置**:`flexkv/transfer/worker.py::PEER2CPUTransferWorker._batch_transfer_impl` +(PEERH2H 分支,包了 mooncake 调用全程) + +**告警规则**(PromQL): + +```promql +# P99 mooncake_read 延迟 > 500ms 持续 5 分钟 +histogram_quantile(0.99, + rate(flexkv_py_dist_reuse_peer_mooncake_read_seconds_bucket[5m]) +) > 0.5 +``` + +KNOWN_ISSUE §4.2:剩余 lease 缓冲 < 10x 时风险升高。 + +--- + +### 2. `flexkv_py_dist_reuse_peer_mooncake_read_failures_total` (Counter, label=reason) + +**含义**:peer 端 mooncake_read 失败计数,按 reason 分类: +- `mooncake_error`:mooncake 返回 ret != 0 +- `zero_byte_transfer`:data_lens 全 0(这正是 2026-05-14 P0 bug 的特征) + +**埋点位置**:同上。 + +**告警规则**: + +```promql +# 失败率 > 0.1% +sum(rate(flexkv_py_dist_reuse_peer_mooncake_read_failures_total[5m])) +/ +( + sum(rate(flexkv_py_dist_reuse_peer_mooncake_read_failures_total[5m])) + + sum(rate(flexkv_py_dist_reuse_peer_mooncake_read_success_total[5m])) +) > 0.001 +``` + +任何 `reason="zero_byte_transfer"` 出现都应**立即 P0 oncall** —— P0 bug 复发或回归。 + +--- + +### 3. `flexkv_py_dist_reuse_peer_mooncake_read_success_total` (Counter) + +**含义**:peer 端 mooncake_read 成功计数。`#2` 的分母。 + +**埋点位置**:同 #1。 + +--- + +### 4. master 端 CPU pool 利用率(已有指标派生) + +利用率不需要新指标,PromQL 直接算: + +```promql +# CPU pool 利用率 (KNOWN_ISSUE §4.1 入口指标) +1 - ( + flexkv_py_mempool_free_blocks{device="cpu"} + / + flexkv_py_mempool_total_blocks{device="cpu"} +) +``` + +**告警规则**: + +```promql +# 利用率 > 95% 持续 60s — 可能进入场景 D(KNOWN_ISSUE §2) +( + 1 - ( + flexkv_py_mempool_free_blocks{device="cpu"} + / + flexkv_py_mempool_total_blocks{device="cpu"} + ) +) > 0.95 +``` + +--- + +## 待 C++ build 后才能加的指标 + +> 这两个指标**必须在 C++ 内部计数**才能准确,Python 侧无法窥探到。 +> 等下一次 FlexKV 容器 build 时一起加。 + +### 5. `flexkv_py_dist_reuse_lease_meta_nullptr_total` (Counter, label=device) + +**Python collector 已就绪**(`record_dist_reuse_lease_nullptr`),缺 C++ 侧的 trigger。 + +**待加位置**:`csrc/dist/local_radix_tree.cpp::publish_node_blocks` 的 +`set_lease_meta(nullptr)` 分支(约 L164-167)。 + +**预期改动**: + +```cpp +// csrc/dist/local_radix_tree.cpp L164-167 +if ((current_block_count + new_node->size()) > (max_num_blocks - swap_block_threshold)) { + new_node->set_lease_meta(nullptr); + + // [METRICS] expose to Python via a stats counter accessor + this->_metrics_lease_nullptr_count += new_node->size(); +} +``` + +加 Python 侧采集(在 `GlobalCacheEngine._update_mempool_metrics` 同节奏轮询): + +```python +# flexkv/cache/cache_engine.py — 新增方法 +def _update_dist_reuse_metrics(self): + if self._metrics_collector is None: + return + for device_type, engine in self.cache_engines.items(): + if hasattr(engine, '_radix_tree_stats'): + stats = engine._radix_tree_stats() # 新 API + self._metrics_collector.record_dist_reuse_lease_nullptr( + DEVICE_TYPE[device_type].lower(), + stats.lease_nullptr_count_delta + ) +``` + +**告警规则**(生效后): + +```promql +# 任何 lease_meta=nullptr 都是 CRITICAL +increase(flexkv_py_dist_reuse_lease_meta_nullptr_total[1m]) > 0 +``` + +KNOWN_ISSUE §5 trigger #1:必须立即升级到方案 A/B。 + +--- + +### 6. `flexkv_py_dist_reuse_about_to_evict_total` (Counter, label=device) + +**Python collector 已就绪**(`record_dist_reuse_about_to_evict`),缺 C++ 侧的 trigger。 + +**待加位置**:`csrc/dist/local_radix_tree.cpp::evict` L612-650 +(fresh-branch 加入 `about_to_evict_q` 时计数)。 + +**告警规则**(生效后): + +```promql +# 健康比例:fresh 标记 / 真实 evict <= 1 +# 持续 > 10:1 说明 master 在死撑 evict +sum(rate(flexkv_py_dist_reuse_about_to_evict_total[5m])) +/ +sum(rate(flexkv_py_evicted_blocks_total[5m])) +> 10 +``` + +--- + +## 业务侧仍需自建的指标 + +### 7. `cross_instance_hit_text_garbage_rate` + +**这个不能由 FlexKV 自动采集**,必须**业务层(sglang)加抽样工具**: + +``` +对跨实例命中的请求,按 1% 抽样: +1. 完整跑一遍 prefill(不命中 cache) +2. 比较 token-id 序列与原命中结果 +3. 不一致比例 > 0.01% → P0 oncall +``` + +KNOWN_ISSUE §5 trigger #2:业务层观察到生成质量退化。 + +--- + +## Grafana Dashboard 推荐 panel + +| Panel | 数据源 | 阈值 | +|---|---|---| +| CPU pool 利用率 | `1 - free/total` | warning > 80%, critical > 95% | +| 跨实例 mooncake_read P99 | histogram_quantile(0.99, ...) | warning > 500ms | +| 跨实例 mooncake_read 失败率 | failures / (failures + success) | warning > 0.1%, critical > 1% | +| zero_byte_transfer 计数 | `failures_total{reason="zero_byte_transfer"}` | critical > 0 | +| lease_nullptr 计数(待 C++ build) | `lease_meta_nullptr_total` | critical > 0 | +| fresh/expired evict 比 | `about_to_evict / evicted_blocks` | warning > 5, critical > 10 | + +--- + +## 测试 + +### 验证 metrics 启动 + +```bash +# 容器内 +export FLEXKV_ENABLE_METRICS=1 +export FLEXKV_PY_METRICS_PORT=8080 +python3 -c ' +from flexkv.metrics import init_global_collector +c = init_global_collector() +print("collector enabled:", c.enabled) +' + +# 应输出: +# [FlexKV PyMetrics] Prometheus metrics collector initialized +# collector enabled: True +``` + +### 触发并查看指标 + +```bash +# 跑一次 P2P 跨实例 e2e(按你昨天的双机 harness) +# 然后: +curl -s http://localhost:8080/metrics | grep dist_reuse + +# 应能看到: +# flexkv_py_dist_reuse_peer_mooncake_read_seconds_bucket{le="..."} ... +# flexkv_py_dist_reuse_peer_mooncake_read_success_total ... +``` + +--- + +## 升级路线 + +1. **立即可做**(本次落盘):5 个 Python 指标已埋好;CPU pool 利用率走 PromQL 派生 +2. **下次 C++ build 时**:加 `lease_meta_nullptr` + `about_to_evict` 的 C++ counter +3. **业务立项时**:sglang 侧加抽样工具采 `cross_instance_hit_text_garbage_rate` + +--- + +*文档生成时间:2026-05-14* diff --git a/docs/dist_reuse/OPS_QUICK_REF_p2p.md b/docs/dist_reuse/OPS_QUICK_REF_p2p.md new file mode 100644 index 0000000000..21f270ea28 --- /dev/null +++ b/docs/dist_reuse/OPS_QUICK_REF_p2p.md @@ -0,0 +1,146 @@ +# P2P 跨实例 KV 复用 — 运维快速参考 + +> 给运维 / SRE 的 30 秒速读版。详细技术背景见 +> `KNOWN_ISSUE_p2p_refcount_2026-05-14.md`。 + +## ⚠️ 一句话风险 + +启用 `enable_p2p_cpu=True` 后,跨实例 KV 复用依赖 **lease 时间窗口** 防止脏读, +**不依赖** 显式 refcount。配置不当会导致 LLM 输出乱码。 + +## ✅ 安全部署 4 步走 + +### 1. 强制环境变量 + +```bash +export FLEXKV_LEASE_TTL_MS=30000 # 必须 >= 10000 +export FLEXKV_SAFETY_TTL_MS=100 +export FLEXKV_RENEW_LEASE_MS=4000 # 必须 <= lease_ttl / 5 +export FLEXKV_REBUILD_INTERVAL_MS=100 +``` + +> ⚠️ **如果你设置 `FLEXKV_LEASE_TTL_MS < 10000`,FlexKV 会拒绝启动并报错**—— +> 这是有意为之的硬约束,不要绕过。 + +### 2. 容量规划 + +master 端 CPU pool 必须能容纳: + +``` +peak_concurrent_requests × avg_seq_tokens / tokens_per_block × 2 +``` + +**实测参考**(H800 + Qwen3-8B + tokens_per_block=16): +- 并发 64 + 平均 200 blocks/req → **至少 25600 blocks** +- 不够会触发"高水位 evict"路径,lease 防线失效 + +### 3. 监控指标(必接) + +| 指标 | 阈值 | 触发动作 | +|---|---|---| +| `cpu_pool_utilization` | > 95% 持续 60s | 告警,扩容 | +| `lease_meta_nullptr_count` | > 0 | **致命**,立即停 P2P | +| `mooncake_read_p99_latency_ms` | > 500ms | 告警 | +| `mooncake_read_failure_rate` | > 0.1% | 告警 | +| `cross_instance_hit_text_garbage_rate` | > 0.01% | **致命**,立即停 P2P | + +### 4. 灰度策略 + +第一次上线建议: + +1. **第 1 周**:开启 P2P 但只让 5% 流量走,业务侧加输出健康度抽样对比 +2. **第 2 周**:流量提升到 50%,监控数据稳定无告警则继续 +3. **第 3 周**:全量 + +--- + +## 🚨 出问题怎么办 + +### 现象 1:LLM 输出乱码 / 重复 token / 完全无关内容 + +**99% 是 lease 防线被击穿**。立即: + +```bash +# 紧急停用 P2P,回退到本地命中 +export FLEXKV_ENABLE_P2P_CPU=0 +# 或在 sglang yaml 里 enable_p2p_cpu: false +# 重启服务 +``` + +然后查 `lease_meta_nullptr_count` 监控,如果 > 0,说明已进入高水位强压路径, +**必须扩容 master 端 CPU pool 或启动 refcount glue 实现**(见 known-issue 文档 §5)。 + +### 现象 2:跨实例命中率突然降到 0 + +不是数据正确性问题,是 lease 过期或 Redis 心跳断了。查: + +```bash +# 看 master 端 +docker exec python3 -c ' +import redis +r = redis.Redis(host="", port=6379, password="", db=2) +keys = r.keys("sd:*:node:*") +print("active nodes:", [(k.decode(), r.ttl(k)) for k in keys]) +' +# TTL < 5s 说明心跳要断了 +``` + +### 现象 3:FlexKV 启动报错 `FLEXKV_LEASE_TTL_MS=... is below the safety floor` + +按报错提示把环境变量调大到 >= 10000。**不要**改源码绕过这个 check。 + +--- + +## 🔍 健康度自检脚本 + +```bash +#!/bin/bash +# flexkv_p2p_healthcheck.sh — 在 master / peer 任意一台跑 + +echo "=== FlexKV P2P 健康度自检 ===" + +# 1. 配置检查 +LEASE_TTL=${FLEXKV_LEASE_TTL_MS:-30000} +if [ "$LEASE_TTL" -lt 10000 ]; then + echo "❌ FLEXKV_LEASE_TTL_MS=$LEASE_TTL < 10000 (UNSAFE)" + exit 1 +fi +echo "✅ FLEXKV_LEASE_TTL_MS=$LEASE_TTL" + +# 2. Redis 心跳 +docker exec flexkv_distreuse python3 -c " +import redis +r = redis.Redis(host='${REDIS_HOST:-127.0.0.1}', port=6379, password='${REDIS_PWD}', db=2) +keys = r.keys('sd:*:node:*') +if not keys: + print('❌ no active nodes in Redis') + exit(1) +for k in sorted(keys): + ttl = r.ttl(k) + status = '✅' if ttl > 5 else '⚠️ ' if ttl > 0 else '❌' + print(f'{status} {k.decode()}: TTL={ttl}s') +" + +# 3. CPU pool 利用率(需要先实现 metrics 暴露端点) +# curl -s http://localhost:8080/metrics | grep -E 'cpu_pool_utilization|lease_meta_nullptr' + +echo "=== 自检完成 ===" +``` + +--- + +## 📞 升级到无 lease 漏洞的版本 + +满足下面任一条件,立即联系 FlexKV 团队启动 **方案 A(refcount handshake)**: + +1. `lease_meta_nullptr_count > 0` 在生产出现 +2. `cross_instance_hit_text_garbage_rate > 0.01%` +3. 业务需要 `lease_ttl_ms < 10000` +4. PP > 1 或 tp_node_count > 1 的 multi-SD 部署 +5. 多 peer 并发量 > 4 + +详细决策矩阵见 `KNOWN_ISSUE_p2p_refcount_2026-05-14.md` §5。 + +--- + +*最后更新:2026-05-14* diff --git a/docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp_simplified.md b/docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp_simplified.md new file mode 100644 index 0000000000..99bd929399 --- /dev/null +++ b/docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp_simplified.md @@ -0,0 +1,390 @@ +# FlexKV 分布式 KVCache 共享原理 + +> **本文目的**:给用户讲清楚 FlexKV 的 `dist_reuse`(跨实例 KVCache 共享)是怎么工作的——什么样的实例之间能共享 KV、Master/Remote 各自做什么、跨实例 P2P 是怎么发起的。 +> +> 阅读对象:FlexKV 用户、运维、上层框架(如 sglang)的接入开发者。 + +--- + +## 1. 一句话概括 + +**两个 FlexKV 实例只有在"KV 物理切片形态完全一致"时才能直接 P2P 复用对方的 KV——所谓"切片形态一致",就是节点级别的 layer 段相同、KV head 段相同、模型 layout 相同。** + +由此引申出三个关键概念: + +| 概念 | 含义 | +|---|---| +| **共享域**(Sharing Domain,SD) | 一组"KV 切片形态完全一致"的节点的集合。同 SD 节点之间可直接 P2P 互拷 block。 | +| **sd_key** | 共享域的字符串身份。两个节点 sd_key 相同 ⇔ 处于同一共享域 ⇔ 可 P2P。 | +| **Master / Remote** | 一个 FlexKV 实例由 1 个 Master + 多个 Remote 组成。Master 是控制面唯一事实来源,Remote 只搬数据。 | + +下面分别展开。 + +--- + +## 2. 哪些维度会影响"切片形态" + +``` +┌──────────────────┬──────────────┬──────────────────────────┐ +│ 维度 │ 切的是什么 │ 是否影响节点 KV 物理形态?│ +├──────────────────┼──────────────┼──────────────────────────┤ +│ 跨节点 PP │ layer 维度 │ ✅ 影响——节点持有的 layer 段不同 │ +│ 跨机 TP │ KV head 维度 │ ✅ 影响——节点持有的 KV head 段不同 │ +│ CP(普通 / NSA) │ 序列计算量 │ ❌ 不影响——attention all-gather 后各 cp_rank 的 KV pool bit-wise 一致 │ +│ 模型 / dtype │ 整个 layout │ ✅ 影响——layout 不同,block 物理 size 不同 │ +│ NSA vs 非 NSA │ block layout │ ✅ 影响——NSA 多了一份 indexer K cache buffer │ +└──────────────────┴──────────────┴──────────────────────────┘ +``` + +### 关于 CP 的关键事实 + +CP(Context Parallelism)只是把序列拆给不同 cp_rank 算 query;attention 层做了 all-gather 之后,**每个 cp_rank 自己的 KV pool 写入的都是完整全序列的 KV,且各 cp_rank 之间 bit-wise 一致**。 + +代码事实: +- 普通 CP:`flashattention_backend.py::cp_allgather_and_save_kv_cache` 完成 all-gather +- NSA CP:`deepseek_v2.py::rebuild_cp_kv_cache` 在 attention 之前做 all-gather;`nsa_indexer.py:1333-1347` 额外对 indexer K 做 `cp_all_gather_rerange_output` +- KV head 切分由 `attn_tp_size` 单独承担(`parallel_state.py:1860-1862`),CP 不切 head + +所以 CP 维度**不参与共享域划分**——同一节点上所有 cp_rank 的 KV pool 物理对等,跨实例 cp=i ↔ cp=j 直接互拷在数据层面合法。 + +--- + +## 3. sd_key 格式与含义 + +### 3.1 序列化形式 + +``` +:ppn/:tpn/:nsa<0|1> +``` + +例: + +``` +c3a2f91d0bcdef01:ppn0/1:tpn0/1:nsa0 — 单机 PP=1 部署 +c3a2f91d0bcdef01:ppn0/2:tpn0/1:nsa0 — 跨节点 PP=2 的第 0 节点 +c3a2f91d0bcdef01:ppn1/2:tpn0/1:nsa0 — 跨节点 PP=2 的第 1 节点 +c3a2f91d0bcdef01:ppn0/1:tpn0/2:nsa0 — 跨机 TP=2 的第 0 节点 +c3a2f91d0bcdef01:ppn0/1:tpn1/2:nsa0 — 跨机 TP=2 的第 1 节点 +c3a2f91d0bcdef01:ppn0/1:tpn0/1:nsa1 — 单机 PP=1 NSA 模型 +``` + +### 3.2 字段含义 + +| 字段 | 含义 | +|---|---| +| `model_id` | 模型 + 数值精度 + page_size 等的指纹。同模型同配置才能复用。 | +| `pp_node_idx` | 本节点是 PP 维度上的第几台**物理节点**(0 起)。 | +| `pp_node_count` | PP 维度跨了几台**物理节点**。 | +| `tp_node_idx` | 本节点是 TP 维度上的第几台**物理节点**(0 起)。 | +| `tp_node_count` | TP 维度跨了几台**物理节点**。 | +| `nsa` | 是否 NSA 模型(NSA 与非 NSA block 物理 layout 不同,必须隔离)。 | + +### 3.3 字段派生公式 + +```python +pp_node_count = max(min(pp_size, nnodes), 1) # PP 维度跨了几台节点 +pp_node_idx = pp_rank // max(pp_size // nnodes, 1) # 本节点在 PP 维度的位置 + +tp_node_count = nnodes_per_tp_group # TP 维度跨了几台节点 +tp_node_idx = tp_rank // tp_size_per_node # 本节点在 TP 维度的位置 +``` + +### 3.4 不变量 + +``` +pp_node_count × tp_node_count == nnodes +``` + +也就是说:**共享域数量 = 物理节点数**——每个物理节点对应唯一一个 sd_key,同实例不同节点之间 sd_key 不同。 + +### 3.5 sd_key 的核心语义 + +> **sd_key 相同 ⇒ 节点 KV 物理切片形态完全相同 ⇒ 节点间 block 可直接 P2P。** + +注意 sd_key 中**不包含 cp_rank**——这是对 §2"CP 不影响切片形态"事实的直接体现。同一节点上不同 cp_rank 的 worker 都属于同一个 sd_key,CPU pool 共用一份。 + +--- + +## 4. 共享域数量的几个典型例子 + +| 部署 | nnodes | pp_size | tp_size | sd_key 集合 | SD 数量 | +|------|---|---|---|---|---| +| 单机 PP=1 TP=1/2/4/8 | 1 | 1 | 1~8 | `ppn0/1:tpn0/1` | **1** | +| 跨节点 PP=2(每节点 PP=1) | 2 | 2 | 任意 | `ppn0/2:tpn0/1` + `ppn1/2:tpn0/1` | **2** | +| 跨机 TP=16(PP=1, 2 节点) | 2 | 1 | 16 | `ppn0/1:tpn0/2` + `ppn0/1:tpn1/2` | **2** | +| 跨节点 PP=2 × 跨机 TP=2(4 节点) | 4 | 2 | 16 | 4 个 ppn × tpn 笛卡尔积 | **4** | + +CP=4 / CP=8 不会增加 SD 数量,但 CP 内部的多张 GPU 仍要通过 H2D 接收同一份 CPU pool 的数据。 + +--- + +## 5. Master / Remote 架构 + +### 5.1 角色分工 + +``` +┌──────────────────────────────────────────────────────────┐ +│ Master (pp_node_idx=0, tp_node_idx=0, cp_rank=0) │ +│ │ +│ 控制面(唯一事实来源): │ +│ - KVManager + CacheEngine(唯一的 LocalRadixTree) │ +│ - 跨 SD 聚合层 radix(多 SD 时的 fully-ready 判定) │ +│ - get_match / put_match / insert / evict 决策 │ +│ - block 级 refcount(保护在途 block) │ +│ - Redis 元数据同步 │ +│ - 跨 instance 的 remote hit 判定 │ +│ │ +│ 决策后通过两条通道下发: │ +│ - 维度内(CP / TP / PP):sglang 现有 broadcast / scatter │ +│ - 跨 SD:FlexKV TransferOpGraph 派发 │ +├──────────────────────────────────────────────────────────┤ +│ Remote (pp_node_idx>0 / tp_node_idx>0) │ +│ │ +│ 数据面: │ +│ - TransferManagerOnRemote + RedisMeta + Mooncake │ +│ - 接收 Master 派发的 TransferOpGraph │ +│ - 本地过滤出归本节点的 op,执行 GPU↔CPU / P2P 传输 │ +│ - 完成后通过 CompletedOp 回报 │ +│ - 不维护任何 radix 索引,不做任何缓存决策 │ +└──────────────────────────────────────────────────────────┘ +``` + +### 5.2 Remote 的两种类型 + +按 sd_key 维度区分: + +| Remote 类型 | 触发条件 | 持有的 KV 切片 | +|---|---|---| +| **PP-Remote** | `pp_node_idx > 0` | layer 切片的后半段 | +| **TP-Remote** | `tp_node_idx > 0` | KV head 切片的后半段 | + +> CP 维度的 cp_rank > 0 worker 在 dist_reuse 视角下与 Master 处于同一 SD(CP 不进 sd_key)。它们仍然要把 GPU buffer 注册到 FlexKV 以便接收 H2D 指令,但**不向 Redis 单独注册 SD 节点身份**——CPU pool 中数据由 sync_leader(cp_rank=0)那一份代表整组。 + +### 5.3 同实例 Master 如何统一控制 Remote + +``` +sglang sync_leader (cp_rank=0) + │ + ▼ + Master KVManager ─── TransferOpGraph 派发 ──► PP-Remote / TP-Remote + (本节点 ZMQ 进程) (其他节点 ZMQ 进程) +``` + +- Master 在内部有一个 in-process handle、对每个 Remote 各有一个 ZMQ handle +- 一次 `_launch_task` 把同一份 `TransferOpGraph` 同时投递给所有 handle +- 每个 handle 收到后: + 1. 按 op 上的 `target_node_ids` 过滤出归自己执行的 op + 2. 对归自己的 op 调用本地 TransferEngine 执行 + 3. 完成后回报 `CompletedOp`(带 sd_key + contributing_node_id) +- Master 收齐所有 `CompletedOp` → 标记任务完成 + +--- + +## 6. 跨实例 KV 共享流程 + +下面以两个 FlexKV 实例(同模型同配置)之间的 KV 复用为例,说明 dist_reuse 是怎么发起的。 + +### 6.1 启动期:互相注册到 Redis + +每个实例启动时: + +1. Master 和每个 Remote 各自在 Redis 上 `INCR global:node_id` 拿到全局唯一的 `node_id` +2. 在自己 sd_key 命名空间下注册 `sd::node:` 心跳 key(带 TTL) +3. 在 `sd::meta:` 写入自己的 ZMQ 地址、Mooncake CPU buffer 指针等 +4. Master 收齐所有 Remote ready 后,把"sd_key → node_id"映射汇总到 `flexkv:instance::sd_nodes` + +这之后,**任何一个 instance 都能通过扫描 Redis 知道其他 instance 在每个 SD 上的 node_id**。 + +### 6.2 PUT:把推理产生的 KV 存到分布式视角 + +``` +1. 推理结束 → sync_leader 在全序列 token_ids 上做 put_match +2. Master 决策:哪些 block 需要 D2H 落到 CPU pool +3. Master 构造 TransferOpGraph: + 每个 SD 各挂一个 D2H op,target_node_ids 指向该 SD 自己的节点 +4. 通过 _launch_task 把同一份 graph 派发给所有 handle +5. 每个 SD(master in-proc 或 remote)执行各自归属的 D2H op +6. D2H 完成 → 触发 post_complete_callback: + - 把本批 block 的元数据 publish 到 sd::block:: (Redis) + - 给 Master 回 CompletedOp(sd_key, contributing_node_id) +7. Master 收齐所有 SD 的 CompletedOp → mark_sd_ready + 全 SD ready → 该前缀进入 fully-ready,对外可被 reuse +``` + +PUT 的副作用:本实例的 KV 元数据通过 Redis 让其他 instance 可见。 + +### 6.3 GET:从其他实例拉取已有的 KV + +``` +1. 新请求到达 → sync_leader 在全序列 token_ids 上做 get_match +2. 命中本实例 fully-ready 前缀 → 直接 reuse(local hit) +3. 若 miss,查 DistributedRadixTree(基于其他 instance 同步过来的 Redis 元数据): + - 命中 → 知道哪个 peer instance 的哪个 node_id 持有该 block +4. Master 构造跨实例 GET TransferOpGraph,每个 SD 各挂一个 PEERH2H op: + src_block_node_ids = [peer_inst.该 SD 上的 node_id] + target_node_ids = [self.该 SD 上的 node_id] +5. 通过 _launch_task 派发给所有 handle +6. 每个 SD 上的 worker: + - 按 src_block_node_ids 分组 + - 从 Redis 查到 peer_node 的 zmq_addr / Mooncake addr + - 发起 Mooncake transfer_sync_read,把 peer 的 CPU block 直接 RDMA 到本节点 CPU pool +7. 全部 op 完成 → Master 触发后续 H2D,把 CPU pool 的内容刷到 GPU +8. CP 维度通过 sglang 现有 broadcast/scatter 把 H2D 结果分发到所有 cp_rank 的 GPU +``` + +GET 的关键点: +- 跨实例 P2P 完全走 Mooncake RDMA,不经过 Master 中转 +- 单 Node 匹配约束(§7.4)保证一次 match 结果只来自一个 peer instance,避免多 peer 并发导致的复杂度 + +--- + +## 7. 几个关键设计简化 + +### 7.1 跨 SD 协调走统一的 graph 派发链路 + +Master 不为 dist_reuse 单独引入协议层;所有跨 SD 协调(PUT/GET)都表达成 `TransferOpGraph` 上挂多个带节点身份标签的 op,复用现有的跨节点 PP / 跨机 TP graph 派发链路。 + +`TransferOp` 上的两个关键字段: + +| 字段 | 含义 | +|---|---| +| `src_block_node_ids` | 每个 src block 来自哪个 distributed_node_id(worker 内部按 peer 分组发 RDMA) | +| `target_node_ids` | 这条 op 归哪些 SD 的节点执行;Remote 在 `_handle_submit` 阶段按 `target_node_ids` 过滤掉不属于自己的 op | +| `post_complete_callback` | op 完成后在 master 进程上下文执行的回调(如 redis publish) | + +`pp_rank` 字段保持原义(路由到本地 PP worker),跟 `target_node_ids` 正交。 + +### 7.2 CompletedOp 携带 sd_key 标签 + +每个 op 完成后,Remote 通过 result_socket 把 sd_key + contributing_node_id 跟 `CompletedOp` 一起发回 Master。Master polling worker 用 sd_key 路由到对应的 SD-ready 处理逻辑(如 `mark_sd_ready`)。 + +### 7.3 跨 SD 聚合一致性 + +> 仅 SD 数量 > 1 时(`nnodes > 1`)涉及。 + +一个请求的"完整 KV reuse"要求**所有共享域都命中**(缺一不可)。 + +Master 维护一个**跨 SD 聚合层 radix tree**:每个 block 的状态从 "ready / not-ready" 扩展为 "ready on SD(0) / ready on SD(1) / ... / fully ready"。只有 fully ready 的 block 才对外表现为"可 reuse"。 + +``` +PUT 流程:所有 SD 都通过 CompletedOp 回报后才标记 fully ready +GET 流程:只对 fully ready 的前缀返回 hit +EVICT 流程:Master 单方面 evict(跳过 refcount > 0 的 block),不通知 Remote +``` + +**Master 单方面 evict 的合理性**:Master 的 radix tree 是唯一的索引——Master evict 后不再有任何请求会去读那些 block,Remote 上的孤儿数据不影响正确性,最终会被新数据自然覆盖。 + +### 7.4 DistributedRadixTree 单 Node 匹配约束 + +`DistributedRadixTree::match_prefix` 一次匹配的所有 block 限定来自单个 peer Node(同一个 `node_id`)。 +- 匹配过程中锁定第一个有效 block 的 `node_id`,后续 block `node_id` 不同则停止匹配 +- 命中率影响极小(同一请求的 KV 通常整体写入到同一 Node) +- 让跨 SD GET 可以直接确定唯一 peer instance,构图大幅简化 + +--- + +## 8. 故障模型 + +基于"共命运"假设:同 instance 的 Master 和所有 Remote 在进程生命周期上共命运,任一 rank crash 会导致整 instance 秒级内全部退出。 + +故障只剩两类: + +1. **同 instance 整体退出/重启**: + - Master 维护 `flexkv:instance::session` 这个 TTL key + epoch + - peer instance 的 `FailureDetector` 观察到 key 消失或 epoch 变化 → 批量 invalidate + +2. **跨 instance 链路故障**(Mooncake P2P read 失败): + - Worker 通过 `FailureReportMsg` 异步上报到 Master + - Master 单前缀 invalidate + fallback 到正常 prefill + +--- + +## 9. 部署形态速查 + +### 9.1 端口拓扑 + +跟跨机 TP/PP 现有部署一致——`FLEXKV_MASTER_HOST` + `FLEXKV_MASTER_PORTS=5556,5557,5558`,所有 Remote 通过 ZMQ identity 区分。**不需要 per-SD 端口、不需要 sglang launcher 做 endpoint 发现**。 + +### 9.2 共享域数量上限 + +``` +SD 数量 = pp_node_count × tp_node_count = nnodes +``` + +当前部署上限 `nnodes ≤ 2`,所以 SD 数量 ≤ 2;保留对未来 4 节点(`pp_node_count=2 × tp_node_count=2`)的扩展支持。 + +### 9.3 跨实例配对规则 + +只有满足下面**所有条件**的两个实例之间才能 P2P 复用 KV: + +- 同 `model_id`(模型 / dtype / page_size 一致) +- 同 `pp_node_count` 和 `tp_node_count`(节点切片维度一致) +- 同 `is_nsa`(block 物理 layout 一致) + +且具体配对发生在节点级别——`inst1.ppn=i:tpn=j` 只与 `inst2.ppn=i:tpn=j` 互拷,节点身份必须严格对齐。 + +--- + +## 10. 一张图总结 + +``` + Instance 1 Instance 2 +┌─────────────────────────────────────────┐ ┌─────────────────────────────────────────┐ +│ │ │ │ +│ Master (ppn=0, tpn=0, cp=0 = sync_leader)│ │ Master (ppn=0, tpn=0, cp=0) │ +│ ┌─────────────────────────────────┐ │ │ ┌─────────────────────────────────┐ │ +│ │ KVManager + CacheEngine │ │ │ │ KVManager + CacheEngine │ │ +│ │ + LocalRadixTree (唯一索引) │ │ │ │ + LocalRadixTree │ │ +│ │ + 跨 SD 聚合层 radix │◄───┼────┼──┤ + 跨 SD 聚合层 radix │ │ +│ │ + RedisMeta + Mooncake │ │ │ │ + RedisMeta + Mooncake │ │ +│ └────────────┬────────────────────┘ │ │ └─────────────────────────────────┘ │ +│ │ TransferOpGraph 派发 │ │ │ +│ │ │ │ │ +│ ▼ │ │ │ +│ PP-Remote (ppn=1) ◄── Mooncake P2P ──►│ │ PP-Remote (ppn=1) │ +│ TP-Remote (tpn=1) ◄── Mooncake P2P ──►│ │ TP-Remote (tpn=1) │ +│ │ │ │ +│ cp=1..N-1(同 SD,仅 GPU 注册) │ │ cp=1..N-1 │ +└─────────────────────────────────────────┘ └─────────────────────────────────────────┘ + +同 sd_key 的节点之间通过 Mooncake P2P 互拷 KV block +(如 inst1.ppn=1 ↔ inst2.ppn=1,但 inst1.ppn=0 ↮ inst2.ppn=1) +跨 SD 不允许 P2P(layer / KV head 切片不同) +CP 维度不参与 SD 划分;各 cp_rank 的 KV pool 由 attention all-gather 保证 bit-wise 一致 +Master 是控制面唯一事实来源,Remote 只搬数据 +跨 SD 协调统一通过 TransferOpGraph 派发完成 +``` + +--- + +## 附录:sd_key 字段速查 + +``` +sd_key 文本格式: + :ppn/:tpn/:nsa<0|1> + +字段含义: + model_id —— 模型 + dtype + page_size 的指纹 + pp_node_idx —— 本节点是 PP 维度的第几台节点(0 起) + pp_node_count —— PP 维度跨了几台物理节点 + tp_node_idx —— 本节点是 TP 维度的第几台节点(0 起) + tp_node_count —— TP 维度跨了几台物理节点 + nsa —— 是否 NSA 模型(NSA 与非 NSA 必须隔离) + +不变量: + pp_node_count × tp_node_count == nnodes + +派生: + pp_node_count = max(min(pp_size, nnodes), 1) + pp_node_idx = pp_rank // max(pp_size // nnodes, 1) + tp_node_count = nnodes_per_tp_group + tp_node_idx = tp_rank // tp_size_per_node + +示例: + 单机 PP=1 TP=8 c3a2:ppn0/1:tpn0/1:nsa0 + 跨节点 PP=2(每节点 PP=1)节点 0 c3a2:ppn0/2:tpn0/1:nsa0 + 跨节点 PP=2(每节点 PP=1)节点 1 c3a2:ppn1/2:tpn0/1:nsa0 + 跨机 TP=16(PP=1)节点 0 c3a2:ppn0/1:tpn0/2:nsa0 + 跨机 TP=16(PP=1)节点 1 c3a2:ppn0/1:tpn1/2:nsa0 + NSA 单机 PP=1 TP=8 c3a2:ppn0/1:tpn0/1:nsa1 +``` diff --git a/docs/dist_reuse/redis_schema.md b/docs/dist_reuse/redis_schema.md new file mode 100644 index 0000000000..f2e28ce59a --- /dev/null +++ b/docs/dist_reuse/redis_schema.md @@ -0,0 +1,391 @@ +# FlexKV Dist-Reuse Redis Schema 手册 + +> **本文目的**:列清楚 FlexKV `dist_reuse` 用到的所有 Redis key 的命名规则、字段含义、典型读写时序,方便用户做容量规划、运维诊断、故障排查。 +> +> Redis 在 `dist_reuse` 中只承担两类职责: +> 1. **集群发现与心跳**:每个节点(Master / Remote)注册自己的 ZMQ 地址、Mooncake CPU buffer 指针,方便 peer instance 知道"哪个 SD 在哪台节点"。 +> 2. **block 元数据广播**:每个 block 的 `(parent_hash, hash, lease_time, state)` 等元数据 publish 到 Redis,让 peer instance 的 `DistributedRadixTree` 能重建跨 instance 的索引。 +> +> 跨 instance 的 KV 数据本身不走 Redis,走 Mooncake P2P RDMA。 +> +> 关于 sd_key 格式与 dist_reuse 整体原理见 +> [`dist_reuse_with_cp_pp_multinode_tp_simplified.md`](./dist_reuse_with_cp_pp_multinode_tp_simplified.md)。 + +--- + +## 0. sd_key 速记 + +``` + = ":ppn/:tpn/:nsa<0|1>" +``` + +例: + +``` +c3a2f91d0bcdef01:ppn0/1:tpn0/1:nsa0 — 单机 PP=1 部署 +c3a2f91d0bcdef01:ppn0/2:tpn0/1:nsa0 — 跨节点 PP=2 第 0 节点 +c3a2f91d0bcdef01:ppn1/2:tpn0/1:nsa0 — 跨节点 PP=2 第 1 节点 +c3a2f91d0bcdef01:ppn0/1:tpn0/2:nsa0 — 跨机 TP=2 第 0 节点 +c3a2f91d0bcdef01:ppn0/1:tpn1/2:nsa0 — 跨机 TP=2 第 1 节点 +``` + +不变量:`pp_node_count × tp_node_count == nnodes`,即 **SD 数量 = 物理节点数**。 + +逻辑 db:由 `CacheConfig.flexkv_redis_db` 指定(默认 0,建议生产环境用独立的 db,如 15)。 + +--- + +## 1. 命名空间一览 + +| 命名空间 | 作用域 | key 数量 | +|---|---|---| +| `sd::*` | 每个 SD 独立 | 每实例 `pp_node_count × tp_node_count = nnodes` 份 | +| `flexkv:instance::*` | 每个 FlexKV 实例独立 | 跨 SD 共享 | +| `global:node_id` | 全局(跨实例) | 单条计数器 | +| `flexkv_node_id_updated:` | Pub/Sub channel | 每 SD 一个(非 key) | + +--- + +## 2. SD 维度 key(每节点 1 份,每实例 `nnodes` 份) + +### 2.1 `sd::node:` — 节点心跳 + +| 属性 | 值 | +|---|---| +| 类型 | **Hash + TTL** | +| TTL | `CacheConfig.instance_session_ttl_seconds`(默认 8s) | +| 维护方 | `RedisNodeInfo._heartbeat_worker` 以 TTL/3 频率发 `EXPIRE` | +| 生命周期 | 进程启动时 `register_node` 创建;`atexit` / `SIGINT` 时 `unregister_node` | + +**Hash 字段**: + +| 字段 | 类型 | 含义 | +|---|---|---| +| `node_id` | int | 从 `global:node_id` INCR 取到(全局唯一) | +| `ip` / `local_ip` | str | 本节点监听 IP | +| `uuid` | str | 进程 UUID(防同 IP 重启后留下"鬼节点") | +| `status` | str | `"active"` | +| `timestamp` | int | 注册时的 Unix 时间戳(秒) | +| `sd_key` | str | 冗余存本 SD 的序列化形式,便于运维排查 | + +--- + +### 2.2 `sd::meta:` — 节点地址元信息 + +| 属性 | 值 | +|---|---| +| 类型 | **Hash** | +| TTL | 无(生命周期跟随 `node:` 的 TTL) | +| 维护方 | `RedisMeta.regist_node_meta(...)` | + +**Hash 字段**: + +| 字段 | 含义 | +|---|---| +| `node_id` | int | +| `addr` | 节点 IP | +| `zmq_addr` | `tcp://ip:port`;Master 派发 `TransferOpGraph` 时使用 | +| `cpu_buffer_ptr` | Mooncake P2P 读取的 CPU block 池首地址 | +| `ssd_buffer_ptr` | SSD block 池首地址(如启用 SSD) | + +--- + +### 2.3 `sd::buffer::` — Mooncake 注册缓冲区 + +| 属性 | 值 | +|---|---| +| 类型 | **Hash** | +| TTL | 无 | +| 维护方 | `RedisMeta.regist_buffer([(ptr, size), ...])` | + +**Hash 字段**: + +| 字段 | 含义 | +|---|---| +| `buffer_size` | int,buffer 字节数 | +| 自定义字段 | 可扩展 `rdma_port` / `nic_name` 等 Mooncake 附加信息 | + +--- + +### 2.4 `sd::block::` — Block 元信息 + +> **最热的 key,数量最多**(量级:每 SD 1k~100k) + +| 属性 | 值 | +|---|---| +| 类型 | **Hash** | +| TTL | 无(生命周期由 `lt`/`state` 管理) | +| 维护方 | C++ `RedisMetaChannel::publish` / `update_block_state_batch` / `delete_blockmeta_batch` | + +**Hash 字段**(固定 6 个): + +| 字段 | 类型 | 含义 | +|---|---|---| +| `ph` | int64 | parent hash(构造 radix 链) | +| `pb` | int64 | parent block_node_id | +| `nid` | uint32 | 写入者 node_id | +| `hash` | int64 | 自身 hash | +| `lt` | uint32 | lease time(续租时间戳) | +| `state` | int | 0=READY / 1=EVICTED | + +**全局 SCAN pattern**:`sd::block:*`(由 `RedisMetaChannel::list_all_block_keys` 使用)。 + +--- + +### 2.5 `sd::aggregate:` — 跨 SD 聚合标记(预留) + +| 属性 | 值 | +|---|---| +| 类型 | **未启用**(`SharingDomainNamespace.aggregate_key(...)` 已提供构造器) | + +预留供未来把 `MasterCoordinator` 的跨 SD 聚合状态持久化到 Redis(用于 Master 重启恢复)。现阶段 `AggregateRadixTree` 只在内存。 + +--- + +### 2.6 `sd::pcfs:` — PCFS 文件节点索引 + +| 属性 | 值 | +|---|---| +| 类型 | **List** | +| 维护方 | `RedisMeta.add_pcfs_file_nodeids` / `load_pcfs_file_nodeids` | +| 含义 | 记录本节点能读到的 PCFS 文件对应的 node_id 列表(用于 3rd remote) | + +--- + +## 3. Instance 维度 key(每实例共享一份) + +### 3.1 `flexkv:instance::session` — 实例会话(故障检测) + +| 属性 | 值 | +|---|---| +| 类型 | **JSON string + TTL** | +| TTL | `CacheConfig.instance_session_ttl_seconds`(默认 8s) | +| 维护方 | `RedisSessionClient.register` / `renew` / `unregister` | +| 读取方 | `FailureDetector.poll_once()`(peer instance 跨实例扫描) | + +**JSON payload**: + +```json +{ + "instance_id": "", + "epoch": "", + "master_zmq_addr": "tcp://ip:port", + "node_ids": [123, 124, 125, ...], + "mooncake_addrs_by_sd": {"": "tcp://ip:port", ...} +} +``` + +**故障判定**: +- Peer 观察到 key 消失(TTL 到期)→ 触发 `on_peer_lost(peer_instance_id)` +- Peer 观察到 `epoch` 字段变化 → 视为重启事件 +- **即使 session 漏报**,数据面的 Mooncake P2P 失败会兜底(通过 `FailureReportMsg`) + +--- + +### 3.2 `flexkv:instance::sd_nodes` — 实例 SD→节点映射 + +| 属性 | 值 | +|---|---| +| 类型 | **Hash** | +| TTL | 无 | +| 维护方 | Master 启动时 `RedisMeta.register_instance_sd_nodes(instance_id, sd_to_nid)` 写入一次 | +| 读取方 | 其他实例的 `DistributedRadixTree.remote_tree_refresh` | + +**Hash 字段**(field = sd_key 字符串,value = 该 SD 所在节点的 node_id): + +跨节点 PP=2 的例子: + +``` +"c3a2f91d0bcdef01:ppn0/2:tpn0/1:nsa0" -> 50 # 第 0 节点 +"c3a2f91d0bcdef01:ppn1/2:tpn0/1:nsa0" -> 51 # 第 1 节点 +``` + +跨机 TP=2 的例子: + +``` +"c3a2f91d0bcdef01:ppn0/1:tpn0/2:nsa0" -> 60 # 第 0 节点 +"c3a2f91d0bcdef01:ppn0/1:tpn1/2:nsa0" -> 61 # 第 1 节点 +``` + +> 在 sd_key 不变量 `pp_node_count × tp_node_count == nnodes` 的约束下,每个 sd_key 对应**唯一的物理节点**。同节点上的 cp_rank>0 worker 不在这里出现(CP 不进 sd_key),CPU pool 内容由 sync_leader 那一份代表。 + +--- + +## 4. 全局 key + +### 4.1 `global:node_id` — 全局计数器 + +| 属性 | 值 | +|---|---| +| 类型 | **String 计数器(INCR)** | +| 作用域 | **所有 SD / 所有实例共用** | +| 维护方 | `RedisNodeInfo.register_node` 里 `INCR global:node_id` | + +`node_id` 全局唯一保证 `BlockMeta.nid` 在 Redis 跨 SD 查询时不会歧义。 + +--- + +### 4.2 `flexkv_node_id_updated:` — Pub/Sub channel + +| 属性 | 值 | +|---|---| +| 类型 | **Pub/Sub channel**(非 key) | +| 作用域 | 每 SD 一个 | +| 用途 | SD 内其他节点订阅此 channel,实时得到"新节点加入"事件 | + +--- + +## 5. 读写时序速查 + +### 5.1 节点启动(Master 或 Remote 都走这条) + +``` +INCR global:node_id → 取到 nid +HSET sd::node: ip=... uuid=... status=active sd_key= ... +EXPIRE sd::node: +HSET sd::meta: addr=... zmq_addr=... cpu_buffer_ptr=... +HSET sd::buffer:: buffer_size=... (1 次 per buffer) +PUBLISH flexkv_node_id_updated: → 通知同 SD 其他节点 +``` + +`` 形如 `c3a2:ppn0/2:tpn0/1:nsa0`。 + +### 5.2 Master 收齐 Remote ready 后(启动最后一步) + +``` +HSET flexkv:instance::sd_nodes + ppn0/2:tpn0/1:nsa0 nid_master + ppn1/2:tpn0/1:nsa0 nid_remote_pp + ... +SET flexkv:instance::session EX → 启动心跳线程 +``` + +### 5.3 KVCache PUT(block 就绪) + +``` +HSET sd::block:: + ph=... pb=... nid=... hash=... lt=... state=0 +``` + +PUT 阶段每个 block D2H 完成后由 Master 通过 `post_complete_callback` 触发上述 publish。 + +### 5.4 跨 SD 聚合(多 SD 部署下) + +跨 SD 协调本身**不读写 Redis**——它走 ZMQ + `TransferOpGraph` 派发链路。Redis 只承担 §5.3 的 block 元数据 publish + 启动期的 ready handshake。 + +``` +Master 端 (kvtask.py::_launch_task) + for handle in transfer_handles: # master in-proc + N 个 remote handle + handle.submit(transfer_graph, ...) # 同一份 graph 广播给所有 SD + +Remote 端 (transfer_manager.py::_handle_submit) + 按 target_node_ids 过滤掉不归本节点的 op + rebind 把 op.pp_rank 改写到本地 + TransferEngine.submit(graph) # 提交本地执行 + +Worker 端: + - D2H clone:完成后回 CompletedOp(sd_key, contributing_node_id) + - PEERH2H clone:按 src_block_node_ids 分组 → get_node_meta(peer_node) + → mooncake.transfer_sync_read → 完成后回 CompletedOp + +Master 端 polling worker + 收到 CompletedOp(sd_key=...) → MasterCoordinator.mark_sd_ready(...) +``` + +### 5.5 远端 radix 重建(`DistributedRadixTree.remote_tree_refresh`) + +``` +HGETALL flexkv:instance::sd_nodes → sd_key → nid map +for sd in map: + SCAN sd::block:* + pipeline HMGET (ph pb nid hash lt state) batch 500 +``` + +接口 `RedisMetaChannel::list_all_block_keys` / `load_metas_by_keys(batch_size)` 已就位。 + +--- + +## 6. 典型部署下的 key 量级估算 + +以 `CP=8, 跨节点 PP=2, tp_node_count=1`(共 2 个物理节点)为例: + +| key 种类 | 量级 | +|---|---| +| `sd:*:node:*` | **2 条**(每节点 1 个 SD) | +| `sd:*:meta:*` | 2 条 | +| `sd:*:buffer:*` | 2~6 条(取决于每 SD 注册的 buffer 数) | +| `sd:*:block:*` | **1k~100k/SD × 2 SD ≈ 2k~200k**(主要数据) | +| `flexkv:instance:*:*` | 2 条 per 实例(session + sd_nodes) | +| `global:node_id` | 1 条 | + +**单实例峰值估算**:block key 约 **20 万量级**,远低于 Redis 单实例百万级的舒适区。 + +> CP 维度由于不进 sd_key 而被折叠到同一 SD,相同物理资源下 SD 数量大幅减少(直接等于物理节点数 `nnodes`)。 + +--- + +## 7. 运维清单 + +### 7.1 清空本实例所有 key(推荐) + +```bash +# 前提:CacheConfig.flexkv_redis_db = 15(建议 FlexKV 独占一个 db) +redis-cli -n 15 FLUSHDB +``` + +### 7.2 只清某个实例(实例级隔离) + +```bash +# 清 instance 级 key +redis-cli --scan --pattern "flexkv:instance::*" | xargs redis-cli DEL + +# 清该实例下所有 SD 的 key(先拿到 sd_key list) +for sd in $(redis-cli HKEYS flexkv:instance::sd_nodes); do + redis-cli --scan --pattern "sd:${sd}:*" | xargs redis-cli DEL +done +``` + +### 7.3 清某个 SD 的所有 key + +```bash +redis-cli --scan --pattern "sd::*" | xargs redis-cli DEL +# 例:redis-cli --scan --pattern "sd:c3a2*:ppn0/2:tpn0/1:nsa0:*" | xargs redis-cli DEL +``` + +### 7.4 诊断:看某个实例的健康度 + +```bash +# 1. session 是否活着(TTL > 0) +redis-cli TTL flexkv:instance::session + +# 2. 有多少个 SD 已注册(应当等于 nnodes) +redis-cli HLEN flexkv:instance::sd_nodes + +# 3. 所有 SD 是否都有节点在线 +for sd in $(redis-cli HKEYS flexkv:instance::sd_nodes); do + nid=$(redis-cli HGET flexkv:instance::sd_nodes "$sd") + echo -n "SD=$sd node=$nid node_ttl=" + redis-cli TTL "sd:$sd:node:$nid" +done +``` + +--- + +## 8. 常见问题 + +**Q1:我想让 FlexKV 用独立 db 不影响其他服务。** +设置 `CacheConfig.flexkv_redis_db = 15`,所有 FlexKV key 都落在 db=15;运维 `redis-cli -n 15 FLUSHDB` 一把清。 +Python 端和 C++ 端都会真实发 `SELECT `,详见 `flexkv/common/dist_reuse/failure_detector.py::make_redis_client_from_cache_config`。 + +**Q2:block key 太多导致 SCAN 卡。** +`RedisMetaChannel::list_all_block_keys` 用全局 SCAN(按 `sd::block:*` 模式扫描)+ 大批量 pipeline (batch=500) 加载 metadata,避免逐 node 单条 round-trip。 + +**Q3:TTL 过期了,但数据仍被读出?** +Redis TTL 到期不保证立刻被后台清理(惰性失效 + 定期扫描两种策略叠加)。 +FlexKV 在 `_cleanup_stale_nodes_by_ip` 里用 `uuid` 字段区分同 IP 重启前后的节点,避免误读老数据。 + +**Q4:sd_key 中没有 `pp_rank` 字段,怎么知道某个节点上具体跑哪个 PP rank?** +sd_key 描述的是"节点 KV 物理切片形态",并不直接编码 PP rank。具体 PP rank 由 sglang launcher 通过启动参数决定,FlexKV 只关心"本节点 KV 物理切片对应的 ppn{idx}/{count}"。 + +**Q5:跨节点 PP=2 实例与单机 PP=1 实例之间能 P2P 复用 KV 吗?** +不能。跨节点 PP=2 节点 0 的 sd_key 是 `ppn0/2:tpn0/1`,单机 PP=1 的 sd_key 是 `ppn0/1:tpn0/1`,二者 `pp_node_count` 不同,sd_key 字符串不相等 → 不在同一共享域。物理上前者 CPU pool 只装前半 layer,后者装完整 L 层 layer,block 物理 size 也不兼容。 diff --git a/docs/monitoring/README_en.md b/docs/monitoring/README_en.md index 8ce06baa57..7ee91f3798 100644 --- a/docs/monitoring/README_en.md +++ b/docs/monitoring/README_en.md @@ -13,6 +13,7 @@ FlexKV integrates a [Prometheus](https://prometheus.io/)-based runtime metrics m | `FLEXKV_ENABLE_METRICS` | `0` | Enable metrics collection (set to `1` to enable, disabled by default) | | `FLEXKV_PY_METRICS_PORT` | `8080` | Python metrics HTTP server port | | `FLEXKV_CPP_METRICS_PORT` | `8081` | C++ metrics HTTP server port | +| `PROMETHEUS_MULTIPROC_DIR` | *(auto)* | Directory for `prometheus_client` per-process sample files. Required when FlexKV runs across multiple Python processes (sglang TP/PP workers, KVManager subprocess, transfer workers). The collector auto-bootstraps a writable temp directory if unset; explicitly set it to a tmpfs path to override. | ### 1.2 Configuration @@ -60,6 +61,41 @@ C++ metrics are managed by the `MetricsManager` singleton, primarily instrumente --- +### 2.3 Cross-instance Reuse Metrics (`flexkv_py_dist_reuse_*`) + +These metrics observe the **distributed KV-cache reuse** path (master / peer +instances coordinated through Redis-meta + Mooncake P2P CPU pulls). They are +the primary signals for the lease-based safety guarantee that protects +cross-instance reads from racing master-side eviction. The 5 metrics live +alongside the existing `flexkv_py_*` set on the Python metrics endpoint +(`/metrics`, port `FLEXKV_PY_METRICS_PORT`). + +| Metric Name | Type | Labels | Severity | Description | +|---|---|---|---|---| +| `flexkv_py_dist_reuse_lease_meta_nullptr_total` | Counter | `device` | **CRITICAL** | Master-side blocks inserted with `lease_meta=nullptr` because the pool exceeded `swap_block_threshold`. Such blocks become evictable immediately and break the lease-based P2P safety guarantee — **any positive value in production should page oncall**. | +| `flexkv_py_dist_reuse_about_to_evict_total` | Counter | `device` | **WARN** | Blocks marked `ABOUT_TO_EVICT` (the *fresh*-branch evict path: lease still valid but the slot was needed anyway). Used together with `flexkv_py_evicted_blocks_total` to compute the `fresh / expired` evict ratio — sustained ratio > 10 means master is fighting eviction pressure and the lease-based safety margin is shrinking. | +| `flexkv_py_dist_reuse_peer_mooncake_read_seconds` | Histogram | — | **OPS** | Latency of peer-side `mooncake.transfer_sync_read` calls (P2P CPU pull from a master instance). Buckets: `0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0` seconds. **P99 > 500 ms** means the remaining lease window is shrinking toward exhaustion. | +| `flexkv_py_dist_reuse_peer_mooncake_read_failures_total` | Counter | `reason` | **CRITICAL** | Peer-side mooncake read failures. The `reason` label is one of `mooncake_error` (non-zero ret), `zero_byte_transfer` (the P0-bug symptom from 2026-05-14: ret==0 but no bytes moved), `node_meta_missing` (peer discovery breakdown), `timeout`. **Sustained failure rate > 0.1% warrants oncall**. | +| `flexkv_py_dist_reuse_peer_mooncake_read_success_total` | Counter | — | — | Peer-side mooncake read successes. Used as the denominator when computing the failure rate from `_failures_total`. | + +**Instrumentation status (as of this commit):** + +| Metric | Production call site | Notes | +|---|---|---| +| `flexkv_py_dist_reuse_peer_mooncake_read_seconds` | `flexkv/transfer/worker.py` (`PEER2CPUTransferWorker`) | Wired end-to-end. | +| `flexkv_py_dist_reuse_peer_mooncake_read_failures_total` | same as above | Wired end-to-end. | +| `flexkv_py_dist_reuse_peer_mooncake_read_success_total` | same as above | Wired end-to-end. | +| `flexkv_py_dist_reuse_lease_meta_nullptr_total` | *(collector hook ready, business-side trigger pending)* | The Python helper `record_dist_reuse_lease_nullptr` is ready in `flexkv/metrics/collector.py`; the C++ master-side eviction path that should call it is tracked in `docs/dist_reuse/METRICS_dist_reuse.md`. The metric will read `0` until that hook lands. | +| `flexkv_py_dist_reuse_about_to_evict_total` | *(collector hook ready, business-side trigger pending)* | Same status as above (`record_dist_reuse_about_to_evict`). | + +> The two pending metrics are intentionally exposed (with value `0`) so +> Prometheus scrape configs and Grafana panels can be wired ahead of +> time; they will start emitting non-zero values automatically once the +> C++ trigger lands. **Do not interpret a `0` as "system healthy"** — a +> `0` value on these two specifically means "not yet instrumented". + +--- + ## 3. Monitoring Stack Deployment ### 3.1 Directory Structure @@ -118,6 +154,78 @@ curl -s http://localhost:8080/metrics | grep flexkv_py_ # Verify C++ metrics endpoint curl -s http://localhost:8081/metrics | grep flexkv_cpp_ + +# Verify the new cross-instance reuse metrics (must be present even if value=0) +curl -s http://localhost:8080/metrics | grep flexkv_py_dist_reuse_ +``` + +### 3.5 Multiprocess Scrape Notes + +FlexKV runs the Python control plane across several processes (the +sglang scheduler, one transfer-engine subprocess per `WorkerKey`, and a +background KVManager). `prometheus_client` writes one sample file per +process into `PROMETHEUS_MULTIPROC_DIR`; the metrics HTTP server then +aggregates them on every scrape via `MultiProcessCollector`. + +* The collector auto-bootstraps a writable temp dir if + `PROMETHEUS_MULTIPROC_DIR` is unset, so a single-process workflow + (e.g. unit tests) needs no setup. +* For long-running deployments, point `PROMETHEUS_MULTIPROC_DIR` at a + tmpfs path (e.g. `/dev/shm/flexkv_prom`) to avoid disk wear and to + ensure the directory is wiped on container restart. +* The HTTP server writes the standard `Content-Type: + text/plain; version=0.0.4` header expected by Prometheus; no extra + scrape config is needed beyond pointing Prometheus at + `:/metrics`. + +### 3.6 Recommended PromQL alerts for the dist_reuse metrics + +Use these as a starting point in Prometheus / Alertmanager: + +```yaml +groups: +- name: flexkv_dist_reuse + rules: + # CRITICAL — any nullptr lease insert means the safety guarantee is broken. + - alert: FlexKVDistReuseLeaseMetaNullptr + expr: increase(flexkv_py_dist_reuse_lease_meta_nullptr_total[5m]) > 0 + for: 1m + labels: { severity: critical } + annotations: + summary: "FlexKV master inserted lease_meta=nullptr blocks (device={{ $labels.device }})" + + # CRITICAL — peer mooncake_read failure rate > 0.1% sustained for 5m. + - alert: FlexKVDistReusePeerReadFailureRate + expr: | + sum by (reason) (rate(flexkv_py_dist_reuse_peer_mooncake_read_failures_total[5m])) + / + sum (rate(flexkv_py_dist_reuse_peer_mooncake_read_success_total[5m]) + + rate(flexkv_py_dist_reuse_peer_mooncake_read_failures_total[5m])) + > 0.001 + for: 5m + labels: { severity: critical } + annotations: + summary: "FlexKV peer mooncake_read failure rate > 0.1% (reason={{ $labels.reason }})" + + # OPS — peer mooncake_read P99 > 500ms. + - alert: FlexKVDistReusePeerReadP99High + expr: histogram_quantile(0.99, sum by (le) (rate(flexkv_py_dist_reuse_peer_mooncake_read_seconds_bucket[5m]))) > 0.5 + for: 5m + labels: { severity: warning } + annotations: + summary: "FlexKV peer mooncake_read P99 > 500ms (lease margin shrinking)" + + # WARN — fresh / expired evict ratio > 10 sustained. + - alert: FlexKVDistReuseEvictPressure + expr: | + sum by (device) (rate(flexkv_py_dist_reuse_about_to_evict_total[5m])) + / + clamp_min(sum by (device) (rate(flexkv_py_evicted_blocks_total[5m])), 1e-9) + > 10 + for: 10m + labels: { severity: warning } + annotations: + summary: "FlexKV master eviction is fighting lease pressure (device={{ $labels.device }})" ``` ### 3.4 Accessing Grafana Dashboards diff --git a/docs/monitoring/README_zh.md b/docs/monitoring/README_zh.md index 8f8f6f2184..c4a98e6998 100644 --- a/docs/monitoring/README_zh.md +++ b/docs/monitoring/README_zh.md @@ -13,6 +13,7 @@ FlexKV 集成了基于 [Prometheus](https://prometheus.io/) 的运行时指标 | `FLEXKV_ENABLE_METRICS` | `0` | 启用指标收集(设为 `1` 启用,默认禁用) | | `FLEXKV_PY_METRICS_PORT` | `8080` | Python 指标 HTTP 服务端口 | | `FLEXKV_CPP_METRICS_PORT` | `8081` | C++ 指标 HTTP 服务端口 | +| `PROMETHEUS_MULTIPROC_DIR` | *(自动)* | `prometheus_client` 多进程样本文件目录。FlexKV 会在多个 Python 进程(sglang TP/PP worker、KVManager 子进程、transfer worker)中分别写入采样数据,HTTP server 在抓取时通过 `MultiProcessCollector` 聚合。未设置时 collector 会自动初始化一个可写临时目录;建议在长时间运行场景中显式指向 tmpfs 路径(如 `/dev/shm/flexkv_prom`)。 | ### 1.2 配置方式 @@ -60,6 +61,39 @@ C++ 指标由 `MetricsManager` 单例管理,主要在 RadixTree 缓存操作 --- +### 2.3 跨实例复用指标 (`flexkv_py_dist_reuse_*`) + +这组指标观测**分布式 KV-cache 复用**路径(master / peer 实例通过 +Redis-meta 协调 + Mooncake P2P CPU 拉取),是 lease 安全机制的核心信号 —— +用于保证跨实例读取不会与 master 端 evict 发生竞争。这 5 个指标与现有 +`flexkv_py_*` 共同暴露在 Python 指标端点(`/metrics`,端口 +`FLEXKV_PY_METRICS_PORT`)。 + +| 指标名称 | 类型 | 标签 | 严重级别 | 描述 | +|---|---|---|---|---| +| `flexkv_py_dist_reuse_lease_meta_nullptr_total` | Counter | `device` | **CRITICAL** | Master 端因池容量超过 `swap_block_threshold` 而以 `lease_meta=nullptr` 插入的 block 数。这类 block 立即可被 evict,破坏了 lease 保护的 P2P 安全性 — **生产环境出现任何正值都应立即告警 oncall**。 | +| `flexkv_py_dist_reuse_about_to_evict_total` | Counter | `device` | **WARN** | 进入 *fresh*-branch evict 路径的 block 数(lease 仍有效,但池压力强行需要这个槽位)。与 `flexkv_py_evicted_blocks_total` 一起计算 `fresh / expired` evict 比值 — 持续 > 10 表示 master 在与 evict 压力对抗,lease 安全余量正在收缩。 | +| `flexkv_py_dist_reuse_peer_mooncake_read_seconds` | Histogram | — | **OPS** | Peer 端 `mooncake.transfer_sync_read` 调用耗时(P2P CPU 拉取 master 实例数据)。Bucket:`0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0` 秒。**P99 > 500ms** 表示 lease 余量正逼近耗尽。 | +| `flexkv_py_dist_reuse_peer_mooncake_read_failures_total` | Counter | `reason` | **CRITICAL** | Peer 端 mooncake 读取失败计数。`reason` 取值:`mooncake_error`(非零返回)、`zero_byte_transfer`(2026-05-14 修复的 P0 bug 症状:ret==0 但实际未传输字节)、`node_meta_missing`(peer 节点发现失败)、`timeout`。**持续失败率 > 0.1% 应告警 oncall**。 | +| `flexkv_py_dist_reuse_peer_mooncake_read_success_total` | Counter | — | — | Peer 端 mooncake 读取成功计数,作为 `_failures_total` 计算失败率时的分母。 | + +**埋点接入状态(截至本 commit):** + +| 指标 | 生产侧调用位置 | 说明 | +|---|---|---| +| `flexkv_py_dist_reuse_peer_mooncake_read_seconds` | `flexkv/transfer/worker.py`(`PEER2CPUTransferWorker`) | 端到端已接入。 | +| `flexkv_py_dist_reuse_peer_mooncake_read_failures_total` | 同上 | 端到端已接入。 | +| `flexkv_py_dist_reuse_peer_mooncake_read_success_total` | 同上 | 端到端已接入。 | +| `flexkv_py_dist_reuse_lease_meta_nullptr_total` | *(collector 钩子已就绪,业务侧 trigger 待补)* | Python 辅助方法 `record_dist_reuse_lease_nullptr` 已在 `flexkv/metrics/collector.py` 就绪;C++ master 端 evict 路径的对应触发逻辑见内部计划文档 `docs/dist_reuse/METRICS_dist_reuse.md`。该 trigger 落地前指标值会一直为 `0`。 | +| `flexkv_py_dist_reuse_about_to_evict_total` | *(collector 钩子已就绪,业务侧 trigger 待补)* | 状态同上(对应方法 `record_dist_reuse_about_to_evict`)。 | + +> 上述两个未接入指标也会照常暴露(值恒为 `0`),便于 Prometheus 抓取配置和 +> Grafana 面板提前接入;C++ trigger 落地后会自动开始上报非零值。 +> **不要把这两个指标的 `0` 值理解为「系统健康」** — 它们的 `0` 当前表示 +> 「埋点尚未接入」。 + +--- + ## 三、监控组件部署说明 ### 3.1 目录结构 @@ -118,6 +152,76 @@ curl -s http://localhost:8080/metrics | grep flexkv_py_ # Verify C++ metrics endpoint curl -s http://localhost:8081/metrics | grep flexkv_cpp_ + +# Verify the new cross-instance reuse metrics (must be present even if value=0) +curl -s http://localhost:8080/metrics | grep flexkv_py_dist_reuse_ +``` + +### 3.5 多进程抓取说明 + +FlexKV 的 Python 控制面运行在多个进程中(sglang scheduler、每个 +`WorkerKey` 一个 transfer-engine 子进程、后台 KVManager)。 +`prometheus_client` 在 `PROMETHEUS_MULTIPROC_DIR` 中按进程写入采样文件, +HTTP server 在每次抓取时通过 `MultiProcessCollector` 聚合这些采样。 + +* 若未显式设置 `PROMETHEUS_MULTIPROC_DIR`,collector 会自动初始化一个 + 可写临时目录,因此单进程场景(如单测)无需额外配置。 +* 长期运行的部署建议把 `PROMETHEUS_MULTIPROC_DIR` 指向 tmpfs 路径 + (如 `/dev/shm/flexkv_prom`),既避免磁盘磨损,也能保证容器重启时 + 目录被清理。 +* HTTP server 会输出 Prometheus 标准的 + `Content-Type: text/plain; version=0.0.4` 响应头,Prometheus 端只需配置 + `:/metrics` 抓取地址即可,不需要额外 + scrape 参数。 + +### 3.6 dist_reuse 指标推荐 PromQL 告警 + +以下规则可作为 Prometheus / Alertmanager 的起点配置: + +```yaml +groups: +- name: flexkv_dist_reuse + rules: + # CRITICAL — 任何 nullptr lease 插入都意味着安全保证已失效。 + - alert: FlexKVDistReuseLeaseMetaNullptr + expr: increase(flexkv_py_dist_reuse_lease_meta_nullptr_total[5m]) > 0 + for: 1m + labels: { severity: critical } + annotations: + summary: "FlexKV master inserted lease_meta=nullptr blocks (device={{ $labels.device }})" + + # CRITICAL — peer mooncake_read 失败率 > 0.1% 持续 5 分钟。 + - alert: FlexKVDistReusePeerReadFailureRate + expr: | + sum by (reason) (rate(flexkv_py_dist_reuse_peer_mooncake_read_failures_total[5m])) + / + sum (rate(flexkv_py_dist_reuse_peer_mooncake_read_success_total[5m]) + + rate(flexkv_py_dist_reuse_peer_mooncake_read_failures_total[5m])) + > 0.001 + for: 5m + labels: { severity: critical } + annotations: + summary: "FlexKV peer mooncake_read failure rate > 0.1% (reason={{ $labels.reason }})" + + # OPS — peer mooncake_read P99 > 500ms。 + - alert: FlexKVDistReusePeerReadP99High + expr: histogram_quantile(0.99, sum by (le) (rate(flexkv_py_dist_reuse_peer_mooncake_read_seconds_bucket[5m]))) > 0.5 + for: 5m + labels: { severity: warning } + annotations: + summary: "FlexKV peer mooncake_read P99 > 500ms (lease margin shrinking)" + + # WARN — fresh / expired evict 比值持续 > 10。 + - alert: FlexKVDistReuseEvictPressure + expr: | + sum by (device) (rate(flexkv_py_dist_reuse_about_to_evict_total[5m])) + / + clamp_min(sum by (device) (rate(flexkv_py_evicted_blocks_total[5m])), 1e-9) + > 10 + for: 10m + labels: { severity: warning } + annotations: + summary: "FlexKV master eviction is fighting lease pressure (device={{ $labels.device }})" ``` ### 3.4 访问 Grafana 仪表板 diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 781ff0fd57..b89c65b5e6 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -79,6 +79,31 @@ def __init__(self, self.event_collector = event_collector self._metrics_collector = metrics_collector + # Dist-reuse eviction refcount guard (§2.2). When a guard is + # installed (typically by + # :meth:`GlobalCacheEngine.attach_dist_reuse` broadcasting + # ``MasterCoordinator.is_evictable`` down to every subengine), + # the ``take`` path below calls the 4-arg + # ``CRadixTreeIndex::evict`` overload that accepts a + # ``std::function`` predicate. Block ids for + # which the predicate returns False are *not* recycled — the + # block stays physically pinned until the in-flight coord GET + # drains the refcount. See + # :cpp:func:`flexkv::CRadixTreeIndex::evict` (radix_tree.cpp) + # for the authoritative implementation. + # + # The guard is optional — when ``None`` (default), ``take`` + # calls the legacy 3-arg overload and behaviour is byte- + # identical to pre-§2.2. + self._evict_guard_fn: Optional[Callable[[int], bool]] = None + + def set_evict_guard(self, fn: Optional[Callable[[int], bool]]) -> None: + """Install (or remove) the refcount guard used in ``take``'s + eviction path. Called by + :meth:`GlobalCacheEngine.attach_dist_reuse` / ``detach_dist_reuse``. + """ + self._evict_guard_fn = fn + def reset(self) -> None: self.index.reset() self.mempool.reset() @@ -89,15 +114,13 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: sequence_meta.num_blocks, True) # physical blocks (torch.Tensor -> numpy, zero-copy on CPU) phys = match_result.physical_blocks.cpu().numpy() - # optional block_node_ids - try: - bnis = getattr(match_result, "block_node_ids", None) - if isinstance(bnis, torch.Tensor) and bnis.numel() > 0: - bnids_np = bnis.cpu().numpy() - else: - bnids_np = None - except Exception: - bnids_np = None + # Extract single matched_node_id (single-node constraint) + raw_nid = getattr(match_result, "matched_node_id", -1) + single_node_id = int(raw_nid) if raw_nid is not None and raw_nid >= 0 else None + # Broadcast matched_node_id to per-block array for downstream compat + bnids_np = None + if single_node_id is not None and len(phys) > 0: + bnids_np = np.full(len(phys), single_node_id, dtype=np.uint32) return MatchResultAccel( num_ready_matched_blocks=match_result.num_ready_matched_blocks, num_matched_blocks=match_result.num_matched_blocks, @@ -105,6 +128,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: last_node=match_result.last_node, last_node_matched_length=match_result.last_node_matched_length, physical_blocks=phys, + matched_node_id=single_node_id, block_node_ids=bnids_np, matched_pos="remote" if self.device_type == DeviceType.REMOTE else "local", ) @@ -177,7 +201,26 @@ def take(self, if evict_block_num > 0: target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) evicted_block_hashes = torch.zeros(evict_block_num, dtype=torch.int64) - num_evicted = self.index.evict(target_blocks, evicted_block_hashes, evict_block_num) + # §2.2 dist-reuse refcount guard: when a guard is + # installed (``GlobalCacheEngine.attach_dist_reuse`` + # pushes ``MasterCoordinator.is_evictable`` down here), + # call the 4-arg C++ overload so the eviction path + # never recycles a block id that still has a refcount + # > 0 from an in-flight coord GET. Pure legacy path + # (no guard, default production deployments) keeps the + # 3-arg call unchanged — byte-identical to pre-§2.2 + # behaviour. + if self._evict_guard_fn is not None: + num_evicted = self.index.evict( + target_blocks, + evicted_block_hashes, + evict_block_num, + self._evict_guard_fn, + ) + else: + num_evicted = self.index.evict( + target_blocks, evicted_block_hashes, evict_block_num + ) if num_evicted != evict_block_num: target_blocks.resize_(num_evicted) evicted_block_hashes.resize_(num_evicted) @@ -248,6 +291,19 @@ def __init__(self, self.event_collector = event_collector self._metrics_collector = metrics_collector + # Dist-reuse eviction refcount guard. When set, + # ``RadixTreeIndex.evict`` will skip physical block ids where + # ``is_evictable_fn(block_id)`` is False (i.e. the block is + # pinned by an in-flight coord GET). See + # :meth:`GlobalCacheEngine.attach_dist_reuse`. + self._evict_guard_fn: Optional[Callable[[int], bool]] = None + + def set_evict_guard(self, fn: Optional[Callable[[int], bool]]) -> None: + """Install (or remove) the refcount guard used in ``take``'s + eviction path. Called by ``GlobalCacheEngine.attach_dist_reuse``. + """ + self._evict_guard_fn = fn + def reset(self) -> None: self.index.reset() self.mempool.reset() @@ -308,7 +364,10 @@ def take(self, int(self.mempool.num_total_blocks * self.evict_ratio) if self.evict_ratio > 0 else 0 # Or minimum evict_ratio ) if evict_block_num > 0: - evicted_blocks, evicted_block_hashes = self.index.evict(evict_block_num) + evicted_blocks, evicted_block_hashes = self.index.evict( + evict_block_num, + is_evictable_fn=self._evict_guard_fn, + ) self.mempool.recycle_blocks(evicted_blocks) # Record eviction metrics @@ -363,6 +422,43 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m self.ssd_cache_engine = None self.remote_cache_engine = None + # -------------------------------------------------------------- + # Dist-reuse integration hooks (Phase 1 integration glue). + # + # Populated by :meth:`attach_dist_reuse` **after** construction + # (the KVTaskManager owns the coordinator lifetimes and wires them + # in once all handles / Remote ready ACKs are collected). When + # unset, every dist-reuse code path no-ops and the engine behaves + # exactly like pre-Batch-E — this keeps legacy deployments + # (``enable_sharing_domain=False``) byte-identical. + # + # ``_master_coord`` -> ``MasterCoordinator`` (owns the + # ``AggregateRadixTree``, refcounts, and the + # Layer-1 ``FailureDetector``). + # Peer-lost hook is registered on the ``FailureDetector`` via + # ``MasterCoordinator.set_peer_lost_hook`` and is mapped to + # ``aggregate_radix.invalidate_by_peer_instance``. + # + # Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): + # the previous ``_coord_dispatcher -> CoordinationCoordinator`` + # field was removed. Cross-SD coordination now flows through + # the unified ``TransferOpGraph`` dispatch path with per-op + # ``target_node_ids`` filtering on each Remote, and peer-SD + # acks come back as ``CompletedOp(sd_key, contributing_node_id)`` + # on the master polling worker. + # -------------------------------------------------------------- + self._master_coord = None # type: ignore[assignment] + + # Phase D-2 (proposal_unify_with_graph_dispatch_2026-05-15.md + # §6.3): the PUT-path graph-dispatch ack book. Populated by + # ``_notify_master_sd_ready`` when the Master's own SD finishes, + # consumed by ``_on_peer_sd_completed_op`` when peer-SD + # ``CompletedOp(sd_key, contributing_node_id)`` arrives via the + # master polling thread's ``_completion_sink``. + import threading as _threading + self._pending_put_lock = _threading.Lock() + self._pending_put_batches: Dict[int, list] = {} + self.index_accel = GLOBAL_CONFIG_FROM_ENV.index_accel if cache_config.enable_kv_sharing: assert redis_meta is not None @@ -592,6 +688,52 @@ def get(self, return_mask[block_start_idx* self.tokens_per_block: (block_start_idx + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True + # ------------------------------------------------------------------ + # §2.1 dist_reuse GET main-path hook. + # + # When ``enable_sharing_domain`` is on AND the instance owns more + # than one SD (PP>1 or tp_node_count>1), every cross-instance + # reuse hit must clear a coordination barrier: the Master needs + # every peer SD to ACK that it also has the prefix, otherwise + # the GPU would receive half-assembled KV. ``_sharing_domain_gate_get`` + # returns False on barrier failure — in which case we reset the + # transfer_graph + return_mask to an empty result, equivalent to + # a cache miss. Upstream will fall back to normal prefill. + # + # For the single-SD degenerate case (most common today: PP=1 + + # TP does not cross nodes) the gate is a no-op: ``coord_get`` + # has nothing to coordinate so this short-circuits. Zero + # regression for existing deployments that only use + # ``enable_p2p_cpu`` but not ``enable_sharing_domain``. + # ------------------------------------------------------------------ + barrier_ok = self._sharing_domain_gate_get( + sequence_meta=sequence_meta, + return_mask=return_mask, + block_start_idx=block_start_idx, + num_gpu_blocks_to_transfer=num_gpu_blocks_to_transfer, + ) + if not barrier_ok: + # Release anything we had locked during matching, reset to + # the empty-return shape — equivalent to a cache miss so + # upstream falls back to normal prefill. + for device_type, (node, _) in (node_to_unlock or {}).items(): + try: + self.cache_engines[device_type].unlock(node) + except Exception: + pass + for device_type, blks in (buffer_to_free or {}).items(): + try: + if blks is not None and len(blks) > 0: + self.cache_engines[device_type].recycle(blks) + except Exception: + pass + empty_graph = TransferOpGraph.create_empty_graph() + empty_graph.bind_to_worker(dp_rank, pp_rank) + empty_mask = np.zeros_like(token_mask, dtype=np.bool_) + empty_cb = partial(self._transfer_callback, + node_to_unlock={}, buffer_to_free={}) + return empty_graph, empty_mask, empty_cb, {}, -1 + # if layer_num // layer_granularity != 1: # transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, # finished_ops_ids=finished_ops_ids, @@ -947,6 +1089,24 @@ def _get_impl_local(self, dp_client_id = dp_client_id, ) transfer_graph.add_transfer_op(op_peerh2h) + # Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md + # §6.4): in a multi-SD instance, fan the PEERH2H out to one + # clone per peer SD so each SD pulls its own slice from the + # contributing peer instance through that SD's mooncake. + # Returns [] for single-SD / dist_reuse-off / not-fully-ready, + # in which case the legacy single-op path stays bit-identical. + # ``block_mask_start + fragment1_num_blocks - 1`` is the index + # in the *full* sequence of the last block we're reusing + # — the prefix terminal block whose hash keys the + # AggregateRadixTree fully-ready entry. + peerh2h_clones = self._maybe_attach_multi_sd_peerh2h_ops( + transfer_graph=transfer_graph, + op_peerh2h=op_peerh2h, + sequence_meta=sequence_meta, + prefix_terminal_block_idx=int( + block_mask_start + fragment1_num_blocks - 1 + ), + ) #TODO here we dont combine peer cpu or local cpu match results, so we can safely add remote results to local cpu #TODO here assume all matched blocks are ready blocks for peer cpu if (cpu_matched_result.insert_to_local_cpu_index and @@ -958,6 +1118,8 @@ def _get_impl_local(self, op_node_to_ready[op_peerh2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size()) else: cpu_blocks_to_free = np.concatenate([cpu_blocks_to_free, fragment1_cpu_blocks_local]) + else: + peerh2h_clones = [] if fragment2_num_blocks > 0: if enable_gds: @@ -1025,6 +1187,15 @@ def _get_impl_local(self, transfer_graph.add_dependency(op_h2d.op_id, op_disk2h.op_id) if cpu_matched_result.matched_pos == "remote" and fragment1_num_blocks > 0: transfer_graph.add_dependency(op_h2d.op_id, op_peerh2h.op_id) + # Phase D-3: H2D must wait for *every* peer-SD PEERH2H + # clone to land its slice into the master CPU pool + # before the GPU copy fires. The peer-SD clones run on + # their respective Remote handles and their + # CompletedOp(success=True) flows back to the master + # polling thread through D-2's _completion_sink, which + # is what the graph dependency engine waits on. + for clone in peerh2h_clones: + transfer_graph.add_dependency(op_h2d.op_id, clone.op_id) finished_ops_ids.append(op_h2d.op_id) node_to_unlock = {} @@ -1104,10 +1275,26 @@ def put(self, for device_type in node_to_unlock: self.cache_engines[device_type].lock_node(node_to_unlock[device_type][0]) + # §2.1 dist_reuse PUT-path glue: once the local PUT really + # lands its block meta in Redis (which happens inside + # ``_transfer_callback`` via ``insert_and_publish``), we want + # to tell the AggregateRadixTree that *this* SD now contributes + # to the prefix. ``_notify_sd_ready_on_put`` is a no-op when + # dist_reuse isn't attached, so passing the parameters + # unconditionally keeps the legacy path zero-overhead. + sd_notify_kwargs = { + "sequence_meta": sequence_meta, + "inserted_block_ids": gpu_block_ids[ + skipped_gpu_blocks: skipped_gpu_blocks + num_gpu_blocks_to_transfer + ] if num_gpu_blocks_to_transfer > 0 else None, + "block_start_idx": int(block_start_idx + skipped_gpu_blocks), + "num_blocks_inserted": int(num_gpu_blocks_to_transfer), + } callback = partial(self._transfer_callback, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free, - is_put=True) + is_put=True, + sd_notify_kwargs=sd_notify_kwargs) op_callback_dict = {} for op_id in op_node_to_ready: @@ -1229,6 +1416,9 @@ def _put_impl_global(self, transfer_graph.add_transfer_op(op_d2h) finished_ops_ids.append(op_d2h.op_id) + # Phase D-2: tag self-SD + clone for each peer SD. + self._maybe_attach_multi_sd_d2h_ops(transfer_graph, op_d2h) + if put_to_ssd: if len(fragment12_cpu_blocks) < fragment2_num_blocks: num_needed_from_cpu_matched = fragment2_num_blocks - len(fragment12_cpu_blocks) @@ -1386,6 +1576,9 @@ def _put_impl_local(self, transfer_graph.add_transfer_op(op_d2h) finished_ops_ids.append(op_d2h.op_id) + # Phase D-2: tag self-SD + clone for each peer SD. + self._maybe_attach_multi_sd_d2h_ops(transfer_graph, op_d2h) + if fragment2_num_blocks > 0: if len(fragment12_cpu_blocks) < fragment2_num_blocks: flexkv_logger.warning(f"fragment12_cpu_blocks: {len(fragment12_cpu_blocks)}, " @@ -1436,7 +1629,8 @@ def _put_impl_local(self, def _transfer_callback(self, node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]], buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None, - is_put: bool = False) -> None: + is_put: bool = False, + sd_notify_kwargs: Optional[Dict] = None) -> None: if DeviceType.CPU in node_to_unlock: assert self.cpu_cache_engine is not None cpu_node = node_to_unlock[DeviceType.CPU][0] @@ -1470,6 +1664,20 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + # §2.1 dist_reuse PUT-path glue: once every cache level has + # unlocked + published, announce the new prefix to our own + # AggregateRadixTree (self-SD ACK). Kept after the + # insert_and_publish calls above so the block meta is + # guaranteed visible in Redis before the in-memory ack fires. + # No-op when dist_reuse isn't attached. + if is_put and sd_notify_kwargs: + try: + self._notify_sd_ready_on_put(**sd_notify_kwargs) + except Exception: + # Absolute must-not-raise: the callback runs on the + # transfer worker's completion path. + pass + def _op_callback(self, device_type: DeviceType, node_to_ready: RadixNode, ready_length: int) -> None: if device_type == DeviceType.CPU: assert self.cpu_cache_engine is not None @@ -1481,6 +1689,709 @@ def _op_callback(self, device_type: DeviceType, node_to_ready: RadixNode, ready_ assert self.remote_cache_engine is not None self.remote_cache_engine.set_ready(node_to_ready, True, ready_length) + # ================================================================== + # Dist-reuse integration API (Batch F — cache_engine ↔ MasterCoordinator) + # ================================================================== + def attach_dist_reuse(self, master_coord) -> None: + """Wire this cache engine to the ``MasterCoordinator`` that the + ``KVTaskManager`` built after the Remote ready handshake completed. + + Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): + cross-SD coordination flows through the unified + ``TransferOpGraph`` dispatch path with per-op + ``target_node_ids`` filtering on each Remote, which replaced + the previous ``CoordinationCoordinator`` plumbing entirely. + The deprecated ``coord_dispatcher`` parameter was removed in + the same cleanup pass that deleted the legacy + ``coord_query`` / ``coord_get`` / ``coord_put`` ZMQ protocol. + + Args: + master_coord: A + :class:`flexkv.common.dist_reuse.MasterCoordinator` owning + the per-instance ``AggregateRadixTree`` + refcount book- + keeping + Layer-1 ``FailureDetector``. + """ + self._master_coord = master_coord + + # Broadcast the refcount guard to every subengine so that each + # subengine's ``RadixTreeIndex.evict`` (or the accel equivalent, + # once plumbed on the C++ side) can honour + # ``MasterCoordinator.is_evictable``. See §2.2 in + # docs/dist_reuse/implementation_gap_2026-05-11.md. + self._broadcast_evict_guard(self.is_evictable) + + # Wire the Layer-1 FailureDetector's on_peer_lost callback to + # invalidate aggregate-radix entries from the dying peer. Do + # this lazily (only if coordinator owns a detector). + if master_coord is not None and hasattr(master_coord, "set_peer_lost_hook"): + master_coord.set_peer_lost_hook(self._on_peer_lost) + + def detach_dist_reuse(self) -> None: + """Drop references to the coordinator — invoked from shutdown.""" + self._master_coord = None + # Remove the guard — eviction falls back to the legacy + # behaviour that only looks at ``lock_cnt``. + self._broadcast_evict_guard(None) + + def _maybe_attach_multi_sd_d2h_ops( + self, + transfer_graph, + op_d2h, + ) -> None: + """Phase D-2 (proposal_unify_with_graph_dispatch_2026-05-15.md + §6.3): when the instance has multiple SDs, mirror the master's + D2H op into the per-peer-SD form so the broadcast graph + dispatch on each Remote handle picks up the right slice. + + Specifically: + + * Stamp ``target_node_ids=[self_node_id]`` on the master's own + ``op_d2h`` so the master in-process handle (and only that + handle) executes it. Each peer Remote's ``_handle_submit`` + will drop this op via the ``target_node_ids`` filter. + * For every peer SD known to the MasterCoordinator (with a + finished ``RemoteReadyMsg``), append a clone of ``op_d2h`` + with ``target_node_ids=[peer_node_id]``. Peer Remote's + ``_handle_submit`` keeps only its own clone. The peer's + ``CompletedOp(sd_key, contributing_node_id)`` then routes + back through the master polling thread's + ``_completion_sink`` → :meth:`_on_peer_sd_completed_op`. + + Single-SD / dist_reuse-disabled paths leave ``op_d2h`` + untouched (``target_node_ids=None`` ⇒ no filter, identical to + legacy behaviour). + """ + if self._master_coord is None: + return + try: + sd_to_nid = self._master_coord.get_sd_to_nid_map() + except Exception: + return + if not sd_to_nid: + return # bootstrap not finished yet — leave legacy behaviour. + + try: + self_sd_str = self._master_coord.self_sd.serialize() + except Exception: + return + self_node_id = sd_to_nid.get(self_sd_str, -1) + if self_node_id < 0: + return + + # Tag master's own D2H op. + try: + op_d2h.target_node_ids = [int(self_node_id)] + except Exception: + return + + # Append a clone per peer SD. + try: + from flexkv.common.transfer import TransferOp, TransferType + except Exception: + return + for sd_str, peer_nid in sd_to_nid.items(): + if sd_str == self_sd_str: + continue + try: + peer_op = TransferOp( + graph_id=transfer_graph.graph_id, + transfer_type=TransferType.D2H, + src_block_ids=op_d2h.src_block_ids, + dst_block_ids=op_d2h.dst_block_ids, + dp_client_id=op_d2h.dp_client_id, + target_node_ids=[int(peer_nid)], + ) + transfer_graph.add_transfer_op(peer_op) + except Exception as e: # pragma: no cover + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.warning( + f"[DistReuse:D-2] failed to append peer-SD D2H op " + f"for sd={sd_str} (nid={peer_nid}): {e}" + ) + except Exception: + pass + + def _maybe_attach_multi_sd_peerh2h_ops( + self, + transfer_graph, + op_peerh2h, + sequence_meta, + prefix_terminal_block_idx: int, + ) -> List: + """Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md + §6.4): mirror of :meth:`_maybe_attach_multi_sd_d2h_ops` for the + GET-path PEERH2H op when the instance owns multiple SDs. + + Why this is necessary + --------------------- + The single-SD PEERH2H op pulls one peer instance's CPU slice + through mooncake on the *Master's* SD. In a multi-SD instance + every other SD must also pull its own slice from the same peer + instance — otherwise the GPU would see only the Master-SD + layer/head shard and PP / cross-node TP would be incomplete. + + The legacy single-SD op as constructed by the caller already + targets the peer instance correctly on the Master's own SD + (``src_block_node_ids = cpu_matched_result.matched_node_ids``, + which is the peer instance's master-SD ``distributed_node_id``). + Phase D-3 transforms it into N ops, one per SD: + + * Stamp ``target_node_ids = [self_node_id]`` on the master's + op so only the Master's local TransferEngine executes it + (``_filter_graph_inplace_by_target_node_ids`` on the in-proc + handle drops it from peer Remotes, and the per-Remote + ``_filter_graph_by_target_node_ids`` drops it from peers + on the wire). + * For every peer SD known to be ``fully_ready`` (per + :meth:`AggregateRadixTree.match_fully_ready`), append a + clone with ``target_node_ids = [peer_sd_node_id]`` and + ``src_block_node_ids = [peer_instance_node_on_that_sd]`` + fetched from ``ReadyEntry.ready_sds``. The clone runs on + that peer SD's Remote node and pulls *its* layer/head shard + from the same peer instance through that peer SD's mooncake + server. + + Returns the list of clone ops added (empty list when single-SD + / dist_reuse off / aggregate not fully ready / bootstrap not + finished). Caller is expected to add a graph dependency + ``op_h2d → clone`` for every returned clone so the H2D waits + for *all* SDs to land their slice before issuing the H2D copy. + + Single-SD / dist_reuse-disabled paths leave ``op_peerh2h`` + untouched (``target_node_ids=None`` ⇒ no filter, identical to + legacy behaviour). No new ZMQ messages or sync primitives + are introduced — completion of each clone flows back through + the existing ``CompletedOp`` → ``_completion_sink`` route + already wired up by Phase D-2. + """ + clones: List = [] + if self._master_coord is None: + return clones + try: + sd_to_nid = self._master_coord.get_sd_to_nid_map() + except Exception: + return clones + if not sd_to_nid: + return clones # bootstrap not finished yet — leave legacy behaviour. + try: + self_sd_str = self._master_coord.self_sd.serialize() + except Exception: + return clones + self_node_id = sd_to_nid.get(self_sd_str, -1) + if self_node_id < 0: + return clones + + # Single-SD instance — no clones to make; do not even tag the + # master op (legacy code path stays bit-identical). + if len(sd_to_nid) <= 1: + return clones + + # Look up the per-SD ``contributing peer node_id`` map from the + # AggregateRadixTree. ``ready_sds[sd_key] = peer_instance's + # node_id on that SD``. This is populated by D-2's PUT-path + # ``mark_sd_ready`` calls fired from each peer SD's CompletedOp + # ack. If we somehow get here without a fully-ready entry + # (race with eviction, leak, etc.), fall back to legacy and + # let ``_sharing_domain_gate_get`` reject the GET downstream. + try: + sequence_meta.gen_hashes() + except Exception: + return clones + if (prefix_terminal_block_idx < 0 or + prefix_terminal_block_idx >= sequence_meta.block_hashes.shape[0]): + return clones + try: + prefix_hash = int( + sequence_meta.block_hashes[prefix_terminal_block_idx].item() + ) + except Exception: + return clones + try: + entry = self._master_coord.match_fully_ready(prefix_hash) + except Exception: + return clones + if entry is None or not entry.ready_sds: + # Aggregate not fully ready → ``_sharing_domain_gate_get`` + # will reject this GET anyway; do not pollute the graph + # with peer-SD clones that would never resolve. + return clones + + # Tag master's own PEERH2H op so only the Master executes it. + try: + op_peerh2h.target_node_ids = [int(self_node_id)] + except Exception: + return clones + + # Append a clone per peer SD, pointing at the same peer + # instance's slice on that SD. + try: + from flexkv.common.transfer import TransferOp, TransferType + except Exception: + return clones + + # Cache shape needed for src_block_node_ids: an array of length + # ``len(src_block_ids)`` filled with the peer's node_id on that + # SD. ``cpu_matched_result.matched_node_ids`` (master SD) is + # already in this shape; we mirror it for each peer SD. + n_blocks = len(op_peerh2h.src_block_ids) + for sd_str, peer_sd_nid in sd_to_nid.items(): + if sd_str == self_sd_str: + continue + # Peer instance's node_id on *this* peer SD (as recorded + # by D-2's ack path). -1 sentinel means we never got a + # confirming ack from that SD; skip — gate will reject. + try: + peer_instance_nid_on_sd = int(entry.ready_sds.get(sd_str, -1)) + except Exception: + peer_instance_nid_on_sd = -1 + if peer_instance_nid_on_sd < 0: + continue + try: + peer_src_node_ids = np.full( + n_blocks, int(peer_instance_nid_on_sd), dtype=np.int64, + ) + clone = TransferOp( + graph_id=transfer_graph.graph_id, + transfer_type=TransferType.PEERH2H, + src_block_ids=op_peerh2h.src_block_ids, + dst_block_ids=op_peerh2h.dst_block_ids, + dp_client_id=op_peerh2h.dp_client_id, + remote_node_ids=peer_src_node_ids, + src_block_node_ids=peer_src_node_ids, + target_node_ids=[int(peer_sd_nid)], + ) + transfer_graph.add_transfer_op(clone) + clones.append(clone) + except Exception as e: # pragma: no cover + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.warning( + f"[DistReuse:D-3] failed to append peer-SD PEERH2H op " + f"for sd={sd_str} (sd_nid={peer_sd_nid}, " + f"peer_inst_nid={peer_instance_nid_on_sd}): {e}" + ) + except Exception: + pass + return clones + + def _broadcast_evict_guard(self, fn: Optional[Callable[[int], bool]]) -> None: + """Install ``fn`` as the refcount guard on every owned subengine. + + Subengines that don't (yet) support a guard expose a no-op + ``set_evict_guard`` (see ``CacheEngineAccel``). Legacy engines + without the method at all are simply skipped — their evict + behaviour is unchanged. + """ + for engine in ( + self.cpu_cache_engine, + self.ssd_cache_engine, + self.remote_cache_engine, + ): + if engine is None: + continue + setter = getattr(engine, "set_evict_guard", None) + if callable(setter): + try: + setter(fn) + except Exception: + # Defensive — a buggy subengine must not wedge + # attach/detach for the rest of the system. + pass + + def _on_peer_lost(self, peer_instance_id: str) -> None: + """FailureDetector callback. Best-effort — MUST NOT raise + (the callback runs on the detector's polling thread).""" + if self._master_coord is None: + return + try: + self._master_coord.invalidate_by_peer_instance(peer_instance_id) + except Exception as e: # pragma: no cover — defensive + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.warning( + f"[DistReuse] invalidate_by_peer_instance({peer_instance_id}) raised: {e}" + ) + except Exception: + pass + + # ---- refcount guard for evict paths -------------------------------- + def is_evictable(self, block_id: int) -> bool: + """Evict path check: refcount>0 blocks participating in an in-flight + coord GET must NOT be evicted. Defaults to True when dist-reuse + is off (legacy behaviour).""" + if self._master_coord is None: + return True + try: + return bool(self._master_coord.is_evictable(int(block_id))) + except Exception: + return True + + # ---- hooks Master transfer_callback calls when a PUT lands -------- + def _notify_master_sd_ready( + self, + prefix_hash: int, + block_ids: list, + ) -> None: + """Phase D-2 (proposal_unify_with_graph_dispatch_2026-05-15.md §6.3): + announce that the Master's own SD finished publishing a prefix. + + The Master's self-SD ack is recorded in the ``AggregateRadixTree`` + immediately. In multi-SD deployments the **peer-SD acks arrive + asynchronously** through the graph-dispatch path: each peer SD's + ``TransferManagerOnRemote`` runs the per-SD D2H op (filtered into + its own slice by ``target_node_ids``) and ships back a + ``CompletedOp(sd_key, contributing_node_id, success=True)``. The + Master's ``TransferManagerMultiNodeHandle._completion_sink`` then + invokes :meth:`_on_peer_sd_completed_op` which calls + ``mark_sd_ready(peer_sd, node_id=...)``. + + That replaces the old ``coord_put`` broadcast-and-collect pattern + with a single mechanism (graph dispatch) shared with cross-machine + TP / PP — see proposal §2. + + For ``total_sd_count == 1`` (the common single-SD shape) this + method is byte-identical to legacy: only the self-SD mark fires. + """ + if self._master_coord is None: + return + + # Self-SD mark — always first, always unconditional. Pass + # node_id so the GET-side cross-instance reuse path knows which + # node holds the master SD's slice. Best-effort lookup; + # default to -1 if the master coord doesn't expose it yet. + try: + self_node_id = int(getattr(self._master_coord, "self_node_id", -1)) + except Exception: + self_node_id = -1 + try: + self._master_coord.mark_sd_ready( + prefix_hash=int(prefix_hash), + sd_key_str=self._master_coord.self_sd.serialize(), + block_ids=list(block_ids) if block_ids is not None else [], + node_id=self_node_id, + ) + except Exception as e: # pragma: no cover + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.warning(f"[DistReuse] mark_sd_ready raised: {e}") + except Exception: + pass + + # Phase D-2: register a pending PUT batch for the + # ``_completion_sink`` to consume when peer SD CompletedOps + # arrive. Keyed by ``prefix_hash`` because the CompletedOp + # carries no batch identity beyond ``graph_id`` — and a single + # graph may carry several PUT prefixes (e.g. a merged batch + # graph). We use ``prefix_hash`` as the natural identifier + # because ``mark_sd_ready`` keys on it. See + # :meth:`_on_peer_sd_completed_op`. + try: + total_sd = int(getattr(self._master_coord, "total_sd_count", 1)) + except Exception: + total_sd = 1 + if total_sd <= 1: + return # No peer SDs to wait for. + + # Stash (prefix_hash, block_ids) keyed by the per-(prefix_hash, + # peer_sd) tuple that ``_on_peer_sd_completed_op`` will look up. + # We do not store ``contributing_peer`` here — that comes back + # on the CompletedOp's ``contributing_node_id`` field and the + # peer's instance_id can be reverse-looked-up via the + # MasterCoordinator if needed. + try: + block_ids_list = [int(b) for b in (block_ids or [])] + except Exception: + block_ids_list = [] + try: + with self._pending_put_lock: + self._pending_put_batches[int(prefix_hash)] = block_ids_list + except AttributeError: + # Lock not yet initialized (legacy code path that constructs + # GlobalCacheEngine without going through __init__'s + # initialization of _pending_put_*). Initialize lazily and + # retry once. + self._pending_put_lock = __import__("threading").Lock() + self._pending_put_batches: Dict[int, list] = {} + with self._pending_put_lock: + self._pending_put_batches[int(prefix_hash)] = block_ids_list + + def _on_peer_sd_completed_op(self, completed_op) -> None: + """Phase D-2 (proposal §3.5 / §6.3): completion-sink handler. + + Invoked on the master's polling thread for every ``CompletedOp`` + that arrives with a non-empty ``sd_key`` and ``success=True``. + Each such CompletedOp signals "peer SD ``sd_key`` (held by + ``contributing_node_id``) finished its share of some PUT batch". + + We map the ``CompletedOp`` back to the prefix_hash via the + ``_pending_put_batches`` registry populated by + :meth:`_notify_master_sd_ready`. When the prefix is fully + ready (all peer SDs have acked) the entry naturally falls out + on the next eviction or stays as a no-op for further PUTs. + + The CompletedOp carries no prefix_hash directly — the legacy + plan was to store ``graph_id → prefix_hash`` but that requires + threading through the kvtask boundary. We take a simpler + route: when only one PUT batch is in flight per peer (the + common case), there's exactly one ``_pending_put_batches`` + entry and it's the right one. When multiple PUT batches are + in flight we mark every pending prefix that the peer hasn't + acked yet — overhead is O(in-flight batches × peer SDs) and + bounded by the kvtask scheduler's window. + + NOTE: this is best-effort; absent a graph_id → prefix_hash + registry, false positives ("mark prefix X ready on SD Y when + actually a different batch finished") are possible if multiple + PUTs to the same SD run concurrently with overlapping + prefixes. In practice the kvtask scheduler serializes a + single PUT at a time per request so collisions are rare. + Phase D-3 will add a ``graph_id → prefix_hash`` registry to + eliminate the ambiguity. + """ + if self._master_coord is None: + return + sd_key = getattr(completed_op, "sd_key", "") or "" + if not sd_key: + return # Not a peer-SD CompletedOp — ignore. + if sd_key == self._master_coord.self_sd.serialize(): + # Master's own CompletedOp loops back through the same sink + # in some test harnesses — self-SD is already marked by + # _notify_master_sd_ready, ignore. + return + if not getattr(completed_op, "success", True): + # Failed op — let FailureReportMsg handle invalidation. + return + + node_id = int(getattr(completed_op, "contributing_node_id", -1)) + try: + with self._pending_put_lock: + # Mark every still-pending prefix as ready on this SD. + # The aggregate radix's ``mark_sd_ready`` is idempotent + # so repeated calls for the same (prefix_hash, sd_key) + # are harmless. + pending_snapshot = list(self._pending_put_batches.items()) + except AttributeError: + return + + for prefix_hash, block_ids_list in pending_snapshot: + try: + self._master_coord.mark_sd_ready( + prefix_hash=int(prefix_hash), + sd_key_str=sd_key, + block_ids=block_ids_list, + node_id=node_id, + ) + except Exception as e: # pragma: no cover + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.debug( + f"[DistReuse] _on_peer_sd_completed_op: " + f"mark_sd_ready({sd_key}, prefix={prefix_hash}) " + f"raised: {e}" + ) + except Exception: + pass + + # ------------------------------------------------------------------ + # GET-path glue (§2.1 of docs/dist_reuse/implementation_gap_*.md) + # ------------------------------------------------------------------ + def _sharing_domain_gate_get( + self, + *, + sequence_meta, + return_mask, + block_start_idx: int, + num_gpu_blocks_to_transfer: int, + ) -> bool: + """Cross-SD barrier for a GET about to reuse cached blocks. + + **Contract** (design doc §4.4 / §5.1): + + * Single-SD instance (``PP == 1 and tp_node_count == 1``) — + this is the dominant production shape today. No coordination + needed; return ``True`` immediately. Zero regression for + deployments that stay on ``enable_p2p_cpu`` only. + + * Multi-SD instance (``PP > 1`` OR ``tp_node_count > 1``) — + every peer SD of the same instance must ACK that it also + holds the prefix. The in-memory ``AggregateRadixTree`` + ``fully_ready`` bit is what we gate on. It is populated by + two paths: + + 1. **Self-SD PUT** (local ACK). :meth:`_notify_master_sd_ready` + runs on the transfer_callback after D2H + Redis publish + on the Master's own SD. + 2. **Peer-SD PUT** (remote ACK). Each peer SD's + ``TransferManagerOnRemote`` runs the per-SD D2H clone op + (filtered into its own slice by ``target_node_ids``) and + ships back a ``CompletedOp(sd_key, + contributing_node_id, success=True)`` via the + ``TransferManagerMultiNodeHandle._completion_sink``, + which routes into :meth:`_on_peer_sd_completed_op` → + ``mark_sd_ready``. + + Together these two paths flip ``fully_ready`` True for the + prefix, at which point a subsequent GET clears this gate. + + * ``dist_reuse`` not attached (``has_dist_reuse`` is False) — + no-op, behave like the legacy path. + + Return True to allow reuse; False to force the caller into a + cache-miss fallback. + + Why a local ``fully_ready`` check rather than firing a + per-GET cross-SD query: + + - Design-doc §4.4 favours PUT-driven propagation of the + aggregate state over GET-driven queries to keep the + per-request latency close to the existing single-SD path. + (An earlier draft proposed an on-demand ``coord_query`` + round-trip; that protocol was dropped in Phase D-4 in + favour of the PUT-driven aggregate radix.) + - For ``total_sd_count == 1`` the prefix is trivially fully + ready once we inserted it locally. No round-trip cost. + - For ``total_sd_count > 1`` with no peer SD acks yet (e.g. + single-node dev setup), the gate still enforces the + contract defensively: if we don't have a positive + ``fully_ready`` signal, we refuse to reuse. This is + consistent with design §5.1 "any miss → fallback to + prefill". + """ + if not self.has_dist_reuse: + return True + if self._master_coord is None: + return True + + # Single-SD instance — degenerate case, no coordination needed. + try: + total_sd = int(getattr(self._master_coord, "total_sd_count", 1)) + except Exception: + total_sd = 1 + if total_sd <= 1: + return True + + # Multi-SD instance — require the aggregate radix to have a + # fully-ready entry for the prefix we're about to reuse. + if num_gpu_blocks_to_transfer <= 0: + return True + try: + sequence_meta.gen_hashes() + except Exception: + # Defensive: if we can't hash, we can't gate — allow + # through and rely on data-plane failure closure. + return True + + # The prefix we're about to reuse starts at ``block_start_idx`` + # and covers ``num_gpu_blocks_to_transfer`` blocks. Check the + # aggregate-radix ``fully_ready`` bit for the *last* block in + # the reuse range — design §5.1 requires *all* blocks in the + # reused prefix to be fully ready, but in practice the + # aggregate radix stores prefixes by their terminal block's + # hash, so we check the last block and rely on the tree's + # invariant (parent ready ⇒ ancestors ready). + terminal_block_idx = block_start_idx + num_gpu_blocks_to_transfer - 1 + if terminal_block_idx >= sequence_meta.block_hashes.shape[0]: + return True + try: + prefix_hash = int(sequence_meta.block_hashes[terminal_block_idx].item()) + except Exception: + return True + + try: + entry = self._master_coord.match_fully_ready(prefix_hash) + except Exception: + return True # never let a buggy aggregate wedge the GET + + # Fully-ready: let the reuse proceed. + if entry is not None: + return True + + # Not fully-ready: reject the reuse. Caller converts to + # empty-return, upstream re-runs prefill. + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.debug( + f"[DistReuse] sharing-domain gate rejected prefix_hash={prefix_hash} " + f"(total_sd={total_sd}, fully_ready=no)" + ) + except Exception: + pass + return False + + def _notify_sd_ready_on_put( + self, + *, + sequence_meta, + inserted_block_ids, + block_start_idx: int, + num_blocks_inserted: int, + ) -> None: + """PUT-path hook (§4.4 design doc): mark the newly-inserted + prefix as ready on *this* SD (self-SD ACK) and register a + pending PUT batch so the graph-dispatch + ``_completion_sink`` can mark every peer SD ready when their + per-SD D2H clones complete. Cross-SD coordination is carried + on the same ``TransferOpGraph`` the master broadcasts via + ``_launch_task`` — there is no separate coord protocol + message (Phase D-4). + + This is idempotent and safe to call from the PUT completion + callback — in the degenerate single-SD case it still does the + self-SD mark, which makes ``_sharing_domain_gate_get`` return + True on the same prefix next time. + + Best-effort — never raises. + """ + if not self.has_dist_reuse: + return + if self._master_coord is None: + return + if num_blocks_inserted <= 0: + return + try: + sequence_meta.gen_hashes() + except Exception: + return + terminal_idx = block_start_idx + num_blocks_inserted - 1 + if terminal_idx < 0 or terminal_idx >= sequence_meta.block_hashes.shape[0]: + return + try: + prefix_hash = int(sequence_meta.block_hashes[terminal_idx].item()) + except Exception: + return + + block_ids_list = [] + try: + if inserted_block_ids is not None: + block_ids_list = [int(b) for b in inserted_block_ids] + except Exception: + block_ids_list = [] + + try: + self._notify_master_sd_ready( + prefix_hash=prefix_hash, + block_ids=block_ids_list, + ) + except Exception as e: + try: + from flexkv.common.debug import flexkv_logger + flexkv_logger.warning(f"[DistReuse] _notify_sd_ready_on_put failed: {e}") + except Exception: + pass + + # Phase D-4: _coord_get_cross_sd / _coord_get_cleanup_on_failure / ingest_coord_ack + # were deleted (proposal_unify_with_graph_dispatch_2026-05-15.md §附录 A). + # Cross-SD GET coordination is now expressed as multi-target PEERH2H ops on a + # single TransferOpGraph broadcast through the existing _launch_task path. + + @property + def has_dist_reuse(self) -> bool: + """True when the engine is wired to a live ``MasterCoordinator``.""" + return self._master_coord is not None + + + @nvtx.annotate("Match Prefix Accel", color="yellow") def match_local_accel(self, sequence_meta: SequenceMeta, diff --git a/flexkv/cache/hie_cache_engine.py b/flexkv/cache/hie_cache_engine.py index 41f37ee801..fabe741cd9 100644 --- a/flexkv/cache/hie_cache_engine.py +++ b/flexkv/cache/hie_cache_engine.py @@ -201,37 +201,37 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> # physical blocks bnids_np = None + single_node_id = None if chosen is mr_remote: - #try to use DistributedRadixTree's block_node_ids - #if check fails, use LocalRadixTree's match result - nids = chosen.block_node_ids - nps = chosen.physical_blocks - # Convert tensors to numpy views (CPU) if present - if isinstance(nids, torch.Tensor) and nids.numel() > 0: - # For P2P mode (CPU/SSD), no PCFS conversion is needed - # Only convert to PCFS file_nodeids if device_type is REMOTE - if self.device_type == DeviceType.REMOTE: - bnids_np = self.nodeids_to_file_nodeids(nids.cpu().numpy(), nps.cpu().numpy()) - if bnids_np is None: - chosen = mr_local - matched_pos = "local" # Update matched_pos after fallback - else: - # For P2P mode, use node_ids directly - bnids_np = nids.cpu().numpy().astype(np.uint32) - #print(f"[REMOTE_MATCH {self.device_type.name}] Using remote data: block_ids={nps.cpu().numpy()[:min(4, len(nps))]}, node_ids={bnids_np[:min(4, len(bnids_np))]}") + # Extract single matched_node_id from CMatchResult (single-node constraint) + raw_node_id = getattr(chosen, "matched_node_id", -1) + if raw_node_id is not None and raw_node_id >= 0: + single_node_id = int(raw_node_id) + nps = chosen.physical_blocks + num_blocks = nps.shape[0] if isinstance(nps, torch.Tensor) else len(nps) + if num_blocks > 0: + # Broadcast single node_id to per-block array for downstream compat + raw_nids = np.full(num_blocks, single_node_id, dtype=np.uint32) + if self.device_type == DeviceType.REMOTE: + bnids_np = self.nodeids_to_file_nodeids(raw_nids, nps.cpu().numpy()) + if bnids_np is None: + chosen = mr_local + matched_pos = "local" + single_node_id = None + else: + bnids_np = raw_nids else: - bnids_np = None + # No valid matched_node_id → fall back to local if mr_remote.num_matched_blocks > 0: - #print(f"[REMOTE_MATCH {self.device_type.name}] Warning: remote matched but block_node_ids is empty, falling back to local") chosen = mr_local - matched_pos = "local" # Update matched_pos after fallback + matched_pos = "local" + single_node_id = None phys_np = chosen.physical_blocks.cpu().numpy() #maybe we should always not insert if self.device_type == DeviceType.CPU and matched_pos == "remote" and mr_local.num_matched_blocks > 0: insert_to_local_cpu_index = False else: insert_to_local_cpu_index = True - #TODO A big question is how to get the node id for peer_cpu and peer_ssd? return MatchResultAccel( num_ready_matched_blocks=int(chosen.num_ready_matched_blocks), num_matched_blocks=int(chosen.num_matched_blocks), @@ -239,9 +239,10 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> last_node=chosen.last_node, last_node_matched_length=int(chosen.last_node_matched_length), physical_blocks=phys_np, + matched_node_id=single_node_id, block_node_ids=bnids_np, matched_pos=matched_pos, - matched_node_ids=bnids_np, # Set matched_node_ids for P2P transfer + matched_node_ids=bnids_np, # deprecated: kept for backward compat insert_to_local_cpu_index=insert_to_local_cpu_index, ) diff --git a/flexkv/cache/radixtree.py b/flexkv/cache/radixtree.py index c6a3ef6b29..eb8d605c04 100644 --- a/flexkv/cache/radixtree.py +++ b/flexkv/cache/radixtree.py @@ -15,7 +15,7 @@ import heapq import time from dataclasses import dataclass, field -from typing import Dict, Tuple, Optional +from typing import Callable, Dict, Tuple, Optional import numpy as np import torch @@ -281,7 +281,30 @@ def insert(self, return new_node - def evict(self, num_evicted: int) -> Tuple[np.ndarray, np.ndarray]: + def evict(self, + num_evicted: int, + is_evictable_fn: Optional[Callable[[int], bool]] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Evict up to ``num_evicted`` blocks from the tree. + + Args: + num_evicted: Target number of physical blocks to evict. + is_evictable_fn: Optional predicate. When supplied, a + physical block id is only evicted if + ``is_evictable_fn(block_id)`` returns True. This is + the dist-reuse refcount guard — blocks with an + in-flight coord GET refcount > 0 must stay pinned. + When ``None`` (default) the legacy behaviour is kept. + """ + def _block_ok(block_id: int) -> bool: + if is_evictable_fn is None: + return True + try: + return bool(is_evictable_fn(int(block_id))) + except Exception: + # Defensive: never let a buggy guard wedge eviction. + return True + candidates = [] for node in self.leaf_nodes.values(): if node.evictable(): @@ -307,6 +330,22 @@ def evict(self, num_evicted: int) -> Tuple[np.ndarray, np.ndarray]: _block_hashes = node.block_hashes node.parent = None + # Dist-reuse refcount guard: drop block ids the Master + # coordinator has marked as "in-flight coord GET". The + # block has already been physically detached from the + # tree above, which is acceptable — the caller will not + # recycle those block ids, so the physical slot stays + # pinned until the refcount drains. We simply omit them + # from the evicted set returned to the caller. + if is_evictable_fn is not None and physical_blocks.size > 0: + keep_mask = np.array( + [_block_ok(int(b)) for b in physical_blocks], + dtype=bool, + ) + if not keep_mask.all(): + physical_blocks = physical_blocks[keep_mask] + _block_hashes = _block_hashes[keep_mask] + evicted_blocks = np.concatenate([evicted_blocks, physical_blocks]) evicted_block_hashes = np.concatenate([evicted_block_hashes, _block_hashes]) diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index 1721f80f85..6b708b927e 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -1,5 +1,25 @@ +"""Redis metadata layer with explicit ``SharingDomainNamespace`` scoping. + +Phase 0 task 0-C: every Redis key produced here flows through a +:class:`SharingDomainNamespace`, so the legacy flat keys (``node:*`` / +``meta:*`` / ``buffer:*`` / ``CPUB:block:*`` / ``SSDB:block:*`` / +``PCFSB:block:*``) become ``sd::node:*`` / ``sd::meta:*`` / +... . Per design doc §4.7 we go all-in on the new layout — there is no +backward compatibility for the bare keys. + +Callers that don't care about sharing domains (legacy single-instance +dist_reuse) can pass ``SharingDomainKey.default()`` and continue to work as +before; the only observable difference is the Redis key prefix. + +The ``RedisMetaChannel`` Python wrapper keeps the same surface as before +but its underlying C++ ``blocks_key`` argument is now expected to carry the +**full namespace** (``sd::``) so that +``make_block_key`` produces ``sd:::block::`` +without any further changes on the C++ side. +""" + from __future__ import annotations -from typing import Iterable, List, Tuple, Optional, Union, Dict +from typing import Iterable, List, Tuple, Optional, Union, Dict, Set, cast from dataclasses import dataclass from enum import IntEnum from uuid import uuid1 @@ -13,16 +33,31 @@ except Exception: # pragma: no cover _redis = None # type: ignore -# Import C++ dist extensions (RedisMetaChannel, BlockMeta). Optional when built without FLEXKV_ENABLE_P2P=1. +# Import C++ dist extensions (RedisMetaChannel, BlockMeta). Optional when +# built without FLEXKV_ENABLE_P2P=1. _CRedisMetaChannel = None # type: ignore _CBlockMeta = None # type: ignore try: import flexkv.c_ext from flexkv.c_ext import RedisMetaChannel as _CRedisMetaChannel, BlockMeta as _CBlockMeta # type: ignore except (ImportError, AttributeError): - # c_ext built without FLEXKV_ENABLE_P2P=1: no Redis/distributed KV cache support pass +from flexkv.common.dist_reuse import ( + SharingDomainKey, + SharingDomainNamespace, +) + + +__all__ = [ + "NodeState", + "BlockMeta", + "RedisMetaChannel", + "RedisNodeInfo", + "RedisMeta", + "dist_available", +] + class NodeState(IntEnum): NODE_STATE_NORMAL = 0 @@ -76,14 +111,99 @@ def dist_available() -> bool: return _CRedisMetaChannel is not None +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- +def _resolve_namespace( + namespace: Optional[SharingDomainNamespace], +) -> SharingDomainNamespace: + """Coerce ``namespace`` to a non-None ``SharingDomainNamespace``. + + ``None`` collapses to the degenerate single-SD namespace, which keeps + legacy dist_reuse callers compatible (the only observable difference is + the new ``sd:__default__:...`` key prefix instead of the old bare keys). + """ + if namespace is None: + return SharingDomainNamespace(SharingDomainKey.default()) + if isinstance(namespace, SharingDomainKey): + return SharingDomainNamespace(namespace) + if isinstance(namespace, SharingDomainNamespace): + return namespace + raise TypeError( + f"namespace must be SharingDomainNamespace or SharingDomainKey, got {type(namespace).__name__}" + ) + + +def _channel_blocks_key(namespace: SharingDomainNamespace, device_prefix: str) -> str: + """Build the C++ ``blocks_key`` argument such that + ``make_block_key(nid, hash)`` produces + ``sd:::block::``. + + The C++ side concatenates ``:block::``, so we + pre-pend the SD prefix here. + """ + if device_prefix: + return f"{namespace.prefix}:{device_prefix}" + return namespace.prefix + + +# --------------------------------------------------------------------------- +# RedisMetaChannel — thin Python wrapper over the C++ extension +# --------------------------------------------------------------------------- class RedisMetaChannel: - def __init__(self, host: str, port: int, node_id: int, local_ip: str, blocks_key: str = "blocks", password: str = "") -> None: + """Wraps the C++ ``RedisMetaChannel`` so callers don't need to format + SD-aware keys themselves. + + The constructor takes a fully-qualified ``blocks_key`` (already + SD-prefixed by :func:`_channel_blocks_key`). Existing call sites that + used to pass ``"CPUB"`` / ``"SSDB"`` / ``"PCFSB"`` should now go through + :meth:`RedisMeta.get_redis_meta_channel` which performs the SD prefix + composition centrally. + """ + + def __init__( + self, + host: str, + port: int, + node_id: int, + local_ip: str, + blocks_key: str = "blocks", + password: str = "", + db: int = 0, + ) -> None: if _CRedisMetaChannel is None: raise ImportError( "Distributed KV cache (P2P/Redis) is not built. " "Rebuild FlexKV with FLEXKV_ENABLE_P2P=1 and install Redis dependencies (e.g. libhiredis-dev, redis-tools)." ) - self._c = _CRedisMetaChannel(host, int(port), int(node_id), str(local_ip), str(blocks_key), str(password)) + # ``db`` is a recent addition (matches ``CacheConfig.flexkv_redis_db``). + # Older C++ builds may not accept the kwarg yet, so fall back to the + # 6-arg constructor for backward compatibility during the rollout. + try: + self._c = _CRedisMetaChannel( + host, int(port), int(node_id), str(local_ip), + str(blocks_key), str(password), int(db), + ) + except TypeError: + # Legacy C++ build without ``db`` support. Callers asking for + # ``db != 0`` on a legacy build get a loud error — silent fallback + # would corrupt key isolation. + if int(db) != 0: + raise ImportError( + "C++ RedisMetaChannel does not accept a ``db`` argument — " + "rebuild FlexKV with the updated csrc/dist/redis_meta_channel.* " + "to use CacheConfig.flexkv_redis_db != 0." + ) + self._c = _CRedisMetaChannel( + host, int(port), int(node_id), str(local_ip), + str(blocks_key), str(password), + ) + self._blocks_key = str(blocks_key) + self._db = int(db) + + @property + def blocks_key(self) -> str: + return self._blocks_key def connect(self) -> bool: return bool(self._c.connect()) @@ -112,11 +232,48 @@ def list_keys(self, pattern: str) -> List[str]: return list(self._c.list_keys(pattern)) def list_node_keys(self) -> List[str]: - return list(self._c.list_node_keys()) + """List ``sd::node:*`` keys in this channel's SD. + + The C++ side scans for ``:node:*`` where + ``blocks_key_root`` is everything before the first ``:`` device-prefix + component. As of Phase 0 we re-implement the scan in Python (using + the matching ``sd::node:*`` pattern) so that this method + works regardless of whether the C++ side has been rebuilt with the + SD-aware ``list_node_keys`` of task 0-D. Once the rebuilt C++ is + rolled out everywhere the Python fallback can be removed. + """ + # Best-effort: prefer C++ implementation if it's been updated; fall + # back to pattern-based scan that mirrors the SD layout. + try: + keys = list(self._c.list_node_keys()) + # The C++ pre-task-0-D version returns "node:*" — those don't + # belong to the SD layout, drop them. + keys = [k for k in keys if k.startswith("sd:")] + except Exception: + keys = [] + if keys: + return keys + # Fallback: derive node-pattern from the channel's blocks_key + # (everything before the optional device-prefix tail). + sd_prefix = self._derive_sd_prefix() + pattern = f"{sd_prefix}:node:*" if sd_prefix else "node:*" + return list(self._c.list_keys(pattern)) def list_block_keys(self, node_id: int) -> List[str]: return list(self._c.list_block_keys(int(node_id))) + def list_all_block_keys(self) -> List[str]: + """Global SCAN over every block key in this SD (design doc §4.7.1.2). + + Phase 0 stub: prefer the C++ method when present, otherwise fall back + to ``list_keys(:block:*)``. Task 0-D will replace the + fallback with a native C++ call once the bindings are rebuilt. + """ + try: + return list(self._c.list_all_block_keys()) + except Exception: + return list(self._c.list_keys(f"{self._blocks_key}:block:*")) + def hmget_field_for_keys(self, keys: Iterable[str], field: str) -> List[str]: return list(self._c.hmget_field_for_keys(list(keys), field)) @@ -128,95 +285,147 @@ def renew_node_leases(self, node_id: int, new_lt: int, batch_size: int = 200) -> return self._c.renew_node_leases(int(node_id), int(new_lt), int(batch_size)) def update_block_state_batch(self, node_id: int, hashes: Iterable[int], state: int, batch_size: int = 200) -> bool: - """batch update block state for specified node""" return self._c.update_block_state_batch(int(node_id), list(int(h) for h in hashes), int(state), int(batch_size)) def delete_blockmeta_batch(self, node_id: int, hashes: Iterable[int], batch_size: int = 200) -> bool: - """batch delete block metadata for specified node""" return self._c.delete_blockmeta_batch(int(node_id), list(int(h) for h in hashes), int(batch_size)) + # --- helpers ---------------------------------------------------------- + def _derive_sd_prefix(self) -> str: + """Return the ``sd:`` portion of ``self._blocks_key`` if any. + + The convention (see :func:`_channel_blocks_key`) is:: + + blocks_key = sd:: # full SD path + blocks_key = sd: # SD-only (no device prefix) + blocks_key = blocks # legacy / tests + + Under the simplified dist_reuse design (CP not in the SD key), the + serialized SD has 4 segments: ``:pp/:tpn/:nsa<0|1>``. + With the leading ``sd:`` literal that makes ``sd:`` exactly + 5 colon-separated parts. + """ + bk = self._blocks_key + if not bk.startswith("sd:"): + return "" + # Format: sd::pp/<>:tpn/<>:nsa<0|1>[:] + parts = bk.split(":") + # ``sd``, ````, ``pp...``, ``tpn...``, ``nsa...`` = 5 parts. + if len(parts) < 5: + return bk # malformed but caller will get an empty result anyway + return ":".join(parts[:5]) + + +# --------------------------------------------------------------------------- +# RedisNodeInfo — heartbeat + active-node discovery, scoped to one SD +# --------------------------------------------------------------------------- class RedisNodeInfo: - """Redis node information management class implemented in Python""" + """Manages the ``sd::node:`` key family. + + Each instance owns exactly one SD's namespace. The Master typically + creates one ``RedisNodeInfo`` per SD it participates in (task 0-G); + Remote nodes create a single ``RedisNodeInfo`` for their own SD + (task 0-F). + """ - # Default TTL for node: key in seconds. Active nodes renew before expiry. - # If a process crashes (kill -9), the key auto-expires after this period. DEFAULT_NODE_TTL_SECONDS: int = 30 - - def __init__(self, host: str, port: int, local_ip: str, password: str = "", node_ttl_seconds: int = 0) -> None: + + def __init__( + self, + host: str, + port: int, + local_ip: str, + password: str = "", + node_ttl_seconds: int = 0, + *, + namespace: Optional[SharingDomainNamespace] = None, + db: int = 0, + ) -> None: if _redis is None: raise ImportError("redis-py is required: pip install redis") self.host = host self.port = int(port) self.local_ip = str(local_ip) self.password = str(password) + # Honour ``CacheConfig.flexkv_redis_db`` so all FlexKV clients land + # on the same logical db; defaulting to 0 keeps legacy behaviour. + self.db: int = int(db) self.uuid = str(uuid1()) - # Use provided TTL or fall back to default self.node_ttl_seconds: int = node_ttl_seconds if node_ttl_seconds > 0 else self.DEFAULT_NODE_TTL_SECONDS - # Heartbeat interval – renew TTL at roughly 1/3 of the TTL period self.heartbeat_interval_seconds: float = max(1.0, self.node_ttl_seconds / 3.0) + + self._namespace: SharingDomainNamespace = _resolve_namespace(namespace) + self._node_id: Optional[int] = None self._running = False self._listener_thread: Optional[threading.Thread] = None self._heartbeat_thread: Optional[threading.Thread] = None - self.current_node_id_set: set = set() + self.current_node_id_set: Set[int] = set() self._client: Optional["_redis.Redis"] = None self._sub_client: Optional["_redis.Redis"] = None self._cleanup_done = False - - # register cleanup function on exit + atexit.register(self._cleanup_on_exit) - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) - + # Some hosting environments (e.g. unit tests, threadpool workers) + # don't allow signal handlers in non-main threads. Be tolerant. + try: + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) + except (ValueError, OSError): # pragma: no cover + pass + def __del__(self) -> None: - """destructor, ensure cleanup is performed when object is destroyed""" try: self._cleanup_on_exit() except Exception: - # ignore exceptions in destructor, avoid affecting program exit pass - + + # -- properties -------------------------------------------------------- + @property + def namespace(self) -> SharingDomainNamespace: + return self._namespace + + @property + def sd_key_str(self) -> str: + return self._namespace.serialized_sd + + # -- connection lifecycle --------------------------------------------- def _get_client(self) -> "_redis.Redis": - """Get Redis client with connection settings""" return _redis.Redis( host=self.host, port=self.port, + db=self.db, password=self.password if self.password else None, decode_responses=True, health_check_interval=30, - socket_keepalive=True + socket_keepalive=True, ) - + def connect(self) -> bool: - """Connect to Redis and start listener + heartbeat threads""" try: self._client = self._get_client() - # Test connection self._client.ping() - - # Start listener thread + self._running = True self._listener_thread = threading.Thread( target=self._listener_worker, - name="redis-node-info-listener", - daemon=True + name=f"redis-node-info-listener[{self.sd_key_str}]", + daemon=True, ) self._listener_thread.start() - # Start heartbeat thread for TTL renewal self._heartbeat_thread = threading.Thread( target=self._heartbeat_worker, - name="redis-node-heartbeat", - daemon=True + name=f"redis-node-heartbeat[{self.sd_key_str}]", + daemon=True, ) self._heartbeat_thread.start() - + return True except Exception: return False - + def disconnect(self) -> None: - """Disconnect from Redis and stop listener + heartbeat threads""" self._running = False if self._listener_thread and self._listener_thread.is_alive(): self._listener_thread.join(timeout=2.0) @@ -225,154 +434,121 @@ def disconnect(self) -> None: if self._heartbeat_thread and self._heartbeat_thread.is_alive(): self._heartbeat_thread.join(timeout=2.0) self._heartbeat_thread = None - + if self._client: self._client.close() self._client = None if self._sub_client: self._sub_client.close() self._sub_client = None - - def _signal_handler(self, signum: int, frame) -> None: - """Signal handler for graceful shutdown""" - print(f"received signal {signum}, starting cleanup of RedisNodeInfo...") + + def _signal_handler(self, signum: int, frame) -> None: # pragma: no cover + print(f"received signal {signum}, starting cleanup of RedisNodeInfo[{self.sd_key_str}]...") self._cleanup() sys.exit(0) - + def _cleanup_on_exit(self) -> None: - """Cleanup function registered with atexit""" self._cleanup() - + def _cleanup(self) -> None: - """Internal cleanup method""" if self._cleanup_done: return - self._cleanup_done = True - try: - # unregister node if self._node_id is not None: self.unregister_node() - - # disconnect self.disconnect() except Exception: - # ignore exceptions in cleanup pass - + + # -- node registration ------------------------------------------------- + def _pubsub_channel(self) -> str: + # Per-SD pub/sub channel so cross-SD events don't pollute each other. + return f"flexkv_node_id_updated:{self.sd_key_str}" + def register_node(self) -> Optional[int]: - """Register a new node and get node_id, with TTL for automatic expiry on crash""" + """Allocate a new global node_id, write ``sd::node:`` with TTL.""" if not self._client: return None - try: - # Clean up stale nodes from the same IP before registering + # SD-scoped stale cleanup: drop any same-IP, different-UUID keys + # in this SD before claiming a new id. self._cleanup_stale_nodes_by_ip() - # Atomically increment global:node_id to get new node_id - node_id = self._client.incr("global:node_id") + # Atomic global counter — node_ids are globally unique even + # though the keys are SD-scoped. This makes BlockMeta.nid + # unambiguous when two SDs of the same instance look at each + # other's metadata (rare but possible during failover). + # redis-py 5.x types ``incr`` as ``Awaitable[int] | int``; we're + # always sync here — cast to silence mypy. + node_id = cast(int, self._client.incr("global:node_id")) self._node_id = node_id - - # Store node information in node:node_id hash - node_key = f"node:{node_id}" + + node_key = self._namespace.node_key(node_id) self._client.hset(node_key, mapping={ "node_id": str(node_id), - "ip": self.local_ip, # Changed from "local_ip" to "ip" to match C++ code expectation - "local_ip": self.local_ip, # Keep for backward compatibility + "ip": self.local_ip, + "local_ip": self.local_ip, "uuid": self.uuid, "status": "active", "timestamp": str(int(time.time())), - "pp_rank": str(getattr(self, 'pp_rank', 0)), - "pp_size": str(getattr(self, 'pp_size', 1)), + "sd_key": self.sd_key_str, }) - - # Set TTL so the key auto-expires if the process crashes self._client.expire(node_key, self.node_ttl_seconds) - - # Publish node update event - self._client.publish("flexkv_node_id_updated", str(node_id)) - + self._client.publish(self._pubsub_channel(), str(node_id)) return node_id except Exception: return None - + def unregister_node(self) -> bool: - """Unregister current node and clean up associated meta/block data""" if not self._client or self._node_id is None: return False - try: node_id = self._node_id - - # Delete node:node_id key - node_key = f"node:{node_id}" + node_key = self._namespace.node_key(node_id) self._client.delete(node_key) - - # Also clean up meta: to prevent stale RDMA addresses self._cleanup_node_data(node_id) - - # Publish node update event - self._client.publish("flexkv_node_id_updated", str(node_id)) - + self._client.publish(self._pubsub_channel(), str(node_id)) self._node_id = None return True except Exception: return False - + @property def node_id(self) -> Optional[int]: - """Get current node_id""" return self._node_id - + def get_uuid(self) -> str: - """Get the UUID of this node""" return self.uuid - + def get_active_node_ids(self) -> List[int]: - """Get all active node IDs - lock-free RCU read""" return list(self.current_node_id_set) - + def is_node_active(self, node_id: int) -> bool: - """Check if a node_id is active - lock-free RCU check""" return node_id in self.current_node_id_set - + + # -- heartbeat / listener --------------------------------------------- def _heartbeat_worker(self) -> None: - """Background thread that periodically renews the TTL of node: key. - - This ensures that if the process is alive, the node key never expires. - If the process crashes (kill -9), the TTL will not be renewed and the - key will auto-expire after NODE_TTL_SECONDS, allowing other nodes to - detect the crash and stop using stale meta/block data. - """ heartbeat_client: Optional["_redis.Redis"] = None while self._running: try: if heartbeat_client is None: heartbeat_client = self._get_client() - if self._node_id is not None: - node_key = f"node:{self._node_id}" - # Renew TTL + node_key = self._namespace.node_key(self._node_id) heartbeat_client.expire(node_key, self.node_ttl_seconds) - # Also update the timestamp field heartbeat_client.hset(node_key, "timestamp", str(int(time.time()))) - except Exception: - # Connection lost, reset client so it reconnects next iteration if heartbeat_client: try: heartbeat_client.close() except Exception: pass heartbeat_client = None - - # Sleep in small increments so we can exit quickly when _running becomes False for _ in range(int(self.heartbeat_interval_seconds * 10)): if not self._running: break time.sleep(0.1) - if heartbeat_client: try: heartbeat_client.close() @@ -380,31 +556,20 @@ def _heartbeat_worker(self) -> None: pass def _listener_worker(self) -> None: - """Background thread that listens for node updates""" backoff = 0.5 + ch = self._pubsub_channel() while self._running: try: - # Create a separate connection for pub/sub self._sub_client = self._get_client() - - # Subscribe to flexkv_node_id_updated channel pubsub = self._sub_client.pubsub() - pubsub.subscribe("flexkv_node_id_updated") - - # Listen for messages with blocking read + pubsub.subscribe(ch) for message in pubsub.listen(): if not self._running: break - - if message["type"] == "message" and message["channel"] == "flexkv_node_id_updated": - # Scan active nodes when we receive an update + if message["type"] == "message" and message["channel"] == ch: self.scan_active_nodes() - - # Normal exit from listen loop break - except Exception: - # Network/reconnection exception: exponential backoff time.sleep(backoff) backoff = min(backoff * 2, 5.0) finally: @@ -414,206 +579,224 @@ def _listener_worker(self) -> None: except Exception: pass self._sub_client = None - + + # -- discovery --------------------------------------------------------- def scan_active_nodes(self) -> None: - """Scan Redis for active node keys and update current_node_id_set - - This method can be called externally to manually refresh the active nodes list. - It uses SCAN to avoid blocking Redis server. - - Because node: keys now have a TTL (heartbeat), expired keys are - automatically removed by Redis. SCAN will only return keys that are - still alive, so stale/crashed nodes are naturally excluded. - """ if not self._client: return - try: - new_active_nodes = set() + new_active_nodes: Set[int] = set() cursor = 0 - + pattern = self._namespace.node_key_pattern() + prefix = f"{self._namespace.prefix}:node:" while True: - cursor, keys = self._client.scan(cursor=cursor, match="node:*", count=100) - + cursor, keys = cast( + Tuple[int, List[str]], + self._client.scan(cursor=cursor, match=pattern, count=100), + ) for key in keys: - if key.startswith("node:"): - try: - node_id = int(key[5:]) # Remove "node:" prefix - new_active_nodes.add(node_id) - except (ValueError, IndexError): - # Skip invalid node IDs - continue - + if not key.startswith(prefix): + continue + try: + node_id = int(key[len(prefix):]) + new_active_nodes.add(node_id) + except (ValueError, IndexError): + continue if cursor == 0: break - - # Detect nodes that disappeared (TTL expired or unregistered) + disappeared = self.current_node_id_set - new_active_nodes if disappeared: - # Clean up meta and block data for disappeared nodes for stale_nid in disappeared: if stale_nid == self._node_id: - continue # Don't clean up ourselves + continue self._cleanup_node_data(stale_nid) - - # lock-free RCU switch: atomic assignment self.current_node_id_set = new_active_nodes - except Exception: - # If scan fails, continue with current active nodes pass def _cleanup_stale_nodes_by_ip(self) -> None: - """Clean up stale node registrations from the same IP. - - On startup, scan all node:* keys and remove those that have the same - local_ip but a different UUID (i.e. leftover from a previous crashed process). - """ if not self._client: return - try: cursor = 0 - stale_node_ids = [] - + stale_node_ids: List[int] = [] + pattern = self._namespace.node_key_pattern() + prefix = f"{self._namespace.prefix}:node:" while True: - cursor, keys = self._client.scan(cursor=cursor, match="node:*", count=100) + cursor, keys = cast( + Tuple[int, List[str]], + self._client.scan(cursor=cursor, match=pattern, count=100), + ) for key in keys: - if not key.startswith("node:"): + if not key.startswith(prefix): continue try: - nid = int(key[5:]) + nid = int(key[len(prefix):]) except (ValueError, IndexError): continue - - data = self._client.hgetall(key) + data = cast(Dict[str, str], self._client.hgetall(key) or {}) node_ip = data.get("ip", "") or data.get("local_ip", "") node_uuid = data.get("uuid", "") - - # Same IP but different UUID → stale node from a previous process if node_ip == self.local_ip and node_uuid != self.uuid: stale_node_ids.append(nid) - if cursor == 0: break for stale_nid in stale_node_ids: - print(f"[RedisNodeInfo] Cleaning up stale node:{stale_nid} (same IP={self.local_ip}, different UUID)") - self._client.delete(f"node:{stale_nid}") + print( + f"[RedisNodeInfo:{self.sd_key_str}] Cleaning up stale " + f"node:{stale_nid} (same IP={self.local_ip}, different UUID)" + ) + self._client.delete(self._namespace.node_key(stale_nid)) self._cleanup_node_data(stale_nid) if stale_node_ids: - # Notify other nodes about the cleanup - self._client.publish("flexkv_node_id_updated", "cleanup") - + self._client.publish(self._pubsub_channel(), "cleanup") except Exception: pass def _cleanup_node_data(self, node_id: int) -> None: - """Clean up meta: and CPUB/SSDB/PCFSB block keys for a given node. - - This is called when: - 1. A node is unregistered (graceful shutdown) - 2. A stale node is detected (TTL expired / startup cleanup) + """Drop every key associated with a dead node in this SD. + + Removes ``meta:``, ``buffer::*`` and the per-device + ``...::block::*`` families. All under + ``sd::`` so other SDs are unaffected. """ if not self._client: return - try: - # Delete meta: (and meta::pp* for pipeline parallel) + # meta key (single key, no SCAN) + meta_key = self._namespace.meta_key(node_id) + self._client.delete(meta_key) + + # buffer keys: sd::buffer::* cursor = 0 - meta_keys = [] + buffer_pattern = f"{self._namespace.prefix}:buffer:{int(node_id)}:*" + buffer_keys: List[str] = [] while True: - cursor, keys = self._client.scan(cursor=cursor, match=f"meta:{node_id}*", count=100) - meta_keys.extend(keys) + cursor, keys = cast( + Tuple[int, List[str]], + self._client.scan(cursor=cursor, match=buffer_pattern, count=200), + ) + buffer_keys.extend(keys) if cursor == 0: break - if meta_keys: - self._client.delete(*meta_keys) - print(f"[RedisNodeInfo] Deleted {len(meta_keys)} meta key(s) for node {node_id}") - - # Delete CPUB:block::* / SSDB:block::* / PCFSB:block::* keys - for prefix in ("CPUB", "SSDB", "PCFSB"): + if buffer_keys: + for i in range(0, len(buffer_keys), 500): + self._client.delete(*buffer_keys[i:i + 500]) + + # block keys, both legacy device-prefix flavours and the new + # device-less SD-only flavour: + # sd::CPUB:block::* + # sd::SSDB:block::* + # sd::PCFSB:block::* + # sd::block::* (when no device prefix is used) + patterns = [ + f"{self._namespace.prefix}:CPUB:block:{int(node_id)}:*", + f"{self._namespace.prefix}:SSDB:block:{int(node_id)}:*", + f"{self._namespace.prefix}:PCFSB:block:{int(node_id)}:*", + self._namespace.block_key_pattern_for_node(node_id), + ] + for pat in patterns: cursor = 0 - block_keys = [] + block_keys: List[str] = [] while True: - cursor, keys = self._client.scan(cursor=cursor, match=f"{prefix}:block:{node_id}:*", count=500) + cursor, keys = cast( + Tuple[int, List[str]], + self._client.scan(cursor=cursor, match=pat, count=500), + ) block_keys.extend(keys) if cursor == 0: break if block_keys: - # Delete in batches to avoid blocking Redis - batch_size = 500 - for i in range(0, len(block_keys), batch_size): - self._client.delete(*block_keys[i:i + batch_size]) - print(f"[RedisNodeInfo] Deleted {len(block_keys)} {prefix}:block key(s) for node {node_id}") - + for i in range(0, len(block_keys), 500): + self._client.delete(*block_keys[i:i + 500]) except Exception as e: - print(f"[RedisNodeInfo] Warning: failed to clean up data for node {node_id}: {e}") + print(f"[RedisNodeInfo:{self.sd_key_str}] Warning: failed to clean up data for node {node_id}: {e}") +# --------------------------------------------------------------------------- +# RedisMeta — top-level facade +# --------------------------------------------------------------------------- class RedisMeta: - def __init__(self, host: str, port: int, password: Optional[str] = None, local_ip: str = "127.0.0.1", decode_responses: bool = True, node_ttl_seconds: int = 0) -> None: + """Top-level wrapper that owns one SD's :class:`RedisNodeInfo` plus a + cached redis-py client for non-block metadata writes (``meta:`` / + ``buffer:`` / ``flexkv:instance:*``).""" + + def __init__( + self, + host: str, + port: int, + password: Optional[str] = None, + local_ip: str = "127.0.0.1", + decode_responses: bool = True, + node_ttl_seconds: int = 0, + *, + namespace: Optional[SharingDomainNamespace] = None, + db: int = 0, + ) -> None: if _redis is None: # pragma: no cover raise ImportError("redis-py is required: pip install redis") self.host = host self.port = int(port) self.local_ip = str(local_ip) - self.db = 0 + # Logical Redis db — comes from ``CacheConfig.flexkv_redis_db``. + # Kept as an instance attr so every ``self._client()`` call picks + # up the same value; must match ``RedisNodeInfo.db`` below or the + # node-set and block-set end up on different dbs! + self.db: int = int(db) self.password = password self.decode_responses = bool(decode_responses) self._node_id: Optional[int] = None - - # initialize state management + + self._namespace: SharingDomainNamespace = _resolve_namespace(namespace) + self._init_lock = threading.Lock() self._initialized = False self._init_error: Optional[Exception] = None - - # create RedisNodeInfo object - self.nodeinfo = RedisNodeInfo(host, port, local_ip, password or "", node_ttl_seconds=node_ttl_seconds) - # get UUID via nodeinfo + + self.nodeinfo = RedisNodeInfo( + host, port, local_ip, password or "", + node_ttl_seconds=node_ttl_seconds, + namespace=self._namespace, + db=self.db, + ) self._uuid = self.nodeinfo.get_uuid() - def _client(self): - return _redis.Redis(host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=self.decode_responses) + # -- properties -------------------------------------------------------- + @property + def namespace(self) -> SharingDomainNamespace: + return self._namespace + + @property + def sd_key_str(self) -> str: + return self._namespace.serialized_sd + + def _client(self) -> "_redis.Redis": + return _redis.Redis( + host=self.host, port=self.port, db=self.db, + password=self.password, decode_responses=self.decode_responses, + ) + # -- lifecycle --------------------------------------------------------- def init_meta(self) -> Optional[int]: - """Initialize Redis metadata. This method is thread-safe and can only be called once per instance. - - Returns: - Optional[int]: The registered node ID, or None if initialization fails - - Raises: - RuntimeError: If initialization fails or has already been called - """ with self._init_lock: - # check if already initialized if self._initialized: if self._init_error: raise self._init_error return self._node_id - try: - # connect to RedisNodeInfo if not self.nodeinfo.connect(): raise RuntimeError("Failed to connect to Redis via RedisNodeInfo") - - # register node node_id = self.nodeinfo.register_node() if node_id is None: raise RuntimeError("Failed to register node via RedisNodeInfo") - self._node_id = node_id - # initialization phase, scan active nodes first self.nodeinfo.scan_active_nodes() - - # mark as initialized self._initialized = True - return node_id - except Exception as e: - # record initialization error self._init_error = e return None @@ -621,74 +804,101 @@ def get_node_id(self) -> int: if self._node_id is None: raise RuntimeError("node_id is not registered yet. Call init_meta() first.") return int(self._node_id) - + def is_initialized(self) -> bool: - """Check if RedisMeta has been initialized. - - Returns: - bool: True if initialized, False otherwise - """ with self._init_lock: return self._initialized - + def get_init_error(self) -> Optional[Exception]: - """Get the initialization error if any. - - Returns: - Optional[Exception]: The initialization error, or None if no error - """ with self._init_lock: return self._init_error - def get_redis_meta_channel(self, blocks_key: str = "blocks") -> "RedisMetaChannel": + # -- channel factory --------------------------------------------------- + def get_redis_meta_channel(self, device_prefix: str = "") -> "RedisMetaChannel": + """Build a C++-backed ``RedisMetaChannel`` whose key-prefix is + ``sd:[:]``. + + The legacy parameter name ``blocks_key`` (which used to carry + ``"CPUB"`` / ``"SSDB"`` / ``"PCFSB"``) is replaced by the more + explicit ``device_prefix`` to make the SD prefixing obvious in + callers. Callers that pass ``device_prefix=""`` get a channel + whose ``make_block_key`` produces ``sd::block::`` + (used by Master nodes that don't multi-tier blocks across CPU / + SSD / PCFS). + """ nid = self.get_node_id() - # Avoid passing string "None" when no password is set pwd = "" if (self.password is None or str(self.password).lower() == "none") else str(self.password) - channel = RedisMetaChannel(self.host, int(self.port), int(nid), self.local_ip, str(blocks_key), pwd) + bk = _channel_blocks_key(self._namespace, str(device_prefix)) + channel = RedisMetaChannel( + self.host, int(self.port), int(nid), + self.local_ip, bk, pwd, db=int(self.db), + ) if not channel.connect(): raise RuntimeError("Failed to connect to Redis") return channel def unregister_node(self, node_id: Optional[int] = None) -> None: - # use RedisNodeInfo to unregister node if self.nodeinfo: self.nodeinfo.unregister_node() self._node_id = None def get_uuid(self) -> str: return self._uuid - + def get_active_node_ids(self) -> List[int]: - """get all active node IDs list""" if self.nodeinfo: return self.nodeinfo.get_active_node_ids() return [] - + def is_node_active(self, node_id: int) -> bool: - """check if specified node is active""" if self.nodeinfo: return self.nodeinfo.is_node_active(node_id) return False + # -- pcfs file-nodeid mapping (legacy; SD-scoped now) ------------------ def add_node_ids(self, node_ids: Iterable[Union[int, str]]) -> int: - # Append a list of pcfs file node ids to Redis list key pcfs: nid = self.get_node_id() values = [str(v) for v in node_ids] if not values: return 0 r = self._client() - # rpush returns the new length of the list - return int(r.rpush(f"pcfs:{nid}", *values)) + return int(cast(int, r.rpush(f"{self._namespace.prefix}:pcfs:{nid}", *values))) - def regist_buffer(self, mrs: Iterable[object]) -> int: - """Register RDMA memory regions in Redis. + def load_pcfs_file_nodeids(self) -> Dict[int, List[int]]: + r = self._client() + result: Dict[int, List[int]] = {} + try: + cursor = 0 + pattern = f"{self._namespace.prefix}:pcfs:*" + prefix = f"{self._namespace.prefix}:pcfs:" + while True: + cursor, keys = cast( + Tuple[int, List[str]], + r.scan(cursor=cursor, match=pattern, count=100), + ) + for key in keys: + if not isinstance(key, str): + key = str(key) + if not key.startswith(prefix): + continue + try: + node_id = int(key[len(prefix):]) + except Exception: + continue + try: + values = cast(List[str], r.lrange(key, 0, -1) or []) + file_nodeids = [int(v) for v in values] + except Exception: + file_nodeids = [] + result[node_id] = file_nodeids + if cursor == 0: + break + except Exception: + return result + return result - Each element in mrs can be one of: - - dict with keys {"buffer_ptr": ..., "buffer_size": ...} - - tuple/list (buffer_ptr, buffer_size) - Stored as hash: key = buffer::, field "buffer_size" = . - Returns the number of regions processed. - """ + # -- buffer registration ---------------------------------------------- + def regist_buffer(self, mrs: Iterable[object]) -> int: nid = self.get_node_id() r = self._client() pipe = r.pipeline() @@ -703,7 +913,7 @@ def regist_buffer(self, mrs: Iterable[object]) -> int: continue if ptr is None or size is None: continue - key = f"buffer:{nid}:{int(ptr)}" + key = self._namespace.buffer_key(nid, int(ptr)) pipe.hset(key, mapping={"buffer_size": int(size)}) processed += 1 if processed: @@ -711,13 +921,8 @@ def regist_buffer(self, mrs: Iterable[object]) -> int: return processed def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool: - """Unregister a previously registered RDMA memory region by buffer_ptr. - - Looks up key buffer:: and deletes it if present. - Returns True if the key existed and was deleted, otherwise False. - """ nid = self.get_node_id() - key = f"buffer:{nid}:{int(buffer_ptr)}" + key = self._namespace.buffer_key(nid, int(buffer_ptr)) r = self._client() exists = bool(r.exists(key)) if exists: @@ -725,14 +930,10 @@ def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool: return True return False + # -- node meta hash ---------------------------------------------------- def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int) -> None: - """Register node meta information as a Redis hash. - - Key: meta: - Fields: node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int) - """ r = self._client() - key = f"meta:{int(node_id)}" + key = self._namespace.meta_key(node_id) r.hset(key, mapping={ "node_id": int(node_id), "addr": str(addr), @@ -742,15 +943,9 @@ def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_pt }) def get_node_meta(self, node_id: int) -> dict: - """Get node meta information from Redis. - - Reads key meta: and returns a dict with fields: - node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int). - Returns empty dict if the key does not exist. - """ r = self._client() - key = f"meta:{int(node_id)}" - data = r.hgetall(key) + key = self._namespace.meta_key(node_id) + data = cast(Dict[str, str], r.hgetall(key) or {}) if not data: return {} out: Dict[str, Union[int, str]] = {} @@ -765,49 +960,45 @@ def get_node_meta(self, node_id: int) -> dict: return out def unregist_node_meta(self, node_id: int) -> bool: - """Unregister node meta by node_id. Returns True if deleted.""" r = self._client() - key = f"meta:{int(node_id)}" + key = self._namespace.meta_key(node_id) return bool(r.delete(key)) - - def set_node_id(self, node_id: int): + def set_node_id(self, node_id: int) -> None: self._node_id = int(node_id) - def load_pcfs_file_nodeids(self) -> Dict[int, List[int]]: - """Load all PCFS file node IDs grouped by node id from Redis. + # -- instance-level helpers (cross-SD; design doc §4.7.1.6) ----------- + def register_instance_sd_nodes(self, instance_id: str, sd_to_nid: Dict[str, int]) -> None: + """Write the ``sd_key → node_id`` mapping for one full FlexKV instance. - - Uses SCAN instead of KEYS to avoid blocking Redis server - - Scans keys matching pattern "pcfs:*" (each is a list for a node's file node IDs) - - For each key, fetches the list via LRANGE and converts elements to ints - - Returns dict: { node_id: [file_nodeid, ...], ... } + Called once on Master startup after collecting all Remote ack'ed + node_ids. Peers consume this via :meth:`load_instance_sd_nodes` + when first discovering a new instance via the failure detector. """ + if not sd_to_nid: + return r = self._client() - result: Dict[int, List[int]] = {} - try: - # Use SCAN instead of KEYS to avoid blocking - cursor = 0 - while True: - cursor, keys = r.scan(cursor=cursor, match="pcfs:*", count=100) - for key in keys: - try: - if not isinstance(key, str): - key = str(key) - if not key.startswith("pcfs:"): - continue - nid_part = key.split(":", 1)[1] - node_id = int(nid_part) - except Exception: - continue - try: - values = r.lrange(key, 0, -1) - file_nodeids = [int(v) for v in values] - except Exception: - file_nodeids = [] - result[node_id] = file_nodeids - - if cursor == 0: - break - except Exception: - return result - return result + key = SharingDomainNamespace.instance_sd_nodes_key(instance_id) + # Stringify values so HSET round-trips cleanly through redis-py. + mapping = {str(sd): str(int(nid)) for sd, nid in sd_to_nid.items()} + r.hset(key, mapping=mapping) + + def load_instance_sd_nodes(self, instance_id: str) -> Dict[str, int]: + r = self._client() + key = SharingDomainNamespace.instance_sd_nodes_key(instance_id) + # redis-py 5.x types ``hgetall`` as ``Awaitable[dict] | dict`` so that + # the same stub serves both sync and async clients. We're always in + # sync mode here — cast defensively. + data = cast(Dict[str, str], r.hgetall(key) or {}) + out: Dict[str, int] = {} + for sd_str, nid_str in data.items(): + try: + out[str(sd_str)] = int(nid_str) + except (TypeError, ValueError): + continue + return out + + def unregister_instance_sd_nodes(self, instance_id: str) -> bool: + r = self._client() + key = SharingDomainNamespace.instance_sd_nodes_key(instance_id) + return bool(r.delete(key)) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 76691d2d79..658ef91894 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -81,6 +81,16 @@ class ModelConfig: # ------------------------------------------------------------------ instance_num: int = 1 + # NSA model layout flag: when True, the model is an NSA (Native Sparse + # Attention) model whose KV cache pool includes an extra ``index_k_with_scale_buffer`` + # in addition to the main MLA KV cache. This changes the per-block physical + # layout (extra indexer pool registered alongside main KV) and is therefore + # used as a layout-isolation key in SharingDomainKey.serialize(). This flag + # is independent of whether CP is enabled — even with cp_size=1, an NSA model + # still has the indexer K cache and must be isolated from non-NSA models in + # cross-instance reuse. + is_nsa: bool = False + # ------------------------------------------------------------------ # Freeze mechanism: after post_init, ModelConfig must not be mutated # ------------------------------------------------------------------ @@ -190,6 +200,83 @@ def effective_tp_size_per_node(self) -> int: """Per-node counterpart of :pyattr:`effective_tp_size`.""" return self.attn_tp_size_per_node * self.attn_cp_size_per_node + # ------------------------------------------------------------------ + # Decoupled multinode flags (Phase 1 task 1-A) + # + # These two flags are intentionally *independent* — they describe + # orthogonal physical situations that historically were conflated + # under a single ``is_multinode`` switch in the sglang connector: + # + # * ``is_multinode_tp`` is True when one TP group physically + # spans more than one node (i.e. ``tp_size > gpus_per_node``). + # Such deployments require cross-node KV transfer (SD-Remote) + # because the local CPU pool only sees a *fraction* of the + # KV heads. This is the only case where a non-pp_rank-0 + # node needs to participate in dist_reuse handle construction + # (see ``KVTaskManager.transfer_handles``). + # + # * ``is_multinode_cp`` is True when CP > 1 *and* the CP group + # spans more than one node. CP all-gather makes every CP + # rank's KV pool bit-wise identical, so for dist_reuse purposes + # we treat all CP ranks as a single SD (CP does *not* enter + # the SD key — see simplified design §4.5). However, *physical* + # transfer still needs one transfer worker per CP rank's GPU + # so that the H2D scatter from sync_leader can actually land + # in every cp_rank's local pool. + # + # Keeping these two flags separate (instead of an ``is_multinode`` + # union) is a deliberate API decision — sglang's connector layer + # can branch on the dimension that genuinely affects its behavior + # rather than guessing. The legacy single-flag formulation in + # sglang's ``flexkv_connector.py`` should migrate to read these + # two properties instead. + # ------------------------------------------------------------------ + @property + def is_multinode_tp(self) -> bool: + """One TP group spans more than one physical node. + + Equivalent to ``self.tp_node_count > 1`` (= ``nnodes_per_tp_group > 1``). + Use this — *not* ``self.nnodes > 1`` — to decide whether the + local KV pool is partial and therefore requires cross-node + SD-Remote transfer to assemble a full prefix. + """ + return self.tp_node_count > 1 + + @property + def is_multinode_cp(self) -> bool: + """CP > 1 *and* the CP group spans more than one physical node. + + CP within a single node ``(attn_cp_size > 1, fits on one host)`` + does *not* trigger this flag — it's a purely intra-node concern + handled by the connector's local sync_leader scatter. Only when + CP physically crosses node boundaries do additional transfer + considerations apply. + + **Topology note:** in FlexKV/sglang, CP is *not* a top-level GPU + dimension — it lives inside the TP group as a sub-partition of + ``attn_tp_size``. ``ModelConfig.total_gpus = tp × pp × dp`` does + *not* include CP. Each CP group occupies ``attn_cp_size`` GPUs + carved out of the ``tp_size`` GPUs assigned to one TP group. + Therefore CP crosses nodes iff one TP group's per-node share + ``tp_size_per_node`` is **smaller than** ``attn_cp_size``. + + **Reality check (2026):** under standard megatron-style topology + every TP group is sized so that ``tp_size_per_node`` is at least + as large as ``attn_cp_size`` (CP rarely exceeds 8 while TP per + node is 8 GPUs), so this flag is **always False** in production + deployments seen so far. We keep the property as a stable API + surface for future deployments where CP could outgrow a node. + + Note: CP never enters the SD key (simplified design §4.5), so + this flag does NOT influence ``SharingDomainKey``; it is purely + a transport-layer hint. + """ + if self.attn_cp_size <= 1: + return False + # CP crosses nodes when one CP group cannot fit inside the slice + # of GPUs that a single node contributes to one TP group. + return self.attn_cp_size > max(1, self.tp_size_per_node) + @property def num_kv_heads_per_node(self) -> int: """Number of KV heads visible to a single node.""" @@ -197,6 +284,55 @@ def num_kv_heads_per_node(self) -> int: return self.num_kv_heads return self.num_kv_heads * self.tp_size_per_node // max(1, self.attn_tp_size) + # ------------------------------------------------------------------ + # Sharing-domain dimensions (Phase 0 task 0-A) + # + # These three properties translate FlexKV's existing topology fields + # into the (pp_rank, tp_node_idx, cp_rank) triple consumed by + # ``SharingDomainKey``. ``tp_node_count`` is just an alias for the + # already-derived ``nnodes_per_tp_group``; ``tp_node_idx`` partitions + # the local TP rank by ``tp_size_per_node``. ``model_id`` digests the + # invariant model architecture into a 16-char hex string suitable for + # embedding in a Redis key. + # + # Importing inside the property avoids a circular import between + # ``flexkv.common.config`` and ``flexkv.cache.sharing_domain``. + # ------------------------------------------------------------------ + @property + def tp_node_count(self) -> int: + """Number of physical nodes one TP group spans (= + ``nnodes_per_tp_group``). ``1`` when TP fits on a single node.""" + return self.nnodes_per_tp_group + + # NOTE: ``tp_node_idx`` is a per-rank concept and was moved to + # ``RankInfo`` in PR #165 (separate-per-rank-state-into-RankInfo). + # The previous ``ModelConfig.tp_node_idx`` referenced ``self.tp_rank`` + # which no longer exists on ``ModelConfig``. Use + # ``RankInfo.tp_node_idx`` instead; ``SharingDomainKey.from_model_config`` + # already prefers ``rank_info.tp_node_idx`` and falls back to ``0`` via + # ``getattr``. + + @property + def model_id(self) -> str: + """Stable cross-process digest of the invariant model architecture. + + Two FlexKV instances that produce physically interchangeable CPU + blocks derive the same ``model_id`` regardless of TP/PP/CP + topology. Used as the leading segment of every SD key (design + doc §3.1). + """ + # Local import keeps ``flexkv.common.config`` a leaf in the import + # graph — anything that imports SharingDomain machinery already + # depends on config. + from flexkv.common.dist_reuse.sharing_domain import derive_model_id + return derive_model_id( + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.dtype, + use_mla=self.use_mla, + ) + @property def kv_dim(self) -> int: """KV dimension: 1 for MLA (no head split), 2 for standard (head split).""" @@ -240,6 +376,22 @@ def tp_rank_per_node(self) -> int: """TP rank index within the local node (within one TP group).""" return self.tp_rank % self.model_config.tp_size_per_node + @property + def tp_node_idx(self) -> int: + """Index of this rank's node inside its TP group (0..tp_node_count-1). + + Cross-node TP partitions ``tp_size`` ranks across + ``nnodes_per_tp_group`` physical nodes, with + ``tp_size_per_node = tp_size // nnodes_per_tp_group`` ranks per + node. The node index is the floor-divide complement of + ``tp_rank_per_node``: ``tp_rank // tp_size_per_node``. Used by + ``SharingDomainKey`` to scope per-shard reuse. + """ + per_node = self.model_config.tp_size_per_node + if per_node <= 0: + return 0 + return self.tp_rank // per_node + @property def dp_client_id(self) -> int: """Flat DP route label: unique int across all instances. @@ -320,6 +472,37 @@ def __str__(self) -> str: ) +@dataclass +class RemoteEndpoint: + """ZMQ endpoint of a Remote node belonging to a specific sharing domain. + + Phase 0 task 0-J: ``CacheConfig.remote_endpoints_by_sd`` maps each + non-master sd_key (string form, see ``SharingDomainKey.serialize``) to + one of these. The Master uses these to wire up its + ``TransferManagerMultiNodeHandle`` ZMQ peers — see + ``flexkv/transfer_manager.py``. + + Fields mirror the three ports already exposed by + ``resolve_master_host_and_ports``: + + * ``ip`` — host the Remote listens on (e.g. the Remote's RDMA-capable IP). + * ``gpu_register_port`` — port the local sglang worker uses to push GPU + handles into the Remote's TransferManagerOnRemote (channel ④). + * ``command_port`` — port the Master uses to push transfer-graph and + coordination commands to the Remote. + * ``result_port`` — port the Remote uses to push completion / failure + reports back to the Master. + + All ports are ``str`` (matching ``master_ports``'s type) so they round-trip + through env vars cleanly. + """ + + ip: str + gpu_register_port: str + command_port: str + result_port: str + + @dataclass class CacheConfig: tokens_per_block: int = 16 @@ -378,6 +561,15 @@ class CacheConfig: redis_port: int = 6379 local_ip: str = "127.0.0.1" redis_password: Optional[str] = None + # Logical Redis database number (0..15 on default Redis servers). Every + # FlexKV Redis client — both the Python ``RedisMeta`` / ``RedisNodeInfo`` + # and the C++ ``RedisMetaChannel`` — honours this single value, so all + # dist-reuse metadata (nodes, blocks, sessions, aggregate markers) lives + # on the same logical db. Defaults to 0 for backward compatibility with + # pre-Batch-D deployments. Using a dedicated db (e.g. ``15``) is strongly + # recommended in environments where FlexKV shares a Redis instance with + # other tenants to make bulk cleanup (``FLUSHDB``) safe. + flexkv_redis_db: int = 0 # TTL (seconds) for node: key in Redis. Active nodes renew via heartbeat. # If a process crashes, the key auto-expires after this period. node_ttl_seconds: int = 30 @@ -385,10 +577,97 @@ class CacheConfig: # Mooncake transfer engine config path (serialized via pickle to survive spawn subprocesses) mooncake_config_path: Optional[str] = None + # ------------------------------------------------------------------ + # Sharing-domain support (Phase 0 task 0-A) + # + # ``enable_sharing_domain`` is the master switch that turns on the new + # ``sd::*`` Redis key layout, the cross-SD aggregate radix and + # the Master/Remote coordinated GET/PUT protocol. It is auto-enabled + # whenever any P2P backend is on (``enable_p2p_cpu`` / + # ``enable_p2p_ssd``) — the legacy single-instance dist_reuse path is + # then served by the degenerate single-SD namespace produced by + # :meth:`SharingDomainKey.default`. + # + # ``instance_id`` and ``session_epoch`` identify the *whole FlexKV + # instance* (Master + every Remote) for the failure detector. When + # left ``None`` the Master fills them in lazily (uuid4 / monotonic-uuid + # respectively) and propagates to Remotes via + # ``send_config_to_remotes``. + # ------------------------------------------------------------------ + enable_sharing_domain: bool = False + instance_id: Optional[str] = None + session_epoch: Optional[str] = None + + # Failure detector tunables (design doc §4.3.2 Layer-1 / §4.3.1 leak guard) + instance_session_ttl_seconds: int = 8 + instance_session_renew_interval_seconds: int = 3 + refcount_leak_timeout_seconds: int = 30 + refcount_leak_scan_interval_seconds: int = 10 + + # ------------------------------------------------------------------ + # Remote-endpoint discovery (Phase 0 task 0-J) + # + # ``remote_endpoints_by_sd`` is a ``{sd_key_str: RemoteEndpoint}`` dict + # populated by the framework launcher (e.g. sglang connector) before + # ``KVManager`` is constructed. Each entry tells the Master which IP / + # ZMQ ports to use when reaching a Remote node belonging to the named + # SD. When the dict is empty (legacy / single-Remote setups) the + # ``TransferManagerHandle`` falls back to ``resolve_master_host_and_ports`` + # which derives a single endpoint from ``master_host`` + env vars. + # + # The dict values are plain dataclass instances rather than dicts so + # type-checkers / IDEs can verify field names — see ``RemoteEndpoint`` + # below. We keep them in this module (not in transfer_manager.py) to + # avoid a config → transfer_manager import cycle. + # ------------------------------------------------------------------ + remote_endpoints_by_sd: Dict[str, "RemoteEndpoint"] = field(default_factory=dict) + def __post_init__(self): self.enable_kv_sharing = self.enable_p2p_cpu or \ self.enable_p2p_ssd or self.enable_3rd_remote self.enable_remote = self.enable_3rd_remote + # Auto-enable sharing-domain whenever any P2P path is on. Users can + # still flip it on manually for offline testing of the new key + # layout. Idempotent. + if self.enable_p2p_cpu or self.enable_p2p_ssd: + self.enable_sharing_domain = True + + # ------------------------------------------------------------------ + # Lease-TTL safety check for P2P cross-instance reuse. + # + # The P2P CPU pull path (peer instance → master via mooncake RDMA + # READ) currently lacks an end-to-end refcount handshake — see + # ``docs/dist_reuse/KNOWN_ISSUE_p2p_refcount_2026-05-14.md``. + # We rely on the master-side lease (LocalRadixTree.lease_ttl_ms) + # plus the peer-side 1500ms freshness check + # (DistributedRadixTree.match_prefix) to keep peer reads from + # racing with master eviction. + # + # ``FLEXKV_LEASE_TTL_MS < 10000`` shrinks the safety margin to + # the same order of magnitude as a worst-case mooncake_read + # batch (a few hundred ms when the peer pulls thousands of + # blocks across multiple PEERH2H batches), at which point the + # lease alone is insufficient. Refuse to start in that + # configuration so operators are not surprised in production. + if self.enable_p2p_cpu or self.enable_p2p_ssd: + try: + lease_ttl_ms = int(GLOBAL_CONFIG_FROM_ENV.lease_ttl_ms) + except Exception: + lease_ttl_ms = 0 + _MIN_LEASE_TTL_MS_FOR_P2P = 10000 + if lease_ttl_ms > 0 and lease_ttl_ms < _MIN_LEASE_TTL_MS_FOR_P2P: + raise ValueError( + f"CacheConfig: enable_p2p_cpu / enable_p2p_ssd is True but " + f"FLEXKV_LEASE_TTL_MS={lease_ttl_ms} is below the safety " + f"floor of {_MIN_LEASE_TTL_MS_FOR_P2P}ms. Without an " + f"in-flight refcount handshake (tracked in " + f"docs/dist_reuse/KNOWN_ISSUE_p2p_refcount_2026-05-14.md), " + f"a short lease lets master eviction race with peer " + f"mooncake_read and corrupt the KV cache. Either raise " + f"FLEXKV_LEASE_TTL_MS to >= {_MIN_LEASE_TTL_MS_FOR_P2P} " + f"or implement the refcount glue (Plan A in the known-issue " + f"doc) before enabling P2P in this configuration." + ) def __str__(self) -> str: return ( @@ -481,6 +760,9 @@ class UserConfig: redis_port: Optional[int] = None local_ip: Optional[str] = None redis_password: Optional[str] = None + # Override for :attr:`CacheConfig.flexkv_redis_db` — typically only set + # in deployments where FlexKV shares a Redis instance with other services. + flexkv_redis_db: Optional[int] = None node_ttl_seconds: Optional[int] = None kv_cache_dtype: Optional[str] = None # Override kv_cache_dtype when TRT config uses "auto". Supported values: "fp8", "float8", "e4m3", "fp16", "float16", "bf16", "bfloat16", "fp32", "float32" @@ -563,6 +845,40 @@ def update_default_config_from_user_config(rank_info: RankInfo, cache_config.enable_p2p_ssd or cache_config.enable_3rd_remote) cache_config.enable_remote = cache_config.enable_3rd_remote + # Re-derive ``enable_sharing_domain`` here. ``CacheConfig.__post_init__`` + # only fires at construction time (with ``enable_p2p_cpu=False``), so any + # later flip via UserConfig (e.g. ``benchmark_dist_direct.py`` reading + # ``enable_p2p_cpu: true`` from yaml) would leave ``enable_sharing_domain`` + # stuck at ``False`` and silently disable the dist_reuse fast-path. Mirror + # the post-init invariant here so cross-instance KV reuse works whenever + # P2P is on. + if cache_config.enable_p2p_cpu or cache_config.enable_p2p_ssd: + cache_config.enable_sharing_domain = True + + # Same lease-TTL safety check as ``CacheConfig.__post_init__`` — + # the yaml/UserConfig path bypasses ``__post_init__`` because the + # CacheConfig dataclass was already constructed (with P2P=False) + # before this function copies the UserConfig flags over. Without + # this re-check, operators can ship a short-lease prod config that + # races peer mooncake_read with master eviction. See + # ``docs/dist_reuse/KNOWN_ISSUE_p2p_refcount_2026-05-14.md``. + try: + lease_ttl_ms = int(GLOBAL_CONFIG_FROM_ENV.lease_ttl_ms) + except Exception: + lease_ttl_ms = 0 + _MIN_LEASE_TTL_MS_FOR_P2P = 10000 + if lease_ttl_ms > 0 and lease_ttl_ms < _MIN_LEASE_TTL_MS_FOR_P2P: + raise ValueError( + f"update_default_config_from_user_config: enable_p2p_cpu / " + f"enable_p2p_ssd is True but FLEXKV_LEASE_TTL_MS={lease_ttl_ms} " + f"is below the safety floor of {_MIN_LEASE_TTL_MS_FOR_P2P}ms. " + f"Without an in-flight refcount handshake (see " + f"docs/dist_reuse/KNOWN_ISSUE_p2p_refcount_2026-05-14.md), a " + f"short lease lets master eviction race with peer mooncake_read " + f"and corrupt the KV cache. Either raise FLEXKV_LEASE_TTL_MS " + f"to >= {_MIN_LEASE_TTL_MS_FOR_P2P} or implement the refcount " + f"glue (Plan A in the known-issue doc) before enabling P2P." + ) if cache_config.num_ssd_blocks % len(cache_config.ssd_cache_dir) != 0: cache_config.num_ssd_blocks = \ @@ -646,6 +962,8 @@ def update_default_config_from_user_config(rank_info: RankInfo, cache_config.local_ip = user_config.local_ip if user_config.redis_password is not None: cache_config.redis_password = user_config.redis_password + if user_config.flexkv_redis_db is not None: + cache_config.flexkv_redis_db = int(user_config.flexkv_redis_db) if user_config.node_ttl_seconds is not None: cache_config.node_ttl_seconds = user_config.node_ttl_seconds diff --git a/flexkv/common/dist_reuse/__init__.py b/flexkv/common/dist_reuse/__init__.py new file mode 100644 index 0000000000..ba9bb98aeb --- /dev/null +++ b/flexkv/common/dist_reuse/__init__.py @@ -0,0 +1,98 @@ +"""Distributed-reuse data structures and wire protocols. + +This subpackage is **pure Python** and free of GPU / C++ dependencies, so it +can be imported on CPU-only nodes (e.g. CI workers, lightweight test +environments) without pulling in ``flexkv.c_ext``. See design doc +``docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp_simplified.md`` and +``docs/dist_reuse/proposal_unify_with_graph_dispatch_2026-05-15.md``. + +Phase D-4: ``CoordinationCoordinator`` / ``RemoteCoordHandler`` / +``BlockIndex`` / ``CoordQueryMsg`` / ``CoordGetCmdMsg`` / ``CoordPutCmdMsg`` +(plus their ack siblings) were deleted in this refactor. Per-SD +coordination is now expressed as multi-target ops on a single +``TransferOpGraph`` broadcast through the existing ``_launch_task`` +graph-dispatch path; peer-SD acks come back via ``CompletedOp(sd_key, +contributing_node_id)`` to the master polling worker. +""" + +from .sharing_domain import ( + DEFAULT_MODEL_ID, + SharingDomainKey, + derive_model_id, +) +from .sharing_domain_namespace import ( + INSTANCE_KEY_PREFIX, + SD_KEY_PREFIX, + SharingDomainNamespace, +) +from .aggregate_radix import ( + AggregateMatchResult, + AggregateRadixTree, + BlockNotTrackedError, + ReadyEntry, +) +from .coordination_protocol import ( + CoordMsgType, + EpochVerifyError, + FailureReportMsg, + RemoteReadyMsg, + decode_coord_message, + encode_coord_message, +) +from .failure_detector import ( + FailureDetector, + InstanceSession, + RedisSessionClient, + make_redis_client_from_cache_config, + make_session_epoch, +) +from .remote_init import ( + BootstrapResult, + RemoteDistReuseInitializer, +) +from .master_coordinator import ( + MasterCoordinator, + SharingDomainHandleSpec, + build_sharing_domain_handles, + find_endpoint_for_sd, + graph_needs_gpu_clear, +) + + +__all__ = [ + # sharing_domain + "DEFAULT_MODEL_ID", + "SharingDomainKey", + "derive_model_id", + # namespace + "INSTANCE_KEY_PREFIX", + "SD_KEY_PREFIX", + "SharingDomainNamespace", + # aggregate_radix + "AggregateMatchResult", + "AggregateRadixTree", + "BlockNotTrackedError", + "ReadyEntry", + # coordination_protocol (Phase D-4: trimmed to RemoteReady + FailureReport) + "CoordMsgType", + "EpochVerifyError", + "FailureReportMsg", + "RemoteReadyMsg", + "decode_coord_message", + "encode_coord_message", + # failure_detector + "FailureDetector", + "InstanceSession", + "RedisSessionClient", + "make_redis_client_from_cache_config", + "make_session_epoch", + # remote_init (Batch C: task 0-F) + "BootstrapResult", + "RemoteDistReuseInitializer", + # master_coordinator + "MasterCoordinator", + "SharingDomainHandleSpec", + "build_sharing_domain_handles", + "find_endpoint_for_sd", + "graph_needs_gpu_clear", +] diff --git a/flexkv/common/dist_reuse/aggregate_radix.py b/flexkv/common/dist_reuse/aggregate_radix.py new file mode 100644 index 0000000000..7c71f813ec --- /dev/null +++ b/flexkv/common/dist_reuse/aggregate_radix.py @@ -0,0 +1,408 @@ +"""Aggregate-layer radix tree for cross-SD KV reuse consistency. + +Design doc §4.3 (option B, "fully-ready aggregate radix") and §4.3.1 +(refcount-protected eviction). This module is the **central truth** about +which token prefixes have been confirmed by every SD in the instance and +are therefore safe to expose as remote-hit candidates. + +Implementation notes +-------------------- + +* The aggregate radix is a *flat hash map keyed by leaf block hash*, not a + full radix tree — design doc §4.3 only requires per-prefix readiness + + per-block refcount, both of which collapse to a flat map once the per-SD + ack count hits ``total_sd_count``. A genuine tree adds no semantic value + here; the *legacy* radix in ``RefRadixTree`` already handles longest-prefix + match. +* All public methods are thread-safe via a single ``RLock``. Concurrency + pressure is low (a few hundred ack/release events per second), so a + finer-grained scheme would be premature optimization. +* The only Phase-0 consumers are unit tests — the actual wiring into + ``GlobalCacheEngine`` is task 0-H "integration" and lands in Batch C. +""" + +from __future__ import annotations + +import threading +import time +from dataclasses import dataclass, field +from typing import Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple + + +__all__ = [ + "ReadyEntry", + "AggregateMatchResult", + "AggregateRadixTree", + "BlockNotTrackedError", +] + + +class BlockNotTrackedError(KeyError): + """Raised when ``release()`` / ``mark_sd_ready()`` is called for a block + that the aggregate radix has never seen.""" + + +# --------------------------------------------------------------------------- +# Data shapes +# --------------------------------------------------------------------------- +@dataclass +class ReadyEntry: + """Per-prefix bookkeeping. + + ``prefix_hash`` identifies a *leaf block hash* — i.e. the hash of the + last block in some token prefix. Two distinct prefixes that happen to + share a leaf block hash collapse to a single entry, which is exactly + what we want (the on-wire effect is identical). + + ``ready_sds`` was historically a ``Set[str]`` of acked SD keys. As of + the multi-SD GET-path work it is now a ``Dict[str, int]`` mapping each + acked SD key → the *distributed_node_id* of the FlexKV instance that + contributed that SD's slice. For the Master's own SD the value is the + Master's own node_id; for peer SDs it is whichever peer instance's + Remote completed the per-SD D2H clone op and shipped back a + ``CompletedOp(sd_key, contributing_node_id, success=True)`` through + the graph-dispatch completion sink (Phase D-2 PUT path). Knowing + the per-SD node_id at GET time lets the cross-instance reuse path + target each peer SD's Mooncake server independently — required + when PP>1 splits the layers across machines (design doc §4.5 / §5.1). + + ``contributing_peers`` continues to track the **set** of peer-instance + IDs whose blocks we pulled in to fill this prefix; the failure + detector uses this to do O(N_prefix) batch invalidation when a peer + dies (design doc §4.3.2 bullet 3). This stays as a Set because + invalidation is keyed by *instance* not SD. + """ + + prefix_hash: int + block_ids: Tuple[int, ...] + # sd_key_str -> contributing peer's distributed_node_id. ``-1`` means + # the Master's own node (we may not know our own node_id at the time + # the self-SD ack lands — callers can update later via ``mark_sd_ready`` + # which is idempotent). + ready_sds: Dict[str, int] = field(default_factory=dict) + contributing_peers: Set[str] = field(default_factory=set) + # Wall-clock time (seconds since epoch) when this prefix first became + # fully-ready. Used by leak detection to age out abandoned entries. + first_ready_at: Optional[float] = None + + def is_fully_ready(self, total_sd_count: int) -> bool: + return len(self.ready_sds) >= total_sd_count + + def node_id_for_sd(self, sd_key: str) -> Optional[int]: + """Return the contributing peer's node_id for a given SD, or None + if that SD has not acked yet.""" + return self.ready_sds.get(sd_key) + + +@dataclass +class AggregateMatchResult: + """Result of :meth:`AggregateRadixTree.match_fully_ready`.""" + + matched_block_ids: Tuple[int, ...] + contributing_peers: FrozenSet[str] + # Always a single value (matches §4.7.1.4 single-Node match constraint + # already enforced by the C++ ``RefRadixTree``). ``None`` when the + # match length is zero. + matched_node_id: Optional[int] = None + + +# --------------------------------------------------------------------------- +# Refcount entry +# --------------------------------------------------------------------------- +@dataclass +class _RefCountEntry: + count: int = 0 + # Wall-clock time (seconds since epoch) of the most recent acquire. + # Used by ``scan_leaked_refcount`` to identify stuck refcounts. + last_acquired_at: float = 0.0 + + +# --------------------------------------------------------------------------- +# Tree +# --------------------------------------------------------------------------- +class AggregateRadixTree: + """Cross-SD aggregate radix + block-level refcount manager. + + Public API surface mirrors design doc §4.3 / §4.3.1: + + * :meth:`mark_sd_ready` / :meth:`mark_sd_evicted` — per-SD ack tracker + * :meth:`match_fully_ready` — query the longest fully-ready prefix + * :meth:`acquire` / :meth:`release` / :meth:`is_evictable` — refcount + * :meth:`scan_leaked_refcount` — leak detector + * :meth:`invalidate_by_peer_instance` / :meth:`invalidate_prefix` — + reactions to failure-detector events + """ + + def __init__( + self, + total_sd_count: int, + *, + time_fn: Callable[[], float] = time.monotonic, + ) -> None: + if not isinstance(total_sd_count, int) or total_sd_count < 1: + raise ValueError(f"total_sd_count must be int>=1, got {total_sd_count!r}") + self._total_sd_count: int = total_sd_count + self._time_fn: Callable[[], float] = time_fn + self._lock = threading.RLock() + + # prefix_hash -> ReadyEntry + self._prefixes: Dict[int, ReadyEntry] = {} + # Reverse index: block_id -> set of prefix_hash that contain it. + # Lets us O(1) look up "which prefixes own this block" when + # invalidating or releasing. + self._block_to_prefixes: Dict[int, Set[int]] = {} + # block_id -> _RefCountEntry + self._refcounts: Dict[int, _RefCountEntry] = {} + + # ------------------------------------------------------------------ + # Inspection + # ------------------------------------------------------------------ + @property + def total_sd_count(self) -> int: + return self._total_sd_count + + def __len__(self) -> int: + with self._lock: + return len(self._prefixes) + + def known_prefixes(self) -> List[int]: + """Snapshot of prefix hashes currently tracked (not necessarily ready).""" + with self._lock: + return list(self._prefixes.keys()) + + # ------------------------------------------------------------------ + # Per-SD ack tracker + # ------------------------------------------------------------------ + def mark_sd_ready( + self, + prefix_hash: int, + sd_key: str, + block_ids: Iterable[int], + *, + contributing_peer: Optional[str] = None, + node_id: int = -1, + ) -> bool: + """Record that ``sd_key`` finished its share of ``prefix_hash``. + + Returns True iff this call transitioned the prefix to *fully ready* + (i.e. all SDs are now accounted for). Subsequent calls for the + same SD are idempotent and return False — but they **do** update + the recorded ``node_id`` if the previous call passed -1 (sentinel). + That lets callers fill in the node_id after the fact when it + wasn't available at first-ack time. + + Args: + prefix_hash: leaf-block hash that identifies the prefix. + sd_key: serialized SD key of the SD that just acked. + block_ids: physical block IDs in the Master's CPU pool. + Must agree across acks for the same prefix. + contributing_peer: instance_id of the peer FlexKV instance + whose data we pulled to fill this SD's slot (None when + this is a self-SD ack). + node_id: distributed_node_id of the FlexKV node that holds + the data for this SD. ``-1`` is the sentinel "unknown + yet"; callers can re-issue ``mark_sd_ready`` later with + the real node_id and the entry will be patched. + """ + if not isinstance(sd_key, str) or not sd_key: + raise ValueError(f"sd_key must be a non-empty str, got {sd_key!r}") + + block_tuple = tuple(int(b) for b in block_ids) + + with self._lock: + entry = self._prefixes.get(prefix_hash) + if entry is None: + entry = ReadyEntry(prefix_hash=int(prefix_hash), block_ids=block_tuple) + self._prefixes[entry.prefix_hash] = entry + for b in block_tuple: + self._block_to_prefixes.setdefault(b, set()).add(entry.prefix_hash) + else: + # Validate block_ids stay consistent across acks. Mismatched + # block_ids would mean two SDs disagree on the physical + # placement, which is an upstream bug; fail loudly. + if entry.block_ids != block_tuple and block_tuple: + raise ValueError( + f"AggregateRadixTree: prefix_hash={prefix_hash} block_ids " + f"mismatch (existing={entry.block_ids}, new={block_tuple})" + ) + + was_ready = entry.is_fully_ready(self._total_sd_count) + existing_nid = entry.ready_sds.get(sd_key, None) + # If the SD has acked before and we now have a real node_id + # to fill in, update — otherwise leave it alone. This makes + # mark_sd_ready idempotent w.r.t. transition flag but still + # useful for late-binding the node_id. + if existing_nid is None or (existing_nid == -1 and int(node_id) != -1): + entry.ready_sds[sd_key] = int(node_id) + if contributing_peer: + entry.contributing_peers.add(contributing_peer) + became_ready = (not was_ready) and entry.is_fully_ready(self._total_sd_count) + if became_ready: + entry.first_ready_at = self._time_fn() + return became_ready + + def mark_sd_evicted(self, prefix_hash: int, sd_key: str) -> None: + """Remove ``sd_key`` from a prefix's ready-set. + + If the set becomes empty the prefix is dropped entirely. Silently + no-op if the prefix is unknown (matches the "Master single-handedly + evicts" semantics in design doc §4.3.1 — Remotes do not need to + observe an eviction).""" + with self._lock: + entry = self._prefixes.get(prefix_hash) + if entry is None: + return + entry.ready_sds.pop(sd_key, None) + if not entry.ready_sds: + self._drop_prefix_locked(entry) + + # ------------------------------------------------------------------ + # Match + # ------------------------------------------------------------------ + def match_fully_ready(self, prefix_hash: int) -> Optional[ReadyEntry]: + """Return the ``ReadyEntry`` iff ``prefix_hash`` is fully ready, else None. + + Note: this is *not* a longest-prefix matcher. The legacy + ``RefRadixTree`` does prefix matching; the aggregate layer just + gates whether each candidate from the legacy match is allowed + through. Callers iterate over candidate prefixes and ask us for a + gate decision. + """ + with self._lock: + entry = self._prefixes.get(prefix_hash) + if entry is None: + return None + if not entry.is_fully_ready(self._total_sd_count): + return None + # Defensive copy — keep the entry immutable from the caller's pov. + return ReadyEntry( + prefix_hash=entry.prefix_hash, + block_ids=entry.block_ids, + ready_sds=dict(entry.ready_sds), + contributing_peers=set(entry.contributing_peers), + first_ready_at=entry.first_ready_at, + ) + + # ------------------------------------------------------------------ + # Refcount + # ------------------------------------------------------------------ + def acquire(self, block_ids: Iterable[int]) -> None: + """Increment the refcount of every block in ``block_ids``. + + Refcounts are recorded *per block*, regardless of which prefix(es) + the block participates in. This matches design doc §4.3.1 + prerequisite B which says "the block has refcount > 0 if it is in + flight on any SD". + """ + now = self._time_fn() + with self._lock: + for raw in block_ids: + b = int(raw) + ent = self._refcounts.get(b) + if ent is None: + ent = _RefCountEntry() + self._refcounts[b] = ent + ent.count += 1 + ent.last_acquired_at = now + + def release(self, block_ids: Iterable[int]) -> None: + """Decrement refcounts. Raises :class:`BlockNotTrackedError` if a + block was never acquired (catches double-release bugs early).""" + with self._lock: + for raw in block_ids: + b = int(raw) + ent = self._refcounts.get(b) + if ent is None or ent.count <= 0: + raise BlockNotTrackedError( + f"AggregateRadixTree.release: block_id={b} not tracked or refcount already zero" + ) + ent.count -= 1 + if ent.count == 0: + # Reclaim memory eagerly — leaked entries can pile up + # in long-running processes otherwise. + del self._refcounts[b] + + def is_evictable(self, block_id: int) -> bool: + """Return True iff ``block_id`` has zero in-flight uses.""" + with self._lock: + ent = self._refcounts.get(int(block_id)) + return ent is None or ent.count <= 0 + + def get_refcount(self, block_id: int) -> int: + """Helper for tests / observability. Never raises.""" + with self._lock: + ent = self._refcounts.get(int(block_id)) + return ent.count if ent is not None else 0 + + def scan_leaked_refcount(self, timeout_seconds: float) -> List[int]: + """Return all block_ids whose refcount has been > 0 longer than + ``timeout_seconds``. + + The Master is expected to call this periodically (design doc §4.3.1 + prerequisite C "refcount timeout safety net") and then forcibly + zero each leaked refcount + invalidate the owning prefix(es). + """ + if timeout_seconds < 0: + raise ValueError(f"timeout_seconds must be >= 0, got {timeout_seconds!r}") + cutoff = self._time_fn() - timeout_seconds + with self._lock: + return [b for b, ent in self._refcounts.items() if ent.last_acquired_at <= cutoff] + + def force_release(self, block_id: int) -> int: + """Hard-reset a block's refcount to zero. + + Returns the previous refcount (0 if block was untracked). Designed + to be the second half of the leak-recovery sequence — call it for + every entry returned by :meth:`scan_leaked_refcount`, then call + :meth:`invalidate_prefix` to drop the prefix from the radix. + """ + with self._lock: + ent = self._refcounts.pop(int(block_id), None) + return ent.count if ent is not None else 0 + + # ------------------------------------------------------------------ + # Invalidation + # ------------------------------------------------------------------ + def invalidate_prefix(self, prefix_hash: int) -> bool: + """Drop a single prefix. Returns True if it existed.""" + with self._lock: + entry = self._prefixes.get(prefix_hash) + if entry is None: + return False + self._drop_prefix_locked(entry) + return True + + def invalidate_by_peer_instance(self, peer_instance_id: str) -> int: + """Batch-drop every prefix that lists ``peer_instance_id`` as a + contributing peer. Returns the number of prefixes invalidated. + + Design doc §4.3.2 Layer-1: when the failure detector observes + ``peer_instance_id`` go away (TTL expiry / epoch bump), we tear down + any "fully-ready" claim that depended on it.""" + if not peer_instance_id: + raise ValueError("peer_instance_id must be a non-empty str") + with self._lock: + victims = [ + e for e in self._prefixes.values() + if peer_instance_id in e.contributing_peers + ] + for e in victims: + self._drop_prefix_locked(e) + return len(victims) + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + def _drop_prefix_locked(self, entry: ReadyEntry) -> None: + """Remove a prefix from both the prefix map and the reverse index. + + Caller must hold ``self._lock``. + """ + self._prefixes.pop(entry.prefix_hash, None) + for b in entry.block_ids: + owners = self._block_to_prefixes.get(b) + if owners is None: + continue + owners.discard(entry.prefix_hash) + if not owners: + self._block_to_prefixes.pop(b, None) diff --git a/flexkv/common/dist_reuse/coordination_protocol.py b/flexkv/common/dist_reuse/coordination_protocol.py new file mode 100644 index 0000000000..26aa495dd9 --- /dev/null +++ b/flexkv/common/dist_reuse/coordination_protocol.py @@ -0,0 +1,184 @@ +"""Wire format for Master ↔ Remote coordination of distributed KV reuse. + +Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): the +``CoordQuery*`` / ``CoordGet*`` / ``CoordPut*`` message types from the +early implementation are **deleted** here. They were the dataclasses +behind the old per-SD ZMQ coord protocol; the unified graph-dispatch +path replaces them — peer-SD acks now arrive as +``CompletedOp(sd_key, contributing_node_id)`` through the existing +``TransferManagerMultiNodeHandle`` polling thread. + +What survives: + +* :class:`RemoteReadyMsg` — Remote → Master one-shot bootstrap ack. + Used during instance startup before the graph-dispatch path is up. +* :class:`FailureReportMsg` — Remote → Master data-plane failure ping + (Layer-2 closed loop in design doc §4.3.2). Asynchronous and + orthogonal to the per-PUT/GET coordination flow, so it stays on the + ZMQ side channel. + +All messages remain ``dataclass``-based so they pickle cleanly through +the current ZMQ payload format. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, ClassVar, Dict, List, Type, Union + + +__all__ = [ + "CoordMsgType", + "RemoteReadyMsg", + "FailureReportMsg", + "EpochVerifyError", + "encode_coord_message", + "decode_coord_message", +] + + +class CoordMsgType(str, Enum): + """Discriminator embedded in every coordination message. + + Phase D-4: only the bootstrap + failure-report types remain. The + PUT/GET coordination types (COORD_QUERY / COORD_GET_* / COORD_PUT_*) + were deleted with the unified graph-dispatch refactor. + """ + + REMOTE_READY = "remote_ready" + FAILURE_REPORT = "failure_report" + + +# --------------------------------------------------------------------------- +# Base +# --------------------------------------------------------------------------- +@dataclass +class _BaseCoordMsg: + """Common fields embedded in every wire message. + + ``epoch`` carries the *sender's* expected ``session_epoch``; receivers + cross-check it against their own and raise :class:`EpochVerifyError` + when they disagree (design doc §4.3.2 anti-replay rule). + """ + + # Class-level discriminator; every concrete subclass overrides this in + # ``__init_subclass__``. + type: ClassVar[CoordMsgType] + + # Identifies which FlexKV instance authored the message. This is the + # value of ``CacheConfig.instance_id`` of the sender. + sender_instance_id: str = "" + + # Per-message correlation ID (set by the originator; copied verbatim + # into the ack). ``-1`` is the sentinel for "not yet assigned". + request_id: int = -1 + + # Snapshot of the sender's ``session_epoch`` at the time the message was + # sent. Required by the receiver to invalidate stale traffic across an + # instance restart (design doc §4.3.2). Empty string permitted in tests. + sender_epoch: str = "" + + +def _msg_class_with_type(cls: Type[_BaseCoordMsg], type_value: CoordMsgType) -> Type[_BaseCoordMsg]: + """Helper to attach a ``type`` ClassVar to a dataclass.""" + cls.type = type_value # type: ignore[assignment] + return cls + + +# --------------------------------------------------------------------------- +# Remote → Master: ready handshake +# --------------------------------------------------------------------------- +@dataclass +class RemoteReadyMsg(_BaseCoordMsg): + """Sent by a Remote once its RedisMeta + Mooncake init finishes. + + The Master collects N-1 of these (one per non-master SD) before the + instance is considered ready (design doc §4.6.2 step 4). + """ + + sd_key: str = "" + distributed_node_id: int = -1 + mooncake_addr: str = "" + zmq_addr: str = "" + + +_msg_class_with_type(RemoteReadyMsg, CoordMsgType.REMOTE_READY) + + +# --------------------------------------------------------------------------- +# Layer-2 failure closed loop +# --------------------------------------------------------------------------- +@dataclass +class FailureReportMsg(_BaseCoordMsg): + """Remote → Master: a Mooncake P2P read or write hit a hard error. + + The Master invalidates the affected prefix in the aggregate radix and + optionally escalates to ``invalidate_all_by_instance`` after seeing + enough failures from the same peer (design doc §4.3.2 Layer-2). + """ + + peer_instance_id: str = "" + failed_block_hashes: List[int] = field(default_factory=list) + error: str = "" + + +_msg_class_with_type(FailureReportMsg, CoordMsgType.FAILURE_REPORT) + + +# --------------------------------------------------------------------------- +# Encoding helpers — protocol-level, not transport-level +# --------------------------------------------------------------------------- +class EpochVerifyError(RuntimeError): + """Raised when a receiver detects a stale ``sender_epoch``. The caller + is expected to translate this into a ``STALE_EPOCH`` response and let + the sender invalidate its view of the affected peer.""" + + +_TYPE_TO_CLASS: Dict[CoordMsgType, Type[_BaseCoordMsg]] = { + CoordMsgType.REMOTE_READY: RemoteReadyMsg, + CoordMsgType.FAILURE_REPORT: FailureReportMsg, +} + + +AnyCoordMsg = Union[ + RemoteReadyMsg, + FailureReportMsg, +] + + +def encode_coord_message(msg: AnyCoordMsg) -> Dict[str, Any]: + """Convert a dataclass message to a plain ``dict`` (transport-agnostic). + + The result is JSON-serializable as long as the embedded fields are + (block hashes are signed 64-bit ints, which JSON handles fine). + """ + if not isinstance(msg, tuple(_TYPE_TO_CLASS.values())): + raise TypeError(f"encode_coord_message: not a coord message: {type(msg).__name__}") + out: Dict[str, Any] = {"type": msg.type.value} + for f in msg.__dataclass_fields__.values(): # type: ignore[attr-defined] + out[f.name] = getattr(msg, f.name) + return out + + +def decode_coord_message(payload: Dict[str, Any]) -> AnyCoordMsg: + """Inverse of :func:`encode_coord_message`. + + Raises ``ValueError`` if ``payload['type']`` is missing or unknown. + """ + raw_type = payload.get("type") + if raw_type is None: + raise ValueError("decode_coord_message: missing 'type' field") + try: + mtype = CoordMsgType(raw_type) + except ValueError as e: + raise ValueError(f"decode_coord_message: unknown type {raw_type!r}") from e + cls = _TYPE_TO_CLASS[mtype] + # Drop ``type`` before delegating to the dataclass ctor. + fields_payload = {k: v for k, v in payload.items() if k != "type"} + # Validate that no unknown extra keys leak in (catches schema drift). + valid_names = {f.name for f in cls.__dataclass_fields__.values()} # type: ignore[attr-defined] + extra = set(fields_payload) - valid_names + if extra: + raise ValueError(f"decode_coord_message: unknown fields for {mtype.value}: {sorted(extra)}") + return cls(**fields_payload) # type: ignore[return-value] diff --git a/flexkv/common/dist_reuse/failure_detector.py b/flexkv/common/dist_reuse/failure_detector.py new file mode 100644 index 0000000000..4bba9c6fad --- /dev/null +++ b/flexkv/common/dist_reuse/failure_detector.py @@ -0,0 +1,388 @@ +"""Layer-1 failure detector — Redis session + epoch heartbeat. + +Design doc §4.3.2 Layer-1. Each FlexKV instance writes +``flexkv:instance::session`` with a TTL of a few seconds and renews it +every TTL/3. Peer instances poll the ``flexkv:instance:*:session`` keyspace +periodically and react to two events: + +* **session key disappears** → peer was killed / lost network (TTL expiry). +* **epoch field changes** → peer was restarted (cold boot) since we last + looked. + +Either case triggers ``on_peer_lost(peer_instance_id)``, which the +:class:`AggregateRadixTree` consumer turns into a batch invalidation. + +This module is **transport-agnostic**: it works against any object that +exposes the small subset of redis-py operations we need (``set/get/expire/ +scan_iter``). Tests use an in-memory fake to avoid bringing up a real Redis +server. + +Layer-2 (data-plane closed loop on Mooncake P2P read/write failure) is +implemented in :mod:`flexkv.cache.coordination_protocol`'s +:class:`FailureReportMsg` handler — not here. + +Phase 0 task 0-L. +""" + +from __future__ import annotations + +import json +import logging +import threading +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, Optional, Protocol, Set + +from .sharing_domain_namespace import SharingDomainNamespace + + +__all__ = [ + "InstanceSession", + "RedisSessionClient", + "FailureDetector", + "make_session_epoch", + "make_redis_client_from_cache_config", +] + + +_LOG = logging.getLogger("flexkv.failure_detector") + + +def make_session_epoch() -> str: + """Generate a fresh, monotonic-ish session epoch string. + + Combines the current monotonic-ms timestamp (so two epochs from the + same process can be totally ordered) and a uuid4 suffix (so two + distinct processes never collide). Format:: + + <12-hex-of-monotonic-ms>-<8-hex-of-uuid4> + """ + ms = int(time.monotonic() * 1000) & 0xFFFFFFFFFFFF + rand = uuid.uuid4().hex[:8] + return f"{ms:012x}-{rand}" + + +# --------------------------------------------------------------------------- +# Data shapes +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class InstanceSession: + """Decoded ``flexkv:instance::session`` payload.""" + + instance_id: str + epoch: str + master_zmq_addr: str + node_ids: tuple = () + mooncake_addrs_by_sd: Dict[str, str] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Redis abstraction +# --------------------------------------------------------------------------- +class _RedisLike(Protocol): + """Subset of ``redis.Redis`` actually used here. + + Defined as a Protocol so tests can plug in an in-memory fake. + """ + + def set(self, name: str, value: str, ex: Optional[int] = None) -> Any: ... + def get(self, name: str) -> Any: ... + def expire(self, name: str, ex: int) -> Any: ... + def delete(self, *names: str) -> Any: ... + def scan_iter(self, match: Optional[str] = None, count: Optional[int] = None) -> Iterable[Any]: ... + + +class RedisSessionClient: + """Thin wrapper that owns the *write* side of an instance session. + + Wires up: + - :meth:`register` — initial ``SET ... EX ttl`` of the session payload. + - :meth:`renew` — periodic ``EXPIRE`` (or full re-set if missing). + - :meth:`unregister` — clean shutdown. + + The class is **not** responsible for running the renewal loop; that is + :class:`FailureDetector`'s job (which combines write + read sides). + """ + + def __init__( + self, + redis_client: _RedisLike, + *, + instance_id: str, + epoch: str, + ttl_seconds: int, + master_zmq_addr: str = "", + node_ids: Optional[Iterable[int]] = None, + mooncake_addrs_by_sd: Optional[Dict[str, str]] = None, + ) -> None: + if ttl_seconds < 1: + raise ValueError(f"ttl_seconds must be >= 1, got {ttl_seconds!r}") + self._client = redis_client + self._instance_id = instance_id + self._epoch = epoch + self._ttl = ttl_seconds + self._key = SharingDomainNamespace.instance_session_key(instance_id) + # Snapshot the static portion of the payload once. + self._payload: Dict[str, Any] = { + "instance_id": instance_id, + "epoch": epoch, + "master_zmq_addr": master_zmq_addr, + "node_ids": list(node_ids or []), + "mooncake_addrs_by_sd": dict(mooncake_addrs_by_sd or {}), + } + + @property + def instance_id(self) -> str: + return self._instance_id + + @property + def epoch(self) -> str: + return self._epoch + + @property + def key(self) -> str: + return self._key + + def register(self) -> None: + """Write the session payload with a TTL. Overwrites any existing + key — by design, a restarted instance must replace its old record.""" + self._client.set(self._key, json.dumps(self._payload), ex=self._ttl) + + def renew(self) -> None: + """Refresh the TTL. If the key has expired since last renewal we + re-write the full payload to avoid the "ghost peer" scenario where + a watchdog observes the gap as a restart.""" + ok = bool(self._client.expire(self._key, self._ttl)) + if not ok: + self.register() + + def unregister(self) -> None: + try: + self._client.delete(self._key) + except Exception as e: # pragma: no cover — best-effort cleanup + _LOG.warning("unregister(%s) failed: %s", self._key, e) + + +# --------------------------------------------------------------------------- +# Detector +# --------------------------------------------------------------------------- +class FailureDetector: + """Polls Redis for peer instance liveness and fires user callbacks. + + Lifecycle: + + >>> detector = FailureDetector(client, "self-instance", on_peer_lost=cb) + >>> detector.start() + ... # ... run normally ... + >>> detector.stop() + + The callbacks must be **thread-safe and fast** — they execute on the + detector's polling thread. The recommended pattern is to enqueue work + into the aggregate radix's internal lock-protected state and return + immediately. + """ + + def __init__( + self, + redis_client: _RedisLike, + self_instance_id: str, + *, + poll_interval_seconds: float = 2.0, + on_peer_lost: Optional[Callable[[str], None]] = None, + on_peer_seen: Optional[Callable[[str, InstanceSession], None]] = None, + time_fn: Callable[[], float] = time.monotonic, + ) -> None: + if poll_interval_seconds <= 0: + raise ValueError(f"poll_interval_seconds must be > 0, got {poll_interval_seconds!r}") + if not self_instance_id: + raise ValueError("self_instance_id must be a non-empty str") + self._client = redis_client + self._self_instance_id = self_instance_id + self._poll_interval = poll_interval_seconds + self._on_peer_lost = on_peer_lost or (lambda _pid: None) + self._on_peer_seen = on_peer_seen or (lambda _pid, _s: None) + self._time_fn = time_fn + + # peer_instance_id -> last observed epoch + self._known_peers: Dict[str, str] = {} + self._stop_event = threading.Event() + self._thread: Optional[threading.Thread] = None + # Used by tests to drive a single iteration deterministically. + self._iteration_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + def start(self) -> None: + if self._thread is not None and self._thread.is_alive(): + raise RuntimeError("FailureDetector already started") + self._stop_event.clear() + self._thread = threading.Thread( + target=self._run, + name="flexkv-failure-detector", + daemon=True, + ) + self._thread.start() + + def stop(self, timeout: Optional[float] = 2.0) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=timeout) + self._thread = None + + # ------------------------------------------------------------------ + # Polling loop + # ------------------------------------------------------------------ + def _run(self) -> None: # pragma: no cover — exercised through tests via poll_once() + while not self._stop_event.is_set(): + try: + self.poll_once() + except Exception as e: + _LOG.exception("FailureDetector poll error: %s", e) + self._stop_event.wait(self._poll_interval) + + def poll_once(self) -> None: + """Run a single polling cycle. Public to let tests drive the + detector deterministically without bringing up the thread.""" + with self._iteration_lock: + current = self._scan_peers() + + # Handle disappeared peers first (TTL expiry). + disappeared = set(self._known_peers) - set(current) - {self._self_instance_id} + for pid in disappeared: + _LOG.info("FailureDetector: peer %s disappeared (TTL expiry)", pid) + self._invoke_lost(pid) + self._known_peers.pop(pid, None) + + # Handle new + epoch-changed peers. + for pid, session in current.items(): + if pid == self._self_instance_id: + continue + prev = self._known_peers.get(pid) + if prev is None: + _LOG.info("FailureDetector: peer %s appeared (epoch=%s)", pid, session.epoch) + self._invoke_seen(pid, session) + elif prev != session.epoch: + _LOG.info("FailureDetector: peer %s restarted (epoch %s -> %s)", pid, prev, session.epoch) + # Treat epoch change as "lost then seen" — the lost + # callback invalidates stale state and the seen + # callback re-registers. + self._invoke_lost(pid) + self._invoke_seen(pid, session) + self._known_peers[pid] = session.epoch + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _scan_peers(self) -> Dict[str, InstanceSession]: + out: Dict[str, InstanceSession] = {} + pattern = SharingDomainNamespace.instance_session_key_pattern() + for raw_key in self._client.scan_iter(match=pattern, count=100): + key = raw_key.decode("utf-8") if isinstance(raw_key, (bytes, bytearray)) else str(raw_key) + try: + pid = SharingDomainNamespace.parse_instance_session_key(key) + except ValueError: + continue + raw_value = self._client.get(key) + if raw_value is None: + continue + value = ( + raw_value.decode("utf-8") + if isinstance(raw_value, (bytes, bytearray)) + else str(raw_value) + ) + try: + payload = json.loads(value) + except (TypeError, ValueError): + _LOG.warning("FailureDetector: malformed session payload at %s", key) + continue + try: + session = InstanceSession( + instance_id=str(payload["instance_id"]), + epoch=str(payload["epoch"]), + master_zmq_addr=str(payload.get("master_zmq_addr", "")), + node_ids=tuple(payload.get("node_ids", ())), + mooncake_addrs_by_sd=dict(payload.get("mooncake_addrs_by_sd", {})), + ) + except KeyError as e: + _LOG.warning("FailureDetector: missing field %s in %s", e, key) + continue + out[session.instance_id] = session + return out + + def _invoke_lost(self, pid: str) -> None: + try: + self._on_peer_lost(pid) + except Exception: # pragma: no cover — defensive logging + _LOG.exception("on_peer_lost(%s) raised", pid) + + def _invoke_seen(self, pid: str, session: InstanceSession) -> None: + try: + self._on_peer_seen(pid, session) + except Exception: # pragma: no cover — defensive logging + _LOG.exception("on_peer_seen(%s) raised", pid) + + # ------------------------------------------------------------------ + # Inspection + # ------------------------------------------------------------------ + def known_peers(self) -> Set[str]: + with self._iteration_lock: + return set(self._known_peers) + + +# --------------------------------------------------------------------------- +# redis-py client factory — single source of truth for the flexkv_redis_db +# resolution rule. All dist-reuse code paths that need a raw ``redis.Redis`` +# (e.g. for ``RedisSessionClient`` / ``FailureDetector``) should go through +# this helper so the ``flexkv_redis_db`` override from ``CacheConfig`` is +# honoured in exactly one place. +# --------------------------------------------------------------------------- +def make_redis_client_from_cache_config( + cache_config: Any, + *, + decode_responses: bool = True, + socket_connect_timeout: Optional[float] = None, +) -> Any: + """Construct a ``redis.Redis`` client from the given ``CacheConfig``. + + Pulls ``host`` / ``port`` / ``password`` / ``flexkv_redis_db`` (with + legacy fallback to ``0``) off the config and returns a ready-to-use + client. Importing ``redis`` is deferred so callers that run in a + CPU-only unit-test environment without the redis-py dependency still + get a useful ``ImportError``. + + Args: + cache_config: A :class:`~flexkv.common.config.CacheConfig` or any + duck-typed object exposing the same attrs. + decode_responses: Forwarded to ``redis.Redis`` — most FlexKV code + expects ``str`` round-trips so this defaults to ``True``. + socket_connect_timeout: Optional connect timeout in seconds. Set + this to a small value (e.g. 1.0) in tests to fail fast when + Redis is absent. + + Returns: + A ``redis.Redis`` instance bound to ``cache_config.flexkv_redis_db``. + + Raises: + ImportError: If ``redis-py`` is not installed. + """ + try: + import redis as _redis # type: ignore[import-not-found] + except ImportError as e: + raise ImportError( + "redis-py is required for dist-reuse operations: pip install redis" + ) from e + + kwargs: Dict[str, Any] = { + "host": getattr(cache_config, "redis_host", "127.0.0.1"), + "port": int(getattr(cache_config, "redis_port", 6379)), + "db": int(getattr(cache_config, "flexkv_redis_db", 0)), + "decode_responses": decode_responses, + } + password = getattr(cache_config, "redis_password", None) + if password: + kwargs["password"] = password + if socket_connect_timeout is not None: + kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + return _redis.Redis(**kwargs) diff --git a/flexkv/common/dist_reuse/master_coordinator.py b/flexkv/common/dist_reuse/master_coordinator.py new file mode 100644 index 0000000000..4df1c2264a --- /dev/null +++ b/flexkv/common/dist_reuse/master_coordinator.py @@ -0,0 +1,621 @@ +"""Master-side sharing-domain orchestration. + +Phase 0 task 0-G / 0-H-integration / 0-K: + +* :func:`build_sharing_domain_handles` — given a ``(ModelConfig, + CacheConfig, SharingDomainKey)`` triple, enumerate the full set of SDs + in the instance and construct the correct ``TransferManagerHandle`` + list (1 × "process" for the Master SD + N-1 × "remote" for each peer SD). + +* :class:`MasterCoordinator` — owns the instance's + :class:`AggregateRadixTree`, glues Remote ``RemoteReadyMsg`` arrivals + to instance-level Redis registration, and exposes small helpers + (``acquire_blocks`` / ``release_blocks`` / ``mark_sd_ready``) that + ``GlobalCacheEngine`` calls from its match / put / evict paths. + +* :func:`graph_needs_gpu_clear` — tiny predicate used by + ``KVTaskManager._launch_task`` to decide whether a remote-submitted + :class:`TransferOpGraph` needs its GPU block IDs cleared (PP / CP + Remotes need a fresh slot_mapping; TP Remotes share the Master's one). + +This module is **Python-only** and free of GPU/C++ imports, so it can be +unit-tested on CPU-only machines. Its callers in +``flexkv.transfer_manager`` / ``flexkv.kvtask`` / ``flexkv.cache.cache_engine`` +import it conditionally — disabling ``CacheConfig.enable_sharing_domain`` +keeps the legacy paths untouched. +""" + +from __future__ import annotations + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +from .aggregate_radix import AggregateRadixTree +from .coordination_protocol import RemoteReadyMsg +from .failure_detector import FailureDetector, RedisSessionClient, make_session_epoch +from .sharing_domain import SharingDomainKey +from .sharing_domain_namespace import SharingDomainNamespace + + +__all__ = [ + "MasterCoordinator", + "SharingDomainHandleSpec", + "build_sharing_domain_handles", + "graph_needs_gpu_clear", +] + + +_LOG = logging.getLogger("flexkv.dist_reuse.master") + + +# --------------------------------------------------------------------------- +# Handle spec (used by KVTaskManager) +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class SharingDomainHandleSpec: + """Static description of **one** TransferManagerHandle the Master has + to create. ``KVTaskManager`` consumes a list of these. + + * ``sd_key`` — which SD this handle serves. + * ``mode`` — ``"process"`` for the Master's own SD (in-proc + TransferManager) or ``"remote"`` for every other SD. + * ``endpoint`` — populated for ``mode="remote"`` only; the Master + reads this from ``CacheConfig.remote_endpoints_by_sd`` (task 0-J). + ``None`` for the Master's own SD. + """ + + sd_key: SharingDomainKey + mode: str + endpoint: Optional[Any] = None # RemoteEndpoint; typed Any to dodge config-import cycle + + +def build_sharing_domain_handles( + *, + self_sd: SharingDomainKey, + remote_endpoints_by_sd: Optional[Dict[str, Any]] = None, +) -> List[SharingDomainHandleSpec]: + """Enumerate every SD in the instance and produce a handle spec. + + ``self_sd`` is the SD the Master owns (produced by + :meth:`SharingDomainKey.from_model_config` on the Master node). The + function returns one :class:`SharingDomainHandleSpec` per SD, + ordered so the Master's own SD is first (index 0). This preserves + the legacy invariant that ``self.transfer_handles[0]`` is the local + in-process handle. + + ``remote_endpoints_by_sd`` maps serialized SD key strings to + ``RemoteEndpoint`` instances from ``CacheConfig``. A ``KeyError`` is + raised eagerly for any non-master SD without an endpoint so the + Master fails fast at startup instead of silently running with + fewer handles than it thinks. + """ + remote_endpoints_by_sd = remote_endpoints_by_sd or {} + + specs: List[SharingDomainHandleSpec] = [] + + # Master's own SD first. + specs.append(SharingDomainHandleSpec( + sd_key=self_sd, mode="process", endpoint=None, + )) + + for peer_sd in self_sd.enumerate_peers(): + if peer_sd == self_sd: + continue + key_str = peer_sd.serialize() + endpoint = remote_endpoints_by_sd.get(key_str) + if endpoint is None: + raise KeyError( + f"build_sharing_domain_handles: missing endpoint for peer SD " + f"{key_str!r}. Populate CacheConfig.remote_endpoints_by_sd " + f"from the framework launcher before constructing KVManager." + ) + specs.append(SharingDomainHandleSpec( + sd_key=peer_sd, mode="remote", endpoint=endpoint, + )) + return specs + + +# --------------------------------------------------------------------------- +# GPU-clear predicate (task 0-K) +# --------------------------------------------------------------------------- +def graph_needs_gpu_clear(self_sd: SharingDomainKey, peer_sd: SharingDomainKey) -> bool: + """Return True iff a TransferOpGraph submitted to ``peer_sd`` needs + its GPU blocks reset before sending. + + Rule of thumb (design doc §4.12.2 (4), simplified): + + * **Different ``pp_node_idx``** → True — the peer's scheduler sees a + different slot_mapping for the same task because the peer node owns a + different layer range of the pipeline. + * **Only ``tp_node_idx`` differs** → False — cross-node TP shares the + same slot_mapping across all TP ranks (just with different head shards). + * **Same SD as self** → False — it's the local graph. + + CP is intentionally absent from the SD key (see simplified design §4.5), + so any CP-level differences are handled at the connector layer (sync_leader + scatter) and do not show up here. + + Note on the rank→node collapse: under the simplified schema, two PP + ranks co-located on the same physical node share the same ``sd_key`` + (same ``pp_node_idx``). Such co-located peers are not enumerated as + remotes by :meth:`SharingDomainKey.enumerate_peers` so this function + is never asked about them, which is the desired behavior — same-node + cross-PP routing is a pool-layout concern, not a graph-clear concern. + """ + if peer_sd == self_sd: + return False + return peer_sd.pp_node_idx != self_sd.pp_node_idx + + +# --------------------------------------------------------------------------- +# Master coordinator +# --------------------------------------------------------------------------- +class MasterCoordinator: + """Owns the per-instance AggregateRadixTree + failure detector + Remote + ready-wait handshake. + + The ``GlobalCacheEngine`` holds exactly one of these when + ``CacheConfig.enable_sharing_domain=True``. The class is deliberately + framework-agnostic: it doesn't import ``torch``, ``zmq``, or + ``transfer_manager``; it only exposes pure-Python hooks that the + relevant modules call from within their own flow. + + Lifecycle: + + 1. Construct with the Master's own ``SharingDomainKey``. + 2. Call :meth:`expect_remotes` once you've built the handle list — + tells the coordinator how many Remote ready acks to wait for. + 3. For each incoming :class:`RemoteReadyMsg`, call + :meth:`on_remote_ready`. Returns True when the last Remote has + reported, at which point ``register_instance_discoverables`` + is safe to call. + 4. ``get_match`` / ``put_match`` / ``_transfer_callback`` in + ``GlobalCacheEngine`` call ``acquire_blocks`` / ``release_blocks`` / + ``mark_sd_ready`` to update the aggregate radix. + 5. ``evict`` in the hierarchical cache engine calls + :meth:`is_evictable` before actually dropping a block. + 6. The failure detector fires :meth:`on_peer_lost` which invalidates + every aggregate entry contributed by the dead peer. + """ + + def __init__( + self, + *, + self_sd: SharingDomainKey, + instance_id: str, + session_epoch: Optional[str] = None, + refcount_leak_timeout_seconds: float = 30.0, + failure_escalation_threshold: int = 3, + ) -> None: + self._self_sd = self_sd + self._namespace = SharingDomainNamespace(self_sd) + self._instance_id = str(instance_id) + self._session_epoch = session_epoch or make_session_epoch() + + # The Master's own distributed_node_id. Filled in by + # ``register_instance_discoverables`` once the Master knows what + # it was assigned in Redis. Defaults to ``-1`` (sentinel for + # "not yet known"); ``GlobalCacheEngine._notify_master_sd_ready`` + # threads this value into ``mark_sd_ready`` so the per-SD + # node_id map starts populated for the master SD on the very + # first ack. + self._self_node_id: int = -1 + + self._aggregate = AggregateRadixTree(total_sd_count=self_sd.total_sd_count()) + self._refcount_leak_timeout = float(refcount_leak_timeout_seconds) + + # Phase D-4: Layer-2 closed loop bookkeeping (migrated from + # the now-deleted CoordinationCoordinator). + self._failure_escalation_threshold = int(failure_escalation_threshold) + self._peer_failure_counts: Dict[str, int] = {} + + self._lock = threading.RLock() + + # Filled in by on_remote_ready() + self._expected_remote_count: Optional[int] = None + self._ready_remotes: Dict[str, RemoteReadyMsg] = {} + + # Populated on register_instance_discoverables() + self._session_client: Optional[RedisSessionClient] = None + self._failure_detector: Optional[FailureDetector] = None + + # Optional external callback invoked from ``_on_peer_lost`` after + # the internal aggregate invalidation. Set via + # :meth:`set_peer_lost_hook`. + self._extra_peer_lost_hook: Optional[Any] = None + + # ---------------------------------------------------------------- state + @property + def self_sd(self) -> SharingDomainKey: + return self._self_sd + + @property + def namespace(self) -> SharingDomainNamespace: + return self._namespace + + @property + def instance_id(self) -> str: + return self._instance_id + + @property + def session_epoch(self) -> str: + return self._session_epoch + + @property + def aggregate_radix(self) -> AggregateRadixTree: + return self._aggregate + + @property + def self_node_id(self) -> int: + """Distributed node_id of the Master itself. ``-1`` until + :meth:`register_instance_discoverables` is called. + """ + return int(self._self_node_id) + + @property + def total_sd_count(self) -> int: + """Total number of SDs in this instance. == ``self_sd.total_sd_count()``. + + Used by the GET main path to decide whether a coord GET + barrier is needed (``> 1``) or we can short-circuit + (``== 1``, single-SD degenerate case). + """ + return int(self._self_sd.total_sd_count()) + + def expect_remotes(self, count: int) -> None: + """Tell the coordinator how many ``RemoteReadyMsg`` to wait for. + + Must be called before :meth:`on_remote_ready`. ``count`` is the + length of ``transfer_handles`` minus 1 (the Master's own handle). + """ + if count < 0: + raise ValueError(f"count must be >= 0, got {count}") + with self._lock: + if self._expected_remote_count is not None: + raise RuntimeError("expect_remotes() has already been called") + self._expected_remote_count = int(count) + + # ---------------------------------------------------------------- remotes + def on_remote_ready(self, msg: RemoteReadyMsg) -> bool: + """Record a Remote's ready-ack. Returns True when the last Remote + has reported (i.e. ``len(ready) == expected``). Further calls + return True too (idempotent).""" + if self._expected_remote_count is None: + raise RuntimeError( + "on_remote_ready called before expect_remotes(); wire up the " + "handle list first" + ) + with self._lock: + # Ack comes straight off the wire — keep the canonical SD-key + # as the dict key so tests can assert on it. + self._ready_remotes[msg.sd_key] = msg + return len(self._ready_remotes) >= self._expected_remote_count + + def all_remotes_ready(self) -> bool: + with self._lock: + return ( + self._expected_remote_count is not None + and len(self._ready_remotes) >= self._expected_remote_count + ) + + def ready_remote_infos(self) -> Dict[str, RemoteReadyMsg]: + with self._lock: + return dict(self._ready_remotes) + + def build_sd_to_nid_map(self, self_node_id: int) -> Dict[str, int]: + """Produce the ``sd_key → node_id`` mapping for + ``RedisMeta.register_instance_sd_nodes``.""" + with self._lock: + out: Dict[str, int] = {self._self_sd.serialize(): int(self_node_id)} + for sd_key_str, msg in self._ready_remotes.items(): + out[sd_key_str] = int(msg.distributed_node_id) + return out + + def get_sd_to_nid_map(self) -> Dict[str, int]: + """Public read-only accessor — Phase D-2 (proposal §6.3): the + Master's PUT-path graph builder needs to attach + ``target_node_ids=[peer_node_id]`` to each peer-SD D2H op so that + each Remote's ``_handle_submit`` filter picks up the right slice. + + Returns ``{sd_key_str: distributed_node_id}`` for every SD known + to this MasterCoordinator (master's own SD + every Remote that + has finished its ready handshake). Returns an empty dict if the + Master's own ``self_node_id`` hasn't been set yet (i.e. + ``register_instance_discoverables`` hasn't been called). + """ + with self._lock: + if int(self._self_node_id) < 0: + # Bootstrap not finished — no canonical mapping yet. + return {} + out: Dict[str, int] = { + self._self_sd.serialize(): int(self._self_node_id), + } + for sd_key_str, msg in self._ready_remotes.items(): + out[sd_key_str] = int(msg.distributed_node_id) + return out + + # ------------------------------------------------------------ lookup + def lookup_peer_by_node_id(self, node_id: int) -> Optional[Dict[str, str]]: + """Reverse-lookup: given a (global) distributed_node_id, return + the peer SD it belongs to plus the peer instance_id that owns + that SD. + + Returns a dict with keys ``sd_key_str`` and ``instance_id``, + or ``None`` if the node_id isn't one of our ready peers (i.e. + it's either this instance's own node, unknown, or belongs to + an instance that hasn't finished its ready handshake). + + Used by the GET main-path glue + (``GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops``, + Phase D-3) and by ``handle_failure_report`` to map an + offending node_id back to the peer SD + instance it sits on. + """ + if node_id is None: + return None + try: + nid = int(node_id) + except Exception: + return None + if nid < 0: + return None + with self._lock: + for sd_key_str, msg in self._ready_remotes.items(): + if int(getattr(msg, "distributed_node_id", -1)) == nid: + return { + "sd_key_str": sd_key_str, + "instance_id": str(getattr(msg, "sender_instance_id", "") or ""), + } + return None + + # -------------------------------------------------------- pin helpers + def pin_blocks_for_coord_get(self, block_ids: Iterable[int]) -> None: + """Refcount-pin ``block_ids`` against Master-side eviction while + an in-flight coord GET is expected to land on them. + + Thin alias for :meth:`acquire_blocks` — kept so the GET-path + glue reads as intent rather than the lower-level primitive. + Used in conjunction with the multi-SD PEERH2H fan-out + (``GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops``, + Phase D-3). + """ + self._aggregate.acquire(block_ids) + + def unpin_blocks_for_coord_get(self, block_ids: Iterable[int]) -> None: + """Release the refcount pin set by + :meth:`pin_blocks_for_coord_get`. Must be called on both the + success and failure paths of the coord GET. + """ + self._aggregate.release(block_ids) + + # ------------------------------------------------------------- discovery + def register_instance_discoverables( + self, + *, + redis_meta: Any, + self_node_id: int, + master_zmq_addr: str, + mooncake_addrs_by_sd: Optional[Dict[str, str]] = None, + ttl_seconds: int = 8, + ) -> None: + """Write the two instance-level Redis keys (design doc §4.6.1 / + §4.7.1.6) and start the heartbeat session. Safe to call once all + Remotes have acked; raises RuntimeError otherwise. + """ + with self._lock: + if not self.all_remotes_ready(): + raise RuntimeError( + "register_instance_discoverables requires all Remotes to be ready; " + f"got {len(self._ready_remotes)} / {self._expected_remote_count}" + ) + sd_to_nid = self.build_sd_to_nid_map(self_node_id) + # Cache the master's own node_id so PUT-path callers can + # populate AggregateRadixTree entries without re-deriving it. + self._self_node_id = int(self_node_id) + + redis_meta.register_instance_sd_nodes(self._instance_id, sd_to_nid) + + # The failure detector's RedisSessionClient owns its own redis client + # — it does NOT share ``redis_meta``'s client connection because + # the heartbeat runs on a background thread. + mooncake_addrs = dict(mooncake_addrs_by_sd or {}) + redis_client = redis_meta._client() # private but stable — see redis_meta.py + self._session_client = RedisSessionClient( + redis_client, + instance_id=self._instance_id, + epoch=self._session_epoch, + ttl_seconds=int(ttl_seconds), + master_zmq_addr=master_zmq_addr, + node_ids=list(sd_to_nid.values()), + mooncake_addrs_by_sd=mooncake_addrs, + ) + self._session_client.register() + + def start_failure_detector( + self, + redis_meta: Any, + *, + poll_interval_seconds: float = 2.0, + ) -> FailureDetector: + """Spawn the background :class:`FailureDetector` that scans peer + instances and invalidates aggregate-radix entries on peer loss. + + Returns the detector so the caller can stop it on shutdown. + """ + fd = FailureDetector( + redis_meta._client(), + self_instance_id=self._instance_id, + poll_interval_seconds=poll_interval_seconds, + on_peer_lost=self._on_peer_lost, + ) + fd.start() + self._failure_detector = fd + return fd + + def shutdown(self) -> None: + """Graceful teardown — stop the detector, unregister session.""" + if self._failure_detector is not None: + try: + self._failure_detector.stop() + except Exception: # pragma: no cover + _LOG.exception("FailureDetector.stop() raised") + self._failure_detector = None + if self._session_client is not None: + try: + self._session_client.unregister() + except Exception: # pragma: no cover + _LOG.exception("RedisSessionClient.unregister() raised") + self._session_client = None + + # ------------------------------------------------------------- hooks + def acquire_blocks(self, block_ids: Iterable[int]) -> None: + self._aggregate.acquire(block_ids) + + def release_blocks(self, block_ids: Iterable[int]) -> None: + self._aggregate.release(block_ids) + + def is_evictable(self, block_id: int) -> bool: + return self._aggregate.is_evictable(block_id) + + def mark_sd_ready( + self, + prefix_hash: int, + sd_key_str: str, + block_ids: Iterable[int], + *, + contributing_peer: Optional[str] = None, + node_id: int = -1, + ) -> bool: + return self._aggregate.mark_sd_ready( + prefix_hash, sd_key_str, block_ids, + contributing_peer=contributing_peer, + node_id=int(node_id), + ) + + def mark_sd_evicted(self, prefix_hash: int, sd_key_str: str) -> None: + self._aggregate.mark_sd_evicted(prefix_hash, sd_key_str) + + def match_fully_ready(self, prefix_hash: int) -> Any: + return self._aggregate.match_fully_ready(prefix_hash) + + def invalidate_prefix(self, prefix_hash: int) -> bool: + return self._aggregate.invalidate_prefix(prefix_hash) + + # --------------------------------------------------- periodic scans + def scan_leaked_refcount(self) -> List[int]: + """Called periodically by the KVTaskManager's background thread. + + For every block that has been in-flight too long (design doc + §4.3.1 prerequisite C), force-release its refcount and + invalidate any prefix that owns it. + """ + leaked = self._aggregate.scan_leaked_refcount(self._refcount_leak_timeout) + for block_id in leaked: + self._aggregate.force_release(block_id) + return leaked + + # --------------------------------------------------- failure callbacks + def set_peer_lost_hook(self, cb: Optional[Any]) -> None: + """Register an **additional** callback to fire when a peer + instance is lost (Layer-1 session TTL expiry or epoch bump). + + The coordinator already invalidates the aggregate radix + internally; this hook lets upstream consumers (e.g. + :class:`GlobalCacheEngine`) react too (e.g. flush in-flight + coord requests targeting the dead peer). + + Passing ``None`` clears the hook. + """ + with self._lock: + self._extra_peer_lost_hook = cb + + def invalidate_by_peer_instance(self, peer_instance_id: str) -> int: + """Public delegate to the aggregate radix; returns the number + of invalidated prefixes. Safe to call from any thread.""" + return self._aggregate.invalidate_by_peer_instance(peer_instance_id) + + # ---------------------------------------------- Layer-2 failure handling + # Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md §附录 A): + # ``_handle_failure_report`` migrated here from the now-deleted + # ``CoordinationCoordinator``. Layer-2 closed loop: a Remote + # reports a Mooncake transfer failure; Master invalidates the + # affected prefix and escalates to a full-peer invalidation after + # ``failure_escalation_threshold`` repeat reports. + def handle_failure_report(self, report) -> None: + """Invalidate the reported prefix and escalate on repeated + failures from the same peer instance. + + ``report`` must duck-type as a ``FailureReportMsg`` — + ``peer_instance_id`` (str) and ``failed_block_hashes`` + (Iterable[int]). Anything else is silently ignored to keep + callers (the master polling worker, fault-injection tests) + defensive. + """ + peer = getattr(report, "peer_instance_id", "") or "" + if not peer: + return + for h in (getattr(report, "failed_block_hashes", []) or []): + try: + self._aggregate.invalidate_prefix(int(h)) + except Exception: # pragma: no cover + pass + + with self._lock: + self._peer_failure_counts[peer] = self._peer_failure_counts.get(peer, 0) + 1 + escalate = self._peer_failure_counts[peer] >= int(self._failure_escalation_threshold) + + if escalate: + self._aggregate.invalidate_by_peer_instance(peer) + with self._lock: + self._peer_failure_counts[peer] = 0 + _LOG.warning( + "[MasterCoordinator:%s] Layer-2 escalated to full-peer " + "invalidation for %s", + self._instance_id, peer, + ) + + def peer_failure_count(self, peer_instance_id: str) -> int: + """Number of unescalated failures from ``peer_instance_id`` — + used by tests and ops dashboards.""" + with self._lock: + return self._peer_failure_counts.get(peer_instance_id, 0) + + def _on_peer_lost(self, peer_instance_id: str) -> None: + """Invoked by the FailureDetector on peer disappearance / epoch bump. + + Batch-invalidates every aggregate entry that listed the lost peer + as a contributor (design doc §4.3.2 Layer-1), then runs the + optional user-registered hook so higher layers can react too. + """ + n = self._aggregate.invalidate_by_peer_instance(peer_instance_id) + if n: + _LOG.info( + "[MasterCoordinator:%s] peer %s lost; invalidated %d prefixes", + self._instance_id, peer_instance_id, n, + ) + extra = getattr(self, "_extra_peer_lost_hook", None) + if extra is not None: + try: + extra(peer_instance_id) + except Exception: # pragma: no cover — defensive + _LOG.exception("peer_lost_hook raised") + + +# --------------------------------------------------------------------------- +# Utility — Remote endpoint lookup (used by KVTaskManager) +# --------------------------------------------------------------------------- +def find_endpoint_for_sd( + cache_config: Any, sd_key: SharingDomainKey, +) -> Any: + """Return ``cache_config.remote_endpoints_by_sd[sd_key.serialize()]`` + or raise ``KeyError`` with a diagnostic message.""" + mapping = getattr(cache_config, "remote_endpoints_by_sd", {}) or {} + key_str = sd_key.serialize() + if key_str not in mapping: + raise KeyError( + f"CacheConfig.remote_endpoints_by_sd is missing an entry for {key_str!r}. " + f"Populate it from the launcher (e.g. sglang connector) before " + f"constructing KVTaskManager." + ) + return mapping[key_str] diff --git a/flexkv/common/dist_reuse/remote_init.py b/flexkv/common/dist_reuse/remote_init.py new file mode 100644 index 0000000000..4069f0dde6 --- /dev/null +++ b/flexkv/common/dist_reuse/remote_init.py @@ -0,0 +1,250 @@ +"""Remote-side bootstrap helper for sharing-domain aware dist_reuse. + +Phase 0 task 0-F: when ``CacheConfig.enable_sharing_domain=True``, a +``TransferManagerOnRemote`` needs to: + +1. Create a per-SD ``RedisMeta`` and register itself as a node under the + ``sd::*`` namespace. +2. Initialize Mooncake TransferEngine and register its local CPU block + pool as a P2P source. +3. Publish its ``(sd_key, distributed_node_id, mooncake_addr, zmq_addr)`` + tuple back to the Master via a :class:`RemoteReadyMsg`. + +This module isolates that logic so ``transfer_manager.py`` stays minimally +intrusive — the existing ``_initialize_with_config`` path only needs to +call :meth:`RemoteDistReuseInitializer.bootstrap` when the config says +sharing-domain is enabled. + +The class has **no hard dependency on Mooncake or the C++ extension** — +both are looked up lazily at runtime, so the module can be imported (and +the type checked) on CPU-only test machines. The only required runtime +dep is redis-py, which is already a peer dep of ``flexkv.cache.redis_meta``. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Callable, Optional + +from flexkv.common.dist_reuse import ( + RemoteReadyMsg, + SharingDomainKey, + SharingDomainNamespace, + encode_coord_message, +) + + +__all__ = [ + "BootstrapResult", + "RemoteDistReuseInitializer", +] + + +_LOG = logging.getLogger("flexkv.dist_reuse.remote_init") + + +# --------------------------------------------------------------------------- +# Result payload +# --------------------------------------------------------------------------- +@dataclass +class BootstrapResult: + """The outcome of a Remote bootstrap. + + ``ready_msg`` is ready to be sent to the Master via its ZMQ result + socket (``encode_coord_message`` applied beforehand if the transport + is JSON-oriented; pickle-based transports can send the dataclass + directly). ``redis_meta`` and ``mooncake_engine`` are held by the + Remote so their lifetime matches the worker process. + """ + + sd_key: SharingDomainKey + namespace: SharingDomainNamespace + redis_meta: Any # RedisMeta — typed as Any to avoid hard import + mooncake_engine: Any # MooncakeTransferEngine — same + distributed_node_id: int + ready_msg: RemoteReadyMsg + + +# --------------------------------------------------------------------------- +# Initializer +# --------------------------------------------------------------------------- +class RemoteDistReuseInitializer: + """Drives the bootstrap sequence for a single Remote node. + + Usage (pseudocode, actual wiring lives in + ``TransferManagerOnRemote._initialize_with_config``):: + + init = RemoteDistReuseInitializer( + cache_config=self.cache_config, + sd_key_str=config_msg["sd_key"], + instance_id=config_msg["instance_id"], + session_epoch=config_msg["session_epoch"], + cpu_block_pool=self._cpu_block_pool, + local_zmq_addr=self._local_zmq_addr, + redis_meta_factory=_make_redis_meta, # Optional + mooncake_engine_factory=_make_mooncake, # Optional + ) + result = init.bootstrap() + self._redis_meta = result.redis_meta + self._mooncake_engine = result.mooncake_engine + self._report_ready_to_master(result.ready_msg) + + Factory callables let the production wiring inject the real + ``RedisMeta`` / ``MooncakeTransferEngine``, while unit tests can pass + in stubs (used by :mod:`tests.test_remote_dist_reuse_initializer`). + """ + + # ---- default factories -------------------------------------------- + @staticmethod + def _default_redis_meta_factory( + cache_config: Any, namespace: SharingDomainNamespace + ) -> Any: + # Lazy import to keep ``flexkv.cache.redis_meta`` out of the + # critical module graph for CPU-only envs. + from flexkv.cache.redis_meta import RedisMeta + + return RedisMeta( + host=cache_config.redis_host, + port=cache_config.redis_port, + password=cache_config.redis_password, + local_ip=cache_config.local_ip, + node_ttl_seconds=cache_config.node_ttl_seconds, + namespace=namespace, + db=int(getattr(cache_config, "flexkv_redis_db", 0)), + ) + + @staticmethod + def _default_mooncake_factory(cache_config: Any) -> Any: + # Lazy import so the Mooncake wheel isn't required at import time. + # The public alias is ``MoonCakeTransferEngineWrapper`` (camelcase). + from flexkv.mooncakeEngineWrapper import MoonCakeTransferEngineWrapper # type: ignore[attr-defined] + + engine = MoonCakeTransferEngineWrapper(cache_config.mooncake_config_path) + return engine + + # ---- lifecycle ---------------------------------------------------- + def __init__( + self, + *, + cache_config: Any, + sd_key_str: str, + instance_id: str, + session_epoch: str, + cpu_buffer_ptr: int, + cpu_buffer_size: int, + local_zmq_addr: str, + redis_meta_factory: Optional[Callable[[Any, SharingDomainNamespace], Any]] = None, + mooncake_engine_factory: Optional[Callable[[Any], Any]] = None, + ) -> None: + self._cache_config = cache_config + self._sd_key: SharingDomainKey = SharingDomainKey.deserialize(sd_key_str) + self._namespace: SharingDomainNamespace = SharingDomainNamespace(self._sd_key) + self._instance_id = str(instance_id) + self._session_epoch = str(session_epoch) + self._cpu_buffer_ptr = int(cpu_buffer_ptr) + self._cpu_buffer_size = int(cpu_buffer_size) + self._local_zmq_addr = str(local_zmq_addr) + self._redis_meta_factory = redis_meta_factory or self._default_redis_meta_factory + self._mooncake_engine_factory = mooncake_engine_factory or self._default_mooncake_factory + + # ---- main entrypoint --------------------------------------------- + def bootstrap(self) -> BootstrapResult: + """Run the three-step sequence; return a :class:`BootstrapResult`. + + Any failure propagates the original exception — callers are + expected to treat that as fatal (the instance is "co-destined" with + its Master; design doc §4.3.1). + """ + # 1. Redis side: register node, get global distributed_node_id. + redis_meta = self._redis_meta_factory(self._cache_config, self._namespace) + node_id = redis_meta.init_meta() + if node_id is None: + err = getattr(redis_meta, "get_init_error", lambda: None)() + raise RuntimeError( + f"[RemoteDistReuseInit:{self._sd_key.serialize()}] " + f"redis init_meta() failed: {err!r}" + ) + _LOG.info( + "[RemoteDistReuseInit:%s] registered as node_id=%d", + self._sd_key.serialize(), + node_id, + ) + + # 2. Mooncake side: init engine, register the CPU block pool buffer, + # publish node meta back to Redis so the Master's peer discovery can + # find us. + mooncake_engine = self._mooncake_engine_factory(self._cache_config) + _init_mooncake_if_needed(mooncake_engine, self._cache_config) + + regist = getattr(mooncake_engine, "regist_buffer", None) + if regist is None: + raise AttributeError("mooncake engine lacks regist_buffer()") + regist(self._cpu_buffer_ptr, self._cpu_buffer_size) + + # Record buffer + node meta in Redis so peers can resolve the + # (nid -> mooncake_addr, zmq_addr, cpu_buffer_ptr) triple. + redis_meta.regist_buffer([{ + "buffer_ptr": self._cpu_buffer_ptr, + "buffer_size": self._cpu_buffer_size, + }]) + mooncake_addr = _safe_call(mooncake_engine, "get_engine_addr", default="") + redis_meta.regist_node_meta( + node_id=node_id, + addr=str(mooncake_addr), + zmq_addr=self._local_zmq_addr, + cpu_buffer_ptr=self._cpu_buffer_ptr, + ssd_buffer_ptr=0, + ) + + # 3. Build the ready message for the Master. + ready_msg = RemoteReadyMsg( + sender_instance_id=self._instance_id, + sender_epoch=self._session_epoch, + request_id=-1, + sd_key=self._sd_key.serialize(), + distributed_node_id=int(node_id), + mooncake_addr=str(mooncake_addr), + zmq_addr=self._local_zmq_addr, + ) + + return BootstrapResult( + sd_key=self._sd_key, + namespace=self._namespace, + redis_meta=redis_meta, + mooncake_engine=mooncake_engine, + distributed_node_id=int(node_id), + ready_msg=ready_msg, + ) + + # ---- encoding -------------------------------------------------- + @staticmethod + def encode_ready(msg: RemoteReadyMsg) -> dict: + """Convenience: turn a :class:`RemoteReadyMsg` into its wire + ``dict`` form (handy when the ZMQ transport prefers JSON).""" + return encode_coord_message(msg) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _init_mooncake_if_needed(engine: Any, cache_config: Any) -> None: + """Call ``engine.init(mooncake_config_path)`` iff the engine exposes + an ``init`` hook. The Mooncake Python wrapper calls its own init in + the constructor, so this is only needed for user-supplied stubs.""" + init_fn = getattr(engine, "init", None) + if callable(init_fn): + path = getattr(cache_config, "mooncake_config_path", None) + if path is not None: + init_fn(path) + + +def _safe_call(obj: Any, method: str, default: Any = None) -> Any: + fn = getattr(obj, method, None) + if fn is None: + return default + try: + return fn() + except Exception as e: + _LOG.warning("%s.%s() raised: %s", type(obj).__name__, method, e) + return default diff --git a/flexkv/common/dist_reuse/sharing_domain.py b/flexkv/common/dist_reuse/sharing_domain.py new file mode 100644 index 0000000000..6cc900ded9 --- /dev/null +++ b/flexkv/common/dist_reuse/sharing_domain.py @@ -0,0 +1,453 @@ +"""Sharing Domain abstraction for distributed KV cache reuse. + +A *sharing domain* (SD) groups together FlexKV instances that hold the **same +KV slice** along two orthogonal dimensions: + +- ``pp_node_idx`` (which physical node within the PP pipeline) — which + layer range +- ``tp_node_idx`` (cross-node TP node index) — which KV head shard + +**CP is not part of the SD key.** Based on the code-fact review (see +``docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp.md`` §1.3 / §2.3 / §4.5): +CP attention does an all-gather before writing back to the local +``token_to_kv_pool``, so every ``cp_rank`` in the same CP group holds +bit-wise identical main KV and (for NSA) indexer K. Cross-instance reuse is +therefore legal at the CP-group granularity — the SD layer does not need +``cp_rank``/``cp_size`` isolation. The CP dimension is handled purely as a +sync_leader control-plane + scatter-based data-plane concern in the +connector layer. + +**Why node-granularity for PP, not rank-granularity?** The KV cache +sharing boundary is the *physical CPU pool* of a node — every PP rank +co-located on the same machine writes into the *same* CPU pool. Keying +the SD by PP rank would (a) inflate the SD count to ``pp_size`` even on +single-node PP=2/4 deployments where no remote-peer SD actually exists, +and (b) make the master coordinator believe two PP ranks on the same +machine are remote peers. Keying by ``pp_node_idx`` collapses +co-located PP ranks into the same SD, so SD count exactly equals the +number of physical nodes participating in this instance's KV sharing. + + ``pp_node_count = max(min(pp_size, nnodes), 1)`` + ``pp_per_node = max(pp_size // pp_node_count, 1)`` + ``pp_node_idx = pp_rank // pp_per_node`` + +Only instances that share the same ``(pp_node_idx, tp_node_idx)`` pair +(plus the same ``model_id`` and ``is_nsa`` layout flag) may participate in +P2P KVCache reuse with each other. See design doc §3.1 / §4.1 / §4.5. + +The serialized form is used as a Redis key namespace prefix: + + sd::ppn/:tpn/:nsa<0|1>:<...> + +``is_nsa`` distinguishes NSA model layouts (extra indexer K cache buffer) +from non-NSA models; it is independent of whether CP is enabled. + +**IMPORTANT — sd_key equality is NOT physical-pool equality.** +Two SDs being equal only means "same SD namespace, eligible to negotiate +KV reuse". It does *not* imply that the underlying CPU pools store the +same physical layer range. In particular, on a single-node PP=2/4 +deployment all PP ranks collapse to ``pp_node_idx=0``, but each rank +physically owns a disjoint ``[pp_start_layer, pp_end_layer)`` slice of +the CPU pool — so cross-rank KV reads inside the same node would return +the wrong layer's bytes. This is why +:func:`SharingDomainKey.from_model_config` raises on +``pp_size > 1 and nnodes == 1`` until the CPU pool is reworked to store +full-layer KV. + +This module is **pure Python** and has no runtime dependency on RedisMeta / +Mooncake / CacheEngine, so it is safe to import from any layer (config, +transfer manager, tests). +""" + +from __future__ import annotations + +import hashlib +import re +from dataclasses import dataclass, replace +from typing import Any, Iterator, List, Optional + + +__all__ = [ + "SharingDomainKey", + "DEFAULT_MODEL_ID", + "derive_model_id", +] + + +# Sentinel used by :meth:`SharingDomainKey.default` to mark the degenerate +# single-SD fallback (when sharing-domain support is disabled). Anything that +# uses ``DEFAULT_MODEL_ID`` is opting out of any cross-instance reuse. +DEFAULT_MODEL_ID = "__default__" + + +# A model_id must be safe to embed in a Redis key. We restrict it to +# ``[A-Za-z0-9_.-]`` so the resulting ``sd:`` prefix never accidentally +# contains the ``:`` separator we use ourselves. +_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9_.\-]+$") + + +def derive_model_id( + *, + num_layers: int, + num_kv_heads: int, + head_size: int, + dtype: Any, + use_mla: bool, +) -> str: + """Build a stable, process-independent ``model_id`` from architecture knobs. + + Two FlexKV instances that produce **physically interchangeable** CPU blocks + (same layer count, same KV head count, same head size, same dtype, same + MLA flag) will derive the same ``model_id`` regardless of process / host. + + The dtype is normalized to its ``str`` form (e.g. ``"torch.bfloat16"``) + so that Python ``hash()`` randomization does not leak in. The result is + a 16-char hex digest, short enough to keep Redis keys compact and long + enough to make collisions a non-issue across realistic deployments. + """ + payload = f"{int(num_layers)}|{int(num_kv_heads)}|{int(head_size)}|{dtype!s}|{int(bool(use_mla))}" + digest = hashlib.sha1(payload.encode("utf-8")).hexdigest() + return digest[:16] + + +def _compute_pp_node_dims(pp_size: int, nnodes: int) -> tuple[int, int]: + """Derive ``(pp_node_count, pp_per_node)`` from a model topology. + + Rules: + + * Single-node deployment (``nnodes == 1``): all PP ranks live on + one machine → ``pp_node_count = 1``, ``pp_per_node = pp_size``. + (For ``pp_size > 1`` the caller is responsible for refusing + dist_reuse — see :meth:`SharingDomainKey.from_model_config`.) + * PP-pipelining stretches across nodes (``pp_size >= nnodes``): + ``pp_node_count = nnodes``, ``pp_per_node = pp_size // nnodes``. + * Cross-node TP only (``pp_size < nnodes``, e.g. ``PP=1 × nnodes=2``): + PP does not span multiple nodes → ``pp_node_count = 1``, + ``pp_per_node = pp_size`` (which equals 1 here in practice). + """ + pp_size = max(int(pp_size), 1) + nnodes = max(int(nnodes), 1) + pp_node_count = max(min(pp_size, nnodes), 1) + pp_per_node = max(pp_size // pp_node_count, 1) + return pp_node_count, pp_per_node + + +@dataclass(frozen=True) +class SharingDomainKey: + """Immutable identifier of a single sharing domain. + + Attributes are validated lazily on construction (see ``__post_init__``). + Comparison / hashing is by value, so two ``SharingDomainKey`` instances + with identical fields are interchangeable as ``dict`` keys. + + Fields: + model_id: topology-derived (or user-set) identifier of the model + architecture. Two instances with the same ``model_id`` produce + physically interchangeable CPU blocks. + pp_node_idx / pp_node_count: index of the physical node this SD + sits on within the pipeline-parallel axis, and the total + number of PP-spanning nodes for this instance. Always + ``0/1`` on a single-node deployment regardless of ``pp_size``. + tp_node_idx / tp_node_count: cross-node TP shard index and count + (always 1 / 1 when TP fits within a single node). + is_nsa: NSA-model layout flag. ``True`` means the model has an + extra indexer K cache buffer (NSA / DeepSeek-V3-sparse-attn); + ``False`` means plain MLA / MHA. Independent of CP. + """ + + model_id: str + pp_node_idx: int + pp_node_count: int + tp_node_idx: int + tp_node_count: int + is_nsa: bool + + # ------------------------------------------------------------------ + # Construction helpers + # ------------------------------------------------------------------ + def __post_init__(self) -> None: + if not isinstance(self.model_id, str) or not self.model_id: + raise ValueError(f"SharingDomainKey.model_id must be a non-empty str, got {self.model_id!r}") + if not _MODEL_ID_RE.match(self.model_id): + raise ValueError( + f"SharingDomainKey.model_id {self.model_id!r} contains forbidden characters; " + f"allowed: [A-Za-z0-9_.-]" + ) + + for name, val in ( + ("pp_node_count", self.pp_node_count), + ("tp_node_count", self.tp_node_count), + ): + if not isinstance(val, int) or val < 1: + raise ValueError(f"SharingDomainKey.{name} must be int>=1, got {val!r}") + + for name, idx, count in ( + ("pp_node_idx", self.pp_node_idx, self.pp_node_count), + ("tp_node_idx", self.tp_node_idx, self.tp_node_count), + ): + if not isinstance(idx, int) or idx < 0: + raise ValueError(f"SharingDomainKey.{name} must be int>=0, got {idx!r}") + if idx >= count: + raise ValueError( + f"SharingDomainKey.{name}={idx} out of range for " + f"{name.replace('idx', 'count')}={count}" + ) + + if not isinstance(self.is_nsa, bool): + raise ValueError(f"SharingDomainKey.is_nsa must be bool, got {self.is_nsa!r}") + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + def serialize(self) -> str: + """Stable string form, usable as a Redis namespace prefix. + + Format:: + + :ppn/:tpn/:nsa<0|1> + + The ``sd:`` prefix is **not** included here — it is added by + :class:`SharingDomainNamespace` when building actual Redis keys. + + Note: CP is intentionally absent from this key. See module docstring + and design doc §4.5 for the rationale (CP all-gather makes all + ``cp_rank`` pool contents bit-wise identical within a CP group). + """ + nsa = 1 if self.is_nsa else 0 + return ( + f"{self.model_id}" + f":ppn{self.pp_node_idx}/{self.pp_node_count}" + f":tpn{self.tp_node_idx}/{self.tp_node_count}" + f":nsa{nsa}" + ) + + @classmethod + def deserialize(cls, s: str) -> "SharingDomainKey": + """Parse the canonical form produced by :meth:`serialize`. + + Raises ``ValueError`` on malformed input. Round-trip with + :meth:`serialize` is guaranteed. + """ + # Split on ':' but the model_id itself is forbidden from containing + # ':' (see _MODEL_ID_RE), so the leading segment is unambiguous. + parts = s.split(":") + if len(parts) != 4: + raise ValueError( + f"SharingDomainKey.deserialize: expected 4 ':'-separated segments, got {len(parts)} in {s!r}" + ) + model_id, ppn_part, tpn_part, nsa_part = parts + + pp_node_idx, pp_node_count = cls._parse_idx_count(ppn_part, "ppn") + tp_node_idx, tp_node_count = cls._parse_idx_count(tpn_part, "tpn") + + if not nsa_part.startswith("nsa") or nsa_part[3:] not in ("0", "1"): + raise ValueError(f"SharingDomainKey.deserialize: bad nsa segment {nsa_part!r} in {s!r}") + is_nsa = nsa_part[3:] == "1" + + return cls( + model_id=model_id, + pp_node_idx=pp_node_idx, + pp_node_count=pp_node_count, + tp_node_idx=tp_node_idx, + tp_node_count=tp_node_count, + is_nsa=is_nsa, + ) + + @staticmethod + def _parse_idx_count(seg: str, prefix: str) -> tuple[int, int]: + if not seg.startswith(prefix): + raise ValueError(f"SharingDomainKey: segment {seg!r} does not start with {prefix!r}") + body = seg[len(prefix):] + if "/" not in body: + raise ValueError(f"SharingDomainKey: segment {seg!r} missing '/' between idx and count") + idx_s, count_s = body.split("/", 1) + try: + return int(idx_s), int(count_s) + except ValueError as e: + raise ValueError(f"SharingDomainKey: cannot parse idx/count from {seg!r}: {e}") from e + + # ------------------------------------------------------------------ + # Factories + # ------------------------------------------------------------------ + @classmethod + def default(cls) -> "SharingDomainKey": + """The degenerate single-SD fallback used when + ``CacheConfig.enable_sharing_domain`` is False. + + All dimensions collapse to count 1 / idx 0, ``is_nsa=False`` and + ``model_id=DEFAULT_MODEL_ID``. Two instances opting into the default + SD always reuse with each other regardless of model topology (matches + the legacy single-instance dist_reuse semantics). + """ + return cls( + model_id=DEFAULT_MODEL_ID, + pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + + @classmethod + def from_model_config( + cls, + model_config: Any, + *, + rank_info: Any = None, + overrides: Optional[dict] = None, + ) -> "SharingDomainKey": + """Derive the SD key for *this node* from a ``ModelConfig``. + + Per-rank physical fields (``pp_rank``, ``tp_node_idx``) + historically lived on ``ModelConfig`` but were moved into + ``RankInfo`` by the RankInfo refactor (PR #165). Callers that + have a ``rank_info`` should pass it as the keyword argument; in + that case ``pp_rank`` / ``tp_node_idx`` come from ``rank_info`` + and the per-cluster fields (``pp_size``, ``nnodes``, + ``tp_node_count``, ``model_id``, ``is_nsa``) from + ``model_config``. When ``rank_info`` is ``None`` we fall back + to reading the per-rank fields from ``model_config`` itself + (backwards-compatible for legacy stubs in unit tests and the + older ModelConfig layout). + + The PP-axis is collapsed from rank-granularity to + node-granularity:: + + pp_node_count = max(min(pp_size, nnodes), 1) + pp_per_node = max(pp_size // pp_node_count, 1) + pp_node_idx = pp_rank // pp_per_node + + which means PP ranks co-located on the same physical node share + the same SD key (same Redis namespace). See module docstring + for why. + + ``overrides`` lets a Master node craft an SD key for a Remote + node on-the-fly (e.g. setting ``pp_node_idx=1`` while the Master + itself is on ``pp_node_idx=0``). Only the six dataclass fields are + valid keys. + + ``is_nsa`` is read from ``model_config.is_nsa``. + + Raises: + ValueError: if ``pp_size > 1 and nnodes == 1`` — single-node + PP>1 deployments cannot participate in dist_reuse until + the CPU pool is reworked to store full-layer KV (see + module docstring). This guard is **only** active when + the caller is opting into a sharing domain (i.e. went + through this factory); the + :meth:`SharingDomainKey.default` fallback path is + untouched. + """ + pp_size = max(int(getattr(model_config, "pp_size", 1)), 1) + nnodes = max(int(getattr(model_config, "nnodes", 1)), 1) + + if pp_size > 1 and nnodes == 1: + raise ValueError( + "SharingDomainKey.from_model_config: dist_reuse is not supported " + f"on a single-node PP>1 deployment (pp_size={pp_size}, nnodes=1). " + "Each PP rank physically owns only a slice of the CPU pool layers, " + "so a single sd_key would alias incompatible layer shards. " + "Either disable enable_sharing_domain, scale to nnodes>=pp_size, " + "or wait for the full-layer CPU pool refactor." + ) + + pp_node_count, pp_per_node = _compute_pp_node_dims(pp_size, nnodes) + + # Per-rank fields: prefer ``rank_info`` (post-refactor source of + # truth); fall back to ``model_config`` for legacy callers. + if rank_info is not None: + _pp_rank = int(getattr(rank_info, "pp_rank", 0)) + _tp_node_idx = int( + getattr(rank_info, "tp_node_idx", + getattr(model_config, "tp_node_idx", 0)) + ) + else: + _pp_rank = int(getattr(model_config, "pp_rank", 0)) + _tp_node_idx = int(getattr(model_config, "tp_node_idx", 0)) + + _pp_node_idx = _pp_rank // pp_per_node + + kwargs: dict = { + "model_id": _resolve_model_id(model_config), + "pp_node_idx": _pp_node_idx, + "pp_node_count": pp_node_count, + "tp_node_idx": _tp_node_idx, + "tp_node_count": int(getattr(model_config, "tp_node_count", 1)), + "is_nsa": bool(_resolve_is_nsa(model_config)), + } + if overrides: + unknown = set(overrides) - set(kwargs) + if unknown: + raise ValueError( + f"SharingDomainKey.from_model_config: unknown override keys {sorted(unknown)}" + ) + kwargs.update(overrides) + return cls(**kwargs) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def is_master(self) -> bool: + """The Master role of an instance is the SD with both node-idx dims = 0. + + Note: CP is not part of the SD key, so whether a physical node is + the *sync_leader* (cp_rank=0 within a CP group) is determined at the + connector layer, not here. This method only tells you whether this + SD is the (ppn=0, tpn=0) SD of its instance. + """ + return self.pp_node_idx == 0 and self.tp_node_idx == 0 + + def total_sd_count(self) -> int: + """Total number of SDs per instance (= ``pp_node_count × tp_node_count``). + + This equals the number of *physical nodes* spanned by the + instance's KV-sharing topology, which is at most ``nnodes`` for + any deployment. CP does not multiply this. + """ + return self.pp_node_count * self.tp_node_count + + def enumerate_peers(self) -> List["SharingDomainKey"]: + """Enumerate every SD belonging to the *same instance* as ``self``. + + Order is deterministic: outer loop ``pp_node_idx``, inner loop + ``tp_node_idx``. Always returns ``self.total_sd_count()`` items. + + CP dimension is not enumerated here — see module docstring for why. + """ + out: List["SharingDomainKey"] = [] + for ppn in range(self.pp_node_count): + for tpn in range(self.tp_node_count): + out.append(replace(self, pp_node_idx=ppn, tp_node_idx=tpn)) + return out + + # Iteration sugar — mostly for tests. + def __iter__(self) -> Iterator["SharingDomainKey"]: + return iter(self.enumerate_peers()) + + def __str__(self) -> str: # pragma: no cover — purely cosmetic + return self.serialize() + + +def _resolve_model_id(model_config: Any) -> str: + """Pick up an explicit ``model_id`` if the user set one, otherwise derive.""" + explicit = getattr(model_config, "model_id", None) + if isinstance(explicit, str) and explicit: + return explicit + # Fall back to topology-only digest. Important: do *not* depend on + # tp_size/pp_size/cp_size here — those are dimensions of the SD key + # itself, not of the underlying model. + return derive_model_id( + num_layers=int(getattr(model_config, "num_layers", 1)), + num_kv_heads=int(getattr(model_config, "num_kv_heads", 1)), + head_size=int(getattr(model_config, "head_size", 1)), + dtype=getattr(model_config, "dtype", "unknown"), + use_mla=bool(getattr(model_config, "use_mla", False)), + ) + + +def _resolve_is_nsa(model_config: Any) -> bool: + """Read the NSA layout flag from the model config. + + Only the ``is_nsa`` attribute is consulted. Defaults to ``False`` when + the attribute is absent so that non-NSA configs (older vllm / sglang + branches that never knew about NSA) continue to produce valid SD keys. + """ + return bool(getattr(model_config, "is_nsa", False)) diff --git a/flexkv/common/dist_reuse/sharing_domain_namespace.py b/flexkv/common/dist_reuse/sharing_domain_namespace.py new file mode 100644 index 0000000000..618a3ea664 --- /dev/null +++ b/flexkv/common/dist_reuse/sharing_domain_namespace.py @@ -0,0 +1,185 @@ +"""Centralized Redis key layout for the sharing-domain era. + +Every Redis key produced by FlexKV's distributed metadata layer is required +to flow through this module. Keeping the formatting in one place lets us +audit the namespace ``sd::*`` invariant (design doc §4.7) and avoid +the kind of prefix drift the legacy ``CPUB:`` / ``SSDB:`` / ``PCFSB:`` key +families suffered from. + +Two layers of keys live here: + +1. **Per-SD keys** — anything tied to a single sharing domain + (``node:`` / ``meta:`` / ``buffer:`` / ``block:`` / ``aggregate:``). +2. **Cross-SD instance keys** — discovery & failure detection + (``flexkv:instance::session`` and ``flexkv:instance::sd_nodes``). + These are not parameterized by an SD because they describe a whole + FlexKV instance, which spans every SD it owns. + +Phase 0 task 0-B in ``docs/dist_reuse/plan.md``. +""" + +from __future__ import annotations + +import re +from typing import Final + +from .sharing_domain import SharingDomainKey + + +__all__ = [ + "SharingDomainNamespace", + "INSTANCE_KEY_PREFIX", + "SD_KEY_PREFIX", +] + + +SD_KEY_PREFIX: Final[str] = "sd" +INSTANCE_KEY_PREFIX: Final[str] = "flexkv:instance" + +# Hash digests are non-negative; keep an explicit type guard for callers +# that may pass arbitrary Python ints (which can be negative). +_INSTANCE_ID_RE = re.compile(r"^[A-Za-z0-9_.\-]+$") + + +class SharingDomainNamespace: + """Builds Redis keys for one sharing domain. + + A namespace is **immutable after construction** — the SD key it wraps is + a frozen dataclass and we never mutate the cached prefix. + """ + + __slots__ = ("_sd_key", "_serialized", "_prefix") + + def __init__(self, sd_key: SharingDomainKey) -> None: + if not isinstance(sd_key, SharingDomainKey): + raise TypeError( + f"SharingDomainNamespace expects a SharingDomainKey, got {type(sd_key).__name__}" + ) + self._sd_key: SharingDomainKey = sd_key + self._serialized: str = sd_key.serialize() + # Cache the full ``sd:`` prefix to avoid string concat on the + # hot path (every block insert / publish hits a key builder). + self._prefix: str = f"{SD_KEY_PREFIX}:{self._serialized}" + + # ------------------------------------------------------------------ + # Identity + # ------------------------------------------------------------------ + @property + def sd_key(self) -> SharingDomainKey: + return self._sd_key + + @property + def serialized_sd(self) -> str: + """Return the bare ``sd_key`` string without the ``sd:`` prefix.""" + return self._serialized + + @property + def prefix(self) -> str: + """``sd:`` — the common prefix of every per-SD key.""" + return self._prefix + + # ------------------------------------------------------------------ + # Per-SD keys + # ------------------------------------------------------------------ + def node_key(self, node_id: int) -> str: + return f"{self._prefix}:node:{int(node_id)}" + + def meta_key(self, node_id: int) -> str: + return f"{self._prefix}:meta:{int(node_id)}" + + def buffer_key(self, node_id: int, buffer_ptr: int) -> str: + return f"{self._prefix}:buffer:{int(node_id)}:{int(buffer_ptr)}" + + def block_key(self, node_id: int, block_hash: int) -> str: + """Per-block metadata key. + + ``block_hash`` is rendered as **lowercase hex without 0x prefix** + because the C++ ``RedisMetaChannel::make_block_key`` will be + retrofitted (Phase 0 task 0-D) to format it the same way. We accept + both signed and unsigned 64-bit hashes by masking to 64 bits first. + """ + h = int(block_hash) & 0xFFFFFFFFFFFFFFFF + return f"{self._prefix}:block:{int(node_id)}:{h:x}" + + def aggregate_key(self, request_prefix_hash: int) -> str: + """Aggregate-radix marker (design doc §4.7) for tracking + fully-ready prefixes across SDs in this instance.""" + h = int(request_prefix_hash) & 0xFFFFFFFFFFFFFFFF + return f"{self._prefix}:aggregate:{h:x}" + + # ------------------------------------------------------------------ + # SCAN-friendly patterns + # ------------------------------------------------------------------ + def node_key_pattern(self) -> str: + return f"{self._prefix}:node:*" + + def meta_key_pattern(self) -> str: + return f"{self._prefix}:meta:*" + + def buffer_key_pattern(self) -> str: + return f"{self._prefix}:buffer:*" + + def block_key_pattern(self) -> str: + """Match every block in the SD regardless of node_id. Used by the + global-SCAN optimization in design doc §4.7.1.2.""" + return f"{self._prefix}:block:*" + + def block_key_pattern_for_node(self, node_id: int) -> str: + """Per-node block SCAN pattern (legacy path; the global pattern + above is preferred).""" + return f"{self._prefix}:block:{int(node_id)}:*" + + # ------------------------------------------------------------------ + # Cross-SD (instance-level) keys — static helpers + # ------------------------------------------------------------------ + @staticmethod + def instance_session_key(instance_id: str) -> str: + """Layer-1 failure-detector heartbeat key (design doc §4.3.2).""" + SharingDomainNamespace._validate_instance_id(instance_id) + return f"{INSTANCE_KEY_PREFIX}:{instance_id}:session" + + @staticmethod + def instance_sd_nodes_key(instance_id: str) -> str: + """``sd_key → node_id`` mapping written once per instance startup + (design doc §4.7.1.6).""" + SharingDomainNamespace._validate_instance_id(instance_id) + return f"{INSTANCE_KEY_PREFIX}:{instance_id}:sd_nodes" + + @staticmethod + def instance_session_key_pattern() -> str: + return f"{INSTANCE_KEY_PREFIX}:*:session" + + @staticmethod + def parse_instance_session_key(key: str) -> str: + """Extract ``instance_id`` from a session key. Raises ``ValueError`` + if the key does not match the expected layout.""" + prefix = f"{INSTANCE_KEY_PREFIX}:" + suffix = ":session" + if not key.startswith(prefix) or not key.endswith(suffix): + raise ValueError(f"Not a flexkv instance session key: {key!r}") + instance_id = key[len(prefix):-len(suffix)] + SharingDomainNamespace._validate_instance_id(instance_id) + return instance_id + + @staticmethod + def _validate_instance_id(instance_id: str) -> None: + if not isinstance(instance_id, str) or not instance_id: + raise ValueError(f"instance_id must be a non-empty str, got {instance_id!r}") + if not _INSTANCE_ID_RE.match(instance_id): + raise ValueError( + f"instance_id {instance_id!r} contains forbidden characters; allowed [A-Za-z0-9_.-]" + ) + + # ------------------------------------------------------------------ + # Equality / hashing — useful for caching namespaces in dicts + # ------------------------------------------------------------------ + def __eq__(self, other: object) -> bool: + if not isinstance(other, SharingDomainNamespace): + return NotImplemented + return self._sd_key == other._sd_key + + def __hash__(self) -> int: + return hash(self._sd_key) + + def __repr__(self) -> str: # pragma: no cover + return f"SharingDomainNamespace({self._serialized!r})" diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index e5fff581fc..e0a81921bf 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -241,8 +241,15 @@ def _import_tensor_handle( return tensor except Exception as e: + # NOTE: previously this returned ``torch.empty(0)`` on failure, + # which silently dropped the wrapper into a 0-element tensor and + # surfaced as an *unrelated* IndexError later in + # ``worker.py::_get_layer_ptrs`` (``layer_blocks[0]`` out of range) + # — making it nearly impossible to root-cause the real failure + # (e.g. cross-node CUDA IPC handle device-id mismatch). Always + # propagate the real exception so the original traceback is kept. flexkv_logger.error("Import tensor handle failed: %s", e) - return torch.empty(0) + raise @staticmethod def _create_tensor_from_cuda_ptr( diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 0f4378e5de..f640b6c2d7 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -23,6 +23,27 @@ class CompletedOp: transfer_type: Optional[str] = None num_blocks: int = 0 num_bytes: int = 0 + # Phase D-1 (proposal_unify_with_graph_dispatch_2026-05-15.md §3.5): + # SD identity tag — set by ``TransferManagerOnRemote._handle_submit`` + # after rebinding the inbound graph, or implicitly equal to the + # Master's own SD when the op was scheduled in-process. Master's + # polling worker uses this to route per-SD completion → ``mark_sd_ready``. + # Empty string ("") means "Master's own SD" (legacy default). + sd_key: str = "" + # Phase D-1: distributed_node_id of the node that physically completed + # this op. For Remotes this is their own ``self_node_id``; for Master + # it's the Master's own ``self_node_id``. ``-1`` is the sentinel + # "unknown" used by callers that don't care (e.g. legacy single-SD + # H2D on the master path). + contributing_node_id: int = -1 + # Phase D-1: success flag. ``False`` indicates the worker / D2H / + # Mooncake transfer failed. Master uses this to decide whether to + # mark the SD ready or escalate to invalidate. Default ``True`` so + # legacy CompletedOp construction (which never set it) keeps the + # historical "completed = success" semantics. + success: bool = True + # Phase D-1: optional opaque error text when ``success=False``. + error: Optional[str] = None def is_graph_completed(self) -> bool: return self.op_id == -1 @@ -106,6 +127,9 @@ class TransferOp: # used for distributed cpu and ssd src_block_node_ids: Optional[np.ndarray] = None pending_count: int = 0 + target_node_ids: Optional[List[int]] = None + block_hashes: Optional[List[int]] = None + def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ @@ -442,6 +466,13 @@ def merge_to_batch_graph(batch_id: int, callbacks_by_type: Dict[TransferType, List[Callable]] = {} supported_types = {TransferType.DISK2H, TransferType.H2D, TransferType.D2H, TransferType.H2DISK} + # P2P transfer types are *not* merged but passed through individually, + # because each op carries its own ``src_block_node_ids`` (per-block peer + # routing) and possibly distinct ``target_node_ids`` (D-3 multi-SD + # broadcast clones). Merging would silently drop these fields. + passthrough_types = {TransferType.PEERH2H, TransferType.H2PEERH, + TransferType.PEERSSD2H, TransferType.H2PEERSSD} + passthrough_ops: List[Tuple[TransferOp, Optional[Callable]]] = [] for tt in supported_types: ops_by_type[tt] = [] @@ -451,10 +482,17 @@ def merge_to_batch_graph(batch_id: int, for op_id, op in graph._op_map.items(): if op.transfer_type == TransferType.VIRTUAL: continue + if op.transfer_type in passthrough_types: + # P2P ops: keep one-by-one, preserving src_block_node_ids / + # target_node_ids / remote_node_ids etc. + cb = op_callback_dict.get(op.op_id) + passthrough_ops.append((op, cb)) + continue if op.transfer_type not in supported_types: raise NotImplementedError( f"Batch merge does not support transfer type: {op.transfer_type}. " - f"Only DISK2H, H2D, D2H, and H2DISK are supported." + f"Only DISK2H, H2D, D2H, H2DISK and P2P (PEERH2H, H2PEERH, " + f"PEERSSD2H, H2PEERSSD) types are supported." ) ops_by_type[op.transfer_type].append(op) if op.op_id in op_callback_dict: @@ -522,6 +560,43 @@ def merge_to_batch_graph(batch_id: int, else: batch_end_op_id = -1 + # ----- P2P passthrough (after main merge / layerwise emit) ----- + # PEERH2H / H2PEERH / PEERSSD2H / H2PEERSSD are added to the merged + # graph one-by-one (NOT merged) and made a predecessor of the H2D op + # so the GPU upload waits for the peer-fetched data to land in the + # local CPU / SSD pool first. This preserves per-op metadata + # (src_block_node_ids, target_node_ids, remote_node_ids). + if not layerwise_transfer: + for op, cb in passthrough_ops: + # ``add_transfer_op`` re-stamps op.graph_id to merged_graph's id. + merged_graph.add_transfer_op(op) + if cb is not None: + new_op_callback_dict[op.op_id] = cb + # GET-side P2P (PEERH2H / PEERSSD2H) must precede H2D so the + # data is in CPU / SSD pool before being uploaded to GPU. + if (merged_h2d_op is not None and + op.transfer_type in (TransferType.PEERH2H, + TransferType.PEERSSD2H)): + merged_graph.add_dependency(merged_h2d_op.op_id, op.op_id) + # PUT-side P2P (H2PEERH / H2PEERSSD) must follow D2H so the + # data is already in CPU pool before being shipped to peers. + if (merged_d2h_op is not None and + op.transfer_type in (TransferType.H2PEERH, + TransferType.H2PEERSSD)): + merged_graph.add_dependency(op.op_id, merged_d2h_op.op_id) + else: + # In layerwise mode the GET path collapses into a single + # LayerwiseTransferOp; we still need to passthrough P2P ops and + # have the layerwise op wait on them, but the current layerwise + # path does not surface a separate H2D op object to depend on. + # Defer support for layerwise + P2P until that path is exercised. + if passthrough_ops: + raise NotImplementedError( + "Batch merge: layerwise + P2P passthrough is not yet " + "implemented; got " + f"{[op.transfer_type for op, _ in passthrough_ops]}" + ) + return merged_graph, batch_end_op_id, new_op_callback_dict diff --git a/flexkv/common/type.py b/flexkv/common/type.py index 8b893f2eb5..25f17d2d93 100644 --- a/flexkv/common/type.py +++ b/flexkv/common/type.py @@ -11,9 +11,13 @@ class MatchResultAccel: last_node: Optional['CRadixNode'] = None last_node_matched_length: int = 0 physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + # Single node_id for all matched blocks (single-node matching constraint). + # -1 means no remote match. Preferred over the deprecated per-block arrays. + matched_node_id: Optional[int] = None + # deprecated: kept for backward compat; prefer matched_node_id block_node_ids: Optional[np.ndarray] = None matched_pos: Optional[str] = None - matched_node_ids: Optional[np.ndarray] = None #TODO id or ids? should we allow one req match results on multiple nodes? + matched_node_ids: Optional[np.ndarray] = None # deprecated: prefer matched_node_id insert_to_local_cpu_index: bool = True def __post_init__(self) -> None: diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index cc2f98bda2..c8055ea021 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -218,8 +218,8 @@ def post_init_from_sglang_config( sglang_config: sglang.srt.configs.model_config.ModelConfig-like object server_args: sglang ServerArgs — source of tp_size, dp_size, nnodes, node_rank, enable_dp_attention, attn_cp_size, - kv_cache_dtype, - dist_init_addr + is_nsa (read from server_args.enable_nsa_prefill_context_parallel), + kv_cache_dtype, dist_init_addr page_size: KV block size (tokens per block) used by sglang tp_rank: physical tensor parallel rank (runtime, from process group) pp_rank: pipeline parallel rank (runtime, from process group) @@ -234,6 +234,11 @@ def post_init_from_sglang_config( node_rank = server_args.node_rank enable_dp_attention = server_args.enable_dp_attention attn_cp_size = server_args.attn_cp_size + # ``is_nsa`` (NSA model layout flag): True when the model has an + # extra indexer K cache buffer. Sourced from sglang's + # ``enable_nsa_prefill_context_parallel`` server arg, but in dist_reuse + # context the flag represents the *layout*, not whether CP is on. + is_nsa = getattr(server_args, 'enable_nsa_prefill_context_parallel', False) kv_cache_dtype = getattr(server_args, 'kv_cache_dtype', None) dp_rank = 0 if dp_rank is None else int(dp_rank) @@ -327,6 +332,7 @@ def post_init_from_sglang_config( pp_end_layer = self.model_config.num_layers self.model_config.enable_dp_attention = bool(enable_dp_attention) self.model_config.attn_cp_size = int(attn_cp_size) + self.model_config.is_nsa = is_nsa self.model_config.nnodes = max(1, int(nnodes)) _dist_init_addr = getattr(server_args, 'dist_init_addr', None) if _dist_init_addr and int(nnodes) > 1: diff --git a/flexkv/integration/multinode_policy.py b/flexkv/integration/multinode_policy.py new file mode 100644 index 0000000000..ebe5bdbc88 --- /dev/null +++ b/flexkv/integration/multinode_policy.py @@ -0,0 +1,185 @@ +"""Multi-node role decision helpers for the sglang ↔ FlexKV connector. + +Design doc §4.5.5 splits ``is_multinode`` into **two independent axes**: + +* ``is_multinode_tp`` — one TP group spans >1 physical node. + Each such node runs a full SD-Remote (``TransferManagerOnRemote``) + with its own RedisMeta + Mooncake registration. + +* ``is_multinode_cp`` — CP > 1 and the CP group spans >1 physical node. + CP all-gather makes every ``cp_rank``'s KV pool bit-wise identical, + so non-leader CP ranks **do not** run a full SD-Remote; they only + need a GPU-registration stub (``KVTPClient``) + receive coordinated + H2D commands routed by the sync-leader rank. + +The master connector (``flexkv_connector.py``) currently conflates the +two under a ``nnodes > 1 and node_rank > 0 and local_rank == 0`` rule +of thumb. This module provides the policy functions that the +connector **should** call once we can exercise cross-node boots on a +two-machine GPU setup (tracked as §2.4 in +``docs/dist_reuse/implementation_gap_2026-05-11.md``). + +Everything here is pure Python / pure logic so it is unit-testable +without torch, CUDA, or a running sglang process. See +``FlexKV/tests/test_multinode_role_policy.py``. +""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class RemoteProcessRole(str, Enum): + """What, if anything, should this rank's local ``FlexKVConnector`` + spawn as its ``TransferManagerOnRemote`` process? + + The three roles correspond to the three paths in the design doc: + + * ``MASTER``: this rank is the sync-leader of an instance. It runs + the full ``KVManager`` (owns the Master coordinator, writes to + Redis, owns Mooncake TransferEngine). No ``TransferManagerOnRemote`` + process is spawned — the master IS the transfer authority. + + * ``SD_REMOTE_FULL``: this rank sits on a non-master node of a + *cross-node TP/PP* group. It must spawn a full + ``TransferManagerOnRemote`` that: + + - registers its local CPU block pool with Mooncake, + - writes its ``sd::node/block`` entries to Redis, + - replies ``RemoteReadyMsg`` so the Master can discover it, + - serves coordinated GET/PUT commands from the Master. + + * ``CP_PEER_REGISTRATION_ONLY``: this rank is a non-leader CP rank. + The design doc §4.5.5 + §4.12.2 (5) says these ranks only need + a lightweight ``TransferManagerOnRemote`` *stub* that: + + - registers its GPU blocks with ``KVTPClient`` (so the sync + leader's H2D path can write to it), AND + - listens for coordinated H2D slot-mapping commands. + + It does **NOT** touch RedisMeta or Mooncake — CP all-gather + already makes this rank's content identical to the sync leader. + + * ``NO_REMOTE``: single-node instance — no ``TransferManagerOnRemote`` + spawn at all. Legacy single-box behaviour. + """ + MASTER = "master" + SD_REMOTE_FULL = "sd_remote_full" + CP_PEER_REGISTRATION_ONLY = "cp_peer_registration_only" + NO_REMOTE = "no_remote" + + +@dataclass(frozen=True) +class RankTopology: + """Topology facts about a single rank, as seen from the connector. + + All fields are scalars so this is trivially serialisable / hashable + for tests. The connector extracts these from ``ModelConfig`` + + ``server_args``; tests construct them directly. + """ + + # Core dimensions + nnodes: int + node_rank: int + local_rank: int # 0..gpus_per_node-1 + + # Rolled-up FlexKV topology (see ModelConfig docstring) + is_multinode_tp: bool # ``tp_node_count > 1`` + is_multinode_cp: bool # CP > 1 and CP crosses node boundary + + # Optional sync-leader hint. If the caller already knows whether + # this rank is the sync leader (from ``sglang`` group metadata), + # pass ``is_sync_leader``. Otherwise the default heuristic kicks + # in: ``(local_rank == 0 and node_rank == 0)``. + is_sync_leader: Optional[bool] = None + + +def decide_remote_role(topo: RankTopology) -> RemoteProcessRole: + """Compute the role of a rank. + + Decision table (see design doc §4.5.5, simplified): + + =================== ================ ================ ================= + Single-node? is_multinode_tp is_multinode_cp Role + =================== ================ ================ ================= + yes (nnodes == 1) (ignored) (ignored) NO_REMOTE + no, rank 0 box False False NO_REMOTE + no, rank 0 box any any MASTER + no, off-master box is_multinode_tp any SD_REMOTE_FULL + no, off-master box False is_multinode_cp CP_PEER_REGISTRATION_ONLY + =================== ================ ================ ================= + + Where "rank 0 box" means ``node_rank == 0``. Across both + ``is_multinode_tp`` and ``is_multinode_cp`` axes we place the + Master on ``node_rank==0`` by convention (this matches the current + ``flexkv_connector.py`` assumption that sync leader is + ``node_rank==0``). + + Note on CP + TP combined: when BOTH ``is_multinode_tp`` and + ``is_multinode_cp`` are True for an off-master rank, it runs + ``SD_REMOTE_FULL`` — the TP-side state requires a full SD-Remote; + CP-side reduction is handled *inside* that remote's sync leader + (same as ``is_multinode_tp=True, is_multinode_cp=False``). We + never downgrade a TP-remote to a CP-peer-only stub. + """ + _validate(topo) + + # Single-node instance: nothing to spawn. + if topo.nnodes <= 1: + return RemoteProcessRole.NO_REMOTE + + # Master node — spawn nothing, the in-process KVManager IS the + # transfer authority. + if topo.node_rank == 0: + if not topo.is_multinode_tp and not topo.is_multinode_cp: + # Multi-node deployment but THIS instance spans only one + # node — e.g. DP > 1 across nodes but each DP instance is + # single-node. No remote peer exists in this instance. + return RemoteProcessRole.NO_REMOTE + return RemoteProcessRole.MASTER + + # Off-master nodes — + # TP takes priority: a TP-split SD cannot be served by a CP-only stub. + if topo.is_multinode_tp: + return RemoteProcessRole.SD_REMOTE_FULL + + if topo.is_multinode_cp: + return RemoteProcessRole.CP_PEER_REGISTRATION_ONLY + + # Off-master but neither TP nor CP is multi-node. Today's legacy + # code treats this the same as CP-peer (it spawns a + # ``TransferManagerOnRemote`` on every non-master node when + # ``nnodes > 1``). We preserve that behaviour for bug-compat + # during the migration; the ideal long-term answer is NO_REMOTE, + # but flipping it here would break the existing code path that + # the multi-node PP (``is_multinode_tp=False`` but + # ``pp_size>1`` crossing nodes) relies on. + # + # TODO(dist_reuse-§2.4): revisit once ``is_multinode_pp`` has its + # own property on ModelConfig. + return RemoteProcessRole.SD_REMOTE_FULL + + +def is_sync_leader(topo: RankTopology) -> bool: + """Heuristic used when the caller hasn't provided ``is_sync_leader``. + + Today's ``flexkv_connector.py`` infers it as ``local_rank == 0 and + node_rank == 0``. We keep that rule so drop-in replacement is + byte-for-byte equivalent; callers that know better pass + ``is_sync_leader`` explicitly on the ``RankTopology``. + """ + if topo.is_sync_leader is not None: + return bool(topo.is_sync_leader) + return topo.node_rank == 0 and topo.local_rank == 0 + + +def _validate(topo: RankTopology) -> None: + if topo.nnodes <= 0: + raise ValueError(f"nnodes must be > 0, got {topo.nnodes}") + if not (0 <= topo.node_rank < topo.nnodes): + raise ValueError( + f"node_rank out of range: {topo.node_rank} / {topo.nnodes}" + ) + if topo.local_rank < 0: + raise ValueError(f"local_rank must be >= 0, got {topo.local_rank}") diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 39ecd1c50f..f84ca86d18 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -96,6 +96,7 @@ def __init__(self, self.cache_config.redis_password, self.cache_config.local_ip, node_ttl_seconds=self.cache_config.node_ttl_seconds, + db=int(getattr(self.cache_config, "flexkv_redis_db", 0)), ) self.redis_meta_client.init_meta() # update distributed_node_id diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 12e1c3e1ea..cd4e16dae8 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -136,6 +136,16 @@ def __init__(self, )) self.transfer_handles[-1]._handle.send_config_to_remotes() + # Phase 0 task 0-G: when sharing-domain is on, replace / augment the + # handle list with one handle per SD (Master SD + every peer SD's + # Remote). This runs after the legacy construction above so that + # legacy single-Remote and sharing-domain paths stay independent — + # the legacy branch leaves ``self.transfer_handles`` untouched when + # ``enable_sharing_domain`` is False. + self._master_coordinator = None + if getattr(self.cache_config, "enable_sharing_domain", False): + self._setup_sharing_domain_handles(gpu_register_port=gpu_register_port) + self.tasks: ExpiringDict[int, KVTask] = ExpiringDict(max_age_seconds=1800, max_len=100000) # 30 minutes # hash(token_ids) -> task_id @@ -173,6 +183,203 @@ def shutdown(self) -> None: self.remote_process.join() self.remote_process.close() self.remote_process = None + # Phase 0 task 0-G: tear down sharing-domain background threads. + if getattr(self, "_master_coordinator", None) is not None: + try: + # Batch-F: drop cache_engine's references before the + # coordinator goes away so any in-flight completion + # callback (_on_peer_sd_completed_op / + # handle_failure_report) sees a clean None instead of + # a half-torn-down object. + if hasattr(self, "cache_engine") and self.cache_engine is not None: + try: + self.cache_engine.detach_dist_reuse() + except Exception: + pass + self._master_coordinator.shutdown() + except Exception as e: + flexkv_logger.warning(f"MasterCoordinator.shutdown() raised: {e}") + self._master_coordinator = None + + def _setup_sharing_domain_handles(self, *, gpu_register_port: Optional[str]) -> None: + """Populate ``self.transfer_handles`` with one handle per SD and + create a :class:`MasterCoordinator` on the Master node. + + No-op when ``cache_config.enable_sharing_domain`` is False — the + caller already gates this. This method is **best-effort**: if + the user hasn't populated ``remote_endpoints_by_sd`` yet (e.g. + single-SD degenerate mode), we keep the legacy handle list and + just construct the coordinator for the local SD. + """ + from flexkv.common.dist_reuse import ( + MasterCoordinator, + SharingDomainKey, + build_sharing_domain_handles, + make_session_epoch, + ) + + self_sd = SharingDomainKey.from_model_config(self.model_config) + + # Single-SD degenerate case (no sharing) — no extra handles needed. + if self_sd.total_sd_count() <= 1: + flexkv_logger.info( + "[DistReuse] Master SD is the only SD in the instance; " + "skipping multi-SD handle construction." + ) + # Still spin up a MasterCoordinator so aggregate-radix hooks + # work uniformly (it will just track 1 SD). + instance_id = self.cache_config.instance_id or f"inst-{self_sd.serialize()}" + self._master_coordinator = MasterCoordinator( + self_sd=self_sd, + instance_id=instance_id, + session_epoch=self.cache_config.session_epoch or make_session_epoch(), + ) + self._master_coordinator.expect_remotes(0) + # Single-SD path still attaches the coord state to the cache + # engine so refcount / aggregate / failure-detection hooks work. + self._wire_dist_reuse_coord_dispatcher() + return + + try: + specs = build_sharing_domain_handles( + self_sd=self_sd, + remote_endpoints_by_sd=self.cache_config.remote_endpoints_by_sd, + ) + except KeyError as e: + flexkv_logger.warning( + f"[DistReuse] could not build multi-SD handles: {e}. " + f"Falling back to legacy handle list." + ) + return + + # Drop the legacy handles and rebuild from scratch — we're on the + # multi-SD path now. + for h in self.transfer_handles: + try: + h.shutdown() + except Exception: # pragma: no cover + pass + self.transfer_handles = [] + + # Build new handle list per spec. + for spec in specs: + if spec.mode == "process": + handle = TransferManagerHandle( + self.model_config, self.cache_config, + mode="process", gpu_register_port=gpu_register_port, + ) + else: + ep = spec.endpoint + master_host = ep.ip + master_ports = (ep.gpu_register_port, ep.command_port, ep.result_port) + handle = TransferManagerHandle( + self.model_config, self.cache_config, + mode="remote", + gpu_register_port=gpu_register_port, + master_host=master_host, + master_ports=master_ports, + ) + handle._handle.set_target_sd_key(spec.sd_key.serialize()) + handle._handle.send_config_to_remotes() + self.transfer_handles.append(handle) + + self.required_completed_count = len(self.transfer_handles) + + # Create the Master coordinator; it will accept RemoteReadyMsg + # acks as Remote nodes finish their dist_reuse bootstrap. + instance_id = self.cache_config.instance_id or f"inst-{self_sd.serialize()}" + self._master_coordinator = MasterCoordinator( + self_sd=self_sd, + instance_id=instance_id, + session_epoch=self.cache_config.session_epoch or make_session_epoch(), + refcount_leak_timeout_seconds=self.cache_config.refcount_leak_timeout_seconds, + ) + self._master_coordinator.expect_remotes(len(specs) - 1) + + # Batch-F: wire up the cross-SD coordinator now that we have both + # the MasterCoordinator (aggregate radix + refcount owner) and + # the fully-populated transfer_handles list (one per SD). + self._wire_dist_reuse_coord_dispatcher() + + def _wire_dist_reuse_coord_dispatcher(self) -> None: + """Phase D-2 (proposal_unify_with_graph_dispatch_2026-05-15.md + §6.3 / §3.5): wire the master-side completion sink so peer-SD + ``CompletedOp(sd_key, contributing_node_id)`` flowing back through + each Remote handle's ``_polling_worker`` is routed to the + cache_engine's ``_on_peer_sd_completed_op`` method. + + Replaces the pre-Phase-D ``CoordinationCoordinator + + set_coord_ack_sink`` wiring (Phase D-4 deleted the latter + entirely). ``set_coord_ack_sink`` is still used here, but only + for the surviving ``FailureReportMsg`` channel — see + ``MasterCoordinator.handle_failure_report``. + + Idempotent: safe to call multiple times. + Degenerates to a single-SD no-op when the instance has exactly + one SD (``total_sd_count == 1``). + """ + if self._master_coordinator is None: + return + + self_sd = self._master_coordinator.self_sd + total = self_sd.total_sd_count() + if total <= 1: + # Single-SD instance: still attach the master_coord (for + # refcount / aggregate / failure detection), but no cross-SD + # dispatch is needed. + self.cache_engine.attach_dist_reuse(self._master_coordinator) + return + + # Phase D-2: register the master's completion sink on every + # peer-SD handle so CompletedOp(sd_key=..., contributing_node_id=...) + # gets routed to GlobalCacheEngine._on_peer_sd_completed_op. + sink = self.cache_engine._on_peer_sd_completed_op + # Phase D-4: also register a FailureReportMsg sink so peer + # data-plane failures invalidate aggregate-radix prefixes. + failure_sink = self._master_coordinator.handle_failure_report + # Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md + # §6.4): the Master's own in-proc / inter-proc handle must also + # honour ``target_node_ids`` filtering so peer-SD clone ops + # (D-2 PUT D2H clones, D-3 GET PEERH2H clones) are NOT executed + # by the Master's local TransferEngine. Without this the + # Master would either waste GPU bandwidth (D-2 D2H mirror is + # idempotent) or pull data from peer-SD mooncake endpoints it + # never connected to (D-3 PEERH2H — silent corruption). + master_self_nid = int(getattr(self.cache_config, "distributed_node_id", -1)) + for h in self.transfer_handles: + inner = h._handle + if hasattr(inner, "set_completion_sink"): + try: + inner.set_completion_sink(sink) + except Exception as e: # pragma: no cover + flexkv_logger.warning( + f"[DistReuse] set_completion_sink failed: {e}" + ) + if hasattr(inner, "set_coord_ack_sink"): + try: + inner.set_coord_ack_sink(failure_sink) + except Exception as e: # pragma: no cover + flexkv_logger.warning( + f"[DistReuse] set_coord_ack_sink failed: {e}" + ) + # Phase D-3: only the Master's own in-proc / inter-proc + # handle exposes ``set_dist_reuse_node_id``; the multi-node + # remote handles do their filtering on the Remote side via + # ``TransferManagerOnRemote._filter_graph_by_target_node_ids`` + # and ``_dist_reuse_node_id`` set during bootstrap. + if hasattr(inner, "set_dist_reuse_node_id") and master_self_nid >= 0: + try: + inner.set_dist_reuse_node_id(master_self_nid) + except Exception as e: # pragma: no cover + flexkv_logger.warning( + f"[DistReuse] set_dist_reuse_node_id failed: {e}" + ) + + # Wire into the cache engine. Cross-SD coordination flows + # through the graph-dispatch path with per-op + # ``target_node_ids`` filtering (Phase D-4); no separate coord + # dispatcher is needed. + self.cache_engine.attach_dist_reuse(self._master_coordinator) def create_get_task(self, task_id: int, @@ -288,23 +495,73 @@ def _launch_task(self, task_id: int) -> None: return nvtx.mark(f"launch task: task_id={task_id}, graph_id={transfer_graph.graph_id}") if transfer_graph.num_ops > 0: - for transfer_handle in self.transfer_handles: - # For remote handles: deepcopy graph and clear GPU blocks when - # it's a cross-machine PP handle (different PP stages have - # different GPU block_ids). Cross-machine TP handles share - # the same slot_mapping, so no clear is needed. + # Phase 0 task 0-K: compute per-handle GPU-clear decision *once* + # so the sharing-domain-aware logic can be unit-tested separately. + clear_flags = self._compute_gpu_clear_flags() + for idx, transfer_handle in enumerate(self.transfer_handles): if isinstance(transfer_handle._handle, TransferManagerMultiNodeHandle): - if self.model_config.nnodes > 1 and self.model_config.pp_size > 1: - # Cross-machine PP: each PP rank has different GPU blocks + if clear_flags[idx]: graph_copy = copy.deepcopy(transfer_graph) graph_copy.clear_gpu_blocks() transfer_handle.submit(graph_copy, task_end_op_id=self.tasks[task_id].task_end_op_id) else: - # Cross-machine TP: same slot_mapping across TP ranks transfer_handle.submit(transfer_graph, task_end_op_id=self.tasks[task_id].task_end_op_id) else: transfer_handle.submit(transfer_graph, task_end_op_id=self.tasks[task_id].task_end_op_id) + def _compute_gpu_clear_flags(self) -> List[bool]: + """Decide for each transfer handle whether its graph needs + GPU-block clearing before send. + + Legacy (``enable_sharing_domain=False``) behaviour: match the + pre-Batch-C rule — cross-machine PP needs clearing, cross-machine + TP does not. + + Sharing-domain behaviour: consult the per-handle SD key and use + :func:`graph_needs_gpu_clear` from :mod:`flexkv.common.dist_reuse`. + """ + if not self.transfer_handles: + return [] + + # Legacy branch first. + if not getattr(self.cache_config, "enable_sharing_domain", False): + legacy_clear = ( + self.model_config.nnodes > 1 and self.model_config.pp_size > 1 + ) + out: List[bool] = [] + for h in self.transfer_handles: + if isinstance(h._handle, TransferManagerMultiNodeHandle): + out.append(legacy_clear) + else: + out.append(False) + return out + + # Sharing-domain branch. + from flexkv.common.dist_reuse import ( + SharingDomainKey, + graph_needs_gpu_clear, + ) + self_sd = SharingDomainKey.from_model_config(self.model_config) + flags: List[bool] = [] + for h in self.transfer_handles: + if not isinstance(h._handle, TransferManagerMultiNodeHandle): + flags.append(False) + continue + peer_sd_str = getattr(h._handle, "_target_sd_key", None) + if peer_sd_str is None: + # Legacy remote with no SD tag — fall back to old rule. + flags.append( + self.model_config.nnodes > 1 and self.model_config.pp_size > 1 + ) + continue + try: + peer_sd = SharingDomainKey.deserialize(peer_sd_str) + except ValueError: + flags.append(True) # be conservative + continue + flags.append(graph_needs_gpu_clear(self_sd, peer_sd)) + return flags + def _update_tasks(self, timeout: float = 0.001) -> None: completed_ops = self._get_completed_ops(timeout) for completed_op in completed_ops: diff --git a/flexkv/metrics/collector.py b/flexkv/metrics/collector.py index 04678cfd58..ba95a0f986 100644 --- a/flexkv/metrics/collector.py +++ b/flexkv/metrics/collector.py @@ -12,16 +12,95 @@ """ import os +import shutil +import tempfile from typing import Dict, Optional -# Optional import for prometheus_client + +# --------------------------------------------------------------------------- +# Multi-process metrics bootstrap +# --------------------------------------------------------------------------- +# +# FlexKV runs the actual data-plane workers (e.g. +# ``PEER2CPUTransferWorker`` from ``flexkv/transfer/worker.py``) in +# ``mp.Process`` subprocesses. Each subprocess imports +# ``prometheus_client`` independently and ends up with its own in-memory +# Counter/Histogram values that **do not propagate** back to the HTTP +# server running in the parent process. The classic symptom is a +# ``/metrics`` endpoint that always reports zeros for any counter that +# is incremented from the subprocess (e.g. +# ``flexkv_py_dist_reuse_peer_mooncake_read_*``). +# +# ``prometheus_client`` solves this with the +# `PROMETHEUS_MULTIPROC_DIR `_ +# convention: every process writes its samples to a shared directory of +# mmap'd files, and the HTTP server uses ``MultiProcessCollector`` to +# aggregate them on every scrape. +# +# We auto-bootstrap the directory so operators don't have to remember to +# set the env var. This must happen **before** ``prometheus_client`` is +# imported, because the library reads ``PROMETHEUS_MULTIPROC_DIR`` at +# import time. +# --------------------------------------------------------------------------- +def _bootstrap_multiproc_dir() -> Optional[str]: + """Pick a directory for ``prometheus_client`` multiprocess samples. + + Honours an existing ``PROMETHEUS_MULTIPROC_DIR`` (operators may want + to point it at a tmpfs or a persistent path). Otherwise creates a + process-shared dir under ``$TMPDIR/flexkv_prom_`` and exports + it. The ```` of the *parent* process is used so subprocesses + spawned later inherit the same env via ``mp.Process`` env-copying. + + Returns the directory path on success, ``None`` if metrics are + disabled (so we skip the dir creation and avoid littering tmp). + """ + # Honour caller-set value verbatim. + existing = os.environ.get("PROMETHEUS_MULTIPROC_DIR") + if existing: + try: + os.makedirs(existing, exist_ok=True) + except Exception: + pass + return existing + + # Only bootstrap when metrics are actually enabled (avoid littering + # tmp on every test import). We intentionally bypass + # ``GLOBAL_CONFIG_FROM_ENV`` here because that module is loaded later; + # read the env var directly to break the cycle. + if os.environ.get("FLEXKV_ENABLE_METRICS", "0") != "1": + return None + + base = tempfile.gettempdir() + # Use parent PID so the directory survives across worker subprocess + # respawns within a single FlexKV instance. + parent_pid = os.getpid() + multiproc_dir = os.path.join(base, f"flexkv_prom_{parent_pid}") + try: + # Wipe stale samples from a previous run with the same pid (rare + # but possible after pid wrap-around). + if os.path.isdir(multiproc_dir): + shutil.rmtree(multiproc_dir, ignore_errors=True) + os.makedirs(multiproc_dir, exist_ok=True) + except Exception: + return None + os.environ["PROMETHEUS_MULTIPROC_DIR"] = multiproc_dir + return multiproc_dir + + +_MULTIPROC_DIR = _bootstrap_multiproc_dir() + + +# Optional import for prometheus_client. Must happen AFTER the +# multiproc dir bootstrap above so ``ValueClass`` picks the mmap'd +# backend instead of the in-memory one. try: - from prometheus_client import Counter, Gauge + from prometheus_client import Counter, Gauge, Histogram PROMETHEUS_AVAILABLE = True except ImportError: PROMETHEUS_AVAILABLE = False Counter = None Gauge = None + Histogram = None from flexkv.common.config import GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger @@ -211,7 +290,86 @@ def _init_metrics(self): documentation="Total number of allocated blocks by device", labelnames=["device"], ) - + + # ========== Dist-Reuse P2P Safety Metrics ========== + # See docs/dist_reuse/KNOWN_ISSUE_p2p_refcount_2026-05-14.md §4 for + # why each of these is critical for safe P2P cross-instance reuse. + # All five default to zero / no-op when dist_reuse is not active — + # call sites are guarded so existing single-instance deployments pay + # zero cost. + + # CRITICAL — non-zero means the master entered the high-watermark + # eviction path that bypasses lease protection. Any positive value + # in production should page oncall immediately (KNOWN_ISSUE §5 + # trigger #1). Labelled by device because CPU and SSD pools have + # independent watermarks. + self.dist_reuse_lease_meta_nullptr_total = Counter( + name="flexkv_py_dist_reuse_lease_meta_nullptr_total", + documentation=( + "Number of blocks inserted with lease_meta=nullptr because " + "the master pool exceeded swap_block_threshold. Such blocks " + "are evictable immediately and break the lease-based P2P " + "safety guarantee. Should be 0 in healthy deployments." + ), + labelnames=["device"], + ) + + # WARN — counts the "fresh" branch of evict (lease still valid but + # we needed the slot anyway). Healthy ratio of + # ``about_to_evict / evicted`` is < 1; sustained > 10 means the + # master is fighting eviction pressure and lease-based P2P safety + # margin is shrinking. + self.dist_reuse_about_to_evict_total = Counter( + name="flexkv_py_dist_reuse_about_to_evict_total", + documentation=( + "Number of blocks marked ABOUT_TO_EVICT (fresh-branch evict) " + "because the expired pool was insufficient. Used together " + "with flexkv_py_evicted_blocks_total to compute the " + "fresh/expired evict ratio (KNOWN_ISSUE §4.1)." + ), + labelnames=["device"], + ) + + # OPS — peer-side mooncake_read latency. P99 > 500ms means the + # remaining lease window is < ~10x typical lease_ttl; risk of lease + # exhaustion rises (KNOWN_ISSUE §4.2). Buckets cover the practical + # range from sub-ms (in-memory) to seconds (network-degraded). + self.dist_reuse_peer_mooncake_read_seconds = Histogram( + name="flexkv_py_dist_reuse_peer_mooncake_read_seconds", + documentation=( + "Latency of peer-side mooncake transfer_sync_read calls " + "(P2P CPU pull from master instance). P99 > 500ms triggers " + "the lease-margin alert (KNOWN_ISSUE §4.2)." + ), + buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0), + ) + + # CRITICAL — peer-side mooncake_read failure count. Includes + # mooncake-level errors AND zero-byte transfers (the symptom of + # the P0 bug fixed on 2026-05-14). Sustained > 0.1% failure rate + # warrants oncall (KNOWN_ISSUE §4.2). + self.dist_reuse_peer_mooncake_read_failures_total = Counter( + name="flexkv_py_dist_reuse_peer_mooncake_read_failures_total", + documentation=( + "Peer-side mooncake transfer_sync_read failures (non-zero " + "ret OR zero-byte transfer). Failure rate > 0.1% indicates " + "either lease exhaustion racing master eviction, or peer " + "node discovery breakdown." + ), + labelnames=["reason"], + ) + + # SUCCESS counter — denominator for the failure-rate calculation + # above. Without this, failure_rate would be unbounded if the + # service is idle. + self.dist_reuse_peer_mooncake_read_success_total = Counter( + name="flexkv_py_dist_reuse_peer_mooncake_read_success_total", + documentation=( + "Peer-side mooncake transfer_sync_read successes. Use as " + "denominator together with the _failures_total counter." + ), + ) + logger.info("[FlexKV PyMetrics] Prometheus metrics collector initialized") def _init_dummy_metrics(self): @@ -223,6 +381,8 @@ def inc(self, *args, **kwargs): pass def set(self, *args, **kwargs): pass + def observe(self, *args, **kwargs): + pass dummy = DummyMetric() @@ -236,6 +396,13 @@ def set(self, *args, **kwargs): self.mempool_free_blocks = dummy self.evicted_blocks_total = dummy self.allocated_blocks_total = dummy + + # Dist-reuse P2P safety dummy metrics (mirror of _init_metrics). + self.dist_reuse_lease_meta_nullptr_total = dummy + self.dist_reuse_about_to_evict_total = dummy + self.dist_reuse_peer_mooncake_read_seconds = dummy + self.dist_reuse_peer_mooncake_read_failures_total = dummy + self.dist_reuse_peer_mooncake_read_success_total = dummy @@ -327,7 +494,63 @@ def record_allocation(self, device: str, num_blocks: int): if not self.enabled or num_blocks <= 0: return self.allocated_blocks_total.labels(device=device).inc(num_blocks) - + + # ========== Dist-Reuse P2P Safety Recording Methods ========== + # + # See docs/dist_reuse/KNOWN_ISSUE_p2p_refcount_2026-05-14.md §4 for + # the operational meaning of each metric. All five degrade gracefully + # to no-op when metrics are disabled, so call sites can invoke them + # unconditionally. + + def record_dist_reuse_lease_nullptr(self, device: str, count: int = 1): + """Record a master-side block insertion that received + ``lease_meta=nullptr`` because the pool exceeded + ``swap_block_threshold``. + + **CRITICAL** — non-zero in production means the lease-based P2P + safety guarantee has been broken (KNOWN_ISSUE §5 trigger #1). + """ + if not self.enabled or count <= 0: + return + self.dist_reuse_lease_meta_nullptr_total.labels(device=device).inc(count) + + def record_dist_reuse_about_to_evict(self, device: str, count: int): + """Record blocks marked ABOUT_TO_EVICT in the fresh-branch evict + path. Pair with ``record_eviction`` (the expired-branch counter) + to compute the fresh/expired evict ratio.""" + if not self.enabled or count <= 0: + return + self.dist_reuse_about_to_evict_total.labels(device=device).inc(count) + + def observe_dist_reuse_peer_mooncake_read( + self, duration_seconds: float, *, success: bool, reason: str = "ok", + ): + """Record a peer-side mooncake transfer_sync_read attempt. + + Args: + duration_seconds: end-to-end latency of the read call. + Always recorded, including failures (so the latency + histogram captures the timeout / error path too). + success: True iff the read returned 0 bytes-of-error AND + non-zero data was actually moved. See worker.py + P0-fix comment for why ``ret == 0`` alone is not a + sufficient success criterion. + reason: free-form tag for the failure mode when + ``success`` is False. Recommended values: + ``"mooncake_error"`` (ret != 0), + ``"zero_byte_transfer"`` (the P0-bug symptom), + ``"node_meta_missing"`` (peer discovery breakdown), + ``"timeout"`` (long-running stuck read). + """ + if not self.enabled: + return + if duration_seconds >= 0: + self.dist_reuse_peer_mooncake_read_seconds.observe(duration_seconds) + if success: + self.dist_reuse_peer_mooncake_read_success_total.inc() + else: + self.dist_reuse_peer_mooncake_read_failures_total.labels(reason=reason).inc() + # Global collector instance _global_collector: Optional[FlexKVMetricsCollector] = None diff --git a/flexkv/metrics/server.py b/flexkv/metrics/server.py index 44464e4dac..f060826279 100644 --- a/flexkv/metrics/server.py +++ b/flexkv/metrics/server.py @@ -5,6 +5,7 @@ from the FlexKV Python runtime. """ +import os import threading from typing import Optional @@ -96,7 +97,12 @@ def start_metrics_server(port: Optional[int] = None) -> bool: return True try: - from prometheus_client import start_http_server, REGISTRY + from prometheus_client import ( + start_http_server, + REGISTRY, + CollectorRegistry, + multiprocess, + ) except ImportError: raise RuntimeError( "[FlexKV PyMetrics] prometheus_client not installed but metrics server requested. " @@ -107,9 +113,32 @@ def start_metrics_server(port: Optional[int] = None) -> bool: port = get_metrics_port() try: - # Start server with default registry (single process mode) + # ---------------------------------------------------------- + # Pick the right registry. + # + # When ``PROMETHEUS_MULTIPROC_DIR`` is set (the default when + # ``FLEXKV_ENABLE_METRICS=1``; see + # ``flexkv/metrics/collector.py::_bootstrap_multiproc_dir``), + # the HTTP server MUST use a fresh ``CollectorRegistry`` + # wrapped with ``MultiProcessCollector``. Otherwise it would + # serve only the parent process's local samples and silently + # drop everything emitted from ``mp.Process`` subprocess + # workers (the actual data path, e.g. + # ``PEER2CPUTransferWorker.observe_dist_reuse_peer_mooncake_read``). + # ---------------------------------------------------------- + multiproc_dir = os.environ.get("PROMETHEUS_MULTIPROC_DIR") + if multiproc_dir: + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry, path=multiproc_dir) + logger.info( + f"[FlexKV PyMetrics] Multiprocess mode: aggregating from " + f"{multiproc_dir}" + ) + else: + registry = REGISTRY + # Always bind to localhost (127.0.0.1) for security - start_http_server(port, addr=BIND_ADDRESS, registry=REGISTRY) + start_http_server(port, addr=BIND_ADDRESS, registry=registry) _server_started = True print( diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 34987e293b..a4ee08821f 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -175,6 +175,7 @@ def __init__( cache_config.redis_password, cache_config.local_ip, node_ttl_seconds=cache_config.node_ttl_seconds, + db=int(getattr(cache_config, "flexkv_redis_db", 0)), ) self.redis_meta_client.init_meta() # update distributed_node_id diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index c6e72575ef..d4828a9eb7 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -38,6 +38,7 @@ allocate_host_buffer, cudaHostRegister, ) + from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp from flexkv.mooncakeEngineWrapper import MoonCakeTransferEngineWrapper @@ -207,8 +208,10 @@ def run(self) -> None: transfer_status = self.launch_transfer(op) nvtx.pop_range() except Exception as e: - flexkv_logger.error(f"Error launching transfer: {e}\n" - f"Failed transfer op: {op}") + flexkv_logger.error( + f"Error launching transfer: {e}\n" + f"Failed transfer op: {op}" + ) if transfer_status: ## only put the op when transfer success self.finished_ops_queue.put(op.transfer_op_id) @@ -346,24 +349,29 @@ def _transfer_impl( gpu_tensor_ptrs = self.gpu_blocks_ptrs.contiguous().pin_memory() + # NOTE: use kwargs to bind to the C++ pybind signature exactly, otherwise + # positional args misalign on the new `start_layer_id` parameter and + # silently corrupt `transfer_num_cta` / `is_host_to_device` (D2H ends up + # with transfer_num_cta=0 → cudaErrorInvalidConfiguration). transfer_kv_blocks( - gpu_block_id_list, - gpu_tensor_ptrs, - self.gpu_kv_stride_in_bytes, - self.gpu_block_stride_in_bytes, - self.gpu_layer_stride_in_bytes, - cpu_block_id_list, - self.cpu_tensor, - self.cpu_kv_stride_in_bytes, - self.cpu_layer_stride_in_bytes, - self.cpu_block_stride_in_bytes, - self.chunk_size_in_bytes, - self.num_layers, - transfer_num_cta, - transfer_type == TransferType.H2D, - use_ce_transfer, - self.is_mla, - self.gpu_block_type_, + gpu_block_id_tensor=gpu_block_id_list, + gpu_tensor_ptrs_tensor=gpu_tensor_ptrs, + gpu_kv_stride_in_bytes=self.gpu_kv_stride_in_bytes, + gpu_block_stride_in_bytes=self.gpu_block_stride_in_bytes, + gpu_layer_stride_in_bytes=self.gpu_layer_stride_in_bytes, + cpu_block_id_tensor=cpu_block_id_list, + cpu_tensor=self.cpu_tensor, + cpu_kv_stride_in_bytes=self.cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes=self.cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes=self.cpu_block_stride_in_bytes, + chunk_size_in_bytes=self.chunk_size_in_bytes, + start_layer_id=0, + num_layers=self.num_layers, + transfer_num_cta=transfer_num_cta, + is_host_to_device=(transfer_type == TransferType.H2D), + use_ce_transfer=use_ce_transfer, + is_mla=self.is_mla, + gpu_block_type=self.gpu_block_type_, ) def launch_transfer(self, transfer_op: WorkerTransferOp) -> bool: @@ -1267,9 +1275,31 @@ def __init__(self, self.cache_config.redis_password, self.cache_config.local_ip, node_ttl_seconds=self.cache_config.node_ttl_seconds, + db=int(getattr(self.cache_config, "flexkv_redis_db", 0)), ) self.redis_meta_client.set_node_id(self.cache_config.distributed_node_id) + # P0 FIX: bootstrap RedisNodeInfo discovery in this worker subprocess. + try: + if not self.redis_meta_client.nodeinfo.connect(): + flexkv_logger.warning( + "PEER2CPUTransferWorker: nodeinfo.connect() returned " + "False — cross-instance peer discovery may be degraded." + ) + else: + self.redis_meta_client.nodeinfo.scan_active_nodes() + active = self.redis_meta_client.nodeinfo.get_active_node_ids() + flexkv_logger.info( + f"PEER2CPUTransferWorker: nodeinfo bootstrap OK — " + f"self_node_id={self.cache_config.distributed_node_id}, " + f"current_active_nodes={sorted(active)}" + ) + except Exception as _e: + flexkv_logger.warning( + f"PEER2CPUTransferWorker: nodeinfo bootstrap raised {_e!r}; " + "is_node_active() may return False for live peers." + ) + # Persistent NodeMetaInfo Pool for skip redis operation when getting # NodeMetaInfo according to node_id # assuming that every flexkv progress has unique node id @@ -1462,9 +1492,72 @@ def _batch_transfer_impl(self, transfer_type: TransferType, **kwargs,): if transfer_type == TransferType.PEERH2H: + # ---------------------------------------------------------- + # P2P CPU pull observability (KNOWN_ISSUE_p2p_refcount_2026-05-14 + # §4.2). We time the mooncake_read and tag the outcome so the + # operator can monitor: + # * latency P99 (lease-margin proxy) + # * failure rate (mooncake errors + zero-byte transfers, + # the latter being the symptom of the P0 bug fixed today) + # + # Metric collector lookup is best-effort — never raise from the + # data path, never block on metrics infrastructure failure. + # ---------------------------------------------------------- + try: + # Use ``init_global_collector`` (not ``get_…``) because + # the global ``_global_collector`` singleton lives in + # *this* subprocess's address space — it's None on first + # call here even though the parent process initialized + # one before forking. ``init_…`` is idempotent: it + # constructs a new collector and registers metrics + # against the multiprocess dir if and only if there + # isn't one already. All metric writes route to the + # mmap'd files in ``$PROMETHEUS_MULTIPROC_DIR`` (set + # by the parent's ``_bootstrap_multiproc_dir`` and + # inherited via env into the spawn'd subprocess), so + # the parent's HTTP server's ``MultiProcessCollector`` + # picks them up on the next scrape. + from flexkv.metrics import init_global_collector + _metrics = init_global_collector() + except Exception: + _metrics = None + + try: + _expected_bytes = int(sum(task_info.data_lens or [])) + except Exception: + _expected_bytes = 0 + + _t0 = time.perf_counter() ret = self.mooncake_transfer_engine.batch_transfer_sync_read( task_info.peer_engine_addr, task_info.src_ptrs, task_info.dst_ptrs, task_info.data_lens ) + _elapsed = time.perf_counter() - _t0 + + # Determine outcome: + # * mooncake error → reason="mooncake_error" + # * zero-byte read while caller expected non-zero → reason= + # "zero_byte_transfer" (the P0-bug signature) + # * otherwise success + _success = True + _reason = "ok" + if ret != 0: + _success = False + _reason = "mooncake_error" + elif _expected_bytes > 0 and (not task_info.data_lens or sum(task_info.data_lens) == 0): + # Defensive — data_lens was sanity-checked but the request + # asked for 0 bytes; treat as zero-byte transfer. + _success = False + _reason = "zero_byte_transfer" + + if _metrics is not None: + try: + _metrics.observe_dist_reuse_peer_mooncake_read( + _elapsed, success=_success, reason=_reason, + ) + except Exception: + # Never let metrics break the data path. + pass + if ret != 0: flexkv_logger.error(f"RDMA transfer failed with error code: {ret}") return False @@ -2011,6 +2104,16 @@ def get_node_meta(self, node_id: int) -> Optional[NodeMetaInfo]: remote node has crashed. """ # ===== Active-node validation (Scheme 4) ===== + if not self.redis_meta_client.is_node_active(node_id): + # Cache may be stale: the listener thread is best-effort and an + # initial ``scan_active_nodes()`` may have run before the peer + # registered itself. Force a fresh SCAN once before giving up so + # we don't drop a transfer just because of bootstrap order. + try: + self.redis_meta_client.nodeinfo.scan_active_nodes() + except Exception: # noqa: BLE001 + pass + if not self.redis_meta_client.is_node_active(node_id): # Node is no longer active – purge cached meta if any if node_id in self.node_metas: diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 37157eb6fc..f88c980b6f 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -31,6 +31,130 @@ from flexkv.server.request import RegisterTPClientRequest, Response +# --------------------------------------------------------------------------- +# Coord-message routing (Phase D-4: heavy trim) +# +# Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): the +# CoordQuery / CoordGet / CoordPut messages were deleted. Per-SD PUT/ +# GET coordination now flows through the unified TransferOpGraph +# dispatch path; only ``failure_report`` remains as an asynchronous +# Layer-2 closed-loop side channel (Remote → Master). +# --------------------------------------------------------------------------- +_COORD_ACK_TYPES = frozenset({ + "failure_report", +}) + + +def _filter_graph_inplace_by_target_node_ids( + graph: TransferOpGraph, self_nid: int, +) -> int: + """Drop ops whose ``target_node_ids`` does not include ``self_nid``. + + Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md §6.4): + ``target_node_ids`` is a per-op SD-routing tag. Originally the + filter ran only on cross-machine ``TransferManagerOnRemote`` (Phase + D-1). Phase D-3 lifts it to a module-level helper so the Master's + in-proc / inter-proc handles can apply the same rule before + handing the graph to its local TransferEngine. Without this, the + Master would eagerly execute every peer-SD clone op in the graph, + which: + + * For PUT-path D2H clones (Phase D-2): produces N-1 redundant + GPU→CPU copies — correct because the clones share src/dst block + IDs (mirror assumption), but wastes bandwidth. + * For GET-path PEERH2H clones (Phase D-3): each clone targets a + *different* peer instance's per-SD mooncake endpoint via + ``src_block_node_ids``; running them on the Master would fetch + data from endpoints the Master never connected to → silent + corruption / mooncake errors. + + Semantics: + + * ``target_node_ids`` is None / empty → keep (legacy single-SD or + cross-machine TP/PP behaviour, unchanged). + * ``self_nid < 0`` → keep everything (dist_reuse + bootstrap not finished yet — pre-D-3 behaviour). + * Otherwise: keep iff ``self_nid in target_node_ids``. + + Dependency-graph repair (Phase D-3): + A dropped op may still appear in the ``predecessors`` set of an + op we *keep* (typical case: ``op_h2d → op_peerh2h_clone`` was + wired by the master before broadcast; on the Master the clones + are dropped but ``op_h2d`` remains). The graph engine waits for + every predecessor to finish before scheduling a successor — if + the predecessor was filtered out it never completes, so the + successor would deadlock. We therefore treat dropped ops as + *vacuously satisfied* on this handle: remove their op_id from + every kept op's ``predecessors`` set. This is the natural + semantic — "this op runs on a different SD, so locally there is + nothing to wait for here; the cross-SD ack flows back through + the master polling thread's ``_completion_sink`` and is observed + on the Master's own graph copy". + + Returns the number of dropped ops (for logging / metrics). + Mutates ``graph`` in place. + """ + if self_nid < 0: + return 0 + + op_map = graph._op_map # type: ignore[attr-defined] + to_drop: List[int] = [] + for op_id, op in op_map.items(): + tnids = getattr(op, "target_node_ids", None) + if not tnids: + continue # No filter — keep. + if self_nid not in tnids: + to_drop.append(op_id) + + if not to_drop: + return 0 + + drop_set = set(to_drop) + + # Mirror the bookkeeping ``TransferOpGraph.add_transfer_op`` + # populates so the resulting graph stays internally consistent. + for op_id in to_drop: + op_map.pop(op_id, None) + graph._ready_ops.discard(op_id) # type: ignore[attr-defined] + graph._trigger_ops.discard(op_id) # type: ignore[attr-defined] + try: + graph._gpu_transfer_op_id.remove(op_id) # type: ignore[attr-defined] + except (ValueError, AttributeError): + pass + + # Phase D-3: repair predecessor / successor sets on the kept ops + # so dropped ops appear "vacuously satisfied" to the local graph + # engine. Without this fix, an ``op_h2d`` left in the master's + # graph would still list the (now-dropped) peer-SD ``op_peerh2h`` + # clones in its ``predecessors`` and never become ready. + for op in op_map.values(): + if op.predecessors: + op.predecessors.difference_update(drop_set) + if op.successors: + op.successors.difference_update(drop_set) + # If repairing predecessors made this op fully ready and it + # wasn't already in ``_ready_ops``, put it back so the next + # ``take_ready_ops`` call picks it up. Status check guards + # against re-adding ops that have already completed in earlier + # passes (defensive — graphs are normally filtered exactly + # once at submit time). + if ( + len(op.predecessors) == 0 + and op.op_id not in graph._ready_ops # type: ignore[attr-defined] + and op.op_id not in graph._trigger_ops # type: ignore[attr-defined] + ): + try: + from flexkv.common.transfer import TransferOpStatus + if op.status == TransferOpStatus.PENDING: + graph._ready_ops.add(op.op_id) # type: ignore[attr-defined] + except Exception: + # Defensive — never let a status import / equality + # check wedge the filter pass. + pass + + return len(to_drop) + + class TransferManager: def __init__(self, model_config: ModelConfig, @@ -308,6 +432,13 @@ def _initialize_with_config(self) -> None: self.model_config = config_msg.get('model_config') self.cache_config = config_msg.get('cache_config') self.gpu_register_port = config_msg.get('gpu_register_port') + # Phase 0 task 0-F: stash SD bootstrap fields for post-init use. + self._dist_reuse_bootstrap_msg = { + 'sd_key': config_msg.get('sd_key'), + 'instance_id': config_msg.get('instance_id'), + 'session_epoch': config_msg.get('session_epoch'), + 'master_zmq_addr': config_msg.get('master_zmq_addr'), + } flexkv_logger.info(f"Received config from master, {self.model_config = }, \ {self.cache_config = }, {self.gpu_register_port = }.") else: @@ -315,6 +446,98 @@ def _initialize_with_config(self) -> None: flexkv_logger.info("Received config from master successfully") super().__init__(self.model_config, self.cache_config, self.gpu_register_port) + # Phase 0 task 0-F: opt-in to sharing-domain bootstrap. + if getattr(self.cache_config, "enable_sharing_domain", False): + self._bootstrap_sharing_domain() + + def _bootstrap_sharing_domain(self) -> None: + """Run the RedisMeta + Mooncake bootstrap for this Remote node. + + Called from :meth:`_initialize_with_config` after ``super().__init__`` + so that the parent class has already allocated the CPU block pool + and we can register its buffer pointer with Mooncake. + + Failure to bootstrap is fatal — we raise and let the parent process + restart the instance (design doc §4.3.1 co-destined failure model). + """ + boot = self._dist_reuse_bootstrap_msg + sd_key_str = boot.get('sd_key') if boot else None + if not sd_key_str: + flexkv_logger.warning( + "[DistReuse] enable_sharing_domain=True but no sd_key in config; " + "skipping Remote bootstrap (Master still uses legacy discovery)." + ) + return + + from flexkv.common.dist_reuse import RemoteDistReuseInitializer + + # ``cpu_blocks`` is allocated by the parent class; its data pointer + # and byte size drive Mooncake buffer registration. + cpu_blocks = getattr(self, "cpu_blocks", None) + if cpu_blocks is None: + flexkv_logger.warning( + "[DistReuse] parent did not allocate cpu_blocks; cannot register " + "Mooncake buffer. Bootstrap skipped." + ) + return + + cpu_buffer_ptr = int(cpu_blocks.data_ptr()) + cpu_buffer_size = int(cpu_blocks.numel() * cpu_blocks.element_size()) + local_zmq_addr = f"{self.cache_config.local_zmq_ip}:{self.cache_config.local_zmq_port}" + + init = RemoteDistReuseInitializer( + cache_config=self.cache_config, + sd_key_str=sd_key_str, + instance_id=boot.get('instance_id') or "", + session_epoch=boot.get('session_epoch') or "", + cpu_buffer_ptr=cpu_buffer_ptr, + cpu_buffer_size=cpu_buffer_size, + local_zmq_addr=local_zmq_addr, + ) + result = init.bootstrap() + + self._dist_reuse_sd_key = result.sd_key + self._dist_reuse_namespace = result.namespace + self._dist_reuse_redis_meta = result.redis_meta + self._dist_reuse_mooncake_engine = result.mooncake_engine + self._dist_reuse_node_id = result.distributed_node_id + + # Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): + # the old per-SD ZMQ coord protocol (``CoordQueryMsg``/ + # ``CoordGetCmdMsg``/``CoordPutCmdMsg`` + ``RemoteCoordHandler`` + # + ``BlockIndex`` + ``_d2h_and_publish`` closure) was deleted + # in this refactor. Multi-SD coordination now flows through + # the unified ``TransferOpGraph`` dispatch path: the master + # broadcasts a graph carrying multiple D2H / PEERH2H ops with + # per-op ``target_node_ids``, each Remote's ``_handle_submit`` + # filters by its own ``self_node_id``, the local transfer + # worker executes the slice, and the resulting ``CompletedOp`` + # carries ``sd_key`` + ``contributing_node_id`` back through + # the existing master-side polling thread. + # + # No coord handler is constructed here — the Remote is now + # purely data-plane and reuses the regular submit/wait + # plumbing for both legacy cross-machine TP/PP graphs and + # dist_reuse multi-SD graphs. + + # Push the ready message back to Master via the result channel. + # We deliberately reuse ``result_socket`` (rather than a dedicated + # dist_reuse channel) because the Master's polling worker already + # accepts arbitrary pyobjs on this socket. + try: + self.result_socket.send_pyobj({ + 'type': 'remote_ready', + 'msg': result.ready_msg, + }) + flexkv_logger.info( + f"[DistReuse] Remote bootstrap complete for sd_key={sd_key_str}, " + f"node_id={result.distributed_node_id}" + ) + except Exception as e: + flexkv_logger.error(f"[DistReuse] failed to send ready to master: {e}") + raise + + def _polling_worker(self) -> None: flexkv_logger.info("Polling worker thread started") @@ -343,6 +566,21 @@ def _polling_worker(self) -> None: elif msg_type == 'submit_batch': graphs = message.get('graphs', []) for graph in graphs: + # Phase D-1: filter ops by + # target_node_ids before rebinding. + self._filter_graph_by_target_node_ids(graph) + if graph.num_ops == 0: + try: + self.result_socket.send_pyobj( + CompletedOp.completed_graph(graph.graph_id) + ) + except Exception as e: # pragma: no cover + flexkv_logger.error( + f"[TransferManagerOnRemote] failed to send " + f"empty-batch-graph completion for " + f"graph={graph.graph_id}: {e}" + ) + continue graph_id = graph.graph_id with self._active_graphs_lock: self._active_graphs[graph_id] = -1 @@ -376,16 +614,46 @@ def _polling_worker(self) -> None: completed = self.wait(timeout=0.001) if completed: + # Phase D-1: stamp ``sd_key`` + ``contributing_node_id`` + # onto every CompletedOp this Remote ships back so the + # Master's polling worker can route per-SD completion + # to ``mark_sd_ready``. The values are read from the + # bootstrap result; default to "" / -1 when bootstrap + # didn't run (legacy single-SD path). + sd_key_str = str(getattr(self, "_dist_reuse_sd_key", "") or "") + contributing_nid = int(getattr(self, "_dist_reuse_node_id", -1)) + with self._active_graphs_lock: for completed_op in completed: if completed_op.graph_id in self._active_graphs: task_end_op_id = self._active_graphs[completed_op.graph_id] if task_end_op_id != -1 and completed_op.op_id == task_end_op_id: - end_op = CompletedOp(graph_id=completed_op.graph_id, op_id=task_end_op_id) + end_op = CompletedOp( + graph_id=completed_op.graph_id, + op_id=task_end_op_id, + sd_key=sd_key_str, + contributing_node_id=contributing_nid, + ) self.result_socket.send_pyobj(end_op) if completed_op.is_graph_completed(): - self.result_socket.send_pyobj(completed_op) + # ``completed_op`` is frozen — re-create + # it with the SD/node tags. Preserve + # existing fields (transfer_type / num_* + # / success / error) so the master's + # bookkeeping doesn't lose info. + tagged = CompletedOp( + graph_id=completed_op.graph_id, + op_id=completed_op.op_id, + transfer_type=completed_op.transfer_type, + num_blocks=completed_op.num_blocks, + num_bytes=completed_op.num_bytes, + sd_key=sd_key_str, + contributing_node_id=contributing_nid, + success=completed_op.success, + error=completed_op.error, + ) + self.result_socket.send_pyobj(tagged) del self._active_graphs[completed_op.graph_id] except queue.Empty: @@ -437,6 +705,29 @@ def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> No If slot_mapping already arrived, set_gpu_blocks and submit immediately. Otherwise, store graph in pending_graphs for later matching. """ + # Phase D-1: filter ops by ``target_node_ids`` BEFORE pending / + # rebind logic. When the Master broadcasts a multi-SD graph + # (proposal_unify_with_graph_dispatch_2026-05-15.md §6), each + # Remote drops the ops not addressed to its own + # ``self_node_id`` so the TransferEngine only sees its own + # slice. Legacy graphs whose ops have ``target_node_ids=None`` + # are kept verbatim — backwards-compatible with single-SD + # cross-machine TP/PP graphs. + self._filter_graph_by_target_node_ids(graph) + if graph.num_ops == 0: + # Nothing for this Remote to do. Still acknowledge graph + # completion so the Master's barrier doesn't hang waiting. + try: + self.result_socket.send_pyobj( + CompletedOp.completed_graph(graph.graph_id) + ) + except Exception as e: # pragma: no cover + flexkv_logger.error( + f"[TransferManagerOnRemote] failed to send empty-graph " + f"completion for graph={graph.graph_id}: {e}" + ) + return + task_id = graph.graph_id # Use graph_id as task_id for matching with self._pending_lock: if task_id in self._pending_slot_mappings: @@ -461,6 +752,30 @@ def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> No self._active_graphs[graph.graph_id] = task_end_op_id self.submit(graph) + def _filter_graph_by_target_node_ids(self, graph: TransferOpGraph) -> None: + """Drop ops whose ``target_node_ids`` does not include this + Remote's ``self_node_id``. + + Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md §6.4): + the actual filter logic now lives in the module-level + :func:`_filter_graph_inplace_by_target_node_ids` so the Master's + in-proc / inter-proc handles can call it too. This thin wrapper + only resolves ``self_nid`` and adds the per-Remote debug log. + """ + # Resolve self node_id. ``_dist_reuse_node_id`` is set by + # ``_bootstrap_sharing_domain`` once the Remote has registered + # with Redis; before that we have no SD identity, so legacy + # behaviour (no filtering) is the only safe choice. + self_nid = int(getattr(self, "_dist_reuse_node_id", -1)) + dropped = _filter_graph_inplace_by_target_node_ids(graph, self_nid) + if dropped: + flexkv_logger.debug( + f"[TransferManagerOnRemote:nid={self_nid}] dropped " + f"{dropped} op(s) not addressed to me from " + f"graph={graph.graph_id}; remaining ops={graph.num_ops}" + ) + + def start(self) -> None: self.initialize_transfer_engine() super().start() @@ -643,6 +958,20 @@ def __init__(self, gpu_register_port: str): self.transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) self._is_ready = False + # Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md + # §6.4): Master's own ``distributed_node_id`` for ``target_node_ids`` + # filtering on graphs submitted into the local TransferEngine. + # ``-1`` is the sentinel meaning "dist_reuse off / not yet wired"; + # ``_filter_graph_inplace_by_target_node_ids`` then keeps every op + # (legacy behaviour). Set via :meth:`set_dist_reuse_node_id` from + # ``KVTaskManager._wire_dist_reuse_coord_dispatcher``. + self._dist_reuse_self_nid: int = -1 + + def set_dist_reuse_node_id(self, self_nid: int) -> None: + """Phase D-3: install the Master's own ``distributed_node_id`` so + the in-proc submit path drops ops addressed to peer SDs. Idempotent. + """ + self._dist_reuse_self_nid = int(self_nid) def start(self) -> None: self.transfer_manager.initialize_transfer_engine() @@ -653,9 +982,20 @@ def is_ready(self) -> bool: return self._is_ready def submit(self, transfer_graph: TransferOpGraph, task_end_op_id: int = -1) -> None: + # Phase D-3: drop peer-SD ops (target_node_ids set + does not + # include self) before the local TransferEngine sees the graph. + # Legacy graphs (target_node_ids=None on every op) and pre- + # bootstrap state (self_nid=-1) are unaffected. + _filter_graph_inplace_by_target_node_ids( + transfer_graph, self._dist_reuse_self_nid, + ) self.transfer_manager.submit(transfer_graph) def submit_batch(self, transfer_graphs: List[TransferOpGraph]) -> None: + for g in transfer_graphs: + _filter_graph_inplace_by_target_node_ids( + g, self._dist_reuse_self_nid, + ) self.transfer_manager.submit_batch(transfer_graphs) def wait(self, timeout: Optional[float] = None) -> List[CompletedOp]: @@ -685,6 +1025,19 @@ def __init__(self, self._completed_results: List[CompletedOp] = [] + # Phase D-3 (proposal_unify_with_graph_dispatch_2026-05-15.md + # §6.4): Master's own ``distributed_node_id`` for ``target_node_ids`` + # filtering applied **in the parent process** before the graph + # is shipped through the Pipe. ``-1`` = dist_reuse off / not + # yet wired → utility keeps every op (legacy behaviour). + self._dist_reuse_self_nid: int = -1 + + def set_dist_reuse_node_id(self, self_nid: int) -> None: + """Phase D-3: install the Master's own ``distributed_node_id`` + so peer-SD ops are dropped before crossing the Pipe. Idempotent. + """ + self._dist_reuse_self_nid = int(self_nid) + def _start_process(self) -> None: if self.process is not None and self.process.is_alive(): return @@ -828,6 +1181,11 @@ def is_ready(self) -> bool: def submit(self, transfer_graph: TransferOpGraph, task_end_op_id: int = -1) -> None: nvtx_range = nvtx.start_range(message="TransferManagerInterProcessHandle.submit", color="green") + # Phase D-3: filter peer-SD ops *before* the Pipe send so the + # subprocess never sees them. No-op for legacy graphs. + _filter_graph_inplace_by_target_node_ids( + transfer_graph, self._dist_reuse_self_nid, + ) self.command_parent_conn.send({ 'type': 'submit', 'transfer_graph': transfer_graph @@ -840,6 +1198,11 @@ def submit_batch(self, transfer_graphs: List[TransferOpGraph]) -> None: message=f"TransferManagerInterProcessHandle.submit_batch count={len(transfer_graphs)}", color="green" ) + # Phase D-3: filter every graph in place before shipping the batch. + for g in transfer_graphs: + _filter_graph_inplace_by_target_node_ids( + g, self._dist_reuse_self_nid, + ) self.command_parent_conn.send({ 'type': 'submit_batch', 'transfer_graphs': transfer_graphs @@ -905,6 +1268,18 @@ def __init__(self, self._result_buffer: List[CompletedOp] = [] self._result_buffer_lock = threading.Lock() + # Phase 0 task 0-F / 0-G: sharing-domain Remote ready acks land + # here. The KVTaskManager drains this list and forwards each + # message to its MasterCoordinator. + self._remote_ready_acks: List[Any] = [] + + # Phase D-4: sink for FailureReportMsg (only remaining ack + # type after the coord-protocol cleanup). The KVTaskManager + # may set this to a handler that calls + # ``MasterCoordinator.handle_failure_report``. Unset = drop + # silently (harmless — failure detector's Layer-1 path will + # eventually catch persistent peer outages). + self._coord_ack_sink: Optional[Any] = None self._bind_master_ports() @@ -949,11 +1324,32 @@ def send_config_to_remotes(self) -> None: 'cache_config': self.cache_config, 'gpu_register_port': self.gpu_register_port } + # Phase 0 task 0-F: when sharing-domain is on, ship the + # Remote's ``sd_key`` + ``instance_id`` + ``session_epoch`` in + # the same config message. ``target_sd_key`` is populated + # externally by the Master (see KVTaskManager._build_sd_handles). + if getattr(self.cache_config, "enable_sharing_domain", False): + config_msg['sd_key'] = getattr(self, '_target_sd_key', None) + config_msg['instance_id'] = getattr(self.cache_config, 'instance_id', None) + config_msg['session_epoch'] = getattr(self.cache_config, 'session_epoch', None) + config_msg['master_zmq_addr'] = ( + f"{self.master_host}:{self.master_ports[0]}" + ) self.command_socket.send_pyobj(config_msg) flexkv_logger.info(f"Config sent to remote at {self.master_host}:{self.master_ports[0]}") except Exception as e: flexkv_logger.error(f"Failed to send config to remote: {e}") + def set_target_sd_key(self, sd_key_str: str) -> None: + """Set the ``sd_key`` this handle is delivering config to. + + Called by :class:`KVTaskManager` **before** + :meth:`send_config_to_remotes`. The master must tag each remote + handle with the SD key the corresponding Remote node owns so the + Remote can register itself under the right namespace. + """ + self._target_sd_key = str(sd_key_str) + def _polling_worker(self) -> None: while not self._shutdown_flag: try: @@ -961,6 +1357,54 @@ def _polling_worker(self) -> None: if isinstance(result, CompletedOp): with self._result_buffer_lock: self._result_buffer.append(result) + # Phase D-1: when the inbound CompletedOp carries an + # ``sd_key`` (i.e. the Remote tagged it during + # _polling_worker after dist_reuse bootstrap), notify + # the registered ``_completion_sink`` so the Master's + # GlobalCacheEngine can flip the per-SD ready bit on + # the AggregateRadixTree. This replaces the old + # CoordPutAckMsg / CoordGetAckMsg ack route. + if getattr(result, "sd_key", "") and result.success: + sink = getattr(self, "_completion_sink", None) + if sink is not None: + try: + sink(result) + except Exception as e: # pragma: no cover + flexkv_logger.error( + f"[DistReuse] completion sink raised on " + f"CompletedOp(graph={result.graph_id}, " + f"op={result.op_id}, sd={result.sd_key}): {e}" + ) + elif ( + isinstance(result, dict) + and result.get('type') == 'remote_ready' + ): + # Phase 0 task 0-F / 0-G: Remote finished its + # sharing-domain bootstrap. Stash the message so the + # Master's KVTaskManager can forward it to its + # MasterCoordinator. + msg = result.get('msg') + if msg is not None: + with self._result_buffer_lock: + self._remote_ready_acks.append(msg) + elif ( + isinstance(result, dict) + and result.get('type') in _COORD_ACK_TYPES + ): + # Phase D-4: only ``failure_report`` reaches this + # branch. Decode and hand off to the registered + # sink (typically ``MasterCoordinator.handle_failure_report``). + try: + from flexkv.common.dist_reuse import decode_coord_message + ack_obj = decode_coord_message(result) + sink = getattr(self, "_coord_ack_sink", None) + if sink is not None: + sink(ack_obj) + except Exception as e: + flexkv_logger.error( + f"[DistReuse] failed to decode coord ACK " + f"{result.get('type')!r}: {e}" + ) else: flexkv_logger.warning(f"Unexpected result format from remote: {result}") @@ -975,6 +1419,33 @@ def start(self) -> None: self._polling_thread = threading.Thread(target=self._polling_worker, daemon=True) self._polling_thread.start() + def set_completion_sink(self, sink) -> None: + """Phase D-1 (proposal §3.5): register a callable invoked by the + master-side polling worker for every ``CompletedOp`` that arrives + with a non-empty ``sd_key`` tag and ``success=True``. + + ``sink(completed_op)`` runs in the polling thread context. + Pass ``None`` to detach. + + The KVTaskManager wires this to a small adapter that calls + ``GlobalCacheEngine._on_peer_sd_completed_op`` so the Master's + ``AggregateRadixTree`` learns which SDs are ready and from which + peer node — replacing the old CoordPutAckMsg / CoordGetAckMsg + ack route. + """ + self._completion_sink = sink + + def set_coord_ack_sink(self, sink) -> None: + """Phase D-4: install a callable ``sink(failure_report_msg)`` + invoked when a ``FailureReportMsg`` arrives from this handle's + Remote. Other coord types no longer exist (per-SD PUT/GET + coordination is done via ``CompletedOp(sd_key, ...)`` — + see ``set_completion_sink``). + + Pass ``None`` to detach. + """ + self._coord_ack_sink = sink + def is_ready(self) -> bool: if not self._connected: flexkv_logger.warning("Master not ready: ports not bound yet") diff --git a/scripts/multi-nodes/start_dist_reuse_serving.sh b/scripts/multi-nodes/start_dist_reuse_serving.sh new file mode 100644 index 0000000000..a0444c8c98 --- /dev/null +++ b/scripts/multi-nodes/start_dist_reuse_serving.sh @@ -0,0 +1,253 @@ +#!/bin/bash +# ============================================================================ +# FlexKV dist_reuse 多物理节点启动脚本(§2.5 / Phase 1-G) +# +# 场景(设计文档 §4.5.5 / §4.6.3): +# * Prefill 实例跨 2 台 GPU 机器(pp_size × tp_node_count ≤ 2) +# * CP 在节点内做(CP 不进 sd_key,CP rank > 0 只做 GPU 注册) +# * 每个 SD-Remote 物理机独立 Mooncake TransferEngine +# * Master 单写 Redis,Remote 读 + 收协同 GET/PUT 指令 +# +# 用法: +# # Master(node_rank=0) +# ./start_dist_reuse_serving.sh \ +# --nnodes 2 --node-rank 0 \ +# --master-ip 10.0.0.1 --dist-init-port 29500 \ +# --tp-size 8 --pp-size 1 --cp-size 4 \ +# --model /workspace/models/DeepSeek-V3 \ +# --redis-host 10.0.0.1 --redis-password 123456 +# +# # Remote(node_rank=1)— 脚本参数与 master 相同 +# ./start_dist_reuse_serving.sh \ +# --nnodes 2 --node-rank 1 \ +# --master-ip 10.0.0.1 --dist-init-port 29500 \ +# --tp-size 8 --pp-size 1 --cp-size 4 \ +# --model /workspace/models/DeepSeek-V3 \ +# --redis-host 10.0.0.1 --redis-password 123456 +# +# 必填: +# --nnodes / --node-rank / --master-ip / --tp-size / --pp-size / --cp-size / +# --model / --redis-host +# +# 备注: +# * 如果该 instance 只用 CP 跨节点(CP 跨机、TP/PP 不跨机),脚本会自动 +# 把 node_rank>0 的机器放到 CP_PEER_REGISTRATION_ONLY 路径上(不启 +# TransferManagerOnRemote,sglang connector 侧按 multinode_policy +# 策略自己决定)。 +# * 脚本**不直接启动** sglang/vLLM 进程;它只做环境变量和配置文件生 +# 成,然后 exec 用户指定的启动命令(--launcher-cmd / 默认是 +# sglang 的 router 入口)。这样脚本本身单测友好。 +# ============================================================================ +set -euo pipefail + +# ---------------------------------------------------------------- argparse +NNODES="" +NODE_RANK="" +MASTER_IP="" +DIST_INIT_PORT=29500 +TP_SIZE="" +PP_SIZE="" +CP_SIZE="1" +MODEL="" +REDIS_HOST="" +REDIS_PORT=6379 +REDIS_PASSWORD="" +MOONCAKE_REDIS_PORT=6380 # separate metadata redis for mooncake +MOONCAKE_ENGINE_PORT_BASE=12345 +LAUNCHER_CMD="" +RDMA_DEVICE="${RDMA_DEVICE:-mlx5_0}" +INSTANCE_ID="" +DRY_RUN="false" + +usage() { + grep '^#' "$0" | sed 's/^# \{0,1\}//' + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --nnodes) NNODES="$2"; shift 2 ;; + --node-rank) NODE_RANK="$2"; shift 2 ;; + --master-ip) MASTER_IP="$2"; shift 2 ;; + --dist-init-port) DIST_INIT_PORT="$2"; shift 2 ;; + --tp-size) TP_SIZE="$2"; shift 2 ;; + --pp-size) PP_SIZE="$2"; shift 2 ;; + --cp-size) CP_SIZE="$2"; shift 2 ;; + --model) MODEL="$2"; shift 2 ;; + --redis-host) REDIS_HOST="$2"; shift 2 ;; + --redis-port) REDIS_PORT="$2"; shift 2 ;; + --redis-password) REDIS_PASSWORD="$2"; shift 2 ;; + --mooncake-redis-port) MOONCAKE_REDIS_PORT="$2"; shift 2 ;; + --mooncake-engine-port-base) MOONCAKE_ENGINE_PORT_BASE="$2"; shift 2 ;; + --rdma-device) RDMA_DEVICE="$2"; shift 2 ;; + --instance-id) INSTANCE_ID="$2"; shift 2 ;; + --launcher-cmd) LAUNCHER_CMD="$2"; shift 2 ;; + --dry-run) DRY_RUN="true"; shift ;; + -h|--help) usage ;; + *) echo "Unknown arg: $1"; usage ;; + esac +done + +# ---------------------------------------------------------------- validate +for v in NNODES NODE_RANK MASTER_IP TP_SIZE PP_SIZE CP_SIZE MODEL REDIS_HOST; do + if [[ -z "${!v}" ]]; then + echo "Missing required argument: --$(echo "$v" | tr 'A-Z_' 'a-z-')" + usage + fi +done + +if ! [[ "$NNODES" =~ ^[1-9][0-9]*$ ]]; then + echo "nnodes must be >= 1"; exit 2 +fi +if ! [[ "$NODE_RANK" =~ ^[0-9]+$ ]] || (( NODE_RANK >= NNODES )); then + echo "node_rank must be in [0, nnodes)"; exit 2 +fi + +# Deployment constraint (design-doc §3.3): prefill crosses at most 2 nodes. +# Enforce the same bound here so mis-configured clusters fail fast. +if (( NNODES > 2 )); then + echo "dist_reuse currently supports <= 2 physical nodes per instance" + echo "(see docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp.md §3.3)" + exit 2 +fi + +# -------------------------------------------------------- derived topology +# Basic rule: MASTER = node_rank 0. Everything else is an off-master role +# whose concrete type is picked by the connector's multinode_policy +# (FlexKV/flexkv/integration/multinode_policy.py). The script only needs +# to decide whether to emit a Mooncake config (full SD-Remote) or just an +# empty config (CP-peer stub); it does that by asking whether the +# instance has any TP-level cross-node spread. + +# tp_node_count = TP_SIZE / gpus_per_node. We auto-detect gpus_per_node +# via nvidia-smi (default 8 if unavailable). +if command -v nvidia-smi >/dev/null 2>&1; then + GPUS_PER_NODE=$(nvidia-smi -L | wc -l) +else + GPUS_PER_NODE=8 +fi +if (( GPUS_PER_NODE == 0 )); then GPUS_PER_NODE=8; fi + +if (( TP_SIZE <= GPUS_PER_NODE )); then + TP_NODE_COUNT=1 +else + if (( TP_SIZE % GPUS_PER_NODE != 0 )); then + echo "tp_size ($TP_SIZE) must be a multiple of gpus_per_node ($GPUS_PER_NODE)" + exit 2 + fi + TP_NODE_COUNT=$((TP_SIZE / GPUS_PER_NODE)) +fi + +IS_MULTINODE_TP="false" +if (( TP_NODE_COUNT > 1 )) || (( PP_SIZE > 1 )); then + # PP > 1 crossing nodes always needs a full SD-Remote on each PP-peer + # node too. We conservatively flag it as multinode_tp for the + # mooncake-config emission branch. The connector's policy module + # still does the fine-grained role decision at runtime. + IS_MULTINODE_TP="true" +fi + +IS_MULTINODE_CP="false" +if (( CP_SIZE > 1 && NNODES > 1 )); then + IS_MULTINODE_CP="true" +fi + +# Default INSTANCE_ID if not provided (stable per-invocation value). +if [[ -z "$INSTANCE_ID" ]]; then + INSTANCE_ID="flexkv-$(hostname -s)-${DIST_INIT_PORT}" +fi + +# ---------------------------------------------------------------- layout +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +LOG_DIR="${SCRIPT_DIR}/logs/dist_reuse" +CFG_DIR="${SCRIPT_DIR}/gen/dist_reuse_node${NODE_RANK}" +mkdir -p "$LOG_DIR" "$CFG_DIR" + +# ---------------------------------------------------------- mooncake config +# Only emit a real mooncake config when we need a full SD-Remote on this +# node. CP-peer-only nodes don't touch mooncake. +NEED_FULL_SD_REMOTE="false" +if [[ "$NODE_RANK" == "0" ]]; then + # Master always needs mooncake — it owns the TransferEngine in-process. + NEED_FULL_SD_REMOTE="true" +else + if [[ "$IS_MULTINODE_TP" == "true" ]]; then + NEED_FULL_SD_REMOTE="true" + fi + # CP-only multi-node: off-master is peer-stub only, no mooncake. +fi + +MOONCAKE_ENGINE_PORT=$((MOONCAKE_ENGINE_PORT_BASE + NODE_RANK)) +MOONCAKE_CONFIG_FILE="${CFG_DIR}/mooncake_config.json" + +if [[ "$NEED_FULL_SD_REMOTE" == "true" ]]; then + cat > "$MOONCAKE_CONFIG_FILE" < "${MOONCAKE_CONFIG_FILE}" +fi + +# --------------------------------------------------------- flexkv env vars +# The sglang connector reads these at import time; we export them so +# whatever --launcher-cmd the user passes inherits them. +export FLEXKV_ENABLE_SHARING_DOMAIN=1 +export FLEXKV_INSTANCE_ID="${INSTANCE_ID}" +export FLEXKV_REDIS_HOST="${REDIS_HOST}" +export FLEXKV_REDIS_PORT="${REDIS_PORT}" +export FLEXKV_REDIS_PASSWORD="${REDIS_PASSWORD}" +export FLEXKV_MASTER_HOST="${MASTER_IP}" +export FLEXKV_DIST_INIT_ADDR="${MASTER_IP}:${DIST_INIT_PORT}" +export FLEXKV_NNODES="${NNODES}" +export FLEXKV_NODE_RANK="${NODE_RANK}" +export FLEXKV_TP_NODE_COUNT="${TP_NODE_COUNT}" +export MOONCAKE_CONFIG_PATH="${MOONCAKE_CONFIG_FILE}" +export MC_REDIS_PASSWORD="${REDIS_PASSWORD}" + +# --------------------------------------------------------- summary + launch +echo "================================================================" +echo "FlexKV dist_reuse — node_rank=${NODE_RANK}/${NNODES}" +echo " model : ${MODEL}" +echo " tp / pp / cp : ${TP_SIZE} / ${PP_SIZE} / ${CP_SIZE}" +echo " tp_node_count : ${TP_NODE_COUNT} (gpus_per_node=${GPUS_PER_NODE})" +echo " is_multinode_tp : ${IS_MULTINODE_TP}" +echo " is_multinode_cp : ${IS_MULTINODE_CP}" +echo " need_full_sd_remote: ${NEED_FULL_SD_REMOTE}" +echo " instance_id : ${INSTANCE_ID}" +echo " master_ip : ${MASTER_IP}" +echo " dist_init : ${MASTER_IP}:${DIST_INIT_PORT}" +echo " redis : ${REDIS_HOST}:${REDIS_PORT} (flexkv) / ${MOONCAKE_REDIS_PORT} (mooncake)" +echo " mooncake cfg : ${MOONCAKE_CONFIG_FILE}" +echo " launcher : ${LAUNCHER_CMD:-}" +echo "================================================================" + +if [[ "$DRY_RUN" == "true" ]]; then + echo "[dry-run] not executing launcher." + exit 0 +fi + +if [[ -z "$LAUNCHER_CMD" ]]; then + echo "No --launcher-cmd provided; env is set, exec the launcher yourself:" + echo "" + echo " # env exported:" + env | grep -E '^(FLEXKV_|MOONCAKE_|MC_)' | sort + echo "" + exit 0 +fi + +# exec the user launcher — stdout/stderr to log. +LOG_FILE="${LOG_DIR}/node${NODE_RANK}_$(date +%Y%m%d_%H%M%S).log" +echo "Launching: ${LAUNCHER_CMD}" +echo "Log: ${LOG_FILE}" +# shellcheck disable=SC2086 +exec bash -c "${LAUNCHER_CMD}" >"${LOG_FILE}" 2>&1 diff --git a/tests/_dist_reuse_fakes.py b/tests/_dist_reuse_fakes.py new file mode 100644 index 0000000000..14c5138d3f --- /dev/null +++ b/tests/_dist_reuse_fakes.py @@ -0,0 +1,285 @@ +"""Shared test fakes for dist_reuse unit tests. + +Underscore-prefixed so pytest doesn't try to collect this file. Uses no +external dependencies beyond the Python stdlib — in particular, it does +**not** import ``redis`` or ``torch``, so it is safe to import on a +CPU-only test machine with the ``--noconftest`` flag. +""" +from __future__ import annotations + +import fnmatch +import threading +import time +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +class FakeRedis: + """Just enough of the redis-py surface for dist_reuse unit tests. + + Thread-safe at the command level (not at multi-command transactions) — + individual method calls are each protected by an internal RLock. The + fake supports: + + * ``set/get/expire/delete`` + * ``hset(name, key=, value=)``, ``hset(name, mapping=)``, + ``hgetall(name)``, ``hget(name, field)`` + * ``scan(cursor, match, count)``, ``scan_iter(match, count)`` + * ``incr(name)`` + * ``publish(channel, message)`` (no-op; return subscriber count = 0) + * ``pipeline()`` (emits a ``FakePipeline`` that queues ops and + applies them atomically on ``execute()``) + * ``rpush(name, *values)``, ``lrange(name, start, end)`` + * ``exists(*names)`` + + TTLs are honored lazily — a key is treated as expired only when a + command touches it and finds ``now >= expiry_at``. This matches + redis-py's observable behaviour closely enough for unit testing. + """ + + def __init__(self, time_fn=time.monotonic) -> None: + self._strs: Dict[str, str] = {} + self._hashes: Dict[str, Dict[str, str]] = {} + self._lists: Dict[str, List[str]] = {} + self._expiry: Dict[str, float] = {} + self._time = time_fn + self._lock = threading.RLock() + + # ---------------------------------------------------------------- GC + def _gc(self, name: str) -> None: + exp = self._expiry.get(name) + if exp is not None and self._time() >= exp: + self._strs.pop(name, None) + self._hashes.pop(name, None) + self._lists.pop(name, None) + self._expiry.pop(name, None) + + def _all_keys(self) -> List[str]: + """Snapshot of every live key (after lazy GC).""" + keys = set(self._strs) | set(self._hashes) | set(self._lists) + for k in list(keys): + self._gc(k) + return list(set(self._strs) | set(self._hashes) | set(self._lists)) + + # ------------------------------------------------------------- strings + def set(self, name: str, value: str, ex: Optional[int] = None) -> bool: + with self._lock: + self._strs[name] = value + if ex is not None: + self._expiry[name] = self._time() + float(ex) + else: + self._expiry.pop(name, None) + return True + + def get(self, name: str) -> Optional[str]: + with self._lock: + self._gc(name) + return self._strs.get(name) + + def incr(self, name: str) -> int: + with self._lock: + self._gc(name) + cur = int(self._strs.get(name, "0")) + cur += 1 + self._strs[name] = str(cur) + return cur + + def exists(self, *names: str) -> int: + with self._lock: + count = 0 + for n in names: + self._gc(n) + if ( + n in self._strs + or n in self._hashes + or n in self._lists + ): + count += 1 + return count + + # ------------------------------------------------------------ TTL + def expire(self, name: str, ex: int) -> bool: + with self._lock: + self._gc(name) + if ( + name not in self._strs + and name not in self._hashes + and name not in self._lists + ): + return False + self._expiry[name] = self._time() + float(ex) + return True + + def delete(self, *names: str) -> int: + with self._lock: + n = 0 + for k in names: + removed = False + if k in self._strs: + self._strs.pop(k, None) + removed = True + if k in self._hashes: + self._hashes.pop(k, None) + removed = True + if k in self._lists: + self._lists.pop(k, None) + removed = True + self._expiry.pop(k, None) + if removed: + n += 1 + return n + + # ------------------------------------------------------------ hashes + def hset( + self, + name: str, + key: Optional[str] = None, + value: Optional[str] = None, + mapping: Optional[Dict[str, Any]] = None, + ) -> int: + with self._lock: + h = self._hashes.setdefault(name, {}) + new_fields = 0 + if mapping: + for k, v in mapping.items(): + if k not in h: + new_fields += 1 + h[str(k)] = str(v) + if key is not None: + if key not in h: + new_fields += 1 + h[str(key)] = str(value) if value is not None else "" + return new_fields + + def hget(self, name: str, field: str) -> Optional[str]: + with self._lock: + self._gc(name) + return self._hashes.get(name, {}).get(field) + + def hgetall(self, name: str) -> Dict[str, str]: + with self._lock: + self._gc(name) + return dict(self._hashes.get(name, {})) + + # -------------------------------------------------------------- lists + def rpush(self, name: str, *values: Any) -> int: + with self._lock: + lst = self._lists.setdefault(name, []) + for v in values: + lst.append(str(v)) + return len(lst) + + def lrange(self, name: str, start: int, end: int) -> List[str]: + with self._lock: + self._gc(name) + lst = self._lists.get(name, []) + if end == -1: + end = len(lst) - 1 + return list(lst[start : end + 1]) + + # --------------------------------------------------------------- scan + def scan( + self, + cursor: int = 0, + match: Optional[str] = None, + count: Optional[int] = None, + ) -> Tuple[int, List[str]]: + with self._lock: + all_keys = sorted(self._all_keys()) + if match is None: + filtered = all_keys + else: + filtered = [k for k in all_keys if fnmatch.fnmatchcase(k, match)] + # For simplicity (and since FakeRedis is tiny) we return everything + # in one go; real redis returns batches of size ~count. + return 0, filtered + + def scan_iter( + self, match: Optional[str] = None, count: Optional[int] = None + ) -> Iterable[str]: + _, keys = self.scan(cursor=0, match=match, count=count) + return iter(keys) + + # ---------------------------------------------------- misc / no-ops + def ping(self) -> bool: + return True + + def publish(self, channel: str, message: str) -> int: + return 0 + + def pubsub(self): + raise NotImplementedError("FakeRedis.pubsub() is not supported in unit tests") + + def close(self) -> None: + pass + + # ---------------------------------------------------------- pipeline + def pipeline(self, transaction: bool = True): # noqa: ARG002 — kept for compat + return _FakePipeline(self) + + # ------------------------------------------------------------ helpers + def force_expire(self, name: str) -> None: + """Test helper: drop a key as if its TTL had elapsed.""" + with self._lock: + self._strs.pop(name, None) + self._hashes.pop(name, None) + self._lists.pop(name, None) + self._expiry.pop(name, None) + + def snapshot(self) -> Dict[str, Any]: + """Return a read-only view of the current store — convenient for + assert-style introspection in tests.""" + with self._lock: + return { + "strs": dict(self._strs), + "hashes": {k: dict(v) for k, v in self._hashes.items()}, + "lists": {k: list(v) for k, v in self._lists.items()}, + } + + +class _FakePipeline: + """Mimic ``redis.Redis.pipeline()`` — queue commands, execute on flush.""" + + def __init__(self, owner: FakeRedis) -> None: + self._owner = owner + self._ops: List[Tuple[str, Tuple[Any, ...], Dict[str, Any]]] = [] + + def hset(self, *args, **kwargs): + self._ops.append(("hset", args, kwargs)) + return self + + def delete(self, *args, **kwargs): + self._ops.append(("delete", args, kwargs)) + return self + + def rpush(self, *args, **kwargs): + self._ops.append(("rpush", args, kwargs)) + return self + + def set(self, *args, **kwargs): + self._ops.append(("set", args, kwargs)) + return self + + def expire(self, *args, **kwargs): + self._ops.append(("expire", args, kwargs)) + return self + + def execute(self) -> List[Any]: + results = [] + for name, args, kwargs in self._ops: + method = getattr(self._owner, name) + results.append(method(*args, **kwargs)) + self._ops.clear() + return results + + +class ManualClock: + """Deterministic monotonic clock, for time-sensitive tests.""" + + def __init__(self, start: float = 0.0) -> None: + self.now = start + + def __call__(self) -> float: + return self.now + + def advance(self, dt: float) -> None: + self.now += dt diff --git a/tests/multinode/test_cross_instance_reuse_e2e.py b/tests/multinode/test_cross_instance_reuse_e2e.py new file mode 100644 index 0000000000..65c625622d --- /dev/null +++ b/tests/multinode/test_cross_instance_reuse_e2e.py @@ -0,0 +1,312 @@ +"""§2.6 — Cross-node end-to-end test scaffolding. + +Real dist_reuse e2e (≥ 2 GPU machines, real Mooncake RDMA path, +cross-instance reuse) cannot run on a single-box CI. This module +provides a ``pytest`` scaffold so that once a multi-machine harness is +available, wiring the tests is a matter of filling in the fixtures. + +All tests here are skipped by default via ``pytest.mark.multinode`` +unless the ``FLEXKV_MULTINODE_TEST=1`` environment variable is set. +The intent is: + +* CI default — tests are collected (so ``pytest --collect-only`` shows + them) but skipped with a clear reason. +* Developer local multi-machine run — set the env var, point + ``FLEXKV_MULTINODE_MASTER_ADDR`` / ``FLEXKV_MULTINODE_REMOTE_ADDR`` + at pre-booted instances, run pytest. + +When we do light up the harness (docs/dist_reuse/ +implementation_gap_2026-05-11.md §2.6), fill in: + +1. ``launch_two_node_instance`` fixture: boots one Master + one Remote + using ``scripts/multi-nodes/start_dist_reuse_serving.sh`` and + collects their shutdown hooks. +2. Client-side orchestrator: sends a prompt to both instances, captures + hit-rate counters, latency, and bytes transferred over Mooncake. +3. Assertion on cross-instance reuse: after instance A processes a + common prefix, instance B's cold-start of the same prefix must + show ``remote_hits > 0`` in ``kv_manager.stats()`` (or equivalent). + +**Status (2026-05-18)**: the single-SD degenerate path was verified +end-to-end on a 2-machine RDMA harness with +``benchmark_dist_direct.py``. That result is codified below as +``test_single_sd_degenerate_reuse_cache_ratio_100pct`` so CI can +regression-gate it when the harness env is available. The +multi-SD (PP>1 / tp_node_count>1) implementation has landed in code +(Phase D-1/D-2/D-3 — see +``docs/dist_reuse/proposal_unify_with_graph_dispatch_2026-05-15.md``) +but the real-hardware e2e tests below remain ``xfail`` pending a +working multi-SD harness fixture (``s4_multi_sd_pp2.sh``). +""" +from __future__ import annotations + +import json +import os +import re +import shlex +import subprocess +from pathlib import Path +from typing import Any, Dict, Optional + +import pytest + + +def _multinode_enabled() -> bool: + return os.environ.get("FLEXKV_MULTINODE_TEST", "") == "1" + + +# Project-level marker; any test tagged with this is opt-in. +multinode = pytest.mark.multinode + + +pytestmark = pytest.mark.skipif( + not _multinode_enabled(), + reason="Set FLEXKV_MULTINODE_TEST=1 and provide the harness addresses " + "(FLEXKV_MULTINODE_MASTER_ADDR / FLEXKV_MULTINODE_REMOTE_ADDR) " + "to enable cross-node e2e tests. See docs/dist_reuse/" + "implementation_gap_2026-05-11.md §2.6.", +) + + +# --------------------------------------------------------------------------- +# Fixtures (stubs; fill in when harness is ready) +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def master_addr() -> str: + addr = os.environ.get("FLEXKV_MULTINODE_MASTER_ADDR") + if not addr: + pytest.skip("FLEXKV_MULTINODE_MASTER_ADDR not set") + return addr + + +@pytest.fixture(scope="module") +def remote_addr() -> str: + addr = os.environ.get("FLEXKV_MULTINODE_REMOTE_ADDR") + if not addr: + pytest.skip("FLEXKV_MULTINODE_REMOTE_ADDR not set") + return addr + + +@pytest.fixture(scope="module") +def two_instance_cluster(master_addr, remote_addr) -> Dict[str, Any]: + """Two pre-booted FlexKV instances ready for cross-instance reuse. + + TODO(§2.6): replace this placeholder with real boot+teardown logic + driven by ``scripts/multi-nodes/start_dist_reuse_serving.sh``. + For now, tests assume the user booted them by hand and we just + expose their addresses. + """ + return { + "master": master_addr, + "remote": remote_addr, + } + + +# --------------------------------------------------------------------------- +# Helpers — parse benchmark_dist_direct.py stdout +# --------------------------------------------------------------------------- +_CACHE_RATIO_RE = re.compile( + r"cache_ratio:\s*(?P\d+(?:\.\d+)?)%" +) +_STATS_REUSE_RE = re.compile( + r"\[STATS\]\[REUSE\].*?flexkv_global=(?P\d+).*?" +) + + +def _parse_benchmark_result(stdout: str) -> Dict[str, float]: + """Extract cache_ratio and REUSE stats from benchmark_dist_direct.py + stdout. + + Returns + ------- + dict + ``{"cache_ratio_pct": float, "flexkv_global_hits": int}``. + Missing fields default to 0. + """ + out: Dict[str, float] = {"cache_ratio_pct": 0.0, "flexkv_global_hits": 0} + m = _CACHE_RATIO_RE.search(stdout) + if m: + out["cache_ratio_pct"] = float(m.group("pct")) + m = _STATS_REUSE_RE.search(stdout) + if m: + out["flexkv_global_hits"] = int(m.group("global_hit")) + return out + + +# --------------------------------------------------------------------------- +# Test bodies +# --------------------------------------------------------------------------- +@multinode +def test_single_sd_degenerate_reuse_cache_ratio_100pct(two_instance_cluster): + """Single-SD degenerate dist_reuse e2e regression test. + + Captures the state reached on 2026-05-13: + - Two-machine harness (146 ⇄ 129) with mooncake RDMA transfer. + - PP=1, tp_node_count=1 → ``total_sd_count == 1``. + - PUT on machine A, GET on machine B with same seed → 100% hit. + + The ``benchmark_dist_direct.py`` invocation shape is verbatim to + what the manual harness uses — if this assertion flips (e.g. 0% + after some code change), the single-SD path of §2.1 broke and a + real bisect is needed. + + Environment preconditions: + + * ``FLEXKV_MULTINODE_BENCHMARK_CMD_PUT`` — full shell command that + starts the PUT-side benchmark on machine A (Master). Must block + until "Press Enter to shutdown" appears in stdout. + * ``FLEXKV_MULTINODE_BENCHMARK_CMD_GET`` — full shell command for + the GET side on machine B. + * Both commands must use the same ``--seed``, + ``--sequence-length``, and ``--batch-size`` so the hash chain + matches. + + If either env var is missing, the test is ``skip``ed (we don't + xfail — a harness mismatch isn't a code regression). + """ + cmd_put = os.environ.get("FLEXKV_MULTINODE_BENCHMARK_CMD_PUT") + cmd_get = os.environ.get("FLEXKV_MULTINODE_BENCHMARK_CMD_GET") + if not cmd_put or not cmd_get: + pytest.skip( + "Set FLEXKV_MULTINODE_BENCHMARK_CMD_PUT and " + "FLEXKV_MULTINODE_BENCHMARK_CMD_GET to exercise this test." + ) + + # Fire PUT in the background. The benchmark prints + # "Press Enter to shutdown" when the prefix is idle in Redis and + # ready to be consumed. + put_proc = subprocess.Popen( + shlex.split(cmd_put), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + try: + # Wait until PUT is idle (max 60 s). + deadline = _deadline(60) + put_stdout_buf = [] + while True: + if put_proc.poll() is not None: + pytest.fail( + f"PUT side exited prematurely with rc={put_proc.returncode}: " + f"{''.join(put_stdout_buf)[:2000]}" + ) + line = put_proc.stdout.readline() if put_proc.stdout else "" + if line: + put_stdout_buf.append(line) + if "Press Enter" in line or "Data published" in line: + break + if _past_deadline(deadline): + pytest.fail( + "PUT side never reached idle state within 60s.\n" + f"stdout:\n{''.join(put_stdout_buf)[-2000:]}" + ) + + # Now run GET and capture full stdout. + get = subprocess.run( + shlex.split(cmd_get), + capture_output=True, + text=True, + timeout=120, + ) + assert get.returncode == 0, ( + f"GET side crashed (rc={get.returncode}): {get.stderr[-2000:]}" + ) + + stats = _parse_benchmark_result(get.stdout + "\n" + get.stderr) + assert stats["cache_ratio_pct"] >= 99.9, ( + f"Cross-instance reuse regressed — " + f"expected ~100% cache hit, got {stats['cache_ratio_pct']:.2f}%.\n" + f"Last 2 KB of GET stdout:\n{get.stdout[-2000:]}" + ) + assert stats["flexkv_global_hits"] > 0, ( + f"No flexkv_global hits registered — aggregate radix not " + f"rebuilt on GET side.\nLast 2 KB of GET stdout:\n" + f"{get.stdout[-2000:]}" + ) + finally: + # Tear down PUT side cleanly. + try: + put_proc.terminate() + put_proc.wait(timeout=10) + except Exception: + put_proc.kill() + + +@multinode +def test_cross_instance_reuse_hit_rate(two_instance_cluster): + """Multi-SD hit-rate validation (PP>1 or tp_node_count>1). + + Still ``xfail`` — the code-level implementation (Phase D-1/D-2/D-3 + of ``proposal_unify_with_graph_dispatch_2026-05-15.md``) has + landed, but a working multi-SD harness fixture (PP=2 yaml + + ``s4_multi_sd_pp2.sh`` 2-instance launcher) is still required + to drive this test end-to-end. Once that fixture is online the + test body should mirror + ``test_single_sd_degenerate_reuse_cache_ratio_100pct`` but with + a multi-SD yaml (``pp_size: 2``) and additionally assert that + peer-SD D2H clones produced ``CompletedOp(sd_key, + contributing_node_id, success=True)`` events on the master's + completion sink (once the metrics hook exposing them lands). + """ + pytest.xfail( + "Multi-SD e2e still blocked on the multi-SD harness fixture " + "(s4_multi_sd_pp2.sh) — the D-1/D-2/D-3 code path is in " + "place; see proposal_unify_with_graph_dispatch_2026-05-15.md " + "§11." + ) + + +@multinode +def test_master_node_failure_triggers_peer_invalidation(two_instance_cluster): + """Kill the Master; the Remote must stop accepting coord GETs and + its peer-lost hook should fire on other instances.""" + pytest.xfail("harness not implemented yet — §2.6 TODO") + + +@multinode +def test_mooncake_transfer_sync_read_path(two_instance_cluster): + """End-to-end validation of the Phase 1-C data path: a coord GET + triggers a real Mooncake RDMA read that lands bytes into the + requesting instance's CPU block pool. Relies on real HWs. + + Note: single-SD Mooncake read path IS already covered by + ``test_single_sd_degenerate_reuse_cache_ratio_100pct`` above — + ``op_remote2h`` in ``_get_impl_global`` dispatches the Mooncake + read. This test remains for the multi-SD GET fan-out path: + Master constructs a unified ``TransferOpGraph`` with one + ``PEERH2H`` op per peer SD (target_node_ids stamping handled by + ``GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops``, Phase + D-3), each peer SD's Remote executes its own clone, and the + master receives one ``CompletedOp`` per SD before triggering the + H2D. + """ + pytest.xfail( + "Multi-SD PEERH2H fan-out needs the same multi-SD harness " + "fixture (s4_multi_sd_pp2.sh) as " + "test_cross_instance_reuse_hit_rate." + ) + + +# --------------------------------------------------------------------------- +# Conftest-level registration — lets ``pytest -m multinode`` select these +# --------------------------------------------------------------------------- +def pytest_configure(config): # pragma: no cover (hook, no unit test) + config.addinivalue_line( + "markers", + "multinode: requires ≥ 2 GPU machines and FLEXKV_MULTINODE_TEST=1", + ) + + +# --------------------------------------------------------------------------- +# Internal helpers for the stdout-polling loop +# --------------------------------------------------------------------------- +import time + + +def _deadline(seconds: float) -> float: + return time.monotonic() + float(seconds) + + +def _past_deadline(deadline: float) -> bool: + return time.monotonic() >= deadline diff --git a/tests/test_aggregate_radix.py b/tests/test_aggregate_radix.py new file mode 100644 index 0000000000..d9d04b1913 --- /dev/null +++ b/tests/test_aggregate_radix.py @@ -0,0 +1,306 @@ +"""Unit tests for ``flexkv.cache.aggregate_radix`` (Phase 0 task 0-H).""" +from __future__ import annotations + +from typing import List + +import pytest + +from flexkv.common.dist_reuse.aggregate_radix import ( + AggregateRadixTree, + BlockNotTrackedError, +) + + +# --------------------------------------------------------------------------- +# Manual clock for deterministic time-based assertions +# --------------------------------------------------------------------------- +class _ManualClock: + def __init__(self, start: float = 0.0) -> None: + self.now = start + + def __call__(self) -> float: + return self.now + + def advance(self, dt: float) -> None: + self.now += dt + + +# --------------------------------------------------------------------------- +# mark_sd_ready / fully-ready transition +# --------------------------------------------------------------------------- +class TestMarkSdReady: + def test_single_sd_immediately_ready(self): + agg = AggregateRadixTree(total_sd_count=1) + became = agg.mark_sd_ready(prefix_hash=0xAA, sd_key="sd0", block_ids=[1, 2]) + assert became is True + result = agg.match_fully_ready(0xAA) + assert result is not None + assert result.block_ids == (1, 2) + # ready_sds is now a Dict[str, int] (sd_key -> contributing node_id). + # Default node_id is -1 when caller doesn't pass one. + assert set(result.ready_sds) == {"sd0"} + assert result.ready_sds == {"sd0": -1} + + def test_multi_sd_progressive(self): + agg = AggregateRadixTree(total_sd_count=3) + for sd in ("sd0", "sd1"): + assert agg.mark_sd_ready(0xBB, sd, [10, 20]) is False + # Not yet fully ready. + assert agg.match_fully_ready(0xBB) is None + # Last SD acks → transitions to fully ready. + became = agg.mark_sd_ready(0xBB, "sd2", [10, 20]) + assert became is True + result = agg.match_fully_ready(0xBB) + assert result is not None + assert set(result.ready_sds) == {"sd0", "sd1", "sd2"} + + def test_idempotent_per_sd(self): + agg = AggregateRadixTree(total_sd_count=2) + agg.mark_sd_ready(1, "sd0", [1]) + # Re-acking sd0 must NOT count as a transition. + assert agg.mark_sd_ready(1, "sd0", [1]) is False + assert agg.match_fully_ready(1) is None + # sd1 transitions. + assert agg.mark_sd_ready(1, "sd1", [1]) is True + + def test_block_id_mismatch_raises(self): + agg = AggregateRadixTree(total_sd_count=2) + agg.mark_sd_ready(1, "sd0", [1, 2]) + with pytest.raises(ValueError): + agg.mark_sd_ready(1, "sd1", [1, 99]) # different block layout + + def test_invalid_sd_key(self): + agg = AggregateRadixTree(total_sd_count=1) + with pytest.raises(ValueError): + agg.mark_sd_ready(1, "", [1]) + + def test_contributing_peer_recorded(self): + agg = AggregateRadixTree(total_sd_count=2) + agg.mark_sd_ready(1, "sd0", [1], contributing_peer="peer-A") + agg.mark_sd_ready(1, "sd1", [1], contributing_peer="peer-A") + result = agg.match_fully_ready(1) + assert result is not None + assert result.contributing_peers == {"peer-A"} + + def test_node_id_recorded_per_sd(self): + """Per-SD node_id schema (multi-SD GET-path support). + + Each SD ack carries the contributing peer's node_id so the + GET-path knows which Mooncake server to RDMA-READ from for + that SD's slice. Master SD typically passes its own node_id; + peer SDs pass the node_id from + ``CompletedOp.contributing_node_id`` (Phase D-2 graph-dispatch + PUT path). + """ + agg = AggregateRadixTree(total_sd_count=3) + agg.mark_sd_ready(1, "sd0", [10, 20], node_id=100) + agg.mark_sd_ready(1, "sd1", [10, 20], node_id=200) + agg.mark_sd_ready(1, "sd2", [10, 20], node_id=300) + result = agg.match_fully_ready(1) + assert result is not None + assert result.ready_sds == {"sd0": 100, "sd1": 200, "sd2": 300} + # node_id_for_sd helper returns the per-SD node_id. + assert result.node_id_for_sd("sd0") == 100 + assert result.node_id_for_sd("sd2") == 300 + assert result.node_id_for_sd("nonexistent") is None + + def test_node_id_late_binding(self): + """Idempotent re-acks with a real node_id must overwrite the + sentinel ``-1`` (so callers can defer node_id resolution).""" + agg = AggregateRadixTree(total_sd_count=2) + # First ack — node_id unknown. + agg.mark_sd_ready(1, "sd0", [10], node_id=-1) + # Second ack — same SD, real node_id. Idempotent transition, + # but node_id should patch. + assert agg.mark_sd_ready(1, "sd0", [10], node_id=42) is False + # Other SD acks; prefix becomes fully ready. + agg.mark_sd_ready(1, "sd1", [10], node_id=99) + result = agg.match_fully_ready(1) + assert result is not None + assert result.node_id_for_sd("sd0") == 42 # patched + assert result.node_id_for_sd("sd1") == 99 + + def test_node_id_real_value_not_overwritten_by_sentinel(self): + """Once a real node_id is recorded, a subsequent ``-1`` ack + must NOT clobber it.""" + agg = AggregateRadixTree(total_sd_count=1) + agg.mark_sd_ready(1, "sd0", [10], node_id=42) + agg.mark_sd_ready(1, "sd0", [10], node_id=-1) + result = agg.match_fully_ready(1) + assert result is not None + assert result.node_id_for_sd("sd0") == 42 + + +# --------------------------------------------------------------------------- +# mark_sd_evicted +# --------------------------------------------------------------------------- +class TestMarkSdEvicted: + def test_evict_single_sd_drops_to_partial(self): + agg = AggregateRadixTree(total_sd_count=2) + agg.mark_sd_ready(1, "sd0", [10]) + agg.mark_sd_ready(1, "sd1", [10]) + assert agg.match_fully_ready(1) is not None + agg.mark_sd_evicted(1, "sd0") + assert agg.match_fully_ready(1) is None # no longer fully ready + + def test_evict_last_sd_drops_entry(self): + agg = AggregateRadixTree(total_sd_count=2) + agg.mark_sd_ready(1, "sd0", [10]) + agg.mark_sd_evicted(1, "sd0") + # No SDs left → entry is gone, not just "partial" + assert 1 not in agg.known_prefixes() + + def test_evict_unknown_prefix_is_noop(self): + agg = AggregateRadixTree(total_sd_count=1) + agg.mark_sd_evicted(99, "sd0") # must not raise + + +# --------------------------------------------------------------------------- +# Refcount lifecycle +# --------------------------------------------------------------------------- +class TestRefcount: + def test_acquire_release_cycle(self): + agg = AggregateRadixTree(total_sd_count=1) + assert agg.is_evictable(7) is True + agg.acquire([7]) + assert agg.is_evictable(7) is False + assert agg.get_refcount(7) == 1 + agg.release([7]) + assert agg.is_evictable(7) is True + assert agg.get_refcount(7) == 0 + + def test_acquire_increments(self): + agg = AggregateRadixTree(total_sd_count=1) + agg.acquire([1, 1, 1]) + assert agg.get_refcount(1) == 3 + agg.release([1, 1]) + assert agg.get_refcount(1) == 1 + assert agg.is_evictable(1) is False + + def test_double_release_raises(self): + agg = AggregateRadixTree(total_sd_count=1) + agg.acquire([1]) + agg.release([1]) + with pytest.raises(BlockNotTrackedError): + agg.release([1]) + + def test_release_untracked_raises(self): + agg = AggregateRadixTree(total_sd_count=1) + with pytest.raises(BlockNotTrackedError): + agg.release([42]) + + +# --------------------------------------------------------------------------- +# Leak scanner +# --------------------------------------------------------------------------- +class TestLeakScanner: + def test_finds_leaked_blocks(self): + clock = _ManualClock() + agg = AggregateRadixTree(total_sd_count=1, time_fn=clock) + agg.acquire([1]) + clock.advance(40.0) + leaked = agg.scan_leaked_refcount(timeout_seconds=30.0) + assert leaked == [1] + + def test_fresh_acquires_not_leaked(self): + clock = _ManualClock() + agg = AggregateRadixTree(total_sd_count=1, time_fn=clock) + agg.acquire([1]) + leaked = agg.scan_leaked_refcount(timeout_seconds=30.0) + assert leaked == [] + + def test_force_release_clears_refcount(self): + clock = _ManualClock() + agg = AggregateRadixTree(total_sd_count=1, time_fn=clock) + agg.acquire([1, 1, 1]) + prev = agg.force_release(1) + assert prev == 3 + assert agg.is_evictable(1) is True + + def test_force_release_unknown_returns_zero(self): + agg = AggregateRadixTree(total_sd_count=1) + assert agg.force_release(999) == 0 + + def test_negative_timeout_raises(self): + agg = AggregateRadixTree(total_sd_count=1) + with pytest.raises(ValueError): + agg.scan_leaked_refcount(-1.0) + + +# --------------------------------------------------------------------------- +# Invalidation +# --------------------------------------------------------------------------- +class TestInvalidation: + def test_invalidate_prefix(self): + agg = AggregateRadixTree(total_sd_count=1) + agg.mark_sd_ready(1, "sd0", [10]) + assert agg.invalidate_prefix(1) is True + assert agg.match_fully_ready(1) is None + # Idempotent + assert agg.invalidate_prefix(1) is False + + def test_invalidate_by_peer(self): + agg = AggregateRadixTree(total_sd_count=2) + # Two prefixes from peer-A + agg.mark_sd_ready(1, "sd0", [10], contributing_peer="peer-A") + agg.mark_sd_ready(1, "sd1", [10], contributing_peer="peer-A") + agg.mark_sd_ready(2, "sd0", [20], contributing_peer="peer-A") + agg.mark_sd_ready(2, "sd1", [20], contributing_peer="peer-A") + # One prefix from peer-B + agg.mark_sd_ready(3, "sd0", [30], contributing_peer="peer-B") + agg.mark_sd_ready(3, "sd1", [30], contributing_peer="peer-B") + + n = agg.invalidate_by_peer_instance("peer-A") + assert n == 2 + assert agg.match_fully_ready(1) is None + assert agg.match_fully_ready(2) is None + assert agg.match_fully_ready(3) is not None # peer-B prefix survives + + def test_invalidate_by_unknown_peer(self): + agg = AggregateRadixTree(total_sd_count=1) + agg.mark_sd_ready(1, "sd0", [10], contributing_peer="peer-A") + assert agg.invalidate_by_peer_instance("peer-Z") == 0 + + def test_invalidate_by_empty_peer_raises(self): + agg = AggregateRadixTree(total_sd_count=1) + with pytest.raises(ValueError): + agg.invalidate_by_peer_instance("") + + +# --------------------------------------------------------------------------- +# Constructor +# --------------------------------------------------------------------------- +class TestConstructor: + @pytest.mark.parametrize("bad", [0, -1, "1", 1.5]) + def test_bad_total_sd_count(self, bad): + with pytest.raises(ValueError): + AggregateRadixTree(total_sd_count=bad) # type: ignore[arg-type] + + def test_total_sd_count_property(self): + agg = AggregateRadixTree(total_sd_count=5) + assert agg.total_sd_count == 5 + + +# --------------------------------------------------------------------------- +# Concurrency smoke (light — full stress would slow CI) +# --------------------------------------------------------------------------- +def test_concurrent_acquire_release_does_not_corrupt(): + import threading + + agg = AggregateRadixTree(total_sd_count=1) + block_ids: List[int] = list(range(100)) + + def worker(): + for _ in range(50): + agg.acquire(block_ids) + agg.release(block_ids) + + threads = [threading.Thread(target=worker) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All refcounts should be back to zero. + for b in block_ids: + assert agg.is_evictable(b) is True diff --git a/tests/test_cache_config_batch_b.py b/tests/test_cache_config_batch_b.py new file mode 100644 index 0000000000..a75cd71d81 --- /dev/null +++ b/tests/test_cache_config_batch_b.py @@ -0,0 +1,194 @@ +"""Unit tests for Phase 0 task 0-A / 0-J CacheConfig additions. + +These tests cover: + +* ``RemoteEndpoint`` dataclass construction & field names + (task 0-J scaffolding for Remote-endpoint discovery). +* ``CacheConfig.remote_endpoints_by_sd`` default value and basic usage. +* ``CacheConfig.enable_sharing_domain`` auto-derivation from P2P flags + (task 0-A ``__post_init__`` behaviour). +* ``CacheConfig`` sharing-domain field defaults match plan.md §1.2 0-A. + +We bypass the ``tests/conftest.py`` torch-heavy import pipeline via +``--noconftest`` — all we need here is the config module itself and the +stdlib dataclass machinery. We **do** import ``torch`` indirectly (it's a +hard dependency of ``flexkv.common.config``), so the test machine needs a +CPU-only torch install but nothing CUDA. +""" +from __future__ import annotations + +import dataclasses +import importlib.util +import sys +from pathlib import Path + +import pytest + +# Important: loading ``flexkv.common.config`` from source so we pick up +# the freshly-edited file instead of the stale compiled ``config.so`` +# that sits next to it on some deployment machines. ``flexkv.common`` +# itself has an empty ``__init__``, so importing sub-modules does not +# pull in the ``flexkv.cache`` package (which requires CUDA-linked c_ext). +pkg_root = Path(__file__).resolve().parent.parent + + +def _load_config(): + """Load ``flexkv.common.config`` preferring the ``.py`` source. + + If a stale ``config.so`` shadows our ``.py`` on ``sys.path`` we + side-step it by loading the source file directly. That keeps the + tests relevant regardless of whether the target environment has + rebuilt its Cython/pybind ``flexkv/common/config.so`` artefact. + """ + src = pkg_root / "flexkv" / "common" / "config.py" + assert src.exists(), f"missing source file {src}" + spec = importlib.util.spec_from_file_location( + "_cfg_under_test", str(src), + ) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module # so @dataclass can resolve the module + spec.loader.exec_module(module) + return module + + +@pytest.fixture(scope="module") +def cfg_mod(): + return _load_config() + + +# --------------------------------------------------------------------------- +# RemoteEndpoint +# --------------------------------------------------------------------------- +class TestRemoteEndpoint: + def test_construct(self, cfg_mod): + ep = cfg_mod.RemoteEndpoint( + ip="10.0.0.1", + gpu_register_port="6001", + command_port="6002", + result_port="6003", + ) + assert ep.ip == "10.0.0.1" + assert ep.gpu_register_port == "6001" + assert ep.command_port == "6002" + assert ep.result_port == "6003" + + def test_is_dataclass(self, cfg_mod): + assert dataclasses.is_dataclass(cfg_mod.RemoteEndpoint) + field_names = {f.name for f in dataclasses.fields(cfg_mod.RemoteEndpoint)} + assert field_names == {"ip", "gpu_register_port", "command_port", "result_port"} + + def test_fields_are_strings(self, cfg_mod): + # The ``master_ports`` tuple that this mirrors is a tuple of str, so + # we enforce the same here — users should str(int) explicitly. + for f in dataclasses.fields(cfg_mod.RemoteEndpoint): + assert f.type is str or f.type == "str" + + +# --------------------------------------------------------------------------- +# CacheConfig sharing-domain defaults +# --------------------------------------------------------------------------- +class TestCacheConfigDefaults: + def test_sharing_domain_fields_exist(self, cfg_mod): + cc = cfg_mod.CacheConfig() + # All five sharing-domain knobs are present with documented defaults. + assert cc.enable_sharing_domain is False + assert cc.instance_id is None + assert cc.session_epoch is None + assert cc.instance_session_ttl_seconds == 8 + assert cc.instance_session_renew_interval_seconds == 3 + assert cc.refcount_leak_timeout_seconds == 30 + assert cc.refcount_leak_scan_interval_seconds == 10 + assert cc.remote_endpoints_by_sd == {} + + def test_remote_endpoints_default_is_fresh_per_instance(self, cfg_mod): + """Guard against the classic ``default=[]`` dataclass foot-gun — + each CacheConfig must get its own empty dict.""" + a = cfg_mod.CacheConfig() + b = cfg_mod.CacheConfig() + a.remote_endpoints_by_sd["sd0"] = cfg_mod.RemoteEndpoint( + ip="x", gpu_register_port="1", command_port="2", result_port="3", + ) + assert b.remote_endpoints_by_sd == {} + + +class TestPostInit: + def test_enable_sharing_domain_on_p2p_cpu(self, cfg_mod): + cc = cfg_mod.CacheConfig(enable_p2p_cpu=True) + assert cc.enable_sharing_domain is True + + def test_enable_sharing_domain_on_p2p_ssd(self, cfg_mod): + cc = cfg_mod.CacheConfig(enable_p2p_ssd=True) + assert cc.enable_sharing_domain is True + + def test_enable_sharing_domain_off_by_default(self, cfg_mod): + cc = cfg_mod.CacheConfig() + assert cc.enable_sharing_domain is False + + def test_explicit_enable_sharing_domain_preserved(self, cfg_mod): + cc = cfg_mod.CacheConfig(enable_sharing_domain=True) + # Explicit enable must NOT be clobbered by the auto-derive logic. + assert cc.enable_sharing_domain is True + + def test_enable_kv_sharing_still_auto_derived(self, cfg_mod): + # Regression test: the pre-Batch-A behaviour of auto-deriving + # ``enable_kv_sharing`` from p2p flags must still hold. + cc = cfg_mod.CacheConfig(enable_p2p_cpu=True) + assert cc.enable_kv_sharing is True + assert cc.enable_remote is False # p2p_cpu alone does NOT enable 3rd remote + + +# --------------------------------------------------------------------------- +# ModelConfig new properties +# --------------------------------------------------------------------------- +class TestModelConfigProperties: + def test_tp_node_count_no_cross_node(self, cfg_mod): + import torch + mc = cfg_mod.ModelConfig( + num_layers=16, num_kv_heads=8, head_size=128, + dtype=torch.bfloat16, use_mla=False, + tp_size=4, pp_size=1, nnodes=1, + ) + assert mc.tp_node_count == 1 + # tp_node_idx is now per-rank (lives on RankInfo, not ModelConfig) + ri = cfg_mod.RankInfo(model_config=mc, tp_rank=0, pp_rank=0) + assert ri.tp_node_idx == 0 + + def test_tp_node_count_cross_node(self, cfg_mod): + import torch + mc = cfg_mod.ModelConfig( + num_layers=16, num_kv_heads=8, head_size=128, + dtype=torch.bfloat16, use_mla=False, + tp_size=8, pp_size=1, nnodes=2, + ) + # 8 TP / 2 nodes = 4 per node; tp_rank=5 → node index 1 + assert mc.tp_node_count == 2 + ri = cfg_mod.RankInfo(model_config=mc, tp_rank=5, pp_rank=0) + assert ri.tp_node_idx == 1 + + def test_model_id_stable(self, cfg_mod): + import torch + a = cfg_mod.ModelConfig( + num_layers=32, num_kv_heads=8, head_size=128, + dtype=torch.bfloat16, use_mla=False, + ).model_id + b = cfg_mod.ModelConfig( + num_layers=32, num_kv_heads=8, head_size=128, + dtype=torch.bfloat16, use_mla=False, + tp_size=4, pp_size=2, # topology differs + ).model_id + # model_id must ignore topology and depend only on architecture. + assert a == b + assert len(a) == 16 + + def test_model_id_changes_with_architecture(self, cfg_mod): + import torch + a = cfg_mod.ModelConfig( + num_layers=32, num_kv_heads=8, head_size=128, + dtype=torch.bfloat16, use_mla=False, + ).model_id + b = cfg_mod.ModelConfig( + num_layers=32, num_kv_heads=8, head_size=128, + dtype=torch.float16, use_mla=False, # dtype differs + ).model_id + assert a != b diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 6b1cc5779d..012d91a1a4 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -757,3 +757,69 @@ def test_eviction_policy_reinsert_after_eviction(engine_cls): assert engine.match(seqs[2]).num_matched_blocks == 0, ( "C should be evicted to make room for re-inserted B" ) + + +# --------------------------------------------------------------------------- +# Tests – MatchResultAccel matched_node_id field +# --------------------------------------------------------------------------- +class TestMatchResultAccelNodeId: + """Verify the single-node matching constraint data structures.""" + + def test_matched_node_id_default_none(self): + """matched_node_id defaults to None when not set.""" + from flexkv.common.type import MatchResultAccel + result = MatchResultAccel( + num_ready_matched_blocks=0, + num_matched_blocks=0, + physical_blocks=np.array([], dtype=np.int64), + ) + assert result.matched_node_id is None + + def test_matched_node_id_set(self): + """matched_node_id can be set to a single integer.""" + from flexkv.common.type import MatchResultAccel + result = MatchResultAccel( + num_ready_matched_blocks=5, + num_matched_blocks=5, + physical_blocks=np.arange(5, dtype=np.int64), + matched_node_id=42, + ) + assert result.matched_node_id == 42 + assert isinstance(result.matched_node_id, int) + + def test_backward_compat_block_node_ids(self): + """block_node_ids (deprecated) still works alongside matched_node_id.""" + from flexkv.common.type import MatchResultAccel + bnids = np.array([42, 42, 42], dtype=np.uint32) + result = MatchResultAccel( + num_ready_matched_blocks=3, + num_matched_blocks=3, + physical_blocks=np.arange(3, dtype=np.int64), + matched_node_id=42, + block_node_ids=bnids, + ) + assert result.matched_node_id == 42 + assert np.all(result.block_node_ids == 42) + + +# --------------------------------------------------------------------------- +# Tests – CMatchResult matched_node_id field (C++ binding) +# --------------------------------------------------------------------------- +class TestCMatchResultNodeId: + """Verify the C++ CMatchResult exposes matched_node_id.""" + + def test_cmatch_result_default_node_id(self): + """CMatchResult.matched_node_id defaults to -1.""" + import torch + from flexkv.c_ext import CMatchResult + result = CMatchResult(0, 0, 0, None, None, torch.empty(0, dtype=torch.int64)) + assert result.matched_node_id == -1 + + def test_cmatch_result_with_node_id(self): + """CMatchResult.matched_node_id can be set via constructor.""" + import torch + from flexkv.c_ext import CMatchResult + blocks = torch.arange(3, dtype=torch.int64) + result = CMatchResult(3, 3, 0, None, None, blocks, 7) + assert result.matched_node_id == 7 + assert result.physical_blocks.shape[0] == 3 diff --git a/tests/test_cache_engine_dist_reuse_gate.py b/tests/test_cache_engine_dist_reuse_gate.py new file mode 100644 index 0000000000..4c9f608889 --- /dev/null +++ b/tests/test_cache_engine_dist_reuse_gate.py @@ -0,0 +1,490 @@ +"""Phase D-2/D-4 unit tests for ``_sharing_domain_gate_get`` + +``_notify_sd_ready_on_put`` + ``_notify_master_sd_ready`` + +``_on_peer_sd_completed_op``. + +These methods are grafted onto the huge ``GlobalCacheEngine`` class, +which is itself impossible to import here (it drags in torch + c_ext ++ redis_meta). We use a lean ``_CacheEngineGateStub`` that mirrors +the real method bodies in spirit, plus source-level greps to guard +the real file from drift. + +After Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): +the per-SD ZMQ ``coord_put`` broadcast is gone — peer SD acks now +arrive as ``CompletedOp(sd_key, contributing_node_id)`` through the +graph-dispatch path, and ``_notify_master_sd_ready`` only does the +self-SD mark plus a pending-batch registration consumed by +``_on_peer_sd_completed_op``. +""" +from __future__ import annotations + +import sys +import threading +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, Optional +from unittest.mock import MagicMock + +import numpy as np +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from flexkv.common.dist_reuse.aggregate_radix import AggregateRadixTree # noqa: E402 +from flexkv.common.dist_reuse.master_coordinator import MasterCoordinator # noqa: E402 +from flexkv.common.dist_reuse.sharing_domain import SharingDomainKey # noqa: E402 + + +# --------------------------------------------------------------------------- +# Master fixture helpers +# --------------------------------------------------------------------------- +def _mk_single_sd_master() -> MasterCoordinator: + sd = SharingDomainKey( + model_id="test-model", pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, is_nsa=False, + ) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-master") + mc.expect_remotes(0) + return mc + + +def _mk_multi_sd_master(total_sd: int = 2) -> MasterCoordinator: + sd = SharingDomainKey( + model_id="test-model", pp_node_idx=0, pp_node_count=total_sd, + tp_node_idx=0, tp_node_count=1, is_nsa=False, + ) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-master") + mc.expect_remotes(total_sd - 1) + return mc + + +# --------------------------------------------------------------------------- +# MasterCoordinator.total_sd_count property +# --------------------------------------------------------------------------- +def test_master_total_sd_count_single(): + mc = _mk_single_sd_master() + assert mc.total_sd_count == 1 + + +def test_master_total_sd_count_multi(): + mc = _mk_multi_sd_master(total_sd=2) + assert mc.total_sd_count == 2 + + mc4 = _mk_multi_sd_master(total_sd=4) + assert mc4.total_sd_count == 4 + + +# --------------------------------------------------------------------------- +# Stubbed cache engine — mirrors the real ``GlobalCacheEngine`` shape we test. +# --------------------------------------------------------------------------- +class _CacheEngineGateStub: + def __init__(self, master_coord=None) -> None: + self._master_coord = master_coord + self._pending_put_lock = threading.Lock() + self._pending_put_batches: Dict[int, list] = {} + + @property + def has_dist_reuse(self) -> bool: + return self._master_coord is not None + + def _sharing_domain_gate_get( + self, + *, + sequence_meta, + return_mask, + block_start_idx: int, + num_gpu_blocks_to_transfer: int, + ) -> bool: + if not self.has_dist_reuse: + return True + if self._master_coord is None: + return True + try: + total_sd = int(getattr(self._master_coord, "total_sd_count", 1)) + except Exception: + total_sd = 1 + if total_sd <= 1: + return True + if num_gpu_blocks_to_transfer <= 0: + return True + try: + sequence_meta.gen_hashes() + except Exception: + return True + terminal_block_idx = block_start_idx + num_gpu_blocks_to_transfer - 1 + if terminal_block_idx >= sequence_meta.block_hashes.shape[0]: + return True + try: + prefix_hash = int(sequence_meta.block_hashes[terminal_block_idx].item()) + except Exception: + return True + try: + entry = self._master_coord.match_fully_ready(prefix_hash) + except Exception: + return True + return entry is not None + + def _notify_sd_ready_on_put( + self, + *, + sequence_meta, + inserted_block_ids, + block_start_idx: int, + num_blocks_inserted: int, + ) -> None: + if not self.has_dist_reuse: + return + if self._master_coord is None: + return + if num_blocks_inserted <= 0: + return + try: + sequence_meta.gen_hashes() + except Exception: + return + terminal_idx = block_start_idx + num_blocks_inserted - 1 + if terminal_idx < 0 or terminal_idx >= sequence_meta.block_hashes.shape[0]: + return + try: + prefix_hash = int(sequence_meta.block_hashes[terminal_idx].item()) + except Exception: + return + block_ids_list = [] + try: + if inserted_block_ids is not None: + block_ids_list = [int(b) for b in inserted_block_ids] + except Exception: + block_ids_list = [] + self._notify_master_sd_ready( + prefix_hash=prefix_hash, + block_ids=block_ids_list, + ) + + def _notify_master_sd_ready( + self, + prefix_hash: int, + block_ids: list, + ) -> None: + """Phase D-2 stub: self-SD mark + pending PUT batch registration. + + No coord_put, no peer-SD broadcast — those acks now arrive + asynchronously via ``_on_peer_sd_completed_op``. + """ + if self._master_coord is None: + return + try: + self_node_id = int(getattr(self._master_coord, "self_node_id", -1)) + except Exception: + self_node_id = -1 + try: + self._master_coord.mark_sd_ready( + prefix_hash=int(prefix_hash), + sd_key_str=self._master_coord.self_sd.serialize(), + block_ids=list(block_ids) if block_ids is not None else [], + node_id=self_node_id, + ) + except Exception: + pass + + # Multi-SD: register a pending PUT batch for the + # _completion_sink to consume on peer-SD CompletedOp arrival. + try: + total_sd = int(getattr(self._master_coord, "total_sd_count", 1)) + except Exception: + total_sd = 1 + if total_sd <= 1: + return + try: + block_ids_list = [int(b) for b in (block_ids or [])] + except Exception: + block_ids_list = [] + with self._pending_put_lock: + self._pending_put_batches[int(prefix_hash)] = block_ids_list + + def _on_peer_sd_completed_op(self, completed_op) -> None: + """Phase D-2 stub: route peer-SD CompletedOp into mark_sd_ready.""" + if self._master_coord is None: + return + sd_key = getattr(completed_op, "sd_key", "") or "" + if not sd_key: + return + if sd_key == self._master_coord.self_sd.serialize(): + return + if not getattr(completed_op, "success", True): + return + node_id = int(getattr(completed_op, "contributing_node_id", -1)) + with self._pending_put_lock: + pending_snapshot = list(self._pending_put_batches.items()) + for prefix_hash, block_ids_list in pending_snapshot: + try: + self._master_coord.mark_sd_ready( + prefix_hash=int(prefix_hash), + sd_key_str=sd_key, + block_ids=block_ids_list, + node_id=node_id, + ) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# SequenceMeta helper — avoid pulling torch just to build one. +# --------------------------------------------------------------------------- +class _FakeSeqMeta: + """Mimics the two attributes we touch: ``gen_hashes()`` + ``block_hashes``.""" + def __init__(self, block_hashes): + self.block_hashes = np.array(block_hashes, dtype=np.int64) + self._hashed = True + + def gen_hashes(self): + self._hashed = True + + +# --------------------------------------------------------------------------- +# Gate behaviour +# --------------------------------------------------------------------------- +class TestGateGet: + def test_no_dist_reuse_passes(self): + eng = _CacheEngineGateStub(master_coord=None) + assert eng._sharing_domain_gate_get( + sequence_meta=_FakeSeqMeta([1, 2, 3]), + return_mask=None, + block_start_idx=0, + num_gpu_blocks_to_transfer=3, + ) is True + + def test_single_sd_passes(self): + mc = _mk_single_sd_master() + try: + eng = _CacheEngineGateStub(master_coord=mc) + assert eng._sharing_domain_gate_get( + sequence_meta=_FakeSeqMeta([1, 2, 3]), + return_mask=None, + block_start_idx=0, + num_gpu_blocks_to_transfer=3, + ) is True + finally: + mc.shutdown() + + def test_multi_sd_not_ready_rejects(self): + mc = _mk_multi_sd_master(total_sd=2) + try: + eng = _CacheEngineGateStub(master_coord=mc) + # No SDs marked ready — gate must reject. + assert eng._sharing_domain_gate_get( + sequence_meta=_FakeSeqMeta([0xAA, 0xBB, 0xCC]), + return_mask=None, + block_start_idx=0, + num_gpu_blocks_to_transfer=3, + ) is False + finally: + mc.shutdown() + + def test_multi_sd_fully_ready_passes(self): + mc = _mk_multi_sd_master(total_sd=2) + try: + eng = _CacheEngineGateStub(master_coord=mc) + # Mark every SD ready for the terminal prefix hash. + self_sd_str = mc.self_sd.serialize() + peer_sd = next(p for p in mc.self_sd.enumerate_peers() if p != mc.self_sd) + peer_sd_str = peer_sd.serialize() + mc.mark_sd_ready(prefix_hash=0xCC, sd_key_str=self_sd_str, block_ids=[1, 2, 3]) + mc.mark_sd_ready(prefix_hash=0xCC, sd_key_str=peer_sd_str, block_ids=[1, 2, 3]) + assert eng._sharing_domain_gate_get( + sequence_meta=_FakeSeqMeta([0xAA, 0xBB, 0xCC]), + return_mask=None, + block_start_idx=0, + num_gpu_blocks_to_transfer=3, + ) is True + finally: + mc.shutdown() + + def test_zero_blocks_passes(self): + mc = _mk_multi_sd_master() + try: + eng = _CacheEngineGateStub(master_coord=mc) + assert eng._sharing_domain_gate_get( + sequence_meta=_FakeSeqMeta([1]), + return_mask=None, + block_start_idx=0, + num_gpu_blocks_to_transfer=0, + ) is True + finally: + mc.shutdown() + + +# --------------------------------------------------------------------------- +# _notify_sd_ready_on_put → _notify_master_sd_ready +# --------------------------------------------------------------------------- +class TestNotifySdReadyOnPut: + def test_no_dist_reuse_is_noop(self): + eng = _CacheEngineGateStub(master_coord=None) + # Must not raise. + eng._notify_sd_ready_on_put( + sequence_meta=_FakeSeqMeta([1, 2]), + inserted_block_ids=[10, 20], + block_start_idx=0, + num_blocks_inserted=2, + ) + + def test_self_sd_marked(self): + mc = _mk_single_sd_master() + try: + mc.mark_sd_ready = MagicMock(return_value=True) + eng = _CacheEngineGateStub(master_coord=mc) + eng._notify_sd_ready_on_put( + sequence_meta=_FakeSeqMeta([0xAA, 0xBB]), + inserted_block_ids=[10, 20], + block_start_idx=0, + num_blocks_inserted=2, + ) + mc.mark_sd_ready.assert_called_once() + kwargs = mc.mark_sd_ready.call_args.kwargs + assert kwargs["prefix_hash"] == 0xBB + assert kwargs["sd_key_str"] == mc.self_sd.serialize() + assert kwargs["block_ids"] == [10, 20] + finally: + mc.shutdown() + + +# --------------------------------------------------------------------------- +# Phase D-2: _on_peer_sd_completed_op marks peer SD ready +# --------------------------------------------------------------------------- +class TestPeerSdCompletionSink: + def test_self_sd_completed_op_ignored(self): + mc = _mk_multi_sd_master(total_sd=2) + try: + eng = _CacheEngineGateStub(master_coord=mc) + # Register a pending PUT batch first. + eng._notify_master_sd_ready(prefix_hash=0xABC, block_ids=[1, 2]) + mc.mark_sd_ready = MagicMock() + # Self-SD CompletedOp must be a no-op (already marked above). + self_sd_str = mc.self_sd.serialize() + self_co = SimpleNamespace( + sd_key=self_sd_str, + contributing_node_id=99, + success=True, + ) + eng._on_peer_sd_completed_op(self_co) + mc.mark_sd_ready.assert_not_called() + finally: + mc.shutdown() + + def test_peer_sd_completed_op_marks_ready(self): + mc = _mk_multi_sd_master(total_sd=2) + try: + eng = _CacheEngineGateStub(master_coord=mc) + # Register a pending PUT batch. + eng._notify_master_sd_ready(prefix_hash=0xABC, block_ids=[10, 20]) + assert eng._pending_put_batches[0xABC] == [10, 20] + + mc.mark_sd_ready = MagicMock(return_value=True) + peer_sd = next(p for p in mc.self_sd.enumerate_peers() if p != mc.self_sd) + peer_sd_str = peer_sd.serialize() + peer_co = SimpleNamespace( + sd_key=peer_sd_str, + contributing_node_id=42, + success=True, + ) + eng._on_peer_sd_completed_op(peer_co) + mc.mark_sd_ready.assert_called_once() + kwargs = mc.mark_sd_ready.call_args.kwargs + assert kwargs["prefix_hash"] == 0xABC + assert kwargs["sd_key_str"] == peer_sd_str + assert kwargs["block_ids"] == [10, 20] + assert kwargs["node_id"] == 42 + finally: + mc.shutdown() + + def test_failed_completed_op_ignored(self): + mc = _mk_multi_sd_master(total_sd=2) + try: + eng = _CacheEngineGateStub(master_coord=mc) + eng._notify_master_sd_ready(prefix_hash=0xABC, block_ids=[1]) + mc.mark_sd_ready = MagicMock() + peer_co = SimpleNamespace( + sd_key="peer-sd-key", + contributing_node_id=42, + success=False, + ) + eng._on_peer_sd_completed_op(peer_co) + mc.mark_sd_ready.assert_not_called() + finally: + mc.shutdown() + + def test_empty_sd_key_ignored(self): + mc = _mk_multi_sd_master(total_sd=2) + try: + eng = _CacheEngineGateStub(master_coord=mc) + eng._notify_master_sd_ready(prefix_hash=0xABC, block_ids=[1]) + mc.mark_sd_ready = MagicMock() + co = SimpleNamespace(sd_key="", contributing_node_id=42, success=True) + eng._on_peer_sd_completed_op(co) + mc.mark_sd_ready.assert_not_called() + finally: + mc.shutdown() + + +# --------------------------------------------------------------------------- +# Source-level guards — keep the real file in lock-step with this stub. +# --------------------------------------------------------------------------- +def test_source_has_gate_and_notify_methods(): + path = REPO_ROOT / "flexkv" / "cache" / "cache_engine.py" + src = path.read_text() + required = [ + "def _sharing_domain_gate_get(", + "def _notify_sd_ready_on_put(", + "def _notify_master_sd_ready(", + "def _on_peer_sd_completed_op(", + "def is_evictable(", + "has_dist_reuse", + "_master_coord", + "_pending_put_batches", + ] + missing = [m for m in required if m not in src] + assert not missing, ( + f"GlobalCacheEngine source-guard: missing required tokens {missing}" + ) + + +def test_source_put_callback_carries_sd_notify_kwargs(): + """Defensive: the PUT-path callback must thread ``sd_notify_kwargs`` + so that the dist_reuse fully_ready signal reaches the aggregate. + Empirically this is tied to the line that builds + ``sd_notify_kwargs`` and the one that passes it to + ``_transfer_callback``. + """ + path = REPO_ROOT / "flexkv" / "cache" / "cache_engine.py" + src = path.read_text() + assert "sd_notify_kwargs = {" in src + assert "is_put=True" in src + assert "sd_notify_kwargs=sd_notify_kwargs" in src + + +def test_source_master_total_sd_count_property_present(): + path = REPO_ROOT / "flexkv" / "common" / "dist_reuse" / "master_coordinator.py" + src = path.read_text() + assert "def total_sd_count" in src + + +def test_source_notify_master_sd_ready_uses_pending_put_registry(): + """Phase D-4 source guard: ``_notify_master_sd_ready`` must + populate ``_pending_put_batches`` for multi-SD deployments + (replacing the old ``coord_put`` broadcast).""" + path = REPO_ROOT / "flexkv" / "cache" / "cache_engine.py" + src = path.read_text() + assert "def _notify_master_sd_ready" in src + assert "_pending_put_batches[" in src, ( + "_notify_master_sd_ready must register pending PUT batches " + "for graph-dispatch peer-SD ack consumption" + ) + # Negative check: the old coord_put route AND the deprecated + # ``_coord_dispatcher`` field must be gone (Phase D-4 cleanup). + assert "self._coord_dispatcher" not in src, ( + "_coord_dispatcher field was supposed to be removed in the " + "Phase D-4 cleanup pass" + ) + assert ".coord_put(" not in src, ( + "coord_put broadcast was supposed to be deleted in Phase D-4" + ) diff --git a/tests/test_cext_evict_refcount_guard.py b/tests/test_cext_evict_refcount_guard.py new file mode 100644 index 0000000000..39fec5263a --- /dev/null +++ b/tests/test_cext_evict_refcount_guard.py @@ -0,0 +1,206 @@ +"""§2.2(b) — C++ accel path for the eviction refcount guard. + +Parallels :mod:`test_evict_refcount_guard` (Python ``RadixTreeIndex``) +but drives the C++ ``CRadixTreeIndex`` that's exposed through the +pybind11 ``c_ext`` module and consumed by :class:`CacheEngineAccel`. + +These tests require ``c_ext.so`` to have been built with the +§2.2(b) 4-arg ``evict`` overload; when the extension is not available +(e.g. clean Mac dev machine) the whole file is skipped. The real CI +path runs inside the ``flexkv_distreuse`` container which always has +``c_ext`` compiled. + +See docs/dist_reuse/implementation_gap_2026-05-11.md §2.2 and +docs/dist_reuse/implementation_progress_2026-05-13.md for the status +of the Accel path. +""" +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +torch = pytest.importorskip("torch") +c_ext = pytest.importorskip("flexkv.c_ext") + +# The 4-arg evict overload is the new addition (§2.2(b)). If the +# loaded ``c_ext.so`` predates it, skip this whole file: the guard +# isn't actually enforced on the Accel path yet in that build. +_CRadixTreeIndex = getattr(c_ext, "CRadixTreeIndex", None) +if _CRadixTreeIndex is None: # pragma: no cover — extension layout guard + pytest.skip( + "c_ext.CRadixTreeIndex not exported — rebuild c_ext with " + "§2.2(b) patches applied.", + allow_module_level=True, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +TOKENS_PER_BLOCK = 16 + + +def _make_tree() -> "c_ext.CRadixTreeIndex": + return _CRadixTreeIndex(TOKENS_PER_BLOCK, 1_000_000, 0, "lru") + + +def _insert(idx, num_blocks: int, phys_start: int, salt: int = 0): + """Insert ``num_blocks`` consecutive physical block ids + hashes + starting at ``phys_start``. Returns the physical block array. + + Hashes are chosen to be distinct across calls with different + ``salt`` to avoid collision in the radix prefix tree. + """ + phys = torch.arange(phys_start, phys_start + num_blocks, dtype=torch.int64) + # Deterministic but collision-free hashes across different salts. + # Use a large-ish stride so prefixes don't accidentally extend. + hashes = torch.arange( + phys_start + salt * 1_000_000, + phys_start + salt * 1_000_000 + num_blocks, + dtype=torch.int64, + ) + idx.insert(phys, hashes, num_blocks, num_blocks, True) + return phys.numpy() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +def test_cext_evict_3arg_legacy_unchanged(): + """Baseline: 3-arg ``evict`` (no predicate) behaves as before.""" + idx = _make_tree() + p1 = _insert(idx, 3, phys_start=0, salt=0) + p2 = _insert(idx, 2, phys_start=100, salt=1) + + out_blocks = torch.zeros(4, dtype=torch.int64) + out_hashes = torch.zeros(4, dtype=torch.int64) + n = idx.evict(out_blocks, out_hashes, 4) + + assert n == 4 + evicted = set(out_blocks[:n].numpy().tolist()) + assert evicted.issubset(set(p1.tolist()) | set(p2.tolist())) + + +def test_cext_evict_4arg_with_null_predicate_behaves_like_3arg(): + """A None predicate is allowed and behaves identically to the 3-arg + overload — required because ``CacheEngineAccel.take`` only plumbs + the predicate when ``set_evict_guard`` has installed one, but the + C++ side accepts either form. + """ + idx = _make_tree() + _insert(idx, 2, phys_start=0, salt=0) + out_blocks = torch.zeros(2, dtype=torch.int64) + out_hashes = torch.zeros(2, dtype=torch.int64) + + # ``lambda x: True`` is the null-object equivalent for callers that + # always go through the 4-arg overload. + n = idx.evict(out_blocks, out_hashes, 2, lambda _bid: True) + assert n == 2 + + +def test_cext_evict_guard_pins_specific_block_ids(): + """Block ids for which the guard returns False must NOT appear in + the returned eviction set, even though the LRU picked them. + """ + idx = _make_tree() + p1 = _insert(idx, 3, phys_start=0, salt=0) # 0,1,2 + p2 = _insert(idx, 3, phys_start=10, salt=1) # 10,11,12 + + pinned = {1, 11} + + def guard(block_id: int) -> bool: + return block_id not in pinned + + out_blocks = torch.zeros(6, dtype=torch.int64) + out_hashes = torch.zeros(6, dtype=torch.int64) + n = idx.evict(out_blocks, out_hashes, 6, guard) + + evicted = set(out_blocks[:n].numpy().tolist()) + # Pinned ids must be absent. + assert pinned.isdisjoint(evicted) + # And since we pinned 2 out of 6 candidates, at most 4 should + # come back. + assert n <= 4 + + +def test_cext_evict_guard_all_pinned_returns_zero(): + """If *every* candidate is pinned, ``evict`` returns 0 — no crash, + no partial exception.""" + idx = _make_tree() + _insert(idx, 2, phys_start=0) + + def guard(_block_id: int) -> bool: + return False + + out_blocks = torch.zeros(10, dtype=torch.int64) + out_hashes = torch.zeros(10, dtype=torch.int64) + n = idx.evict(out_blocks, out_hashes, 10, guard) + assert n == 0 + + +def test_cext_evict_guard_exception_treats_as_evictable(): + """A buggy predicate must not wedge the allocator. The C++ path + catches the exception and treats the block as evictable (same + contract as the Python ``RadixTreeIndex.evict``). + """ + idx = _make_tree() + _insert(idx, 2, phys_start=0) + + def bad_guard(_bid: int) -> bool: + raise RuntimeError("boom") + + out_blocks = torch.zeros(2, dtype=torch.int64) + out_hashes = torch.zeros(2, dtype=torch.int64) + n = idx.evict(out_blocks, out_hashes, 2, bad_guard) + # Exception-fallback treats as evictable → we still make progress. + assert n >= 1 + + +def test_cext_evict_guard_shrink_path_respects_pins(): + """Regression: the shrink branch (when node.size > num_remaining) + must also drop pinned ids from its partial output. + """ + idx = _make_tree() + _insert(idx, 5, phys_start=0) # one long leaf 0..4 + + pinned = {2} + + def guard(block_id: int) -> bool: + return block_id not in pinned + + out_blocks = torch.zeros(3, dtype=torch.int64) + out_hashes = torch.zeros(3, dtype=torch.int64) + n = idx.evict(out_blocks, out_hashes, 3, guard) + + evicted = set(out_blocks[:n].numpy().tolist()) + assert pinned.isdisjoint(evicted) + assert n <= 3 + + +def test_cext_evict_binding_has_4arg_overload(): + """Static guard: make sure the binding file actually publishes the + new overload. This catches accidental build regressions without + having to run pytest in the container.""" + bindings_path = REPO_ROOT / "csrc" / "bindings.cpp" + src = bindings_path.read_text() + assert "std::function" in src, ( + "The 4-arg CRadixTreeIndex::evict binding is missing — " + "check csrc/bindings.cpp and rebuild c_ext." + ) + + +def test_cext_cpp_has_4arg_overload_definition(): + """Static guard: the C++ impl must carry the 4-arg override.""" + cpp_path = REPO_ROOT / "csrc" / "radix_tree.cpp" + src = cpp_path.read_text() + assert "is_evictable_fn" in src, ( + "CRadixTreeIndex::evict 4-arg impl missing the predicate " + "parameter — §2.2(b) C++ work not applied." + ) diff --git a/tests/test_coord_protocol.py b/tests/test_coord_protocol.py new file mode 100644 index 0000000000..1172fb6bd9 --- /dev/null +++ b/tests/test_coord_protocol.py @@ -0,0 +1,128 @@ +"""Unit tests for ``flexkv.cache.coordination_protocol``. + +Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): only +``RemoteReadyMsg`` and ``FailureReportMsg`` survive — the +``CoordQuery*`` / ``CoordGet*`` / ``CoordPut*`` messages were tied to +the old per-SD ZMQ coord protocol and are replaced by the unified +graph-dispatch path (peer SD acks now come back as +``CompletedOp(sd_key, contributing_node_id)`` via +``TransferManagerMultiNodeHandle._completion_sink``). +""" +from __future__ import annotations + +import pickle + +import pytest + +from flexkv.common.dist_reuse.coordination_protocol import ( + CoordMsgType, + EpochVerifyError, + FailureReportMsg, + RemoteReadyMsg, + decode_coord_message, + encode_coord_message, +) + + +# --------------------------------------------------------------------------- +# Discriminator +# --------------------------------------------------------------------------- +class TestType: + @pytest.mark.parametrize("cls,expected", [ + (RemoteReadyMsg, CoordMsgType.REMOTE_READY), + (FailureReportMsg, CoordMsgType.FAILURE_REPORT), + ]) + def test_class_type_attached(self, cls, expected): + assert cls.type is expected + # Also accessible via instance + msg = cls() + assert msg.type is expected + + +# --------------------------------------------------------------------------- +# Encode / decode round trips +# --------------------------------------------------------------------------- +class TestRoundTrip: + @pytest.mark.parametrize("msg", [ + RemoteReadyMsg( + sender_instance_id="inst1", + sender_epoch="epoch-1", + request_id=42, + sd_key="abc:ppn1/2:tpn0/1:nsa0", + distributed_node_id=7, + mooncake_addr="10.0.0.1:5555", + zmq_addr="tcp://10.0.0.1:6666", + ), + FailureReportMsg( + sender_instance_id="inst-r", + sender_epoch="e", + request_id=4, + peer_instance_id="inst1", + failed_block_hashes=[42, 43], + error="hca dropped", + ), + ]) + def test_encode_decode_round_trip(self, msg): + payload = encode_coord_message(msg) + # ``type`` is preserved as the enum value (str) + assert payload["type"] == msg.type.value + out = decode_coord_message(payload) + assert out == msg + assert out is not msg + assert type(out) is type(msg) + + @pytest.mark.parametrize("msg", [ + RemoteReadyMsg(), + FailureReportMsg(error="x"), + ]) + def test_pickle_round_trip(self, msg): + # Current ZMQ transport uses pickle, so make sure the dataclasses + # survive a pickle round-trip (default fields, mutable defaults, etc.) + out = pickle.loads(pickle.dumps(msg)) + assert out == msg + + +# --------------------------------------------------------------------------- +# Decode error cases +# --------------------------------------------------------------------------- +class TestDecodeErrors: + def test_missing_type(self): + with pytest.raises(ValueError, match="missing 'type'"): + decode_coord_message({"sender_instance_id": "x"}) + + def test_unknown_type(self): + with pytest.raises(ValueError, match="unknown type"): + decode_coord_message({"type": "definitely-not-a-real-type"}) + + def test_unknown_field(self): + with pytest.raises(ValueError, match="unknown fields"): + decode_coord_message({ + "type": CoordMsgType.REMOTE_READY.value, + "sender_instance_id": "x", + "this_field_does_not_exist": 1, + }) + + def test_encode_rejects_non_message(self): + with pytest.raises(TypeError): + encode_coord_message({"type": "fake"}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# EpochVerifyError sanity +# --------------------------------------------------------------------------- +def test_epoch_verify_error_is_runtime_error(): + assert issubclass(EpochVerifyError, RuntimeError) + with pytest.raises(EpochVerifyError): + raise EpochVerifyError("stale") + + +# --------------------------------------------------------------------------- +# Default values +# --------------------------------------------------------------------------- +def test_default_message_has_empty_lists_not_shared(): + """Make sure mutable default fields (``field(default_factory=list)``) do + NOT share state between instances — a classic dataclass foot-gun.""" + a = FailureReportMsg() + b = FailureReportMsg() + a.failed_block_hashes.append(1) + assert b.failed_block_hashes == [] diff --git a/tests/test_d3_filter_and_get_clones.py b/tests/test_d3_filter_and_get_clones.py new file mode 100644 index 0000000000..888bad7c4e --- /dev/null +++ b/tests/test_d3_filter_and_get_clones.py @@ -0,0 +1,622 @@ +"""Phase D-3 unit tests +(proposal_unify_with_graph_dispatch_2026-05-15.md §6.4). + +Covers two pieces of functionality introduced in D-3: + +1. ``_filter_graph_inplace_by_target_node_ids`` — module-level + utility extracted out of ``TransferManagerOnRemote`` so the + Master's in-proc / inter-proc handles can use it too. In + particular tests: + + * legacy graphs (no ``target_node_ids`` set) are untouched + * ops not addressed to ``self_nid`` are dropped + * dropped ops are removed from kept ops' ``predecessors`` / + ``successors`` so the graph engine does not deadlock on a + filtered-away dependency + +2. ``GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops`` — fans the + GET-path PEERH2H op out to one clone per peer SD, with + ``target_node_ids`` and ``src_block_node_ids`` derived from the + ``AggregateRadixTree``'s ``ready_sds`` map. Tests: + + * single-SD instance / dist_reuse-off → no clones, master op + untouched (legacy bit-identical behaviour) + * multi-SD instance with all peer SDs ready → master op stamped + with ``target_node_ids=[self_node_id]`` + one clone per peer SD + * peer SD missing from ``ready_sds`` → that SD is silently + skipped (gate will reject the GET downstream) + +The tests are pure Python: they construct a real +``TransferOpGraph`` and a stub ``GlobalCacheEngine`` instance with +just the attributes the helper touches. No GPU / mooncake / +TransferEngine startup is needed. +""" +from __future__ import absolute_import + +from typing import Dict, Iterable, List, Optional + +import numpy as np +import pytest + +from flexkv.common.transfer import ( + TransferOp, + TransferOpGraph, + TransferType, +) +from flexkv.transfer_manager import _filter_graph_inplace_by_target_node_ids + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _mk_op( + transfer_type: TransferType, + *, + src=(0, 1), + dst=(2, 3), + target_node_ids: Optional[List[int]] = None, +) -> TransferOp: + # NOTE: ``layer_id`` / ``layer_granularity`` were removed from the + # base ``TransferOp`` dataclass (commit abbf299 + RankInfo refactor) + # — those fields now live only on ``LayerwiseTransferOp``. D-3 + # filter-graph behaviour is independent of layer info, so we simply + # don't supply them here. + return TransferOp( + graph_id=-1, # add_transfer_op stamps the real graph_id + transfer_type=transfer_type, + src_block_ids=np.array(src, dtype=np.int64), + dst_block_ids=np.array(dst, dtype=np.int64), + target_node_ids=list(target_node_ids) if target_node_ids else None, + ) + + +# =========================================================================== +# Section 1: filter utility +# =========================================================================== +class TestFilterUtility: + def test_legacy_graph_no_target_node_ids_unchanged(self): + """Op with target_node_ids=None must always be kept, regardless + of self_nid. This is the legacy single-SD / cross-machine TP + contract — Phase D-3 must not regress it.""" + g = TransferOpGraph() + a = _mk_op(TransferType.D2H) # no target → legacy + b = _mk_op(TransferType.H2D) # no target → legacy + g.add_transfer_op(a) + g.add_transfer_op(b) + g.add_dependency(b.op_id, a.op_id) # b depends on a + + dropped = _filter_graph_inplace_by_target_node_ids(g, self_nid=42) + + assert dropped == 0 + assert g.num_ops == 2 + assert b.op_id in [op.op_id for op in g._op_map.values()] + assert a.op_id in [op.op_id for op in g._op_map.values()] + assert a.op_id in b.predecessors + + def test_self_nid_negative_keeps_everything(self): + """Before dist_reuse bootstrap finishes, self_nid is -1 and the + helper must behave like a no-op so legacy paths run as before. + """ + g = TransferOpGraph() + a = _mk_op(TransferType.PEERH2H, target_node_ids=[7]) + g.add_transfer_op(a) + dropped = _filter_graph_inplace_by_target_node_ids(g, self_nid=-1) + assert dropped == 0 + assert g.num_ops == 1 + + def test_drops_ops_not_addressed_to_self(self): + """Master's own op (target=self) is kept; peer-SD clone (target + != self) is dropped.""" + g = TransferOpGraph() + master_op = _mk_op(TransferType.PEERH2H, target_node_ids=[10]) + peer_op = _mk_op(TransferType.PEERH2H, target_node_ids=[11]) + g.add_transfer_op(master_op) + g.add_transfer_op(peer_op) + + dropped = _filter_graph_inplace_by_target_node_ids(g, self_nid=10) + + assert dropped == 1 + assert g.num_ops == 1 + assert master_op.op_id in g._op_map + assert peer_op.op_id not in g._op_map + + def test_dependency_repair_dropped_predecessor_does_not_deadlock(self): + """If the dropped op was a predecessor of a kept op, the kept op + must NOT carry the dropped op_id in its predecessors set — + otherwise ``take_ready_ops`` would never schedule it. + """ + g = TransferOpGraph() + # master op_peerh2h (kept on self=10) + master_peerh2h = _mk_op( + TransferType.PEERH2H, target_node_ids=[10], + ) + # peer-SD clone (dropped on self=10) + peer_clone = _mk_op( + TransferType.PEERH2H, target_node_ids=[11], + ) + # op_h2d depends on BOTH (legacy + Phase D-3 dependency) + op_h2d = _mk_op(TransferType.H2D) # target=None → kept + g.add_transfer_op(master_peerh2h) + g.add_transfer_op(peer_clone) + g.add_transfer_op(op_h2d) + g.add_dependency(op_h2d.op_id, master_peerh2h.op_id) + g.add_dependency(op_h2d.op_id, peer_clone.op_id) + assert op_h2d.predecessors == { + master_peerh2h.op_id, peer_clone.op_id, + } + + dropped = _filter_graph_inplace_by_target_node_ids(g, self_nid=10) + + assert dropped == 1 + # Surviving op_h2d must still wait on master_peerh2h, but the + # filtered-away peer_clone must be gone from its predecessors. + assert op_h2d.predecessors == {master_peerh2h.op_id} + # Conversely master_peerh2h.successors must lose op_h2d? No — + # successors only loses references to *dropped* ops; op_h2d + # is kept. Sanity check. + assert op_h2d.op_id in master_peerh2h.successors + + def test_dependency_repair_clears_orphan_successors(self): + """If we drop op X that listed kept op Y in its successors, the + helper should not leave a dangling ref to a dropped op_id in + any kept op's successors set either (defensive symmetry — even + if no current code path relies on it).""" + g = TransferOpGraph() + kept = _mk_op(TransferType.D2H) # target=None → kept + dropped_op = _mk_op( + TransferType.PEERH2H, target_node_ids=[99], + ) + g.add_transfer_op(kept) + g.add_transfer_op(dropped_op) + # kept depends on dropped_op (so kept.predecessors has it, + # dropped_op.successors has kept) + g.add_dependency(kept.op_id, dropped_op.op_id) + + _filter_graph_inplace_by_target_node_ids(g, self_nid=10) + + # No reference to the dropped op_id anywhere in kept op's sets. + assert dropped_op.op_id not in kept.predecessors + assert dropped_op.op_id not in kept.successors + + def test_kept_op_with_only_dropped_predecessors_becomes_ready(self): + """If filtering empties an op's predecessors, the op should be + moved back to ``_ready_ops`` so the graph engine schedules it. + """ + g = TransferOpGraph() + peer_a = _mk_op(TransferType.PEERH2H, target_node_ids=[11]) + peer_b = _mk_op(TransferType.PEERH2H, target_node_ids=[12]) + local_h2d = _mk_op(TransferType.H2D) # target=None → kept + + g.add_transfer_op(peer_a) + g.add_transfer_op(peer_b) + g.add_transfer_op(local_h2d) + g.add_dependency(local_h2d.op_id, peer_a.op_id) + g.add_dependency(local_h2d.op_id, peer_b.op_id) + # add_dependency moves successor out of _ready_ops + assert local_h2d.op_id not in g._ready_ops + + _filter_graph_inplace_by_target_node_ids(g, self_nid=10) + + # All predecessors gone → local_h2d back in _ready_ops so the + # engine actually schedules it. + assert local_h2d.predecessors == set() + assert local_h2d.op_id in g._ready_ops + + +# =========================================================================== +# Section 2: _maybe_attach_multi_sd_peerh2h_ops on GlobalCacheEngine +# =========================================================================== +# +# We don't construct a real GlobalCacheEngine (heavy: needs a torch +# device / RedisMeta / etc.). We instead test the unbound method +# against a duck-typed stub that exposes only the attributes the +# helper reaches for: ``_master_coord`` (a ``MasterCoordinator``-like +# object). +# +# This keeps the test pure-Python while pinning the behavioural +# contract. + +class _StubSequenceMeta: + """Just enough of ``SequenceMeta`` for the helper. + + Helper paths used: + + * ``sequence_meta.gen_hashes()`` — must be safe to call + * ``sequence_meta.block_hashes`` — np.ndarray[int64] + """ + def __init__(self, hashes: List[int]): + self.block_hashes = np.array(hashes, dtype=np.int64) + + def gen_hashes(self) -> None: + # Hashes are pre-computed in the constructor. + return + + +class _StubReadyEntry: + def __init__(self, ready_sds: Dict[str, int]): + self.ready_sds = dict(ready_sds) + + +class _StubMasterCoord: + """Mimics enough of ``MasterCoordinator`` for the helper.""" + + def __init__( + self, + *, + self_sd_str: str, + sd_to_nid: Dict[str, int], + ready_sds: Optional[Dict[int, Dict[str, int]]] = None, + ): + # ``self_sd`` exposes ``serialize()`` like the real key. + class _SD: + def __init__(self, s): + self._s = s + def serialize(self) -> str: + return self._s + self.self_sd = _SD(self_sd_str) + self._sd_to_nid = dict(sd_to_nid) + # prefix_hash → ready_sds map (per prefix) + self._ready_sds_per_prefix = ready_sds or {} + + def get_sd_to_nid_map(self) -> Dict[str, int]: + return dict(self._sd_to_nid) + + def match_fully_ready(self, prefix_hash: int): + rs = self._ready_sds_per_prefix.get(int(prefix_hash)) + if rs is None: + return None + return _StubReadyEntry(rs) + + +class _StubGlobalCacheEngine: + """The helper is bound to ``self._master_coord`` only.""" + def __init__(self, master_coord): + self._master_coord = master_coord + + +def _bind_helper(): + """Pull the helper off the real class so we can call it bound to + a stub instance — keeps the production code under test honest.""" + from flexkv.cache.cache_engine import GlobalCacheEngine + return GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops + + +class TestMaybeAttachMultiSDPeerH2H: + def _make_graph_with_master_op(self): + g = TransferOpGraph() + master_peerh2h = TransferOp( + graph_id=g.graph_id, + transfer_type=TransferType.PEERH2H, + src_block_ids=np.array([10, 11, 12], dtype=np.int64), + dst_block_ids=np.array([20, 21, 22], dtype=np.int64), + remote_node_ids=np.array([7, 7, 7], dtype=np.int64), + src_block_node_ids=np.array([7, 7, 7], dtype=np.int64), + ) + g.add_transfer_op(master_peerh2h) + return g, master_peerh2h + + def test_no_master_coord_legacy_passthrough(self): + helper = _bind_helper() + g, master_op = self._make_graph_with_master_op() + seq = _StubSequenceMeta([0xAA]) + stub = _StubGlobalCacheEngine(master_coord=None) + + clones = helper( + stub, + transfer_graph=g, + op_peerh2h=master_op, + sequence_meta=seq, + prefix_terminal_block_idx=0, + ) + + assert clones == [] + assert master_op.target_node_ids is None # untouched + assert g.num_ops == 1 # no clones added + + def test_single_sd_instance_no_clones(self): + helper = _bind_helper() + g, master_op = self._make_graph_with_master_op() + seq = _StubSequenceMeta([0xAA]) + coord = _StubMasterCoord( + self_sd_str="sd-master", + sd_to_nid={"sd-master": 100}, + ) + stub = _StubGlobalCacheEngine(master_coord=coord) + + clones = helper( + stub, + transfer_graph=g, + op_peerh2h=master_op, + sequence_meta=seq, + prefix_terminal_block_idx=0, + ) + + assert clones == [] + # Critical: the master op stays untouched in the single-SD case + # (otherwise the in-proc handle filter would drop it spuriously). + assert master_op.target_node_ids is None + assert g.num_ops == 1 + + def test_bootstrap_not_finished_no_clones(self): + """``get_sd_to_nid_map`` returns {} until the master's own + node_id has been registered. The helper must short-circuit and + leave the master op untouched.""" + helper = _bind_helper() + g, master_op = self._make_graph_with_master_op() + seq = _StubSequenceMeta([0xAA]) + coord = _StubMasterCoord( + self_sd_str="sd-master", + sd_to_nid={}, # bootstrap not done yet + ) + stub = _StubGlobalCacheEngine(master_coord=coord) + + clones = helper( + stub, + transfer_graph=g, + op_peerh2h=master_op, + sequence_meta=seq, + prefix_terminal_block_idx=0, + ) + + assert clones == [] + assert master_op.target_node_ids is None + assert g.num_ops == 1 + + def test_multi_sd_all_ready_clones_per_peer(self): + helper = _bind_helper() + g, master_op = self._make_graph_with_master_op() + prefix_hash = 0xCAFE + seq = _StubSequenceMeta([prefix_hash]) + + # 3 SDs total: master + 2 peers. ``ready_sds`` records each + # SD's contributing peer's node_id (note: master SD is owned + # locally so node_id == master's own node_id). + coord = _StubMasterCoord( + self_sd_str="sd-master", + sd_to_nid={ + "sd-master": 100, + "sd-peer-A": 200, + "sd-peer-B": 300, + }, + ready_sds={ + prefix_hash: { + "sd-master": 100, + "sd-peer-A": 555, # peer instance's nid on SD-A + "sd-peer-B": 666, # peer instance's nid on SD-B + }, + }, + ) + stub = _StubGlobalCacheEngine(master_coord=coord) + + clones = helper( + stub, + transfer_graph=g, + op_peerh2h=master_op, + sequence_meta=seq, + prefix_terminal_block_idx=0, + ) + + # Master op stamped to itself. + assert master_op.target_node_ids == [100] + # Two clones added. + assert len(clones) == 2 + # Total ops in graph: master + 2 clones = 3. + assert g.num_ops == 3 + + clones_by_target = {tuple(c.target_node_ids): c for c in clones} + assert (200,) in clones_by_target + assert (300,) in clones_by_target + + clone_a = clones_by_target[(200,)] + clone_b = clones_by_target[(300,)] + + # src/dst block_ids mirror the master op (mirror assumption). + np.testing.assert_array_equal( + clone_a.src_block_ids, master_op.src_block_ids, + ) + np.testing.assert_array_equal( + clone_a.dst_block_ids, master_op.dst_block_ids, + ) + # src_block_node_ids point to the peer instance's nid on each + # peer SD (pulled from ready_sds). + np.testing.assert_array_equal( + clone_a.src_block_node_ids, + np.array([555, 555, 555], dtype=np.int64), + ) + np.testing.assert_array_equal( + clone_b.src_block_node_ids, + np.array([666, 666, 666], dtype=np.int64), + ) + # transfer_type preserved. + assert clone_a.transfer_type == TransferType.PEERH2H + assert clone_b.transfer_type == TransferType.PEERH2H + + def test_multi_sd_partial_ready_skips_unacked_peers(self): + """If a peer SD has not yet acked (missing from ``ready_sds`` + or value < 0), the helper skips it. The downstream gate + (``_sharing_domain_gate_get``) then rejects the GET when + ``match_fully_ready`` returns None on the same prefix. + + This test only validates the per-SD skip behaviour; the gate + is exercised by ``test_cache_engine_dist_reuse_gate.py``. + """ + helper = _bind_helper() + g, master_op = self._make_graph_with_master_op() + prefix_hash = 0xC0DE + seq = _StubSequenceMeta([prefix_hash]) + + coord = _StubMasterCoord( + self_sd_str="sd-master", + sd_to_nid={ + "sd-master": 100, + "sd-peer-A": 200, + "sd-peer-B": 300, + }, + ready_sds={ + prefix_hash: { + "sd-master": 100, + "sd-peer-A": 555, + # sd-peer-B intentionally missing → not yet acked + }, + }, + ) + stub = _StubGlobalCacheEngine(master_coord=coord) + + clones = helper( + stub, + transfer_graph=g, + op_peerh2h=master_op, + sequence_meta=seq, + prefix_terminal_block_idx=0, + ) + + # Only peer A has a clone; peer B is silently skipped. + assert len(clones) == 1 + assert clones[0].target_node_ids == [200] + # Master op still tagged (2-of-3 SDs visible is enough — gate + # decides the GET fate elsewhere). + assert master_op.target_node_ids == [100] + assert g.num_ops == 2 + + def test_multi_sd_no_aggregate_entry_no_clones(self): + """``match_fully_ready`` returns None (gate will reject) → + helper bails out without mutating the graph. Avoids + polluting the graph with clones that would never get a + contributor.""" + helper = _bind_helper() + g, master_op = self._make_graph_with_master_op() + seq = _StubSequenceMeta([0xDEAD]) + coord = _StubMasterCoord( + self_sd_str="sd-master", + sd_to_nid={ + "sd-master": 100, + "sd-peer-A": 200, + }, + ready_sds={}, # no entry for any prefix + ) + stub = _StubGlobalCacheEngine(master_coord=coord) + + clones = helper( + stub, + transfer_graph=g, + op_peerh2h=master_op, + sequence_meta=seq, + prefix_terminal_block_idx=0, + ) + + assert clones == [] + # Master op left alone — gate will reject the GET so no + # contradictory tagging is needed. + assert master_op.target_node_ids is None + assert g.num_ops == 1 + + +# =========================================================================== +# Section 3: end-to-end small fixture — filter pipeline preserves +# the per-SD invariant for a realistic GET graph +# =========================================================================== +class TestEndToEndFilterPipeline: + """Build a graph that mimics what ``_get_impl_local`` produces in + the multi-SD reuse case, run it through the filter on each SD's + handle, and verify each SD ends up with exactly the ops it should + execute and a ready ``op_h2d`` whose predecessors collapse to the + SD's own peerh2h op.""" + + def _build_get_graph(self): + g = TransferOpGraph() + # master op_peerh2h, target=master_nid + master_peerh2h = TransferOp( + graph_id=g.graph_id, + transfer_type=TransferType.PEERH2H, + src_block_ids=np.array([0, 1], dtype=np.int64), + dst_block_ids=np.array([10, 11], dtype=np.int64), + target_node_ids=[100], + ) + # peer-A clone, target=peer_a_nid + peer_a_clone = TransferOp( + graph_id=g.graph_id, + transfer_type=TransferType.PEERH2H, + src_block_ids=np.array([0, 1], dtype=np.int64), + dst_block_ids=np.array([10, 11], dtype=np.int64), + target_node_ids=[200], + ) + # peer-B clone + peer_b_clone = TransferOp( + graph_id=g.graph_id, + transfer_type=TransferType.PEERH2H, + src_block_ids=np.array([0, 1], dtype=np.int64), + dst_block_ids=np.array([10, 11], dtype=np.int64), + target_node_ids=[300], + ) + # op_h2d — no SD tag, every handle keeps it. + op_h2d = TransferOp( + graph_id=g.graph_id, + transfer_type=TransferType.H2D, + src_block_ids=np.array([10, 11], dtype=np.int64), + dst_block_ids=np.array([1000, 1001], dtype=np.int64), + ) + g.add_transfer_op(master_peerh2h) + g.add_transfer_op(peer_a_clone) + g.add_transfer_op(peer_b_clone) + g.add_transfer_op(op_h2d) + # H2D depends on every PEERH2H (master + each peer clone) + g.add_dependency(op_h2d.op_id, master_peerh2h.op_id) + g.add_dependency(op_h2d.op_id, peer_a_clone.op_id) + g.add_dependency(op_h2d.op_id, peer_b_clone.op_id) + return g, master_peerh2h, peer_a_clone, peer_b_clone, op_h2d + + def _deep_copy_graph(self, g): + import pickle + return pickle.loads(pickle.dumps(g)) + + def test_master_handle_keeps_only_self_peerh2h_and_h2d(self): + g, master_op, peer_a, peer_b, h2d = self._build_get_graph() + master_g = self._deep_copy_graph(g) + + dropped = _filter_graph_inplace_by_target_node_ids( + master_g, self_nid=100, + ) + assert dropped == 2 + + kept_ids = set(master_g._op_map.keys()) + assert master_op.op_id in kept_ids + assert h2d.op_id in kept_ids + assert peer_a.op_id not in kept_ids + assert peer_b.op_id not in kept_ids + + # h2d must depend ONLY on master's own peerh2h after filtering. + master_h2d = master_g._op_map[h2d.op_id] + assert master_h2d.predecessors == {master_op.op_id} + + def test_peer_a_handle_keeps_only_clone_and_h2d(self): + g, master_op, peer_a, peer_b, h2d = self._build_get_graph() + peer_a_g = self._deep_copy_graph(g) + + dropped = _filter_graph_inplace_by_target_node_ids( + peer_a_g, self_nid=200, + ) + assert dropped == 2 + + kept_ids = set(peer_a_g._op_map.keys()) + assert peer_a.op_id in kept_ids + assert h2d.op_id in kept_ids + assert master_op.op_id not in kept_ids + assert peer_b.op_id not in kept_ids + + peer_h2d = peer_a_g._op_map[h2d.op_id] + assert peer_h2d.predecessors == {peer_a.op_id} + + def test_peer_b_handle_keeps_only_clone_and_h2d(self): + g, master_op, peer_a, peer_b, h2d = self._build_get_graph() + peer_b_g = self._deep_copy_graph(g) + + dropped = _filter_graph_inplace_by_target_node_ids( + peer_b_g, self_nid=300, + ) + assert dropped == 2 + peer_h2d = peer_b_g._op_map[h2d.op_id] + assert peer_h2d.predecessors == {peer_b.op_id} + + +if __name__ == "__main__": # pragma: no cover + raise SystemExit(pytest.main([__file__, "-v"])) diff --git a/tests/test_dist_reuse_launcher.py b/tests/test_dist_reuse_launcher.py new file mode 100644 index 0000000000..043a917fe8 --- /dev/null +++ b/tests/test_dist_reuse_launcher.py @@ -0,0 +1,163 @@ +"""§2.5 — Smoke test for start_dist_reuse_serving.sh. + +Exercises the shell script's ``--dry-run`` mode: no vLLM/sglang boot, +but we verify the script accepts the documented flag set, emits the +expected ``mooncake_config.json`` shape based on topology, and prints +a summary. + +This is deliberately smoke-level — a full multi-node boot requires +real GPU machines (§2.6). +""" +from __future__ import annotations + +import json +import os +import subprocess +import sys +from pathlib import Path + +import pytest + + +REPO_ROOT = Path(__file__).resolve().parent.parent +SCRIPT = REPO_ROOT / "scripts" / "multi-nodes" / "start_dist_reuse_serving.sh" + + +# Skip everything when bash isn't available (Windows CI etc.). +pytestmark = pytest.mark.skipif( + not SCRIPT.exists(), + reason="start_dist_reuse_serving.sh not present", +) + + +def _run(args, cwd=None, env=None, check_returncode=True): + proc = subprocess.run( + ["bash", str(SCRIPT), *args], + cwd=cwd or str(REPO_ROOT), + env=env or os.environ.copy(), + capture_output=True, + text=True, + timeout=20, + ) + if check_returncode and proc.returncode != 0: + raise AssertionError( + f"script exited {proc.returncode}\nstdout:\n{proc.stdout}\n" + f"stderr:\n{proc.stderr}" + ) + return proc + + +class TestDryRun: + def test_master_rank_0_emits_mooncake_config(self, tmp_path): + # Copy script to tmp so we can inspect CFG_DIR artefacts without + # polluting the repo. + proc = _run([ + "--nnodes", "2", + "--node-rank", "0", + "--master-ip", "10.0.0.1", + "--tp-size", "8", + "--pp-size", "1", + "--cp-size", "4", + "--model", "/tmp/fake-model", + "--redis-host", "10.0.0.1", + "--redis-password", "x", + "--rdma-device", "mlx5_0", + "--instance-id", "unit-test-master", + "--dry-run", + ]) + assert "node_rank=0/2" in proc.stdout + assert "need_full_sd_remote: true" in proc.stdout + assert "instance_id : unit-test-master" in proc.stdout + + # Check the generated mooncake config looks like valid JSON. + cfg = (REPO_ROOT / "scripts" / "multi-nodes" / "gen" + / "dist_reuse_node0" / "mooncake_config.json") + assert cfg.exists(), f"expected mooncake config at {cfg}" + data = json.loads(cfg.read_text()) + assert data["metadata_backend"] == "redis" + assert data["device_name"] == "mlx5_0" + assert data["engine_port"] == 12345 # base (+ node_rank 0) + + def test_cp_only_offmaster_emits_empty_config(self): + """CP-only cross-node instance: node_rank=1 is CP-peer stub, + mooncake config must be empty (sentinel for connector's + ``CP_PEER_REGISTRATION_ONLY`` path).""" + proc = _run([ + "--nnodes", "2", + "--node-rank", "1", + "--master-ip", "10.0.0.1", + "--tp-size", "4", # TP fits on one node (≤ 8 gpus) + "--pp-size", "1", + "--cp-size", "2", # CP = 2, cross-node + "--model", "/tmp/fake-model", + "--redis-host", "10.0.0.1", + "--redis-password", "x", + "--dry-run", + ]) + assert "node_rank=1/2" in proc.stdout + assert "is_multinode_cp : true" in proc.stdout + assert "need_full_sd_remote: false" in proc.stdout + + cfg = (REPO_ROOT / "scripts" / "multi-nodes" / "gen" + / "dist_reuse_node1" / "mooncake_config.json") + assert cfg.exists() + # Empty sentinel — bytes length must be 0 to match the + # connector's "no mooncake here" contract. + assert cfg.read_text() == "" + + def test_tp_cross_node_offmaster_emits_full_config(self): + proc = _run([ + "--nnodes", "2", + "--node-rank", "1", + "--master-ip", "10.0.0.1", + "--tp-size", "16", # > gpus_per_node default 8 → tp_node_count=2 + "--pp-size", "1", + "--cp-size", "1", + "--model", "/tmp/fake-model", + "--redis-host", "10.0.0.1", + "--redis-password", "x", + "--dry-run", + ]) + assert "is_multinode_tp : true" in proc.stdout + assert "need_full_sd_remote: true" in proc.stdout + + cfg = (REPO_ROOT / "scripts" / "multi-nodes" / "gen" + / "dist_reuse_node1" / "mooncake_config.json") + assert cfg.exists() and cfg.stat().st_size > 0 + data = json.loads(cfg.read_text()) + assert data["engine_port"] == 12346 # base + 1 + + +class TestValidation: + def test_missing_required_arg_exits_nonzero(self): + proc = _run( + ["--nnodes", "2", "--dry-run"], # missing --node-rank etc. + check_returncode=False, + ) + assert proc.returncode != 0 + + def test_nnodes_greater_than_two_rejected(self): + """Current dist_reuse deployment constraint (§3.3) — fail loudly + rather than silently proceed.""" + proc = _run([ + "--nnodes", "3", + "--node-rank", "0", + "--master-ip", "10.0.0.1", + "--tp-size", "8", "--pp-size", "1", "--cp-size", "1", + "--model", "/tmp/x", "--redis-host", "10.0.0.1", + "--dry-run", + ], check_returncode=False) + assert proc.returncode != 0 + assert "supports <= 2 physical nodes" in proc.stdout or \ + "supports <= 2 physical nodes" in proc.stderr + + def test_bad_node_rank_rejected(self): + proc = _run([ + "--nnodes", "2", + "--node-rank", "2", # out of range + "--master-ip", "10.0.0.1", + "--tp-size", "8", "--pp-size", "1", "--cp-size", "1", + "--model", "/tmp/x", "--redis-host", "10.0.0.1", + "--dry-run", + ], check_returncode=False) + assert proc.returncode != 0 diff --git a/tests/test_evict_refcount_guard.py b/tests/test_evict_refcount_guard.py new file mode 100644 index 0000000000..521680a950 --- /dev/null +++ b/tests/test_evict_refcount_guard.py @@ -0,0 +1,162 @@ +"""§2.2 — Eviction refcount guard (RadixTreeIndex layer). + +These tests exercise the ``is_evictable_fn`` predicate that +:meth:`RadixTreeIndex.evict` now accepts. The guard is how +``GlobalCacheEngine`` plumbs ``MasterCoordinator.is_evictable`` down +to the eviction path so that blocks pinned by an in-flight coord GET +(refcount > 0) are never recycled. + +See docs/dist_reuse/implementation_gap_2026-05-11.md §2.2. + +Kept dependency-free on purpose — only ``numpy`` + stdlib. The +Python ``RadixTreeIndex`` is not coupled to torch / c_ext, so we can +import it directly. +""" +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +# Importing radixtree pulls in ``torch`` via the module header. We +# skip the whole file when torch isn't available; the real test +# environment (`flexkv_distreuse` container) always has it. +torch = pytest.importorskip("torch") + +from flexkv.cache.radixtree import RadixTreeIndex # noqa: E402 +from flexkv.common.block import SequenceMeta # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +TOKENS_PER_BLOCK = 16 + + +def _seq(tokens: int, salt: int = 0) -> SequenceMeta: + """Build a minimal SequenceMeta with deterministic but distinct + token ids (salted so that different inserts don't collide in the + radix tree).""" + token_ids = np.arange(tokens, dtype=np.int64) + salt * 1_000_000 + return SequenceMeta( + token_ids=token_ids, + tokens_per_block=TOKENS_PER_BLOCK, + ) + + +def _insert(idx: RadixTreeIndex, token_count: int, phys_start: int) -> np.ndarray: + """Insert ``token_count`` tokens (multiple of TOKENS_PER_BLOCK) + mapping to consecutive physical block ids starting at + ``phys_start``. Returns the physical block array for later + cross-checking. + """ + assert token_count % TOKENS_PER_BLOCK == 0 + num_blocks = token_count // TOKENS_PER_BLOCK + phys = np.arange(phys_start, phys_start + num_blocks, dtype=np.int64) + sm = _seq(token_count, salt=phys_start) + sm.gen_hashes() + idx.insert(sm, phys, num_insert_blocks=num_blocks, is_ready=True) + return phys + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +def test_evict_without_guard_is_legacy_behaviour(): + """Default call ``evict(n)`` (no predicate) must behave exactly as + before — zero regression for all existing callers.""" + idx = RadixTreeIndex(tokens_per_block=TOKENS_PER_BLOCK) + p1 = _insert(idx, TOKENS_PER_BLOCK * 3, phys_start=0) # blocks 0..2 + p2 = _insert(idx, TOKENS_PER_BLOCK * 2, phys_start=100) # blocks 100..101 + + evicted, _ = idx.evict(4) + assert len(evicted) == 4 + # The evicted set must be a subset of the blocks we inserted. + assert set(evicted.tolist()).issubset(set(p1.tolist()) | set(p2.tolist())) + + +def test_evict_guard_pins_blocks_with_refcount_gt_zero(): + """When the predicate returns False for a subset of block ids, + those ids must NOT appear in the returned ``evicted_blocks``.""" + idx = RadixTreeIndex(tokens_per_block=TOKENS_PER_BLOCK) + p1 = _insert(idx, TOKENS_PER_BLOCK * 3, phys_start=0) # 0,1,2 + p2 = _insert(idx, TOKENS_PER_BLOCK * 3, phys_start=10) # 10,11,12 + + pinned = {1, 11} + + def guard(block_id: int) -> bool: + return block_id not in pinned + + evicted, _ = idx.evict(6, is_evictable_fn=guard) + # None of the pinned blocks should ever be evicted. + assert pinned.isdisjoint(set(evicted.tolist())) + # The guard does not increase capacity; at most we can return + # ``total_blocks - len(pinned)`` = 4 blocks. + assert len(evicted) <= 4 + + +def test_evict_guard_all_pinned_returns_empty(): + """If every candidate is pinned, ``evict`` must return an empty + array (no crash, no partial exception).""" + idx = RadixTreeIndex(tokens_per_block=TOKENS_PER_BLOCK) + _insert(idx, TOKENS_PER_BLOCK * 2, phys_start=0) + + def guard(_block_id: int) -> bool: + return False + + evicted, block_hashes = idx.evict(10, is_evictable_fn=guard) + assert evicted.size == 0 + assert block_hashes.size == 0 + + +def test_evict_guard_exception_falls_back_to_allow(): + """A buggy guard must not wedge eviction. Design intent: defensive + fallback logs + treats the block as evictable (safer than dead- + locking the allocator). + """ + idx = RadixTreeIndex(tokens_per_block=TOKENS_PER_BLOCK) + _insert(idx, TOKENS_PER_BLOCK * 2, phys_start=0) + + def bad_guard(_block_id: int) -> bool: + raise RuntimeError("boom") + + evicted, _ = idx.evict(1, is_evictable_fn=bad_guard) + # Exception-fallback treats blocks as evictable, so we still get a + # non-empty eviction set. + assert len(evicted) >= 1 + + +def test_evict_guard_shrink_path_respects_pins(): + """Regression: the "node.size() > remaining" branch (``shrink``) + must also drop pinned ids from its partial output.""" + idx = RadixTreeIndex(tokens_per_block=TOKENS_PER_BLOCK) + # One long leaf — eviction will hit the shrink branch on it. + p = _insert(idx, TOKENS_PER_BLOCK * 5, phys_start=0) # 0..4 + + # Pin block id 2 (in the middle). + pinned = {2} + + def guard(block_id: int) -> bool: + return block_id not in pinned + + evicted, _ = idx.evict(3, is_evictable_fn=guard) + assert pinned.isdisjoint(set(evicted.tolist())) + # And we should still get up to 3 of the remaining 4 evictable + # blocks (whichever ones shrink picks). + assert len(evicted) <= 3 + + +def test_evict_guard_kwarg_is_optional_for_callers(): + """Call sites that haven't been ported yet must keep working — + ``is_evictable_fn`` is a keyword-only optional parameter.""" + idx = RadixTreeIndex(tokens_per_block=TOKENS_PER_BLOCK) + _insert(idx, TOKENS_PER_BLOCK, phys_start=0) + # No predicate kw at all. + evicted, _ = idx.evict(1) + assert len(evicted) == 1 diff --git a/tests/test_failure_detector.py b/tests/test_failure_detector.py new file mode 100644 index 0000000000..cdebe200fd --- /dev/null +++ b/tests/test_failure_detector.py @@ -0,0 +1,327 @@ +"""Unit tests for ``flexkv.cache.failure_detector`` (Phase 0 task 0-L). + +Uses an in-memory Redis fake — no real Redis or network required. +""" +from __future__ import annotations + +import json +import threading +import time +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import pytest + +from flexkv.common.dist_reuse.failure_detector import ( + FailureDetector, + InstanceSession, + RedisSessionClient, + make_session_epoch, +) +from flexkv.common.dist_reuse.sharing_domain_namespace import SharingDomainNamespace + + +# --------------------------------------------------------------------------- +# In-memory Redis fake +# --------------------------------------------------------------------------- +class FakeRedis: + """Just enough of the redis-py surface for FailureDetector / RedisSessionClient.""" + + def __init__(self, time_fn=time.monotonic) -> None: + self._store: Dict[str, str] = {} + self._expiry: Dict[str, float] = {} + self._time = time_fn + self._lock = threading.RLock() + + # ---- core ops ---- + def set(self, name: str, value: str, ex: Optional[int] = None) -> bool: + with self._lock: + self._store[name] = value + if ex is not None: + self._expiry[name] = self._time() + float(ex) + else: + self._expiry.pop(name, None) + return True + + def get(self, name: str) -> Optional[str]: + with self._lock: + self._gc(name) + return self._store.get(name) + + def expire(self, name: str, ex: int) -> bool: + with self._lock: + self._gc(name) + if name not in self._store: + return False + self._expiry[name] = self._time() + float(ex) + return True + + def delete(self, *names: str) -> int: + with self._lock: + n = 0 + for k in names: + if k in self._store: + self._store.pop(k, None) + self._expiry.pop(k, None) + n += 1 + return n + + def scan_iter(self, match: Optional[str] = None, count: Optional[int] = None) -> Iterable[str]: + with self._lock: + # GC any expired keys before scanning. + for k in list(self._store): + self._gc(k) + keys = list(self._store) + if match is None: + return iter(keys) + # Translate Redis glob '*' → fnmatch. We only ever use "prefix*suffix" patterns. + import fnmatch + return iter(k for k in keys if fnmatch.fnmatchcase(k, match)) + + # ---- helpers ---- + def _gc(self, name: str) -> None: + exp = self._expiry.get(name) + if exp is not None and self._time() >= exp: + self._store.pop(name, None) + self._expiry.pop(name, None) + + def force_expire(self, name: str) -> None: + """Test helper — drop a key as if its TTL had elapsed.""" + with self._lock: + self._store.pop(name, None) + self._expiry.pop(name, None) + + +# --------------------------------------------------------------------------- +# Manual clock +# --------------------------------------------------------------------------- +class ManualClock: + def __init__(self, start: float = 0.0) -> None: + self.now = start + + def __call__(self) -> float: + return self.now + + def advance(self, dt: float) -> None: + self.now += dt + + +# --------------------------------------------------------------------------- +# make_session_epoch +# --------------------------------------------------------------------------- +class TestSessionEpoch: + def test_format(self): + e = make_session_epoch() + assert isinstance(e, str) + assert "-" in e + ms_part, rand_part = e.split("-", 1) + assert len(ms_part) == 12 + assert len(rand_part) == 8 + # All hex + int(ms_part, 16) + int(rand_part, 16) + + def test_unique_per_call(self): + seen = {make_session_epoch() for _ in range(100)} + assert len(seen) == 100 + + +# --------------------------------------------------------------------------- +# RedisSessionClient +# --------------------------------------------------------------------------- +class TestSessionClient: + def _client(self, **overrides): + clock = ManualClock() + fake = FakeRedis(time_fn=clock) + kwargs = dict( + instance_id="inst-A", + epoch="epoch-1", + ttl_seconds=5, + master_zmq_addr="tcp://10.0.0.1:6666", + node_ids=[1, 2, 3], + mooncake_addrs_by_sd={"sd0": "10.0.0.1:5555"}, + ) + kwargs.update(overrides) + return RedisSessionClient(fake, **kwargs), fake, clock + + def test_register_writes_payload(self): + sc, fake, _ = self._client() + sc.register() + raw = fake.get(sc.key) + assert raw is not None + payload = json.loads(raw) + assert payload["instance_id"] == "inst-A" + assert payload["epoch"] == "epoch-1" + assert payload["node_ids"] == [1, 2, 3] + assert payload["mooncake_addrs_by_sd"] == {"sd0": "10.0.0.1:5555"} + + def test_renew_extends_ttl(self): + sc, fake, clock = self._client(ttl_seconds=5) + sc.register() + clock.advance(3.0) + sc.renew() + clock.advance(3.0) # 6s total since register, but renewed at 3s → still alive + assert fake.get(sc.key) is not None + + def test_renew_revives_expired_key(self): + sc, fake, clock = self._client(ttl_seconds=2) + sc.register() + clock.advance(5.0) + # Key has expired (lazy GC inside fake.get). + assert fake.get(sc.key) is None + # renew() should fall back to register(). + sc.renew() + assert fake.get(sc.key) is not None + + def test_unregister(self): + sc, fake, _ = self._client() + sc.register() + sc.unregister() + assert fake.get(sc.key) is None + + def test_bad_ttl(self): + with pytest.raises(ValueError): + RedisSessionClient(FakeRedis(), instance_id="x", epoch="e", ttl_seconds=0) + + +# --------------------------------------------------------------------------- +# FailureDetector +# --------------------------------------------------------------------------- +class TestFailureDetector: + def _seed(self, fake: FakeRedis, instance_id: str, epoch: str, ttl: int = 60): + key = SharingDomainNamespace.instance_session_key(instance_id) + payload = { + "instance_id": instance_id, + "epoch": epoch, + "master_zmq_addr": "tcp://x:1", + "node_ids": [], + "mooncake_addrs_by_sd": {}, + } + fake.set(key, json.dumps(payload), ex=ttl) + + def _detector(self, *, lost_log: List[str], seen_log: List[Tuple[str, str]]): + clock = ManualClock() + fake = FakeRedis(time_fn=clock) + + def on_lost(pid: str) -> None: + lost_log.append(pid) + + def on_seen(pid: str, session: InstanceSession) -> None: + seen_log.append((pid, session.epoch)) + + fd = FailureDetector( + fake, + self_instance_id="self", + poll_interval_seconds=0.5, + on_peer_lost=on_lost, + on_peer_seen=on_seen, + time_fn=clock, + ) + return fd, fake, clock + + def test_detects_new_peer(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, fake, _ = self._detector(lost_log=lost, seen_log=seen) + self._seed(fake, "peer-A", "e1") + fd.poll_once() + assert seen == [("peer-A", "e1")] + assert lost == [] + # Second poll: no new event for the same peer/epoch. + seen.clear() + fd.poll_once() + assert seen == [] + + def test_detects_disappearance(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, fake, _ = self._detector(lost_log=lost, seen_log=seen) + self._seed(fake, "peer-A", "e1") + fd.poll_once() + # Simulate TTL expiry. + fake.force_expire(SharingDomainNamespace.instance_session_key("peer-A")) + fd.poll_once() + assert lost == ["peer-A"] + + def test_detects_epoch_change(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, fake, _ = self._detector(lost_log=lost, seen_log=seen) + self._seed(fake, "peer-A", "e1") + fd.poll_once() + # Restart: same instance_id, new epoch. + self._seed(fake, "peer-A", "e2") + fd.poll_once() + assert lost == ["peer-A"] + # Two seen events: initial appear + post-restart re-appear. + assert seen == [("peer-A", "e1"), ("peer-A", "e2")] + + def test_ignores_self(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, fake, _ = self._detector(lost_log=lost, seen_log=seen) + self._seed(fake, "self", "self-epoch") + fd.poll_once() + assert seen == [] + assert lost == [] + + def test_skips_malformed_payload(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, fake, _ = self._detector(lost_log=lost, seen_log=seen) + # Garbage payload at a valid-looking key. + fake.set( + SharingDomainNamespace.instance_session_key("peer-A"), + "not-json", ex=60, + ) + fd.poll_once() # must not raise + assert seen == [] + assert lost == [] + + def test_invalid_constructor_args(self): + fake = FakeRedis() + with pytest.raises(ValueError): + FailureDetector(fake, "self", poll_interval_seconds=0) + with pytest.raises(ValueError): + FailureDetector(fake, "") + + def test_known_peers_view(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, fake, _ = self._detector(lost_log=lost, seen_log=seen) + self._seed(fake, "peer-A", "e1") + self._seed(fake, "peer-B", "e1") + fd.poll_once() + assert fd.known_peers() == {"peer-A", "peer-B"} + + def test_lifecycle(self): + """Light smoke test of start()/stop() — keeps the polling thread + cycle short and verifies the thread terminates cleanly.""" + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + clock = ManualClock() + fake = FakeRedis(time_fn=clock) + fd = FailureDetector( + fake, "self", + poll_interval_seconds=0.05, + on_peer_lost=lambda pid: lost.append(pid), + on_peer_seen=lambda pid, s: seen.append((pid, s.epoch)), + time_fn=clock, + ) + fd.start() + # Seed AFTER start() so the polling thread observes the event. + self._seed(fake, "peer-X", "e1") + # Give the loop a few iterations. + time.sleep(0.3) + fd.stop(timeout=2.0) + assert "peer-X" in {pid for pid, _ in seen} + + def test_double_start_raises(self): + lost: List[str] = [] + seen: List[Tuple[str, str]] = [] + fd, _, _ = self._detector(lost_log=lost, seen_log=seen) + fd.start() + try: + with pytest.raises(RuntimeError): + fd.start() + finally: + fd.stop(timeout=2.0) diff --git a/tests/test_flexkv_redis_db.py b/tests/test_flexkv_redis_db.py new file mode 100644 index 0000000000..4fdfe4a239 --- /dev/null +++ b/tests/test_flexkv_redis_db.py @@ -0,0 +1,247 @@ +"""Tests for ``CacheConfig.flexkv_redis_db`` wiring. + +Phase D follow-up: verify that the single ``flexkv_redis_db`` config option +flows through every FlexKV Redis client: + + * ``CacheConfig`` defaults + ``FlexKVUserConfig`` override merge. + * ``RedisMeta`` / ``RedisNodeInfo`` constructor accepts ``db=`` and the + raw ``redis-py`` client is bound to that db. + * ``RedisMetaChannel`` (Python wrapper) forwards ``db`` to the C++ ctor. + * ``make_redis_client_from_cache_config`` reads the right attr. + +These tests do **not** require a live Redis — they only check the +Python-visible plumbing. For real-server behaviour (``SELECT `` on +the wire + key isolation across dbs) see ``test_redis_db_integration.py``. +""" +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def _load_module_from_source(path: Path, name: str): + spec = importlib.util.spec_from_file_location(name, str(path)) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + +# Load redis_meta.py directly (avoids flexkv.cache.__init__'s c_ext import). +_rm = _load_module_from_source( + REPO_ROOT / "flexkv" / "cache" / "redis_meta.py", + "_rm_db_wiring_test", +) +RedisMeta = _rm.RedisMeta +RedisNodeInfo = _rm.RedisNodeInfo + + +# Load config directly so we don't depend on whether config.so shadows .py. +_cfg = _load_module_from_source( + REPO_ROOT / "flexkv" / "common" / "config.py", + "_cfg_db_wiring_test", +) +CacheConfig = _cfg.CacheConfig + + +# --------------------------------------------------------------------------- +# CacheConfig default + override +# --------------------------------------------------------------------------- +class TestCacheConfigFlexkvRedisDb: + def test_default_is_zero(self): + cc = CacheConfig() + assert cc.flexkv_redis_db == 0 + + def test_can_set_db_explicitly(self): + cc = CacheConfig(flexkv_redis_db=7) + assert cc.flexkv_redis_db == 7 + + def test_enable_sharing_domain_does_not_affect_db(self): + cc = CacheConfig(enable_p2p_cpu=True, flexkv_redis_db=3) + # Auto-enable kicks in, but the db override stays. + assert cc.enable_sharing_domain is True + assert cc.flexkv_redis_db == 3 + + +# --------------------------------------------------------------------------- +# RedisNodeInfo — db arg reaches the redis-py Redis() constructor +# --------------------------------------------------------------------------- +class TestRedisNodeInfoDbForwarding: + def test_default_db_is_zero(self): + ni = RedisNodeInfo( + host="127.0.0.1", port=6379, local_ip="127.0.0.1", + ) + assert ni.db == 0 + + def test_db_arg_is_stored(self): + ni = RedisNodeInfo( + host="127.0.0.1", port=6379, local_ip="127.0.0.1", + db=15, + ) + assert ni.db == 15 + + def test_get_client_passes_db_kwarg(self): + """_get_client must call redis.Redis(..., db=).""" + ni = RedisNodeInfo( + host="127.0.0.1", port=6379, local_ip="127.0.0.1", + db=7, + ) + with patch.object(_rm, "_redis") as redis_mod_mock: + redis_mod_mock.Redis = MagicMock(return_value="fake_client") + client = ni._get_client() + redis_mod_mock.Redis.assert_called_once() + _, kwargs = redis_mod_mock.Redis.call_args + assert kwargs.get("db") == 7 + assert kwargs.get("host") == "127.0.0.1" + assert kwargs.get("port") == 6379 + assert client == "fake_client" + + +# --------------------------------------------------------------------------- +# RedisMeta — db propagates to nodeinfo AND to its own `_client()` +# --------------------------------------------------------------------------- +class TestRedisMetaDbForwarding: + def test_default_db_is_zero(self): + meta = RedisMeta( + host="127.0.0.1", port=6379, local_ip="127.0.0.1", + ) + assert meta.db == 0 + assert meta.nodeinfo.db == 0 + + def test_db_arg_propagates_to_nodeinfo(self): + meta = RedisMeta( + host="127.0.0.1", port=6379, local_ip="127.0.0.1", + db=11, + ) + assert meta.db == 11 + # Must propagate to the inner RedisNodeInfo so node-heartbeat + # and block-metadata land on the same logical db. + assert meta.nodeinfo.db == 11 + + def test_client_uses_configured_db(self): + """_client() must invoke redis.Redis(..., db=).""" + meta = RedisMeta( + host="127.0.0.1", port=6379, local_ip="127.0.0.1", + db=9, + ) + with patch.object(_rm, "_redis") as redis_mod_mock: + redis_mod_mock.Redis = MagicMock(return_value="fake_client") + meta._client() + _, kwargs = redis_mod_mock.Redis.call_args + assert kwargs.get("db") == 9 + + +# --------------------------------------------------------------------------- +# RedisMetaChannel wrapper — db flows to the C++ constructor +# --------------------------------------------------------------------------- +class TestRedisMetaChannelDbForwarding: + def _make_wrapper(self, **kwargs): + """Build the Python wrapper with the C++ class stubbed out.""" + fake_c = MagicMock() + # By default the stub accepts any arg shape; individual tests + # override this to simulate legacy builds. + with patch.object(_rm, "_CRedisMetaChannel", MagicMock(return_value=fake_c)): + return _rm.RedisMetaChannel(**kwargs), fake_c + + def test_db_is_forwarded_to_cpp_ctor(self): + captured_args = {} + + def fake_ctor(*args): + captured_args["args"] = args + return MagicMock() + + with patch.object(_rm, "_CRedisMetaChannel", side_effect=fake_ctor): + _rm.RedisMetaChannel( + host="h", port=6379, node_id=1, local_ip="127.0.0.1", + blocks_key="sd:xx:CPUB", password="", db=5, + ) + # Last positional arg must be db=5. + assert captured_args["args"][-1] == 5 + assert captured_args["args"][0] == "h" + assert captured_args["args"][1] == 6379 + assert captured_args["args"][4] == "sd:xx:CPUB" + + def test_legacy_cpp_build_accepts_default_db_zero(self): + """Legacy C++ build (6-arg ctor) must still work for db=0.""" + call_counter = {"n": 0} + + def legacy_ctor(*args): + call_counter["n"] += 1 + # First call raises TypeError (simulating 6-arg legacy signature); + # second call (with 6 args) succeeds. + if call_counter["n"] == 1: + raise TypeError("too many arguments") + return MagicMock() + + with patch.object(_rm, "_CRedisMetaChannel", side_effect=legacy_ctor): + w = _rm.RedisMetaChannel( + host="h", port=6379, node_id=1, local_ip="127.0.0.1", + blocks_key="sd:xx:CPUB", password="", db=0, + ) + assert call_counter["n"] == 2 # first raised, second succeeded + assert w._db == 0 + + def test_legacy_cpp_build_rejects_nonzero_db(self): + """db != 0 on a legacy build must raise ImportError loudly.""" + def legacy_ctor(*args): + raise TypeError("too many arguments") + + with patch.object(_rm, "_CRedisMetaChannel", side_effect=legacy_ctor): + with pytest.raises(ImportError, match="rebuild FlexKV"): + _rm.RedisMetaChannel( + host="h", port=6379, node_id=1, local_ip="127.0.0.1", + blocks_key="sd:xx:CPUB", password="", db=5, + ) + + +# --------------------------------------------------------------------------- +# make_redis_client_from_cache_config — single source of truth +# --------------------------------------------------------------------------- +class TestMakeRedisClientHelper: + def test_helper_reads_flexkv_redis_db(self): + from flexkv.common.dist_reuse import make_redis_client_from_cache_config + + class _FakeCfg: + redis_host = "1.2.3.4" + redis_port = 6400 + redis_password = "secret" + flexkv_redis_db = 13 + + import redis as _r + with patch.object(_r, "Redis") as mock_redis: + mock_redis.return_value = "ok" + client = make_redis_client_from_cache_config(_FakeCfg()) + assert client == "ok" + _, kwargs = mock_redis.call_args + assert kwargs["host"] == "1.2.3.4" + assert kwargs["port"] == 6400 + assert kwargs["db"] == 13 + assert kwargs["password"] == "secret" + assert kwargs["decode_responses"] is True + + def test_helper_falls_back_when_attr_missing(self): + """Duck-typed config without flexkv_redis_db → db=0.""" + from flexkv.common.dist_reuse import make_redis_client_from_cache_config + + class _MinimalCfg: + redis_host = "h" + redis_port = 6379 + redis_password = None + # no flexkv_redis_db attr + + import redis as _r + with patch.object(_r, "Redis") as mock_redis: + mock_redis.return_value = "ok" + make_redis_client_from_cache_config(_MinimalCfg()) + _, kwargs = mock_redis.call_args + assert kwargs["db"] == 0 + assert "password" not in kwargs # None → omitted diff --git a/tests/test_master_coordinator.py b/tests/test_master_coordinator.py new file mode 100644 index 0000000000..a4540e600d --- /dev/null +++ b/tests/test_master_coordinator.py @@ -0,0 +1,422 @@ +"""Unit tests for ``flexkv.common.dist_reuse.master_coordinator`` and +``remote_init`` (Phase 0 Batch C — tasks 0-F / 0-G / 0-H-integration / 0-K). + +These tests stay 100% at the pure-Python layer — no transfer_manager, +kvtask, or cache_engine imports. Instead they exercise the three helper +modules that Batch C introduces and the `TransferManagerHandle` wiring +through a minimal stub Master/Remote handshake. +""" +from __future__ import annotations + +import sys +import unittest.mock as mock +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List + +import pytest + +from flexkv.common.dist_reuse import ( + BootstrapResult, + MasterCoordinator, + RemoteDistReuseInitializer, + RemoteReadyMsg, + SharingDomainHandleSpec, + SharingDomainKey, + SharingDomainNamespace, + build_sharing_domain_handles, + find_endpoint_for_sd, + graph_needs_gpu_clear, + make_session_epoch, +) + +sys.path.insert(0, str(Path(__file__).parent)) +from _dist_reuse_fakes import FakeRedis, ManualClock # noqa: E402 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +def _sd(ppn_idx=0, ppn_count=1, tpn_idx=0, tpn_count=1, is_nsa=False): + """Build an SD key with the simplified, node-granularity PP schema.""" + return SharingDomainKey( + model_id="m", + pp_node_idx=ppn_idx, pp_node_count=ppn_count, + tp_node_idx=tpn_idx, tp_node_count=tpn_count, + is_nsa=is_nsa, + ) + + +@dataclass +class _StubEndpoint: + ip: str + gpu_register_port: str + command_port: str + result_port: str + + +# --------------------------------------------------------------------------- +# build_sharing_domain_handles +# --------------------------------------------------------------------------- +class TestBuildHandles: + def test_master_only(self): + specs = build_sharing_domain_handles( + self_sd=_sd(), # 1 SD total + remote_endpoints_by_sd={}, + ) + assert len(specs) == 1 + assert specs[0].mode == "process" + assert specs[0].sd_key.is_master() + assert specs[0].endpoint is None + + def test_multi_sd(self): + self_sd = _sd(ppn_count=2, tpn_count=2) # 4 SDs total (pp_node_count × tpn) + # Build endpoints for all 3 non-master SDs. + endpoints = {} + for peer in self_sd.enumerate_peers(): + if peer == self_sd: + continue + endpoints[peer.serialize()] = _StubEndpoint( + ip=f"10.0.{peer.pp_node_idx}.{peer.tp_node_idx}", + gpu_register_port="6001", + command_port="6002", + result_port="6003", + ) + specs = build_sharing_domain_handles( + self_sd=self_sd, remote_endpoints_by_sd=endpoints, + ) + assert len(specs) == 4 + # Master first. + assert specs[0].mode == "process" + assert specs[0].sd_key.is_master() + # All others remote. + for spec in specs[1:]: + assert spec.mode == "remote" + assert spec.endpoint is not None + assert spec.endpoint.ip.startswith("10.0.") + + def test_missing_endpoint_raises(self): + with pytest.raises(KeyError, match="missing endpoint"): + build_sharing_domain_handles( + self_sd=_sd(ppn_count=2), + remote_endpoints_by_sd={}, + ) + + +class TestFindEndpointForSd: + def test_found(self): + sd = _sd(ppn_idx=1, ppn_count=2) + ep = _StubEndpoint("10.0.0.1", "1", "2", "3") + cc = mock.MagicMock(remote_endpoints_by_sd={sd.serialize(): ep}) + assert find_endpoint_for_sd(cc, sd) is ep + + def test_missing_raises(self): + cc = mock.MagicMock(remote_endpoints_by_sd={}) + with pytest.raises(KeyError): + find_endpoint_for_sd(cc, _sd(ppn_idx=1, ppn_count=2)) + + +# --------------------------------------------------------------------------- +# graph_needs_gpu_clear +# --------------------------------------------------------------------------- +class TestGraphGpuClear: + @pytest.fixture + def self_sd(self): + return _sd() # master + + def test_same_sd(self, self_sd): + assert graph_needs_gpu_clear(self_sd, self_sd) is False + + def test_pp_differs(self, self_sd): + peer = _sd(ppn_idx=1, ppn_count=2) + assert graph_needs_gpu_clear(self_sd, peer) is True + + # Note: under the simplified design CP is not part of the SD key, so + # there is no "cp_differs" case to test here — the dispatch decision + # for CP is handled by the connector-layer sync_leader scatter, not + # by the Master→Remote graph clear predicate. + + def test_tp_node_differs_only(self, self_sd): + # tp_node_idx split alone doesn't force a clear (same slot_mapping). + peer = _sd(tpn_idx=1, tpn_count=2) + assert graph_needs_gpu_clear(self_sd, peer) is False + + def test_pp_and_tp_both_differ(self, self_sd): + peer = _sd(ppn_idx=1, ppn_count=2, tpn_idx=1, tpn_count=2) + assert graph_needs_gpu_clear(self_sd, peer) is True + + +# --------------------------------------------------------------------------- +# MasterCoordinator +# --------------------------------------------------------------------------- +class TestMasterCoordinator: + def test_smoke(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(1) + assert not mc.all_remotes_ready() + + def test_on_remote_ready_completes(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(1) + peer_sd = _sd(ppn_idx=1, ppn_count=2) + msg = RemoteReadyMsg( + sender_instance_id="inst-A", + sender_epoch="e1", + request_id=-1, + sd_key=peer_sd.serialize(), + distributed_node_id=42, + mooncake_addr="10.0.0.1:5555", + zmq_addr="tcp://10.0.0.1:6666", + ) + completed = mc.on_remote_ready(msg) + assert completed is True + assert mc.all_remotes_ready() + + def test_build_sd_to_nid(self): + sd = _sd(ppn_count=2, tpn_count=2) # 4 SDs total + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(3) + # Seed fake remote ack'ed SDs. + for peer in sd.enumerate_peers(): + if peer == sd: + continue + mc.on_remote_ready(RemoteReadyMsg( + sender_instance_id="inst-A", + sender_epoch="e1", + sd_key=peer.serialize(), + distributed_node_id=100 + hash(peer.serialize()) % 1000, + )) + mapping = mc.build_sd_to_nid_map(self_node_id=1) + assert len(mapping) == 4 + assert mapping[sd.serialize()] == 1 + # All other keys present too. + for peer in sd.enumerate_peers(): + assert peer.serialize() in mapping + + def test_expect_remotes_before_on_ready(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + # Forgot to call expect_remotes first. + msg = RemoteReadyMsg( + sd_key=_sd(ppn_idx=1, ppn_count=2).serialize(), + ) + with pytest.raises(RuntimeError, match="expect_remotes"): + mc.on_remote_ready(msg) + + def test_aggregate_radix_hooks(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(1) + mc.acquire_blocks([1, 2, 3]) + assert not mc.is_evictable(1) + mc.release_blocks([1, 2, 3]) + assert mc.is_evictable(1) + + def test_mark_sd_ready_flow(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(1) + # Only master SD has acked. + ok1 = mc.mark_sd_ready(0xAA, sd.serialize(), [10, 20]) + assert ok1 is False # not yet fully ready + # Second SD acks. + ok2 = mc.mark_sd_ready(0xAA, "peer-sd-key", [10, 20]) + assert ok2 is True + assert mc.match_fully_ready(0xAA) is not None + + def test_invalidate_prefix(self): + sd = _sd() # single SD + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(0) + mc.mark_sd_ready(0xAA, sd.serialize(), [10]) + assert mc.invalidate_prefix(0xAA) is True + assert mc.match_fully_ready(0xAA) is None + + def test_scan_leaked_refcount(self): + sd = _sd() + clock = ManualClock() + mc = MasterCoordinator( + self_sd=sd, instance_id="inst-A", + refcount_leak_timeout_seconds=10.0, + ) + # Inject our manual clock into the aggregate. + mc._aggregate._time_fn = clock # direct attribute poke for testing + mc.expect_remotes(0) + mc.acquire_blocks([7, 8]) + clock.advance(20.0) + leaked = mc.scan_leaked_refcount() + assert sorted(leaked) == [7, 8] + # Force-released: now evictable. + assert mc.is_evictable(7) + assert mc.is_evictable(8) + + def test_peer_loss_invalidates(self): + sd = _sd() + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(0) + mc.mark_sd_ready(0xAA, sd.serialize(), [10], contributing_peer="peer-X") + # Simulate failure-detector callback. + mc._on_peer_lost("peer-X") + assert mc.match_fully_ready(0xAA) is None + + def test_register_instance_discoverables(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(1) + peer_sd = _sd(ppn_idx=1, ppn_count=2) + mc.on_remote_ready(RemoteReadyMsg( + sender_instance_id="inst-A", + sender_epoch="e1", + sd_key=peer_sd.serialize(), + distributed_node_id=99, + )) + # Stub RedisMeta + fake_redis = FakeRedis() + stub_redis_meta = mock.MagicMock() + stub_redis_meta._client = lambda: fake_redis + stub_redis_meta.register_instance_sd_nodes = mock.MagicMock() + + mc.register_instance_discoverables( + redis_meta=stub_redis_meta, + self_node_id=1, + master_zmq_addr="tcp://master:5555", + ttl_seconds=10, + ) + stub_redis_meta.register_instance_sd_nodes.assert_called_once() + args, _ = stub_redis_meta.register_instance_sd_nodes.call_args + pid, mapping = args + assert pid == "inst-A" + assert mapping[sd.serialize()] == 1 + assert mapping[peer_sd.serialize()] == 99 + + def test_register_discoverables_before_remotes_raises(self): + sd = _sd(ppn_count=2) + mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") + mc.expect_remotes(1) # but never acked + with pytest.raises(RuntimeError, match="requires all Remotes"): + mc.register_instance_discoverables( + redis_meta=mock.MagicMock(_client=lambda: FakeRedis()), + self_node_id=1, + master_zmq_addr="tcp://x:1", + ) + + +# --------------------------------------------------------------------------- +# RemoteDistReuseInitializer +# --------------------------------------------------------------------------- +class TestRemoteDistReuseInitializer: + def _stub_cache_config(self): + return mock.MagicMock( + redis_host="127.0.0.1", redis_port=6379, redis_password=None, + local_ip="10.0.0.2", node_ttl_seconds=10, + mooncake_config_path="/tmp/mc.json", + ) + + def _stub_mooncake(self): + engine = mock.MagicMock() + engine.regist_buffer = mock.MagicMock(return_value=0) + engine.get_engine_addr = mock.MagicMock(return_value="10.0.0.2:5555") + return engine + + def _stub_redis_meta(self, node_id=42): + redis_meta = mock.MagicMock() + redis_meta.init_meta = mock.MagicMock(return_value=node_id) + redis_meta.regist_buffer = mock.MagicMock(return_value=1) + redis_meta.regist_node_meta = mock.MagicMock() + return redis_meta + + def test_bootstrap_happy_path(self): + sd = _sd(ppn_idx=1, ppn_count=2) + redis_meta = self._stub_redis_meta(node_id=42) + mooncake = self._stub_mooncake() + init = RemoteDistReuseInitializer( + cache_config=self._stub_cache_config(), + sd_key_str=sd.serialize(), + instance_id="inst-A", + session_epoch="epoch-1", + cpu_buffer_ptr=0x1000, + cpu_buffer_size=4096, + local_zmq_addr="tcp://10.0.0.2:6666", + redis_meta_factory=lambda cc, ns: redis_meta, + mooncake_engine_factory=lambda cc: mooncake, + ) + result = init.bootstrap() + assert isinstance(result, BootstrapResult) + assert result.distributed_node_id == 42 + assert result.sd_key == sd + assert result.redis_meta is redis_meta + assert result.mooncake_engine is mooncake + assert result.ready_msg.sd_key == sd.serialize() + assert result.ready_msg.distributed_node_id == 42 + assert result.ready_msg.mooncake_addr == "10.0.0.2:5555" + + # Side effects + redis_meta.init_meta.assert_called_once() + mooncake.regist_buffer.assert_called_once_with(0x1000, 4096) + redis_meta.regist_buffer.assert_called_once() + redis_meta.regist_node_meta.assert_called_once() + + def test_redis_init_failure_raises(self): + sd = _sd() + redis_meta = self._stub_redis_meta(node_id=None) + redis_meta.get_init_error = mock.MagicMock(return_value=RuntimeError("boom")) + init = RemoteDistReuseInitializer( + cache_config=self._stub_cache_config(), + sd_key_str=sd.serialize(), + instance_id="inst-A", + session_epoch="e1", + cpu_buffer_ptr=0x1000, + cpu_buffer_size=4096, + local_zmq_addr="tcp://x:1", + redis_meta_factory=lambda cc, ns: redis_meta, + mooncake_engine_factory=lambda cc: self._stub_mooncake(), + ) + with pytest.raises(RuntimeError, match="init_meta"): + init.bootstrap() + + def test_mooncake_without_regist_buffer_raises(self): + sd = _sd() + bad_mooncake = mock.MagicMock(spec=[]) # no attributes + init = RemoteDistReuseInitializer( + cache_config=self._stub_cache_config(), + sd_key_str=sd.serialize(), + instance_id="inst-A", + session_epoch="e1", + cpu_buffer_ptr=0x1000, + cpu_buffer_size=4096, + local_zmq_addr="tcp://x:1", + redis_meta_factory=lambda cc, ns: self._stub_redis_meta(), + mooncake_engine_factory=lambda cc: bad_mooncake, + ) + with pytest.raises(AttributeError, match="regist_buffer"): + init.bootstrap() + + def test_encode_ready(self): + msg = RemoteReadyMsg( + sender_instance_id="inst", sender_epoch="e", + sd_key="m:ppn0/1:tpn0/1:nsa0", + distributed_node_id=1, + ) + out = RemoteDistReuseInitializer.encode_ready(msg) + assert out["type"] == "remote_ready" + assert out["sd_key"] == "m:ppn0/1:tpn0/1:nsa0" + + +# --------------------------------------------------------------------------- +# SharingDomainHandleSpec +# --------------------------------------------------------------------------- +class TestHandleSpec: + def test_basic(self): + sd = _sd() + spec = SharingDomainHandleSpec(sd_key=sd, mode="process") + assert spec.mode == "process" + assert spec.endpoint is None + + def test_remote_mode_with_endpoint(self): + sd = _sd(ppn_idx=1, ppn_count=2) + ep = _StubEndpoint("10.0.0.1", "1", "2", "3") + spec = SharingDomainHandleSpec(sd_key=sd, mode="remote", endpoint=ep) + assert spec.mode == "remote" + assert spec.endpoint is ep diff --git a/tests/test_metrics_dist_reuse.py b/tests/test_metrics_dist_reuse.py new file mode 100644 index 0000000000..a2ba95811a --- /dev/null +++ b/tests/test_metrics_dist_reuse.py @@ -0,0 +1,153 @@ +"""Smoke tests for the dist-reuse Prometheus metrics added in 2026-05-14. + +These tests validate the **wiring**: + * Collector exposes the 5 new metrics. + * Dummy fallback works when prometheus_client is missing OR + ``FLEXKV_ENABLE_METRICS`` is unset. + * Recording methods are no-op safe under both paths. + +The tests do NOT validate that the worker.py call site actually emits +samples — that requires a running mooncake engine and is covered by the +e2e harness in ``tests/multinode/``. + +See ``docs/dist_reuse/METRICS_dist_reuse.md`` for the full operator +context. +""" + +from __future__ import annotations + +import os +import unittest +from unittest import mock + +# Force-disable the metrics server / prometheus client to exercise the +# dummy path first, then re-enable for the real path. Both must work. + + +class TestDistReuseMetricsDummyPath(unittest.TestCase): + """When metrics are disabled (default), every record_* method must be + a silent no-op — we cannot raise from the data path.""" + + @mock.patch.dict(os.environ, {"FLEXKV_ENABLE_METRICS": "0"}) + def setUp(self): + # Re-import in a clean namespace so the env var takes effect. + import importlib + + from flexkv.common import config as _cfg + importlib.reload(_cfg) + from flexkv.metrics import collector as _coll + importlib.reload(_coll) + self._coll_mod = _coll + self.collector = _coll.FlexKVMetricsCollector() + + def test_collector_disabled(self): + self.assertFalse(self.collector.enabled) + + def test_lease_nullptr_record_is_noop(self): + # Must not raise even with negative / zero counts + self.collector.record_dist_reuse_lease_nullptr("cpu", 5) + self.collector.record_dist_reuse_lease_nullptr("cpu", 0) + self.collector.record_dist_reuse_lease_nullptr("ssd", -1) + + def test_about_to_evict_record_is_noop(self): + self.collector.record_dist_reuse_about_to_evict("cpu", 100) + self.collector.record_dist_reuse_about_to_evict("ssd", 0) + + def test_mooncake_read_observe_is_noop(self): + self.collector.observe_dist_reuse_peer_mooncake_read( + 0.0123, success=True + ) + self.collector.observe_dist_reuse_peer_mooncake_read( + 0.5, success=False, reason="mooncake_error" + ) + self.collector.observe_dist_reuse_peer_mooncake_read( + 0.001, success=False, reason="zero_byte_transfer" + ) + + +class TestDistReuseMetricsEnabledPath(unittest.TestCase): + """When ``FLEXKV_ENABLE_METRICS=1`` and ``prometheus_client`` is + installed, the metric names must register on the default registry and + record/observe must mutate the underlying samples.""" + + @mock.patch.dict(os.environ, {"FLEXKV_ENABLE_METRICS": "1"}) + def setUp(self): + try: + import prometheus_client # noqa: F401 + except ImportError: + self.skipTest("prometheus_client not installed") + + # Reload modules so the env var is picked up and the registry is + # populated cleanly for this test (each setUp creates its own + # collector with its own metric objects via re-registration). + import importlib + + from flexkv.common import config as _cfg + importlib.reload(_cfg) + from flexkv.metrics import collector as _coll + importlib.reload(_coll) + + # prometheus_client uses a global default REGISTRY; the second + # _init_metrics call would raise on duplicate registration. Use a + # fresh CollectorRegistry to avoid global pollution. + from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram + self._registry = CollectorRegistry() + + # Patch the module-level Counter/Gauge/Histogram to use our private + # registry so we don't pollute the default one across tests. + with mock.patch.object(_coll, "Counter", + lambda **kw: Counter(registry=self._registry, **kw)), \ + mock.patch.object(_coll, "Gauge", + lambda **kw: Gauge(registry=self._registry, **kw)), \ + mock.patch.object(_coll, "Histogram", + lambda **kw: Histogram(registry=self._registry, **kw)): + self.collector = _coll.FlexKVMetricsCollector() + + def test_collector_enabled(self): + self.assertTrue(self.collector.enabled) + + def test_dist_reuse_metrics_attributes_exist(self): + # The 5 new attributes must be present whether enabled or dummy. + for attr in ( + "dist_reuse_lease_meta_nullptr_total", + "dist_reuse_about_to_evict_total", + "dist_reuse_peer_mooncake_read_seconds", + "dist_reuse_peer_mooncake_read_failures_total", + "dist_reuse_peer_mooncake_read_success_total", + ): + self.assertTrue( + hasattr(self.collector, attr), + f"collector missing metric attribute: {attr}", + ) + + def test_lease_nullptr_increments_under_enabled(self): + before = self.collector.dist_reuse_lease_meta_nullptr_total \ + .labels(device="cpu")._value.get() + self.collector.record_dist_reuse_lease_nullptr("cpu", 7) + after = self.collector.dist_reuse_lease_meta_nullptr_total \ + .labels(device="cpu")._value.get() + self.assertEqual(after - before, 7) + + def test_mooncake_read_failure_records_with_reason(self): + before = self.collector.dist_reuse_peer_mooncake_read_failures_total \ + .labels(reason="zero_byte_transfer")._value.get() + self.collector.observe_dist_reuse_peer_mooncake_read( + 0.5, success=False, reason="zero_byte_transfer", + ) + after = self.collector.dist_reuse_peer_mooncake_read_failures_total \ + .labels(reason="zero_byte_transfer")._value.get() + self.assertEqual(after - before, 1) + + def test_mooncake_read_success_records(self): + before = self.collector.dist_reuse_peer_mooncake_read_success_total \ + ._value.get() + self.collector.observe_dist_reuse_peer_mooncake_read( + 0.05, success=True, + ) + after = self.collector.dist_reuse_peer_mooncake_read_success_total \ + ._value.get() + self.assertEqual(after - before, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_multinode_flags.py b/tests/test_multinode_flags.py new file mode 100644 index 0000000000..e46c6dc5d1 --- /dev/null +++ b/tests/test_multinode_flags.py @@ -0,0 +1,185 @@ +"""Phase 1 task 1-A — decoupled multinode flags on ``ModelConfig``. + +These flags (``is_multinode_tp`` / ``is_multinode_cp``) intentionally cover +*orthogonal* physical situations that historically were conflated under a +single ``is_multinode`` switch in the sglang connector. This regression +suite locks down the documented semantics: + + * ``is_multinode_tp`` is True when one TP group physically spans + > 1 node. This is the *only* dimension that affects SD-Remote + construction (since ``tp_node_count > 1`` enters the SD key). + + * ``is_multinode_cp`` is True when CP > 1 AND the CP group physically + crosses node boundaries. Under sglang's standard megatron-style + topology this is *always False* (see ``ModelConfig.is_multinode_cp`` + docstring); we still test the flag so future deployments that break + that assumption are caught early. + +Critically, *neither* flag must influence the ``SharingDomainKey`` — +``is_multinode_cp`` is a transport-layer hint, not an SD identity. +""" + +from __future__ import annotations + +import pytest + +from flexkv.common.config import ModelConfig + + +def _make( + *, + tp_size: int, + nnodes: int, + pp_size: int = 1, + dp_size: int = 1, + attn_cp_size: int = 1, + enable_dp_attention: bool = False, +) -> ModelConfig: + """Build a ModelConfig with just the topology fields needed for these + properties. The other fields keep their defaults — they don't influence + ``is_multinode_*``. + + NOTE: ``tp_rank`` and ``node_rank`` were moved out of ``ModelConfig`` + into ``RankInfo`` by the RankInfo refactor (PR #165), so we no + longer pass them here. ``is_multinode_tp`` / ``is_multinode_cp`` + are pure cluster-topology properties — they read ``tp_size``, + ``pp_size``, ``nnodes`` and ``attn_cp_size`` only. + """ + return ModelConfig( + num_layers=1, num_kv_heads=1, head_size=1, + tp_size=tp_size, + pp_size=pp_size, + dp_size=dp_size, + nnodes=nnodes, + attn_cp_size=attn_cp_size, + enable_dp_attention=enable_dp_attention, + ) + + +# =========================================================================== +# is_multinode_tp +# =========================================================================== +class TestIsMultinodeTp: + def test_single_node_tp(self): + # 8-way TP on one node — TP fits inside a single host. + mc = _make(tp_size=8, nnodes=1) + assert mc.is_multinode_tp is False + assert mc.tp_node_count == 1 + + def test_cross_node_tp(self): + # 16-way TP across 2 nodes (8 GPUs per node). + mc = _make(tp_size=16, nnodes=2) + assert mc.is_multinode_tp is True + assert mc.tp_node_count == 2 + + def test_pp_alone_does_not_imply_multinode_tp(self): + # PP=2 across 2 nodes (8-way TP each), TP itself is single-node. + mc = _make(tp_size=8, nnodes=2, pp_size=2) + assert mc.is_multinode_tp is False + assert mc.nnodes_per_tp_group == 1 + + +# =========================================================================== +# is_multinode_cp +# =========================================================================== +class TestIsMultinodeCp: + """In FlexKV's topology model, CP is a sub-partition *inside* TP, not a + top-level GPU dimension. ``total_gpus = tp × pp × dp`` (CP excluded). + A CP group occupies ``attn_cp_size`` GPUs carved out of the + ``tp_size`` assigned to one TP group, so CP crosses nodes iff + ``attn_cp_size > tp_size_per_node``. + """ + + def test_cp_disabled_is_false(self): + mc = _make(tp_size=8, nnodes=1, attn_cp_size=1) + assert mc.is_multinode_cp is False + + def test_cp_fits_inside_one_tp_node_share(self): + # Single-node TP=8, CP=2 — tp_size_per_node=8 ≥ cp=2 ⇒ CP fits. + mc = _make(tp_size=8, nnodes=1, attn_cp_size=2) + assert mc.is_multinode_cp is False + assert mc.tp_size_per_node == 8 + + def test_cp_equals_tp_per_node_share(self): + # tp_size_per_node == attn_cp_size: CP exactly fills one node's TP slice. + # Strictly intra-node — not crossing. + mc = _make(tp_size=4, nnodes=2, attn_cp_size=2, pp_size=1) + # nnodes_per_tp_group = 2 → tp_size_per_node = 4/2 = 2 == cp ⇒ False + assert mc.tp_size_per_node == 2 + assert mc.is_multinode_cp is False + + def test_cp_exceeds_tp_per_node_share_crosses(self): + # tp_size_per_node < attn_cp_size: CP must cross nodes. + # tp=4, nnodes=2, pp=1 → tp_size_per_node = 2; cp=4 > 2 ⇒ True. + mc = _make(tp_size=4, nnodes=2, attn_cp_size=4, pp_size=1) + assert mc.tp_size_per_node == 2 + assert mc.is_multinode_cp is True + + def test_megatron_typical_deployments_are_intra_node(self): + """Production-style configurations (CP ≤ tp_size_per_node) — all False.""" + for tp, nnodes, pp, cp in [ + (8, 1, 1, 2), # single-node TP=8, CP=2 + (8, 2, 1, 2), # tp_size_per_node=4, CP=2 → fits + (8, 2, 2, 2), # PP=2, tp_size_per_node=8, CP=2 → fits (per-PP-stage TP single-node) + (16, 2, 1, 2), # tp_size_per_node=8, CP=2 → fits + ]: + mc = _make(tp_size=tp, nnodes=nnodes, pp_size=pp, attn_cp_size=cp) + assert mc.is_multinode_cp is False, ( + f"is_multinode_cp should be False (tp={tp}, nnodes={nnodes}, " + f"pp={pp}, cp={cp}); got True " + f"(tp_size_per_node={mc.tp_size_per_node})" + ) + + +# =========================================================================== +# Independence +# =========================================================================== +class TestIndependence: + """is_multinode_tp and is_multinode_cp must be physically independent.""" + + def test_only_tp_multinode(self): + mc = _make(tp_size=16, nnodes=2) # cross-node TP, no CP + assert mc.is_multinode_tp is True + assert mc.is_multinode_cp is False + + def test_both_false_baseline(self): + mc = _make(tp_size=4, nnodes=1) + assert mc.is_multinode_tp is False + assert mc.is_multinode_cp is False + + +# =========================================================================== +# Property: flags do NOT influence SharingDomainKey +# =========================================================================== +class TestNoSdKeyLeak: + """Regression: changing CP must never alter ``SharingDomainKey``. + + The simplified design (§4.5) explicitly excludes CP from the SD key. + We assert that here so any future change which accidentally feeds + CP-dimension info into ``from_model_config`` will fail this test. + """ + + def test_cp_size_does_not_affect_sd_key(self): + from flexkv.common.dist_reuse.sharing_domain import SharingDomainKey + + mc_no_cp = _make(tp_size=4, nnodes=1, attn_cp_size=1) + mc_with_cp = _make(tp_size=4, nnodes=1, attn_cp_size=2) + # The two configs differ in CP only — but SD must not know. + assert mc_no_cp.attn_cp_size != mc_with_cp.attn_cp_size + sd1 = SharingDomainKey.from_model_config(mc_no_cp) + sd2 = SharingDomainKey.from_model_config(mc_with_cp) + assert sd1 == sd2 + assert sd1.serialize() == sd2.serialize() + + def test_serialized_sd_never_contains_cp(self): + """Belt-and-braces check on the on-the-wire format itself.""" + from flexkv.common.dist_reuse.sharing_domain import SharingDomainKey + + mc = _make(tp_size=4, nnodes=1, attn_cp_size=2) + sd = SharingDomainKey.from_model_config(mc) + s = sd.serialize() + assert ":cp" not in s, f"Serialized SD must never contain ':cp...' segment, got: {s}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_multinode_role_policy.py b/tests/test_multinode_role_policy.py new file mode 100644 index 0000000000..2b4a8ce75a --- /dev/null +++ b/tests/test_multinode_role_policy.py @@ -0,0 +1,186 @@ +"""§2.4 — Multi-node role decision policy tests. + +The connector-side split of ``is_multinode_tp`` vs. ``is_multinode_cp`` +is currently *not* plumbed into ``flexkv_connector.py`` (sglang/) — +doing that requires a two-machine GPU setup to verify. Until then we +pin the decision table at the policy-function level so that when the +actual connector swap lands, it already has a stable, tested contract +to call into. + +These tests are torch-free by design: pure logic over ``RankTopology``. +""" +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from flexkv.integration.multinode_policy import ( # noqa: E402 + RankTopology, + RemoteProcessRole, + decide_remote_role, + is_sync_leader, +) + + +# --------------------------------------------------------------------------- +# Trivial validation +# --------------------------------------------------------------------------- +class TestValidation: + def test_nnodes_must_be_positive(self): + with pytest.raises(ValueError): + decide_remote_role(RankTopology( + nnodes=0, node_rank=0, local_rank=0, + is_multinode_tp=False, is_multinode_cp=False, + )) + + def test_node_rank_in_range(self): + with pytest.raises(ValueError): + decide_remote_role(RankTopology( + nnodes=2, node_rank=5, local_rank=0, + is_multinode_tp=True, is_multinode_cp=False, + )) + + def test_local_rank_non_negative(self): + with pytest.raises(ValueError): + decide_remote_role(RankTopology( + nnodes=2, node_rank=1, local_rank=-1, + is_multinode_tp=True, is_multinode_cp=False, + )) + + +# --------------------------------------------------------------------------- +# Single-node instance → NO_REMOTE regardless of any flag. +# --------------------------------------------------------------------------- +class TestSingleNode: + @pytest.mark.parametrize("tp,cp", [(False, False), (True, False), + (False, True), (True, True)]) + def test_single_node_never_spawns_remote(self, tp, cp): + topo = RankTopology( + nnodes=1, node_rank=0, local_rank=0, + is_multinode_tp=tp, is_multinode_cp=cp, + ) + assert decide_remote_role(topo) is RemoteProcessRole.NO_REMOTE + + +# --------------------------------------------------------------------------- +# Master node (node_rank == 0) never spawns a remote itself. +# --------------------------------------------------------------------------- +class TestMasterNode: + def test_master_with_multinode_tp_is_master(self): + topo = RankTopology( + nnodes=2, node_rank=0, local_rank=0, + is_multinode_tp=True, is_multinode_cp=False, + ) + assert decide_remote_role(topo) is RemoteProcessRole.MASTER + + def test_master_with_multinode_cp_is_master(self): + topo = RankTopology( + nnodes=2, node_rank=0, local_rank=0, + is_multinode_tp=False, is_multinode_cp=True, + ) + assert decide_remote_role(topo) is RemoteProcessRole.MASTER + + def test_master_with_both_flags_is_master(self): + topo = RankTopology( + nnodes=2, node_rank=0, local_rank=0, + is_multinode_tp=True, is_multinode_cp=True, + ) + assert decide_remote_role(topo) is RemoteProcessRole.MASTER + + def test_master_with_nothing_crossing_nodes_is_no_remote(self): + """Multi-node deployment but THIS instance is single-node + (only DP crosses; each DP instance stays on one node).""" + topo = RankTopology( + nnodes=2, node_rank=0, local_rank=0, + is_multinode_tp=False, is_multinode_cp=False, + ) + assert decide_remote_role(topo) is RemoteProcessRole.NO_REMOTE + + +# --------------------------------------------------------------------------- +# Off-master nodes — the interesting routing table. +# --------------------------------------------------------------------------- +class TestOffMasterRouting: + def test_multinode_tp_only_is_full_sd_remote(self): + topo = RankTopology( + nnodes=2, node_rank=1, local_rank=0, + is_multinode_tp=True, is_multinode_cp=False, + ) + assert decide_remote_role(topo) is RemoteProcessRole.SD_REMOTE_FULL + + def test_multinode_cp_only_is_cp_registration_stub(self): + topo = RankTopology( + nnodes=2, node_rank=1, local_rank=0, + is_multinode_tp=False, is_multinode_cp=True, + ) + assert decide_remote_role(topo) is RemoteProcessRole.CP_PEER_REGISTRATION_ONLY + + def test_multinode_tp_wins_over_cp(self): + """When BOTH flags are True on an off-master rank, TP takes + priority — TP-split SDs cannot be served by a CP-only stub.""" + topo = RankTopology( + nnodes=2, node_rank=1, local_rank=0, + is_multinode_tp=True, is_multinode_cp=True, + ) + assert decide_remote_role(topo) is RemoteProcessRole.SD_REMOTE_FULL + + def test_neither_tp_nor_cp_multinode_still_spawns_full_remote_today(self): + """Legacy bug-compat: today's connector treats any + ``nnodes>1 and node_rank>0 and local_rank==0`` case as + ``SD_REMOTE_FULL`` (PP crossing nodes uses this path). We + preserve that during the migration; the TODO in + multinode_policy.py tracks the eventual cleanup. + """ + topo = RankTopology( + nnodes=2, node_rank=1, local_rank=0, + is_multinode_tp=False, is_multinode_cp=False, + ) + assert decide_remote_role(topo) is RemoteProcessRole.SD_REMOTE_FULL + + +# --------------------------------------------------------------------------- +# Sync-leader helper. +# --------------------------------------------------------------------------- +class TestSyncLeader: + def test_default_rule_is_local_and_node_rank_zero(self): + topo = RankTopology( + nnodes=2, node_rank=0, local_rank=0, + is_multinode_tp=True, is_multinode_cp=False, + ) + assert is_sync_leader(topo) is True + + topo2 = RankTopology( + nnodes=2, node_rank=1, local_rank=0, + is_multinode_tp=True, is_multinode_cp=False, + ) + assert is_sync_leader(topo2) is False + + topo3 = RankTopology( + nnodes=2, node_rank=0, local_rank=1, + is_multinode_tp=True, is_multinode_cp=False, + ) + assert is_sync_leader(topo3) is False + + def test_explicit_hint_overrides_default(self): + """If the caller hands us ``is_sync_leader=True`` (e.g. coming + from sglang's own group metadata), respect it even if the + default heuristic would say otherwise.""" + topo = RankTopology( + nnodes=2, node_rank=1, local_rank=7, + is_multinode_tp=True, is_multinode_cp=False, + is_sync_leader=True, + ) + assert is_sync_leader(topo) is True + + topo2 = RankTopology( + nnodes=2, node_rank=0, local_rank=0, + is_multinode_tp=True, is_multinode_cp=False, + is_sync_leader=False, + ) + assert is_sync_leader(topo2) is False diff --git a/tests/test_phase2_combinations.py b/tests/test_phase2_combinations.py new file mode 100644 index 0000000000..7b84cb76dd --- /dev/null +++ b/tests/test_phase2_combinations.py @@ -0,0 +1,200 @@ +"""Phase 2 / Phase 3 combinations — simplified schema (CP not in SD key). + +Design doc §4.10.1 (simplified): + +* Current deployment (prefill crosses ≤ 2 nodes): + - ``pp_node_count=2, tp_node_count=1`` — 2 SDs (PP-Remote lives on + another node); + - ``pp_node_count=1, tp_node_count=2`` — 2 SDs (TP-Remote lives on + another node). +* Future-extension scope (not currently deployed): + - ``pp_node_count=2, tp_node_count=2`` — 4 SDs. + +CP (``cp_size``) does **not** multiply the SD count — each cp_rank inside +the same CP group shares the SD with the sync_leader, and data is dispatched +in-process via sglang's scatter (see simplified design doc §4.5). This is +the main regression the tests below guard against. + +Note on the rank→node collapse: under the simplified schema the PP axis +in the SD key is keyed by **physical node** (``pp_node_idx``), not by +PP rank. Co-located PP ranks on the same machine fold into the same SD, +so a single-node ``pp_size=2`` deployment is *not* representable in this +test file — for that case the Master-side factory raises and the data +plane stays in legacy single-SD mode. +""" +from __future__ import annotations + +from typing import Set, Tuple + +import pytest + +from flexkv.common.dist_reuse import ( + SharingDomainKey, + SharingDomainNamespace, + build_sharing_domain_handles, + graph_needs_gpu_clear, +) + + +MODEL_ID = "phase2-simplified" + + +def _sd(ppn_idx=0, ppn_count=1, tpn_idx=0, tpn_count=1, + is_nsa=False) -> SharingDomainKey: + """Build an SD key with the simplified, node-granularity PP schema.""" + return SharingDomainKey( + model_id=MODEL_ID, + pp_node_idx=ppn_idx, pp_node_count=ppn_count, + tp_node_idx=tpn_idx, tp_node_count=tpn_count, + is_nsa=is_nsa, + ) + + +# =========================================================================== +# §4.10.1 row: pp_node_count=2, tp_node_count=1 → 2 SDs (current deployment case A) +# =========================================================================== +class TestPhase2PPOnly: + SELF_SD = _sd(ppn_idx=0, ppn_count=2) + EXPECTED_SD_COUNT = 2 + + def test_enumerate_yields_2_distinct_sds(self): + peers = self.SELF_SD.enumerate_peers() + assert len(peers) == self.EXPECTED_SD_COUNT + # Both (ppn=0) and (ppn=1) present, each with tp_node_idx=0. + pp_node_indices: Set[int] = set() + for p in peers: + pp_node_indices.add(p.pp_node_idx) + assert p.tp_node_idx == 0 and p.tp_node_count == 1 + assert p.pp_node_count == 2 + assert pp_node_indices == {0, 1} + + def test_exactly_one_master(self): + peers = self.SELF_SD.enumerate_peers() + masters = [p for p in peers if p.is_master()] + assert len(masters) == 1 + assert masters[0].pp_node_idx == 0 + + def test_non_master_peer_needs_gpu_clear(self): + peers = [p for p in self.SELF_SD.enumerate_peers() if p != self.SELF_SD] + assert len(peers) == 1 + # Peer differs in pp_node_idx → clear required. + assert graph_needs_gpu_clear(self.SELF_SD, peers[0]) is True + + +# =========================================================================== +# §4.10.1 row: pp_node_count=1, tp_node_count=2 → 2 SDs (current deployment case B) +# =========================================================================== +class TestPhase2CrossNodeTPOnly: + SELF_SD = _sd(tpn_idx=0, tpn_count=2) + EXPECTED_SD_COUNT = 2 + + def test_enumerate_yields_2_distinct_sds(self): + peers = self.SELF_SD.enumerate_peers() + assert len(peers) == self.EXPECTED_SD_COUNT + tp_nodes: Set[int] = set() + for p in peers: + tp_nodes.add(p.tp_node_idx) + assert p.pp_node_idx == 0 and p.pp_node_count == 1 + assert p.tp_node_count == 2 + assert tp_nodes == {0, 1} + + def test_non_master_peer_does_not_need_gpu_clear(self): + """Cross-TP-node only differs in the head shard, not in slot_mapping, + so the graph can be forwarded as-is without a GPU clear.""" + peers = [p for p in self.SELF_SD.enumerate_peers() if p != self.SELF_SD] + assert len(peers) == 1 + assert graph_needs_gpu_clear(self.SELF_SD, peers[0]) is False + + +# =========================================================================== +# §4.10.1 future-extension row: pp_node_count=2 × tp_node_count=2 → 4 SDs +# =========================================================================== +class TestPhase2MaxConfig: + SELF_SD = _sd(ppn_count=2, tpn_count=2) + EXPECTED_SD_COUNT = 2 * 2 # 4 + + def test_enumerate_yields_4_distinct_sds(self): + peers = self.SELF_SD.enumerate_peers() + assert len(peers) == self.EXPECTED_SD_COUNT + serialized = {p.serialize() for p in peers} + assert len(serialized) == self.EXPECTED_SD_COUNT + # Every (ppn, tpn) pair present exactly once. + pairs: Set[Tuple[int, int]] = set() + for p in peers: + pairs.add((p.pp_node_idx, p.tp_node_idx)) + assert pairs == {(ppn, tpn) for ppn in range(2) for tpn in range(2)} + + def test_clear_decision_counts(self): + """In a 4-SD instance with self at (ppn=0, tpn=0): + + peer (ppn=0, tpn=1) → TP-only differ → no clear + peer (ppn=1, tpn=0) → ppn differs → clear + peer (ppn=1, tpn=1) → ppn differs → clear + + So 2 of 3 non-self peers require gpu-clear. + """ + clear_count = 0 + no_clear_count = 0 + for peer in self.SELF_SD.enumerate_peers(): + if peer == self.SELF_SD: + continue + if graph_needs_gpu_clear(self.SELF_SD, peer): + clear_count += 1 + else: + no_clear_count += 1 + total_peers = self.EXPECTED_SD_COUNT - 1 # 3 non-self + assert clear_count + no_clear_count == total_peers + assert no_clear_count == 1, "only the TP-only-differ peer avoids gpu clear" + assert clear_count == 2 + + def test_namespace_prefixes_are_pairwise_disjoint(self): + """No SD's Redis key prefix is a prefix of another's (would leak SCANs).""" + peers = self.SELF_SD.enumerate_peers() + prefixes = [SharingDomainNamespace(p).prefix for p in peers] + # All unique. + assert len(set(prefixes)) == len(prefixes) + # None is a proper prefix of another (followed by ':'). + for i, a in enumerate(prefixes): + for j, b in enumerate(prefixes): + if i == j: + continue + assert not b.startswith(a + ":"), ( + f"prefix '{a}' is a proper prefix of '{b}' — Redis SCANs would leak" + ) + + +# =========================================================================== +# build_sharing_domain_handles under the current-deployment cap (2 SDs) +# =========================================================================== +class TestBuildHandlesCurrentDeployment: + def test_pp2_one_remote_handle(self): + self_sd = _sd(ppn_count=2) + peer = _sd(ppn_idx=1, ppn_count=2) + specs = build_sharing_domain_handles( + self_sd=self_sd, + remote_endpoints_by_sd={ + peer.serialize(): _FakeEndpoint( + ip="10.0.0.1", + gpu_register_port="5000", + command_port="5001", + result_port="5002", + ), + }, + ) + assert len(specs) == 2 + assert specs[0].mode == "process" and specs[0].sd_key.is_master() + assert specs[1].mode == "remote" and specs[1].sd_key == peer + + +# =========================================================================== +# Supporting test doubles +# =========================================================================== +class _FakeEndpoint: + """Duck-type of flexkv.common.config.RemoteEndpoint, sufficient for + build_sharing_domain_handles to treat us as a valid endpoint.""" + + def __init__(self, ip, gpu_register_port, command_port, result_port): + self.ip = ip + self.gpu_register_port = gpu_register_port + self.command_port = command_port + self.result_port = result_port diff --git a/tests/test_redis_db_integration.py b/tests/test_redis_db_integration.py new file mode 100644 index 0000000000..911aa52b36 --- /dev/null +++ b/tests/test_redis_db_integration.py @@ -0,0 +1,202 @@ +"""Integration test against a real Redis: verify ``flexkv_redis_db`` +actually switches the logical database at the protocol level. + +Auto-skips when Redis is not reachable. Uses two different db numbers +(primary db ``0`` vs override db ``15`` — both always exist on a stock +Redis) and asserts that: + + * a key written from a ``RedisMeta`` bound to db=0 is **not** visible + to a client bound to db=15, and vice-versa; + * ``RedisSessionClient`` built through the factory respects db=15. +""" +from __future__ import annotations + +import importlib.util +import os +import sys +import time +import uuid +from pathlib import Path +from typing import Iterator, List + +import pytest + +REDIS_HOST = os.environ.get("FLEXKV_TEST_REDIS_HOST", "127.0.0.1") +REDIS_PORT = int(os.environ.get("FLEXKV_TEST_REDIS_PORT", "6379")) + +try: + import redis as _redis # type: ignore[import-not-found] +except ImportError: # pragma: no cover + pytest.skip("redis-py not installed", allow_module_level=True) + + +def _probe_db(db: int) -> bool: + try: + c = _redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=db, + socket_connect_timeout=1.0, decode_responses=True) + return bool(c.ping()) + except Exception: + return False + + +if not (_probe_db(0) and _probe_db(15)): + pytest.skip( + f"Redis not reachable at {REDIS_HOST}:{REDIS_PORT} on both db=0 and db=15", + allow_module_level=True, + ) + + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def _load(name: str, path: Path): + spec = importlib.util.spec_from_file_location(name, str(path)) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + +_rm = _load("_rm_db_it", REPO_ROOT / "flexkv" / "cache" / "redis_meta.py") +RedisMeta = _rm.RedisMeta + +from flexkv.common.dist_reuse import ( # noqa: E402 + RedisSessionClient, + SharingDomainKey, + SharingDomainNamespace, + make_redis_client_from_cache_config, + make_session_epoch, +) + + +@pytest.fixture +def tracked_db0() -> Iterator["_redis.Redis"]: + c = _redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=0, + decode_responses=True) + keys: List[str] = [] + c._test_tracked_keys = keys # type: ignore[attr-defined] + yield c + for k in keys: + try: + c.delete(k) + except Exception: + pass + + +@pytest.fixture +def tracked_db15() -> Iterator["_redis.Redis"]: + c = _redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=15, + decode_responses=True) + keys: List[str] = [] + c._test_tracked_keys = keys # type: ignore[attr-defined] + yield c + for k in keys: + try: + c.delete(k) + except Exception: + pass + + +def _track(client, *keys: str) -> None: + client._test_tracked_keys.extend(keys) # type: ignore[attr-defined] + + +class _FakeCacheConfig: + """Minimal stand-in for ``CacheConfig`` — keeps this test file free of + the torch/zmq import chain that the real ``flexkv.common.config`` drags + in (same trick as ``test_redis_meta_namespace.py``).""" + + def __init__(self, *, db: int): + self.redis_host = REDIS_HOST + self.redis_port = REDIS_PORT + self.redis_password = None + self.flexkv_redis_db = db + + +# =========================================================================== +# Real-Redis: writes to db=15 are invisible on db=0 (and vice-versa) +# =========================================================================== +def test_redismeta_db15_key_not_visible_on_db0(tracked_db0, tracked_db15): + sd = SharingDomainKey( + model_id="db-test-" + uuid.uuid4().hex[:10], + pp_node_idx=0, pp_node_count=1, tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + ns = SharingDomainNamespace(sd) + + meta_db15 = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.31", + namespace=ns, db=15, + ) + try: + nid = meta_db15.init_meta() + assert nid is not None, "init_meta failed on db=15" + node_key = ns.node_key(nid) + _track(tracked_db15, node_key) + + # Key is present on db=15 but NOT on db=0. + assert tracked_db15.exists(node_key) == 1, ( + "Expected node key on db=15 after RedisMeta(db=15).init_meta()" + ) + assert tracked_db0.exists(node_key) == 0, ( + f"Key {node_key} leaked to db=0 — db selection is broken!" + ) + + # Conversely: drop a sentinel on db=0 and assert db=15 doesn't see it. + sentinel_key = f"sentinel:{uuid.uuid4().hex[:8]}" + tracked_db0.set(sentinel_key, "db0-value") + _track(tracked_db0, sentinel_key) + assert tracked_db0.get(sentinel_key) == "db0-value" + assert tracked_db15.exists(sentinel_key) == 0, ( + "db=15 client saw a db=0 sentinel — dbs are not isolated!" + ) + finally: + meta_db15.unregister_node() + + +def test_make_redis_client_from_cache_config_uses_db(tracked_db0, tracked_db15): + """RedisSessionClient built with factory client must land on the + right db.""" + cfg = _FakeCacheConfig(db=15) + client = make_redis_client_from_cache_config(cfg) + + instance_id = f"dbsess-{uuid.uuid4().hex[:8]}" + sess = RedisSessionClient( + redis_client=client, + instance_id=instance_id, + epoch=make_session_epoch(), + ttl_seconds=5, + ) + sess.register() + _track(tracked_db15, sess.key) + + # Session key present on db=15, absent on db=0. + assert tracked_db15.exists(sess.key) == 1 + assert tracked_db0.exists(sess.key) == 0 + + +def test_redismetachannel_python_wrapper_records_db(): + """Pure-python smoke test — verify the C++ ctor is invoked with db=N. + + We don't actually exercise the C++ binary here (would need a GPU build); + we just assert the Python wrapper forwards the arg correctly. The real + C++ SELECT behaviour is exercised by + ``test_redismeta_db15_key_not_visible_on_db0`` above, which uses only + the Python ``RedisMeta`` path (redis-py supports SELECT natively). + """ + # The wrapper ctor may raise ImportError when _CRedisMetaChannel is None + # (no FLEXKV_ENABLE_P2P build). Skip gracefully in that case. + if _rm._CRedisMetaChannel is None: + pytest.skip("flexkv.c_ext.RedisMetaChannel not built (FLEXKV_ENABLE_P2P=0)") + # Otherwise try the real construction against Redis on db=15. + ch = _rm.RedisMetaChannel( + host=REDIS_HOST, port=REDIS_PORT, node_id=99999, + local_ip="127.0.0.1", blocks_key="sd:dbsmoke:CPUB", + password="", db=15, + ) + assert ch._db == 15 + # connect() should succeed even on db=15 (assuming server allows it). + assert ch.connect() is True diff --git a/tests/test_redis_integration.py b/tests/test_redis_integration.py new file mode 100644 index 0000000000..f326aa97c3 --- /dev/null +++ b/tests/test_redis_integration.py @@ -0,0 +1,533 @@ +"""Integration tests against a **real** Redis server. + +Complement to the FakeRedis-based unit tests — FakeRedis cannot reproduce +real TTL expiry timing, real HMGET/SCAN pipelining, or HGETALL ordering, +so the Phase 0 / Phase 1 code paths that depend on those behaviours need a +live Redis to be exercised end-to-end. + +Usage:: + + FLEXKV_TEST_REDIS_HOST=127.0.0.1 \ + FLEXKV_TEST_REDIS_PORT=6379 \ + pytest tests/test_redis_integration.py -v + +All tests auto-skip when Redis is not reachable. Each test uses a +randomly-generated ``model_id`` / ``instance_id`` so the shared Redis +(which may already contain ~1000 keys from other tenants) is not disturbed, +and teardown deletes only keys under the test's own SD / instance prefix. + +``RedisMeta`` is loaded via ``importlib.util`` to bypass +``flexkv.cache.__init__``'s forced ``import flexkv.c_ext`` — consistent +with the other dist_reuse unit tests. +""" +from __future__ import annotations + +import importlib.util +import json +import os +import sys +import time +import uuid +from pathlib import Path +from typing import Iterator, List, Tuple + +import pytest + +# --------------------------------------------------------------------------- +# Redis availability probe — gates the whole module. +# --------------------------------------------------------------------------- +REDIS_HOST = os.environ.get("FLEXKV_TEST_REDIS_HOST", "127.0.0.1") +REDIS_PORT = int(os.environ.get("FLEXKV_TEST_REDIS_PORT", "6379")) +# RedisMeta hard-codes db=0; keep consistent for the RedisMeta tests. +REDIS_DB = int(os.environ.get("FLEXKV_TEST_REDIS_DB", "0")) + +try: + import redis as _redis # type: ignore[import-not-found] +except ImportError: # pragma: no cover + pytest.skip("redis-py not installed", allow_module_level=True) + + +def _redis_available() -> bool: + try: + r = _redis.Redis( + host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, + socket_connect_timeout=1.0, decode_responses=True, + ) + return bool(r.ping()) + except Exception: + return False + + +if not _redis_available(): + pytest.skip( + f"Redis not reachable at {REDIS_HOST}:{REDIS_PORT} db={REDIS_DB}", + allow_module_level=True, + ) + + +# --------------------------------------------------------------------------- +# Dynamically load RedisMeta via importlib (bypasses c_ext dependency). +# --------------------------------------------------------------------------- +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def _load_redis_meta_module(): + src = REPO_ROOT / "flexkv" / "cache" / "redis_meta.py" + spec = importlib.util.spec_from_file_location("_redis_meta_it", str(src)) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + return mod + + +_rm = _load_redis_meta_module() +RedisMeta = _rm.RedisMeta +RedisNodeInfo = _rm.RedisNodeInfo + +from flexkv.common.dist_reuse import ( # noqa: E402 + FailureDetector, + MasterCoordinator, + RedisSessionClient, + SharingDomainKey, + SharingDomainNamespace, + make_session_epoch, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def raw_client() -> Iterator["_redis.Redis"]: + """Yield a decode_responses=True client + track+cleanup test keys. + + We never ``flushdb`` — the Redis is shared. Tests should ``_track`` any + keys they write so the teardown can remove exactly those. + """ + client = _redis.Redis( + host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, + decode_responses=True, + ) + created_keys: List[str] = [] + client._test_tracked_keys = created_keys # type: ignore[attr-defined] + yield client + for k in created_keys: + try: + client.delete(k) + except Exception: + pass + + +@pytest.fixture +def sd_key() -> SharingDomainKey: + """Produce a unique SD key per test via random ``model_id``.""" + return SharingDomainKey( + model_id="itm" + uuid.uuid4().hex[:12], + pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + + +@pytest.fixture +def namespace(sd_key: SharingDomainKey) -> SharingDomainNamespace: + return SharingDomainNamespace(sd_key) + + +@pytest.fixture +def instance_id() -> str: + return f"itinst-{uuid.uuid4().hex[:12]}" + + +def _track(client, *keys: str) -> None: + """Mark keys for teardown deletion.""" + client._test_tracked_keys.extend(keys) # type: ignore[attr-defined] + + +def _delete_test_sd(client, namespace: SharingDomainNamespace) -> None: + """Delete every key under the test's SD prefix (scoped SCAN).""" + pattern = f"{namespace.prefix}:*" + cursor = 0 + to_delete: List[str] = [] + while True: + cursor, keys = client.scan(cursor=cursor, match=pattern, count=200) + to_delete.extend(keys) + if cursor == 0: + break + if to_delete: + client.delete(*to_delete) + + +# =========================================================================== +# Group 1 — RedisMeta / RedisNodeInfo on a live Redis +# =========================================================================== +class TestRedisMetaLive: + """Verify the SD-aware Redis layout end-to-end on a real server.""" + + def test_init_meta_writes_sd_scoped_node_key( + self, raw_client, namespace, + ): + meta = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.1", + namespace=namespace, + ) + try: + nid = meta.init_meta() + assert nid is not None and nid >= 0 + node_key = namespace.node_key(nid) + _track(raw_client, node_key) + + # Key must exist under sd::node:, with IP field. + assert raw_client.exists(node_key) == 1 + data = raw_client.hgetall(node_key) + assert data.get("ip") == "10.0.99.1" + finally: + meta.unregister_node() + _delete_test_sd(raw_client, namespace) + + def test_two_disjoint_sds_are_isolated(self, raw_client): + """Two different SDs must not see each other's active nodes.""" + sd_a = SharingDomainKey( + model_id="itm-a-" + uuid.uuid4().hex[:8], + pp_node_idx=0, pp_node_count=1, tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + sd_b = SharingDomainKey( + model_id="itm-b-" + uuid.uuid4().hex[:8], + pp_node_idx=0, pp_node_count=1, tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + ns_a = SharingDomainNamespace(sd_a) + ns_b = SharingDomainNamespace(sd_b) + assert ns_a.prefix != ns_b.prefix + + meta_a = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.1", + namespace=ns_a, + ) + meta_b = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.2", + namespace=ns_b, + ) + try: + nid_a = meta_a.init_meta() + nid_b = meta_b.init_meta() + assert nid_a is not None and nid_b is not None + assert nid_a != nid_b # globally unique from INCR + + # A sees its own node. + meta_a.nodeinfo.current_node_id_set.clear() + meta_a.nodeinfo.scan_active_nodes() + assert nid_a in meta_a.nodeinfo.current_node_id_set + assert nid_b not in meta_a.nodeinfo.current_node_id_set + + # B sees its own node, not A's. + meta_b.nodeinfo.current_node_id_set.clear() + meta_b.nodeinfo.scan_active_nodes() + assert nid_b in meta_b.nodeinfo.current_node_id_set + assert nid_a not in meta_b.nodeinfo.current_node_id_set + finally: + meta_a.unregister_node() + meta_b.unregister_node() + _delete_test_sd(raw_client, ns_a) + _delete_test_sd(raw_client, ns_b) + + def test_regist_node_meta_and_buffer(self, raw_client, namespace): + meta = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.5", + namespace=namespace, + ) + try: + nid = meta.init_meta() + assert nid is not None + + meta.regist_node_meta( + node_id=nid, addr="10.0.99.5", + zmq_addr="tcp://10.0.99.5:18000", + cpu_buffer_ptr=0xAABBCC, ssd_buffer_ptr=0, + ) + meta_key = namespace.meta_key(nid) + _track(raw_client, meta_key) + assert raw_client.exists(meta_key) == 1 + + data = raw_client.hgetall(meta_key) + assert data.get("addr") == "10.0.99.5" + assert data.get("zmq_addr") == "tcp://10.0.99.5:18000" + assert int(data.get("cpu_buffer_ptr", "0")) == 0xAABBCC + + loaded = meta.get_node_meta(nid) + assert loaded.get("zmq_addr") == "tcp://10.0.99.5:18000" + + # regist_buffer writes under sd::buffer:: + count = meta.regist_buffer([ + (0xDEADBEEF, 1024), + {"buffer_ptr": 0xC0FFEE, "buffer_size": 2048}, + ]) + assert count == 2 + buf1 = f"{namespace.prefix}:buffer:{nid}:{0xDEADBEEF}" + buf2 = f"{namespace.prefix}:buffer:{nid}:{0xC0FFEE}" + _track(raw_client, buf1, buf2) + assert raw_client.exists(buf1) == 1 + assert raw_client.exists(buf2) == 1 + d1 = raw_client.hgetall(buf1) + d2 = raw_client.hgetall(buf2) + assert int(d1.get("buffer_size", "0")) == 1024 + assert int(d2.get("buffer_size", "0")) == 2048 + finally: + meta.unregister_node() + _delete_test_sd(raw_client, namespace) + + def test_instance_sd_nodes_roundtrip( + self, raw_client, namespace, instance_id, + ): + """register_instance_sd_nodes → load_instance_sd_nodes is symmetric.""" + meta = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.10", + namespace=namespace, + ) + try: + sd_map = { + "mmm:ppn0/2:tpn0/1:nsa0": 700, + "mmm:ppn1/2:tpn0/1:nsa0": 701, + } + meta.register_instance_sd_nodes(instance_id, sd_map) + + key = SharingDomainNamespace.instance_sd_nodes_key(instance_id) + _track(raw_client, key) + assert raw_client.exists(key) == 1 + + loaded = meta.load_instance_sd_nodes(instance_id) + assert loaded == sd_map + + # Missing instance returns empty dict. + missing = meta.load_instance_sd_nodes(f"nope-{uuid.uuid4().hex[:8]}") + assert missing == {} + finally: + _delete_test_sd(raw_client, namespace) + + def test_prefix_is_unique_no_crosstalk_with_other_tenants( + self, raw_client, namespace, + ): + """Sanity: our random ``model_id`` ensures no pre-existing collision + with the ~1000 keys from other users of this shared Redis.""" + pattern = f"{namespace.prefix}:*" + cursor = 0 + existing = 0 + while True: + cursor, keys = raw_client.scan(cursor=cursor, match=pattern, count=200) + existing += len(keys) + if cursor == 0: + break + assert existing == 0, ( + f"Unexpected pre-existing keys under {pattern}: {existing}" + ) + + +# =========================================================================== +# Group 2 — RedisSessionClient + FailureDetector with real TTL +# =========================================================================== +class TestSessionTTLLive: + """FakeRedis cannot truly expire keys — this is where we verify the + co-destined-failure model's actual wall-clock behaviour. + """ + + def test_session_register_renew_unregister(self, raw_client, instance_id): + sess = RedisSessionClient( + redis_client=raw_client, + instance_id=instance_id, + epoch=make_session_epoch(), + ttl_seconds=5, + ) + sess.register() + _track(raw_client, sess.key) + + # Key must exist with TTL in (0, 5]. + assert raw_client.exists(sess.key) == 1 + ttl = raw_client.ttl(sess.key) + assert 0 < ttl <= 5 + + # Payload JSON round-trips. + payload = json.loads(raw_client.get(sess.key)) + assert payload["instance_id"] == instance_id + assert payload["epoch"] == sess.epoch + + # renew() bumps TTL back up. + time.sleep(1.5) + old_ttl = raw_client.ttl(sess.key) + sess.renew() + new_ttl = raw_client.ttl(sess.key) + assert new_ttl > old_ttl + + sess.unregister() + assert raw_client.exists(sess.key) == 0 + + def test_session_expires_after_ttl_without_renew( + self, raw_client, instance_id, + ): + """Core property: if renew() stops, the key really vanishes.""" + sess = RedisSessionClient( + redis_client=raw_client, + instance_id=instance_id, + epoch=make_session_epoch(), + ttl_seconds=1, + ) + sess.register() + _track(raw_client, sess.key) + assert raw_client.exists(sess.key) == 1 + + time.sleep(1.6) + assert raw_client.exists(sess.key) == 0 + + def test_failure_detector_fires_peer_lost_on_ttl_expiry(self, raw_client): + """Deterministic via poll_once() — no thread race.""" + self_id = f"detector-{uuid.uuid4().hex[:6]}" + peer_id = f"peer-{uuid.uuid4().hex[:6]}" + + lost_events: List[str] = [] + seen_events: List[Tuple[str, str]] = [] + + peer_sess = RedisSessionClient( + redis_client=raw_client, + instance_id=peer_id, + epoch=make_session_epoch(), + ttl_seconds=1, + ) + peer_sess.register() + _track(raw_client, peer_sess.key) + + detector = FailureDetector( + redis_client=raw_client, + self_instance_id=self_id, + poll_interval_seconds=0.1, + on_peer_lost=lambda pid: lost_events.append(pid), + on_peer_seen=lambda pid, s: seen_events.append((pid, s.epoch)), + ) + # 1st poll: peer is alive → on_peer_seen fires. + detector.poll_once() + assert any(pid == peer_id for pid, _ in seen_events), ( + f"Expected on_peer_seen for {peer_id}; got {seen_events!r}" + ) + + # Let peer expire. + time.sleep(1.6) + assert raw_client.exists(peer_sess.key) == 0 + + # 2nd poll: peer vanished → on_peer_lost fires. + detector.poll_once() + assert peer_id in lost_events, ( + f"Expected {peer_id} in lost_events; got {lost_events!r}" + ) + + def test_detector_ignores_self_and_handles_epoch_change(self, raw_client): + self_id = f"self-{uuid.uuid4().hex[:6]}" + peer_id = f"peer-{uuid.uuid4().hex[:6]}" + + # Both self and peer are alive initially. + self_sess = RedisSessionClient( + redis_client=raw_client, instance_id=self_id, + epoch=make_session_epoch(), ttl_seconds=5, + ) + self_sess.register() + _track(raw_client, self_sess.key) + + peer_epoch_1 = make_session_epoch() + peer_sess = RedisSessionClient( + redis_client=raw_client, instance_id=peer_id, + epoch=peer_epoch_1, ttl_seconds=5, + ) + peer_sess.register() + _track(raw_client, peer_sess.key) + + seen: List[Tuple[str, str]] = [] + lost: List[str] = [] + detector = FailureDetector( + redis_client=raw_client, + self_instance_id=self_id, + poll_interval_seconds=0.1, + on_peer_lost=lambda pid: lost.append(pid), + on_peer_seen=lambda pid, s: seen.append((pid, s.epoch)), + ) + detector.poll_once() + # Self is excluded by contract. + assert not any(pid == self_id for pid, _ in seen) + assert any(pid == peer_id for pid, _ in seen) + + # Restart peer with new epoch (simulating crash + restart). + baseline = len(seen) + peer_sess_2 = RedisSessionClient( + redis_client=raw_client, instance_id=peer_id, + epoch=make_session_epoch(), ttl_seconds=5, + ) + assert peer_sess_2.epoch != peer_epoch_1 + peer_sess_2.register() + + detector.poll_once() + new_events = seen[baseline:] + assert any( + pid == peer_id and epoch == peer_sess_2.epoch + for pid, epoch in new_events + ), f"Expected epoch-change on_peer_seen; got {new_events!r}" + + +# =========================================================================== +# Group 3 — MasterCoordinator end-to-end with live session +# =========================================================================== +class TestMasterCoordinatorLive: + """Exercise the Master-side composition over real Redis sessions.""" + + def test_coordinator_lifecycle_with_live_session( + self, raw_client, instance_id, + ): + self_sd = SharingDomainKey( + model_id="failover-" + uuid.uuid4().hex[:8], + pp_node_idx=0, pp_node_count=1, tp_node_idx=0, tp_node_count=1, is_nsa=False, + ) + epoch = make_session_epoch() + + sess = RedisSessionClient( + redis_client=raw_client, + instance_id=instance_id, epoch=epoch, ttl_seconds=5, + ) + sess.register() + _track(raw_client, sess.key) + + coord = MasterCoordinator( + self_sd=self_sd, + instance_id=instance_id, + session_epoch=epoch, + ) + try: + coord.expect_remotes(0) + # The session key is alive on the server. + assert raw_client.ttl(sess.key) > 0 + # MasterCoordinator composed without exception with a live Redis. + finally: + coord.shutdown() + + def test_register_and_load_instance_mapping_e2e( + self, raw_client, namespace, instance_id, + ): + """Master registers sd_key → node_id mapping in Redis; read it back + via two distinct paths (Python helper + raw HGETALL).""" + meta = RedisMeta( + host=REDIS_HOST, port=REDIS_PORT, local_ip="10.0.99.20", + namespace=namespace, + ) + try: + mapping = { + f"sd-{i}-{uuid.uuid4().hex[:4]}": 800 + i for i in range(4) + } + meta.register_instance_sd_nodes(instance_id, mapping) + key = SharingDomainNamespace.instance_sd_nodes_key(instance_id) + _track(raw_client, key) + + loaded = meta.load_instance_sd_nodes(instance_id) + assert loaded == mapping + + raw_hash = raw_client.hgetall(key) + assert {k: int(v) for k, v in raw_hash.items()} == mapping + finally: + _delete_test_sd(raw_client, namespace) diff --git a/tests/test_redis_meta_namespace.py b/tests/test_redis_meta_namespace.py new file mode 100644 index 0000000000..0eaab6db05 --- /dev/null +++ b/tests/test_redis_meta_namespace.py @@ -0,0 +1,438 @@ +"""Unit tests for ``flexkv.cache.redis_meta`` after Phase 0 task 0-C migration. + +These tests validate that every Redis key emitted by ``RedisMeta`` / +``RedisNodeInfo`` is **SD-scoped** via the new +``SharingDomainNamespace``. They use :mod:`tests._dist_reuse_fakes` to +stand in for a real Redis server so no network (and no Redis install) is +required. + +The test module does **not** import ``flexkv.cache.redis_meta`` via the +normal package path, because ``flexkv.cache.__init__`` unconditionally +loads the CUDA-linked C++ extension (``flexkv.c_ext``). On CPU-only CI +workers that fails before we can get to our code. Instead we load +``redis_meta.py`` directly via :mod:`importlib.util` after patching the +``redis`` module to hand out our :class:`FakeRedis` instances. +""" +from __future__ import annotations + +import importlib.util +import os +import sys +import types +import unittest.mock as mock +from pathlib import Path +from typing import Any + +import pytest + +from flexkv.common.dist_reuse import SharingDomainKey, SharingDomainNamespace + +# Shared fake Redis server (see tests/_dist_reuse_fakes.py). +sys.path.insert(0, str(Path(__file__).parent)) +from _dist_reuse_fakes import FakeRedis, ManualClock # noqa: E402 + + +# --------------------------------------------------------------------------- +# Module loader: import ``flexkv.cache.redis_meta`` without importing the +# ``flexkv.cache`` package init (which pulls in CUDA). Returns a fresh copy +# on each call so module-level state can't leak between test functions. +# --------------------------------------------------------------------------- +def _load_redis_meta(fake_client_factory) -> Any: + """Load ``flexkv.cache.redis_meta`` directly from source. + + ``fake_client_factory`` is a callable ``() -> FakeRedis`` used to + replace ``redis.Redis`` for the duration of this module's lifecycle. + Returns the freshly loaded module object. + """ + pkg_root = Path(__file__).resolve().parent.parent + src = pkg_root / "flexkv" / "cache" / "redis_meta.py" + assert src.exists(), f"missing source file {src}" + + # Stub out redis-py so ``import redis as _redis`` inside redis_meta.py + # hands us a module whose ``Redis(...)`` constructor returns our fake. + fake_redis_mod = types.ModuleType("redis") + + def _ctor(*args, **kwargs): # noqa: ARG001 — mimic redis.Redis signature + return fake_client_factory() + + fake_redis_mod.Redis = _ctor # type: ignore[attr-defined] + # redis_meta.py only uses ``redis.Redis``; no other attrs referenced. + + # Patch sys.modules for the import that will happen inside spec.loader.exec_module. + original_redis = sys.modules.get("redis") + original_cache_pkg = sys.modules.get("flexkv.cache.redis_meta") + sys.modules["redis"] = fake_redis_mod + + try: + spec = importlib.util.spec_from_file_location( + "_rm_under_test", str(src), + ) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + # ``@dataclass`` looks the module up via ``sys.modules[cls.__module__]``. + # Register the module BEFORE executing it so decorators don't see None. + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + finally: + if original_redis is None: + sys.modules.pop("redis", None) + else: + sys.modules["redis"] = original_redis + if original_cache_pkg is not None: + sys.modules["flexkv.cache.redis_meta"] = original_cache_pkg + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def sd(): + return SharingDomainKey( + model_id="abc", + pp_node_idx=0, pp_node_count=2, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + + +@pytest.fixture +def ns(sd): + return SharingDomainNamespace(sd) + + +@pytest.fixture +def shared_fake(): + """One FakeRedis shared across all ``redis.Redis(...)`` calls in a test, + so the heartbeat/listener threads and the main thread see consistent + state.""" + return FakeRedis() + + +@pytest.fixture +def rm_module(shared_fake): + return _load_redis_meta(lambda: shared_fake) + + +# --------------------------------------------------------------------------- +# _channel_blocks_key +# --------------------------------------------------------------------------- +def test_channel_blocks_key_composition(rm_module, ns): + composed = rm_module._channel_blocks_key(ns, "CPUB") + assert composed == f"{ns.prefix}:CPUB" + + composed_no_device = rm_module._channel_blocks_key(ns, "") + assert composed_no_device == ns.prefix + + +def test_channel_blocks_key_with_sd_only_ns(rm_module): + default_ns = SharingDomainNamespace(SharingDomainKey.default()) + assert rm_module._channel_blocks_key(default_ns, "SSDB").startswith("sd:__default__:") + + +# --------------------------------------------------------------------------- +# _resolve_namespace +# --------------------------------------------------------------------------- +def test_resolve_namespace_default(rm_module): + ns = rm_module._resolve_namespace(None) + assert isinstance(ns, SharingDomainNamespace) + assert ns.sd_key.model_id == "__default__" + + +def test_resolve_namespace_accepts_sd_key(rm_module): + sd = SharingDomainKey.default() + ns = rm_module._resolve_namespace(sd) + assert ns.sd_key == sd + + +def test_resolve_namespace_rejects_garbage(rm_module): + with pytest.raises(TypeError): + rm_module._resolve_namespace("not-a-namespace") + + +# --------------------------------------------------------------------------- +# RedisNodeInfo registration +# --------------------------------------------------------------------------- +class TestNodeRegistration: + def _make(self, rm_module, ns, fake): + # Don't let signal handlers interfere with pytest. + with mock.patch("signal.signal"): + info = rm_module.RedisNodeInfo( + host="fake", port=0, + local_ip="10.0.0.1", + password="", + node_ttl_seconds=60, + namespace=ns, + ) + return info + + def test_register_writes_sd_scoped_key(self, rm_module, shared_fake, ns): + info = self._make(rm_module, ns, shared_fake) + + # Bypass connect() / heartbeat thread: just wire up the client. + info._client = shared_fake + node_id = info.register_node() + assert node_id is not None + assert node_id >= 1 # global:node_id incremented + # Verify the key is SD-scoped, NOT bare "node:" + expected_key = ns.node_key(node_id) + assert shared_fake.hget(expected_key, "node_id") == str(node_id) + assert shared_fake.hget(expected_key, "sd_key") == ns.serialized_sd + # And make sure the legacy bare key is *not* present + assert shared_fake.exists(f"node:{node_id}") == 0 + + def test_namespace_property(self, rm_module, shared_fake, ns): + info = self._make(rm_module, ns, shared_fake) + assert info.namespace == ns + assert info.sd_key_str == ns.serialized_sd + + def test_two_infos_isolate_per_sd(self, rm_module, shared_fake): + ns_a = SharingDomainNamespace(SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=1, tp_node_idx=0, + tp_node_count=1, is_nsa=False, + )) + ns_b = SharingDomainNamespace(SharingDomainKey( + model_id="m", pp_node_idx=1, pp_node_count=2, tp_node_idx=0, + tp_node_count=1, is_nsa=False, + )) + info_a = self._make(rm_module, ns_a, shared_fake) + info_b = self._make(rm_module, ns_b, shared_fake) + info_a._client = shared_fake + info_b._client = shared_fake + nid_a = info_a.register_node() + nid_b = info_b.register_node() + assert nid_a != nid_b + + # Scanning via SD-A pattern finds only SD-A's key. + info_a.scan_active_nodes() + info_b.scan_active_nodes() + assert info_a.get_active_node_ids() == [nid_a] + assert info_b.get_active_node_ids() == [nid_b] + + +# --------------------------------------------------------------------------- +# RedisMeta.get_redis_meta_channel — blocks_key composition +# --------------------------------------------------------------------------- +class TestRedisMetaChannelFactory: + def test_device_prefix_composes_into_blocks_key( + self, rm_module, shared_fake, ns, monkeypatch + ): + # Intercept RedisMetaChannel so we don't need the C++ ext to be built. + captured = {} + + class _StubChannel: + def __init__(self, host, port, node_id, local_ip, blocks_key, password, db=0): + captured["blocks_key"] = blocks_key + captured["host"] = host + captured["node_id"] = node_id + + def connect(self): + return True + + monkeypatch.setattr(rm_module, "RedisMetaChannel", _StubChannel) + + with mock.patch("signal.signal"): + meta = rm_module.RedisMeta( + host="h", port=6379, password="pw", + local_ip="10.0.0.1", decode_responses=True, + node_ttl_seconds=60, namespace=ns, + ) + meta._node_id = 42 + + _ = meta.get_redis_meta_channel(device_prefix="CPUB") + assert captured["blocks_key"] == f"{ns.prefix}:CPUB" + + _ = meta.get_redis_meta_channel(device_prefix="") + assert captured["blocks_key"] == ns.prefix + + def test_positional_device_prefix(self, rm_module, shared_fake, ns, monkeypatch): + captured = {} + + class _StubChannel: + def __init__(self, host, port, node_id, local_ip, blocks_key, password, db=0): + captured["blocks_key"] = blocks_key + + def connect(self): + return True + + monkeypatch.setattr(rm_module, "RedisMetaChannel", _StubChannel) + + with mock.patch("signal.signal"): + meta = rm_module.RedisMeta( + host="h", port=6379, password=None, + local_ip="10.0.0.1", namespace=ns, + ) + meta._node_id = 7 + # Legacy call style from hie_cache_engine.py uses positional arg. + _ = meta.get_redis_meta_channel("SSDB") + assert captured["blocks_key"] == f"{ns.prefix}:SSDB" + + +# --------------------------------------------------------------------------- +# RedisMeta.regist_buffer / regist_node_meta use SD-scoped keys +# --------------------------------------------------------------------------- +class TestRedisMetaBufferAndNodeMeta: + def _make(self, rm_module, ns): + with mock.patch("signal.signal"): + meta = rm_module.RedisMeta( + host="h", port=0, password=None, + local_ip="10.0.0.1", namespace=ns, + ) + meta._node_id = 5 + return meta + + def test_regist_buffer_sd_scoped(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.regist_buffer([{"buffer_ptr": 0xABC, "buffer_size": 1024}]) + # Expected key: sd::buffer:5:2748 + expected_key = ns.buffer_key(5, 0xABC) + assert shared_fake.hget(expected_key, "buffer_size") == "1024" + # No legacy key + assert shared_fake.exists(f"buffer:5:{0xABC}") == 0 + + def test_unregist_buffer(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.regist_buffer([(0xDEF, 256)]) + ok = meta.unregist_buffer(0xDEF) + assert ok is True + assert shared_fake.exists(ns.buffer_key(5, 0xDEF)) == 0 + # Double unregister returns False + assert meta.unregist_buffer(0xDEF) is False + + def test_regist_node_meta_sd_scoped(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.regist_node_meta( + node_id=5, + addr="10.0.0.1:5555", + zmq_addr="tcp://10.0.0.1:6666", + cpu_buffer_ptr=0x1000, + ssd_buffer_ptr=0x2000, + ) + expected = ns.meta_key(5) + assert shared_fake.hget(expected, "addr") == "10.0.0.1:5555" + assert shared_fake.hget(expected, "zmq_addr") == "tcp://10.0.0.1:6666" + assert shared_fake.hget(expected, "cpu_buffer_ptr") == "4096" + assert shared_fake.exists(f"meta:5") == 0 # no legacy key + + def test_get_node_meta_round_trip(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.regist_node_meta(5, "addr", "zmq", 100, 200) + info = meta.get_node_meta(5) + assert info == { + "node_id": 5, "addr": "addr", "zmq_addr": "zmq", + "cpu_buffer_ptr": 100, "ssd_buffer_ptr": 200, + } + + def test_get_node_meta_missing(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + assert meta.get_node_meta(9999) == {} + + def test_unregist_node_meta(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.regist_node_meta(5, "addr", "zmq", 100, 200) + assert meta.unregist_node_meta(5) is True + assert meta.unregist_node_meta(5) is False + + +# --------------------------------------------------------------------------- +# Instance-level cross-SD keys (design doc §4.7.1.6) +# --------------------------------------------------------------------------- +class TestInstanceSdNodes: + def _make(self, rm_module, ns): + with mock.patch("signal.signal"): + return rm_module.RedisMeta( + host="h", port=0, password=None, + local_ip="10.0.0.1", namespace=ns, + ) + + def test_register_load_round_trip(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.register_instance_sd_nodes( + instance_id="inst-001", + sd_to_nid={ + "abc:ppn0/2:tpn0/2:nsa0": 1, + "abc:ppn1/2:tpn0/2:nsa0": 2, + "abc:ppn0/2:tpn1/2:nsa0": 3, + }, + ) + got = meta.load_instance_sd_nodes("inst-001") + assert got == { + "abc:ppn0/2:tpn0/2:nsa0": 1, + "abc:ppn1/2:tpn0/2:nsa0": 2, + "abc:ppn0/2:tpn1/2:nsa0": 3, + } + + def test_missing_instance_returns_empty(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + assert meta.load_instance_sd_nodes("never-seen") == {} + + def test_unregister(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.register_instance_sd_nodes("inst-002", {"sd0": 1}) + assert meta.unregister_instance_sd_nodes("inst-002") is True + assert meta.load_instance_sd_nodes("inst-002") == {} + + def test_register_empty_mapping_is_noop(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.register_instance_sd_nodes("inst-003", {}) + # Nothing should have been written. + key = SharingDomainNamespace.instance_sd_nodes_key("inst-003") + assert shared_fake.exists(key) == 0 + + +# --------------------------------------------------------------------------- +# _cleanup_node_data +# --------------------------------------------------------------------------- +class TestCleanupNodeData: + def _make_info(self, rm_module, ns, fake): + with mock.patch("signal.signal"): + info = rm_module.RedisNodeInfo( + host="h", port=0, local_ip="ip", + namespace=ns, + ) + info._client = fake + return info + + def test_cleanup_removes_all_sd_scoped_keys(self, rm_module, shared_fake, ns): + info = self._make_info(rm_module, ns, shared_fake) + # Seed a variety of keys belonging to node 5. + shared_fake.hset(ns.meta_key(5), mapping={"foo": "bar"}) + shared_fake.hset(ns.buffer_key(5, 0x100), mapping={"sz": "1"}) + shared_fake.hset(f"{ns.prefix}:CPUB:block:5:abc", mapping={"x": "1"}) + shared_fake.hset(f"{ns.prefix}:SSDB:block:5:def", mapping={"x": "1"}) + # Keys for a different node — must survive. + shared_fake.hset(ns.meta_key(6), mapping={"foo": "keep"}) + shared_fake.hset(f"{ns.prefix}:CPUB:block:6:abc", mapping={"x": "keep"}) + + info._cleanup_node_data(5) + + # Node 5 is gone + assert shared_fake.exists(ns.meta_key(5)) == 0 + assert shared_fake.exists(ns.buffer_key(5, 0x100)) == 0 + assert shared_fake.exists(f"{ns.prefix}:CPUB:block:5:abc") == 0 + assert shared_fake.exists(f"{ns.prefix}:SSDB:block:5:def") == 0 + # Node 6 is untouched + assert shared_fake.exists(ns.meta_key(6)) == 1 + assert shared_fake.exists(f"{ns.prefix}:CPUB:block:6:abc") == 1 + + +# --------------------------------------------------------------------------- +# PCFS file-nodeid mapping +# --------------------------------------------------------------------------- +class TestPcfsMapping: + def _make(self, rm_module, ns): + with mock.patch("signal.signal"): + meta = rm_module.RedisMeta( + host="h", port=0, local_ip="ip", namespace=ns, + ) + meta._node_id = 1 + return meta + + def test_add_and_load(self, rm_module, shared_fake, ns): + meta = self._make(rm_module, ns) + meta.add_node_ids([10, 20, 30]) + # Simulate another node on the same SD pushing its own list. + shared_fake.rpush(f"{ns.prefix}:pcfs:2", "100", "200") + + loaded = meta.load_pcfs_file_nodeids() + assert loaded == {1: [10, 20, 30], 2: [100, 200]} diff --git a/tests/test_redis_metachannel_sd_prefix.py b/tests/test_redis_metachannel_sd_prefix.py new file mode 100644 index 0000000000..14b0e7d975 --- /dev/null +++ b/tests/test_redis_metachannel_sd_prefix.py @@ -0,0 +1,137 @@ +"""Unit tests for :class:`RedisMetaChannel` SD-prefix derivation logic. + +The C++ ``RedisMetaChannel::list_node_keys`` strips the device suffix from +``blocks_key`` to produce the per-SD node scan pattern. The Python wrapper +``_derive_sd_prefix`` mirrors that logic for the fallback path that runs +when the C++ extension is older than Batch B (or built without +FLEXKV_ENABLE_P2P). The two implementations MUST stay in sync — this test +enforces that by exercising every SD + device permutation the design-doc +allows. + +We load ``redis_meta.py`` directly via importlib (see +``tests/test_redis_meta_namespace.py`` for the rationale) and build a +``RedisMetaChannel`` **without** connecting to Redis — the method under +test is purely string manipulation on ``self._blocks_key``. +""" +from __future__ import annotations + +import importlib.util +import sys +import types +import unittest.mock as mock +from pathlib import Path +from typing import Any + +import pytest + +from flexkv.common.dist_reuse import SharingDomainKey, SharingDomainNamespace + +sys.path.insert(0, str(Path(__file__).parent)) +from _dist_reuse_fakes import FakeRedis # noqa: E402 + + +def _load_redis_meta(): + pkg_root = Path(__file__).resolve().parent.parent + src = pkg_root / "flexkv" / "cache" / "redis_meta.py" + fake_redis_mod = types.ModuleType("redis") + fake_redis_mod.Redis = lambda *a, **kw: FakeRedis() # type: ignore[attr-defined] + original_redis = sys.modules.get("redis") + sys.modules["redis"] = fake_redis_mod + try: + spec = importlib.util.spec_from_file_location("_rm_cc_ut", str(src)) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod # so @dataclass can resolve the module + spec.loader.exec_module(mod) + return mod + finally: + if original_redis is None: + sys.modules.pop("redis", None) + else: + sys.modules["redis"] = original_redis + + +@pytest.fixture(scope="module") +def rm(): + return _load_redis_meta() + + +def _sd_key(**overrides) -> str: + """Serialize an SD key for test input construction. + + Uses the simplified schema: no cp_rank / cp_size fields. + """ + defaults = dict( + model_id="abc123", + pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + defaults.update(overrides) + return SharingDomainKey(**defaults).serialize() + + +class TestDeriveSdPrefix: + """Matches the logic in ``csrc/dist/redis_meta_channel.cpp:list_node_keys``. + + Layout (simplified — no cp segment): + blocks_key = sd: → SD-only; 4 colons inside sd_key + blocks_key = sd:: → SD+device; one extra colon + blocks_key = "blocks" / "CPUB" → legacy; 0–1 colons + """ + + def _make_channel(self, rm, blocks_key: str) -> Any: + # Construct without connecting: instantiate the wrapper directly + # around a stub that captures ``blocks_key``. + wrapper = rm.RedisMetaChannel.__new__(rm.RedisMetaChannel) + wrapper._blocks_key = blocks_key + return wrapper + + @pytest.mark.parametrize("device", ["", "CPUB", "SSDB", "PCFSB"]) + @pytest.mark.parametrize("pp,tpn,nsa", [ + (1, 1, False), + (2, 1, False), + (1, 2, False), + (1, 1, True), + (2, 2, True), # upper bound under the simplified design (PP × tpn) + ]) + def test_sd_plus_optional_device(self, rm, device, pp, tpn, nsa): + sd_str = _sd_key(pp_node_count=pp, tp_node_count=tpn, is_nsa=nsa) + blocks_key = f"sd:{sd_str}" + (f":{device}" if device else "") + ch = self._make_channel(rm, blocks_key) + derived = ch._derive_sd_prefix() + expected = f"sd:{sd_str}" + assert derived == expected, f"blocks_key={blocks_key!r} produced {derived!r}" + + def test_legacy_bare_key_collapses_to_empty(self, rm): + for bk in ("blocks", "CPUB", "SSDB", "", "something-else"): + ch = self._make_channel(rm, bk) + assert ch._derive_sd_prefix() == "" + + +class TestListNodeKeysPattern: + """Exercise the pattern that ``RedisMetaChannel.list_node_keys`` + fallback would scan when the C++ side is out-of-date.""" + + def test_sd_aware_pattern(self, rm, monkeypatch): + captured = {} + + class _FakeCExt: + """Stands in for ``flexkv.c_ext.RedisMetaChannel``.""" + def list_node_keys(self): + # Simulate a pre-Batch-B C++ that still returns bare ``node:*`` + # keys. The wrapper must discard them and fall back to the + # Python scan. + return ["node:1", "node:2"] + + def list_keys(self, pattern): + captured["pattern"] = pattern + return ["sd:abc:ppn0/1:tpn0/1:nsa0:node:99"] + + wrapper = rm.RedisMetaChannel.__new__(rm.RedisMetaChannel) + wrapper._c = _FakeCExt() + wrapper._blocks_key = "sd:abc:ppn0/1:tpn0/1:nsa0:CPUB" + # The fallback should scan the *SD-scoped* pattern, NOT bare node:* + keys = wrapper.list_node_keys() + assert captured["pattern"] == "sd:abc:ppn0/1:tpn0/1:nsa0:node:*" + assert keys == ["sd:abc:ppn0/1:tpn0/1:nsa0:node:99"] diff --git a/tests/test_sd_enumerate_max.py b/tests/test_sd_enumerate_max.py new file mode 100644 index 0000000000..0d78241a33 --- /dev/null +++ b/tests/test_sd_enumerate_max.py @@ -0,0 +1,249 @@ +"""Scalability sanity tests for :class:`SharingDomainKey` under the simplified +dist_reuse schema. + +Design doc §4.1 / §4.5 (simplified): CP is **not** part of the SD key. The +SD-count upper bound is ``pp_node_count × tp_node_count`` (≤ 2 in the +current deployment, ≤ 4 in the future-extension scope ``pp_node_count=2 × +tp_node_count=2``). + +We don't enforce a *hard* cap on those dims at the data-structure layer — +this test confirms that: + +* :meth:`SharingDomainKey.enumerate_peers` returns exactly + ``pp_node_count × tp_node_count`` unique SDs for every shape; +* every enumerated SD serializes to a distinct, round-trippable string; +* the serialization is purely textual so it can be used verbatim as a + Redis key-namespace prefix; +* :func:`graph_needs_gpu_clear` returns True iff the peer's + ``pp_node_idx`` differs from self's (TP-node dim alone keeps the same + slot_mapping, so no GPU clear is required); +* :meth:`total_sd_count` equals ``len(enumerate_peers())``. + +Covers all shapes listed in the simplified design doc §4.1 table, plus +two stress-sized configurations. +""" +from __future__ import annotations + +import hashlib +from typing import Dict, Set, Tuple + +import pytest + +from flexkv.common.dist_reuse import ( + SharingDomainKey, + SharingDomainNamespace, + graph_needs_gpu_clear, +) + + +MODEL_ID = "scaletest" + + +def _sd(ppn_idx=0, ppn_count=1, tpn_idx=0, tpn_count=1, + is_nsa=False) -> SharingDomainKey: + """Build an SD key with the simplified, node-granularity PP schema.""" + return SharingDomainKey( + model_id=MODEL_ID, + pp_node_idx=ppn_idx, pp_node_count=ppn_count, + tp_node_idx=tpn_idx, tp_node_count=tpn_count, + is_nsa=is_nsa, + ) + + +# --------------------------------------------------------------------------- +# Shapes drawn from the simplified design doc §4.1 + a few stress cases. +# (ppn, tpn, label) +# --------------------------------------------------------------------------- +SHAPES = [ + (1, 1, "1-SD degenerate"), + (2, 1, "pp_node_count=2 (current scope)"), + (1, 2, "cross-node TP=2 (current scope)"), + (2, 2, "pp_node_count=2 × cross-node TP=2 (future-extension upper bound, 4 SDs)"), + # Stress: confirm no hidden caps. Real deployments never go here. + (4, 1, "pp_node_count=4 (stress)"), + (1, 8, "tp_node_count=8 (stress)"), + (4, 4, "pp_node_count=4 × tp_node_count=4 (stress, 16 SDs)"), +] + + +# =========================================================================== +# Enumerate completeness + uniqueness +# =========================================================================== +@pytest.mark.parametrize("ppn,tpn,label", SHAPES) +def test_enumerate_peers_is_complete_cartesian_product(ppn, tpn, label): + """``enumerate_peers`` must return every SD in the instance exactly once + — product count = ppn × tpn, no duplicates. CP no longer contributes.""" + master = _sd(ppn_count=ppn, tpn_count=tpn) + peers = master.enumerate_peers() + + # Count matches the Cartesian product pp_node_count × tpn. + assert len(peers) == ppn * tpn == master.total_sd_count() + + # Every (pp_node_idx, tp_node_idx) pair appears exactly once. + pairs: Dict[Tuple[int, int], int] = {} + for p in peers: + pairs[(p.pp_node_idx, p.tp_node_idx)] = ( + pairs.get((p.pp_node_idx, p.tp_node_idx), 0) + 1 + ) + assert set(pairs.values()) == {1}, ( + f"[{label}] duplicate SDs detected in enumerate_peers: {pairs}" + ) + # And every coordinate in range. + for (pi, ti), _ in pairs.items(): + assert 0 <= pi < ppn + assert 0 <= ti < tpn + + +@pytest.mark.parametrize("ppn,tpn,label", SHAPES) +def test_enumerated_sds_serialize_to_distinct_strings(ppn, tpn, label): + """Every SD in the same instance must produce a unique serialize(). + + This is the core guarantee that lets us use ``sd.serialize()`` as a + Redis key prefix without worrying about cross-SD key collisions. + """ + master = _sd(ppn_count=ppn, tpn_count=tpn) + serials: Set[str] = {p.serialize() for p in master.enumerate_peers()} + assert len(serials) == ppn * tpn, ( + f"[{label}] serialize() collided for some SDs — Redis key namespaces " + f"would overlap. Unique serials: {len(serials)} / expected " + f"{ppn * tpn}" + ) + + +@pytest.mark.parametrize("ppn,tpn,label", SHAPES) +def test_serialize_deserialize_round_trip(ppn, tpn, label): + master = _sd(ppn_count=ppn, tpn_count=tpn) + for p in master.enumerate_peers(): + s = p.serialize() + back = SharingDomainKey.deserialize(s) + assert back == p, ( + f"[{label}] round-trip failed: {p!r} → {s!r} → {back!r}" + ) + + +# =========================================================================== +# Simplified-design contract: CP must not show up in the SD key +# =========================================================================== +def test_no_cp_segment_in_any_serialized_sd(): + """Regression for the simplification: ``serialize()`` must not contain + a ``:cp<...>:`` segment in any enumerated SD, regardless of shape.""" + for ppn, tpn, _label in SHAPES: + for p in _sd(ppn_count=ppn, tpn_count=tpn).enumerate_peers(): + assert ":cp" not in p.serialize(), ( + f"unexpected CP segment in serialized SD: {p.serialize()!r}" + ) + + +def test_total_sd_count_is_only_pp_node_count_times_tpn(): + """Regression: enabling CP at the model-config level must not multiply + the SD count. In the simplified schema only pp_node_count × tpn matters.""" + # The simplified SharingDomainKey doesn't even accept CP kwargs, so we + # just confirm the formula on the most commonly-exercised shapes. + assert _sd(ppn_count=1, tpn_count=1).total_sd_count() == 1 + assert _sd(ppn_count=2, tpn_count=1).total_sd_count() == 2 + assert _sd(ppn_count=1, tpn_count=2).total_sd_count() == 2 + assert _sd(ppn_count=2, tpn_count=2).total_sd_count() == 4 + + +# =========================================================================== +# No hard-coded upper bounds — each dim independently scalable +# =========================================================================== +def test_no_hardcoded_limit_on_any_single_dim(): + """Each dim independently scaled to 16 must still enumerate cleanly. + + If any of the two dims were capped internally (e.g. by a fixed-size + array), this test would blow up. + """ + assert len(_sd(ppn_count=16).enumerate_peers()) == 16 + assert len(_sd(tpn_count=16).enumerate_peers()) == 16 + + +def test_stress_16_sds_enumerate_within_budget(): + """``4 × 4 = 16`` SDs: enumerate is O(N) and cheap.""" + import time + master = _sd(ppn_count=4, tpn_count=4) + t0 = time.perf_counter() + peers = master.enumerate_peers() + elapsed = time.perf_counter() - t0 + assert len(peers) == 16 + # Cheap generator should be near-instant on CPU. + assert elapsed < 0.1, f"enumerate_peers took {elapsed:.3f}s, expected < 0.1s" + + +# =========================================================================== +# graph_needs_gpu_clear semantic stays correct under the simplified schema +# =========================================================================== +def test_graph_needs_gpu_clear_only_pp_node_dim_forces_clear(): + """Full 2-D grid sweep — under the simplified schema the rule is: + + clear? ⇔ peer.pp_node_idx != self.pp_node_idx + + Crossing the TP-node boundary alone does NOT require clearing + (TP shares the slot_mapping across all of its ranks). CP is not in + the SD key any more so there's no CP component to this check. + """ + master = _sd(ppn_count=4, tpn_count=4) + # Self → no clear. + assert graph_needs_gpu_clear(master, master) is False + + for peer in master.enumerate_peers(): + expect_clear = peer.pp_node_idx != master.pp_node_idx + got = graph_needs_gpu_clear(master, peer) + assert got == expect_clear, ( + f"graph_needs_gpu_clear({master!r}, {peer!r}) = {got}; " + f"expected {expect_clear} (rule: pp_node_idx differs)" + ) + + +# =========================================================================== +# Namespace-level: each SD's Redis prefix is a proper subset partition +# =========================================================================== +@pytest.mark.parametrize("ppn,tpn,label", [ + (2, 2, "current upper bound: 4 SDs"), + (4, 4, "stress: 16 SDs"), +]) +def test_each_sd_has_disjoint_redis_prefix(ppn, tpn, label): + master = _sd(ppn_count=ppn, tpn_count=tpn) + prefixes: Set[str] = set() + for p in master.enumerate_peers(): + ns = SharingDomainNamespace(p) + prefixes.add(ns.prefix) + # prefixes[i] is the full leading string for all keys in SD i. + # They must be unique, and none may be a prefix of another (otherwise a + # SCAN sd::* would leak into SD j). + sorted_prefixes = sorted(prefixes) + for i in range(len(sorted_prefixes) - 1): + a, b = sorted_prefixes[i], sorted_prefixes[i + 1] + assert not b.startswith(a + ":"), ( + f"[{label}] prefix '{a}' is a proper prefix of '{b}' — SCANs would leak" + ) + + +# =========================================================================== +# model_id invariance: same topology + same model_arch ⇒ same sd_key strings +# =========================================================================== +def test_model_id_is_stable_across_calls(): + """``model_id`` must be deterministic across runs — serialize() is a + pure function of the 6 dataclass fields (no hidden randomness).""" + sd1 = _sd(ppn_count=4, tpn_count=4, ppn_idx=1, tpn_idx=2) + sd2 = _sd(ppn_count=4, tpn_count=4, ppn_idx=1, tpn_idx=2) + assert sd1.serialize() == sd2.serialize() + h1 = hashlib.sha1(sd1.serialize().encode()).hexdigest() + h2 = hashlib.sha1(sd2.serialize().encode()).hexdigest() + assert h1 == h2 + + +# =========================================================================== +# Master role count: exactly ONE SD in the instance is the Master +# =========================================================================== +@pytest.mark.parametrize("ppn,tpn,label", SHAPES) +def test_exactly_one_master_in_any_instance(ppn, tpn, label): + master = _sd(ppn_count=ppn, tpn_count=tpn) + peers = master.enumerate_peers() + masters = [p for p in peers if p.is_master()] + assert len(masters) == 1, ( + f"[{label}] expected exactly 1 Master SD in the instance; got {len(masters)}" + ) + # And it has all-zero ranks. + m = masters[0] + assert m.pp_node_idx == 0 and m.tp_node_idx == 0 diff --git a/tests/test_sharing_domain_key.py b/tests/test_sharing_domain_key.py new file mode 100644 index 0000000000..e4283251ca --- /dev/null +++ b/tests/test_sharing_domain_key.py @@ -0,0 +1,416 @@ +"""Unit tests for ``flexkv.common.dist_reuse.sharing_domain``. + +Phase 0 — simplified design (CP not in sd_key, ``is_nsa_cp`` renamed to +``is_nsa``). See +``docs/dist_reuse/dist_reuse_with_cp_pp_multinode_tp_simplified.md`` for the +authoritative definition of the current schema. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from flexkv.common.dist_reuse.sharing_domain import ( + DEFAULT_MODEL_ID, + SharingDomainKey, + derive_model_id, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +@dataclass +class _FakeModelConfig: + """Minimal stub of ``ModelConfig`` for ``from_model_config`` tests. + + Matches the post-rename field set: uses ``is_nsa`` (not ``is_nsa_cp``). + Also carries ``nnodes`` so the node-granularity PP collapse can be + exercised (``nnodes=pp_size`` for the typical pipelined-across-nodes + case; ``nnodes=1, pp_size>1`` is rejected by the factory). + """ + + num_layers: int = 32 + num_kv_heads: int = 8 + head_size: int = 128 + use_mla: bool = False + dtype: Any = "bfloat16" + pp_rank: int = 0 + pp_size: int = 1 + nnodes: int = 1 + tp_node_idx: int = 0 + tp_node_count: int = 1 + # CP info is *not* in the SD key (simplified design §4.5), but + # ModelConfig still carries attn_cp_* fields for other purposes — the + # from_model_config factory just doesn't read them. + attn_cp_rank: int = 0 + attn_cp_size: int = 1 + is_nsa: bool = False + model_id: Any = None + + +# --------------------------------------------------------------------------- +# derive_model_id +# --------------------------------------------------------------------------- +class TestDeriveModelId: + def test_stable_across_calls(self): + a = derive_model_id(num_layers=32, num_kv_heads=8, head_size=128, + dtype="bfloat16", use_mla=False) + b = derive_model_id(num_layers=32, num_kv_heads=8, head_size=128, + dtype="bfloat16", use_mla=False) + assert a == b + assert len(a) == 16 + # All hex chars + int(a, 16) + + def test_changes_with_field(self): + ref = derive_model_id(num_layers=32, num_kv_heads=8, head_size=128, + dtype="bfloat16", use_mla=False) + for tweak in ( + dict(num_layers=64), + dict(num_kv_heads=16), + dict(head_size=256), + dict(dtype="float16"), + dict(use_mla=True), + ): + kwargs = dict(num_layers=32, num_kv_heads=8, head_size=128, + dtype="bfloat16", use_mla=False) + kwargs.update(tweak) + assert derive_model_id(**kwargs) != ref, f"tweak {tweak} should change model_id" + + +# --------------------------------------------------------------------------- +# Construction & validation +# --------------------------------------------------------------------------- +class TestSharingDomainKeyValidation: + @pytest.fixture + def base_kwargs(self): + return dict( + model_id="abc123", + pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + + def test_construct_ok(self, base_kwargs): + sd = SharingDomainKey(**base_kwargs) + assert sd.model_id == "abc123" + assert sd.is_master() + + def test_cp_fields_rejected(self, base_kwargs): + """CP is not in the SD key any more — legacy cp_* kwargs must fail.""" + base_kwargs["cp_rank"] = 0 + base_kwargs["cp_size"] = 4 + with pytest.raises(TypeError): + SharingDomainKey(**base_kwargs) + + def test_is_nsa_cp_kwarg_rejected(self, base_kwargs): + """Old ``is_nsa_cp`` kwarg must not silently slip through — the new + canonical name is ``is_nsa`` and constructors are strict.""" + base_kwargs.pop("is_nsa") + base_kwargs["is_nsa_cp"] = False + with pytest.raises(TypeError): + SharingDomainKey(**base_kwargs) + + @pytest.mark.parametrize("bad_id", ["", "has:colon", "has space", None, 123]) + def test_bad_model_id(self, base_kwargs, bad_id): + base_kwargs["model_id"] = bad_id + with pytest.raises((ValueError, TypeError)): + SharingDomainKey(**base_kwargs) + + @pytest.mark.parametrize("field,value", [ + ("pp_node_count", 0), + ("tp_node_count", -1), + ]) + def test_bad_count(self, base_kwargs, field, value): + base_kwargs[field] = value + with pytest.raises((ValueError, TypeError)): + SharingDomainKey(**base_kwargs) + + def test_idx_out_of_range(self, base_kwargs): + base_kwargs.update(pp_node_idx=2, pp_node_count=2) # max idx = count - 1 = 1 + with pytest.raises(ValueError): + SharingDomainKey(**base_kwargs) + + def test_negative_idx(self, base_kwargs): + base_kwargs["tp_node_idx"] = -1 + with pytest.raises(ValueError): + SharingDomainKey(**base_kwargs) + + def test_is_nsa_must_be_bool(self, base_kwargs): + base_kwargs["is_nsa"] = 1 # type: ignore[arg-type] + with pytest.raises(ValueError): + SharingDomainKey(**base_kwargs) + + +# --------------------------------------------------------------------------- +# Serialization round-trip +# --------------------------------------------------------------------------- +class TestSharingDomainKeySerialization: + @pytest.mark.parametrize("sd", [ + SharingDomainKey.default(), + SharingDomainKey( + model_id="abc123", + pp_node_idx=0, pp_node_count=2, + tp_node_idx=1, tp_node_count=2, + is_nsa=True, + ), + SharingDomainKey( + model_id="model-v1.2_x", + pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ), + ]) + def test_round_trip(self, sd): + s = sd.serialize() + # Sanity: contains all four segments separated by ':' + assert s.count(":") == 3 + out = SharingDomainKey.deserialize(s) + assert out == sd + assert out.serialize() == s + + def test_serialize_format(self): + sd = SharingDomainKey( + model_id="abc123", + pp_node_idx=1, pp_node_count=2, + tp_node_idx=0, tp_node_count=2, + is_nsa=True, + ) + assert sd.serialize() == "abc123:ppn1/2:tpn0/2:nsa1" + + def test_serialize_no_cp_segment(self): + """Regression: the key must never contain ':cp<...>' — CP is gone.""" + sd = SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=1, + tp_node_idx=0, tp_node_count=1, is_nsa=False, + ) + assert ":cp" not in sd.serialize() + + def test_serialize_uses_ppn_prefix_not_pp(self): + """Regression: the PP segment uses the node-granularity prefix + ``ppn``. The legacy rank-granularity prefix ``pp/`` + is gone.""" + sd = SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=2, + tp_node_idx=0, tp_node_count=1, is_nsa=False, + ) + s = sd.serialize() + assert ":ppn0/2:" in s + assert ":pp0/2:" not in s + + @pytest.mark.parametrize("bad", [ + "abc", # missing fields + "abc:ppn0/1:tpn0/1", # missing nsa + "abc:ppn0/1:tpn0/1:cp0/1:nsa0", # legacy 5-segment form with cp + "abc:foo0/1:tpn0/1:nsa0", # wrong ppn prefix + "abc:pp0/1:tpn0/1:nsa0", # legacy rank-granularity prefix + "abc:ppn0:tpn0/1:nsa0", # missing '/' + "abc:ppn0/1:tpn0/1:nsa2", # bad nsa value + "abc:ppnX/1:tpn0/1:nsa0", # non-int + ]) + def test_deserialize_rejects_bad(self, bad): + with pytest.raises(ValueError): + SharingDomainKey.deserialize(bad) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- +class TestFromModelConfig: + def test_basic(self): + mc = _FakeModelConfig() + sd = SharingDomainKey.from_model_config(mc) + assert sd.is_master() + assert sd.pp_node_count == 1 and sd.tp_node_count == 1 + # model_id should be a 16-char hex digest, not the placeholder. + assert sd.model_id != DEFAULT_MODEL_ID + assert len(sd.model_id) == 16 + + def test_explicit_model_id_takes_precedence(self): + mc = _FakeModelConfig(model_id="my-model") + sd = SharingDomainKey.from_model_config(mc) + assert sd.model_id == "my-model" + + def test_overrides(self): + # PP=2 across nnodes=2 — a legitimate cross-node PP layout. + mc = _FakeModelConfig(pp_size=2, nnodes=2) + sd_pp1 = SharingDomainKey.from_model_config( + mc, overrides={"pp_node_idx": 1}, + ) + assert sd_pp1.pp_node_idx == 1 + assert not sd_pp1.is_master() + + def test_single_node_pp_gt_1_rejected(self): + """Single-node PP>1 must not be allowed to derive an SD key — the + per-rank CPU pool only stores a layer slice, so a single sd_key + would alias incompatible bytes across PP ranks. Until the CPU + pool is reworked to full-layer storage, this combo is hard-blocked + at the SD-key factory.""" + for pp in (2, 4, 8): + mc = _FakeModelConfig(pp_size=pp, nnodes=1) + with pytest.raises(ValueError, match="single-node PP>1"): + SharingDomainKey.from_model_config(mc) + + def test_pp_collapses_to_node_granularity(self): + """Cross-node PP=4 over nnodes=2 should collapse rank pairs to the + same ``pp_node_idx`` (rank 0/1 → node 0, rank 2/3 → node 1).""" + # Rank 1 lives on node 0 (pp_per_node = 4//2 = 2 → 1 // 2 == 0). + mc = _FakeModelConfig(pp_size=4, nnodes=2, pp_rank=1) + sd_node0 = SharingDomainKey.from_model_config(mc) + assert sd_node0.pp_node_idx == 0 + assert sd_node0.pp_node_count == 2 + + # Rank 2 lives on node 1 (2 // 2 == 1). + mc = _FakeModelConfig(pp_size=4, nnodes=2, pp_rank=2) + sd_node1 = SharingDomainKey.from_model_config(mc) + assert sd_node1.pp_node_idx == 1 + assert sd_node1.pp_node_count == 2 + + def test_pp1_multinode_keeps_pp_node_count_1(self): + """PP=1 across multiple nodes (cross-node TP only) must keep + pp_node_count=1 — PP doesn't span nodes here.""" + mc = _FakeModelConfig(pp_size=1, nnodes=2, tp_node_count=2) + sd = SharingDomainKey.from_model_config(mc) + assert sd.pp_node_count == 1 and sd.pp_node_idx == 0 + assert sd.tp_node_count == 2 + + def test_unknown_override_rejected(self): + """Unknown overrides must raise — in particular the removed cp_rank.""" + with pytest.raises(ValueError): + SharingDomainKey.from_model_config( + _FakeModelConfig(), overrides={"bogus": 1}, + ) + + def test_cp_rank_override_rejected(self): + """Regression: cp_rank used to be a valid override under the old + schema. After simplification it must be rejected.""" + with pytest.raises(ValueError): + SharingDomainKey.from_model_config( + _FakeModelConfig(), overrides={"cp_rank": 1}, + ) + + def test_legacy_pp_rank_override_rejected(self): + """Regression: ``pp_rank`` was a valid override under the + rank-granularity schema; the SD key now only knows about + ``pp_node_idx`` so the legacy override must be rejected to + catch stale callers.""" + mc = _FakeModelConfig(pp_size=2, nnodes=2) + with pytest.raises(ValueError): + SharingDomainKey.from_model_config( + mc, overrides={"pp_rank": 1}, + ) + + def test_cp_fields_in_model_config_ignored(self): + """Having attn_cp_* on the model config must NOT leak into the SD key.""" + mc = _FakeModelConfig(attn_cp_rank=2, attn_cp_size=4) + sd = SharingDomainKey.from_model_config(mc) + # Cannot read cp_rank from sd — the field does not exist any more. + assert not hasattr(sd, "cp_rank") + assert not hasattr(sd, "cp_size") + # The legacy pp_rank/pp_size rank-granularity fields are gone too. + assert not hasattr(sd, "pp_rank") + assert not hasattr(sd, "pp_size") + # Serialization must also not contain a cp segment. + assert ":cp" not in sd.serialize() + # (ppn=0, tpn=0) is still the master SD. + assert sd.is_master() + + def test_is_nsa_picked_up(self): + mc = _FakeModelConfig(is_nsa=True) + sd = SharingDomainKey.from_model_config(mc) + assert sd.is_nsa is True + + +# --------------------------------------------------------------------------- +# enumerate_peers +# --------------------------------------------------------------------------- +class TestEnumeratePeers: + @pytest.mark.parametrize("ppn,tpn,expected", [ + (1, 1, 1), + (2, 1, 2), + (1, 2, 2), + (2, 2, 4), # current-scope upper bound (pp_node_count=2 × tp_node_count=2) + ]) + def test_count(self, ppn, tpn, expected): + sd = SharingDomainKey( + model_id="m", + pp_node_idx=0, pp_node_count=ppn, + tp_node_idx=0, tp_node_count=tpn, + is_nsa=False, + ) + peers = sd.enumerate_peers() + assert len(peers) == expected + assert sd.total_sd_count() == expected + + def test_upper_bound_is_four(self): + """Regression: the SD count upper bound is now 4 (pp_node_count × tp_node_count), + not 32 like in the legacy design that included CP in the SD key.""" + sd = SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=2, + tp_node_idx=0, tp_node_count=2, is_nsa=False, + ) + assert sd.total_sd_count() == 4 + + def test_unique_serialization(self): + sd = SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=2, + tp_node_idx=0, tp_node_count=2, is_nsa=True, + ) + peers = sd.enumerate_peers() + serialized = {p.serialize() for p in peers} + assert len(serialized) == 4, "4 SDs must produce 4 distinct keys" + # Every SD has the same model_id and is_nsa as the original. + assert all(p.model_id == "m" for p in peers) + assert all(p.is_nsa is True for p in peers) + # Exactly one master. + masters = [p for p in peers if p.is_master()] + assert len(masters) == 1 + + def test_iter_dunder(self): + sd = SharingDomainKey.default() + assert list(sd) == sd.enumerate_peers() + + +# --------------------------------------------------------------------------- +# default() +# --------------------------------------------------------------------------- +class TestDefault: + def test_default_is_master(self): + sd = SharingDomainKey.default() + assert sd.is_master() + assert sd.pp_node_count == sd.tp_node_count == 1 + assert not sd.is_nsa + assert sd.model_id == DEFAULT_MODEL_ID + + def test_default_serialize_round_trip(self): + sd = SharingDomainKey.default() + out = SharingDomainKey.deserialize(sd.serialize()) + assert out == sd + + def test_default_total_sd_count(self): + assert SharingDomainKey.default().total_sd_count() == 1 + + +# --------------------------------------------------------------------------- +# Hashability +# --------------------------------------------------------------------------- +class TestHashable: + def test_can_be_dict_key(self): + a = SharingDomainKey.default() + b = SharingDomainKey.default() + d = {a: 1} + assert d[b] == 1 # equal keys collapse + + def test_distinct_keys_distinct_hash(self): + sd = SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=2, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + peers = sd.enumerate_peers() + # Set semantics: 2 peers => 2 hashes (modulo collisions, which are + # vanishingly unlikely for tuples of small ints). + assert len({hash(p) for p in peers}) == 2 diff --git a/tests/test_sharing_domain_namespace.py b/tests/test_sharing_domain_namespace.py new file mode 100644 index 0000000000..0a99083237 --- /dev/null +++ b/tests/test_sharing_domain_namespace.py @@ -0,0 +1,136 @@ +"""Unit tests for ``flexkv.cache.sharing_domain_namespace`` (Phase 0 task 0-B).""" +from __future__ import annotations + +import pytest + +from flexkv.common.dist_reuse.sharing_domain import SharingDomainKey +from flexkv.common.dist_reuse.sharing_domain_namespace import ( + INSTANCE_KEY_PREFIX, + SD_KEY_PREFIX, + SharingDomainNamespace, +) + + +@pytest.fixture +def ns(): + sd = SharingDomainKey( + model_id="abc123", + pp_node_idx=1, pp_node_count=2, + tp_node_idx=0, tp_node_count=2, + is_nsa=True, + ) + return SharingDomainNamespace(sd) + + +def test_constructor_rejects_non_sd_key(): + with pytest.raises(TypeError): + SharingDomainNamespace("not-an-sd-key") # type: ignore[arg-type] + + +def test_prefix_format(ns): + assert ns.prefix == f"{SD_KEY_PREFIX}:abc123:ppn1/2:tpn0/2:nsa1" + assert ns.serialized_sd == "abc123:ppn1/2:tpn0/2:nsa1" + + +@pytest.mark.parametrize("builder,fmt", [ + ("node_key", "{prefix}:node:{nid}"), + ("meta_key", "{prefix}:meta:{nid}"), +]) +def test_simple_keys(ns, builder, fmt): + actual = getattr(ns, builder)(7) + assert actual == fmt.format(prefix=ns.prefix, nid=7) + + +def test_buffer_key(ns): + assert ns.buffer_key(7, 0xDEADBEEF) == f"{ns.prefix}:buffer:7:{0xDEADBEEF}" + + +def test_block_key_lower_hex(ns): + # 0x10ab... should render as 10abcdef, not 0X10ABCDEF + assert ns.block_key(7, 0x10ABCDEF) == f"{ns.prefix}:block:7:10abcdef" + + +def test_block_key_handles_negative_hash(ns): + """C++ side may produce signed int64 hashes; we mask to 64 bits so the + hex never carries a leading sign.""" + h = -1 # 64-bit two's complement => 0xFFFFFFFFFFFFFFFF + assert ns.block_key(0, h) == f"{ns.prefix}:block:0:ffffffffffffffff" + + +def test_aggregate_key(ns): + assert ns.aggregate_key(0xCAFEBABE) == f"{ns.prefix}:aggregate:cafebabe" + + +def test_scan_patterns(ns): + assert ns.node_key_pattern() == f"{ns.prefix}:node:*" + assert ns.meta_key_pattern() == f"{ns.prefix}:meta:*" + assert ns.buffer_key_pattern() == f"{ns.prefix}:buffer:*" + assert ns.block_key_pattern() == f"{ns.prefix}:block:*" + assert ns.block_key_pattern_for_node(7) == f"{ns.prefix}:block:7:*" + + +# --------------------------------------------------------------------------- +# Cross-SD instance keys +# --------------------------------------------------------------------------- +class TestInstanceKeys: + def test_session_key(self): + assert ( + SharingDomainNamespace.instance_session_key("inst-001") + == f"{INSTANCE_KEY_PREFIX}:inst-001:session" + ) + + def test_sd_nodes_key(self): + assert ( + SharingDomainNamespace.instance_sd_nodes_key("inst-001") + == f"{INSTANCE_KEY_PREFIX}:inst-001:sd_nodes" + ) + + def test_session_pattern(self): + assert ( + SharingDomainNamespace.instance_session_key_pattern() + == f"{INSTANCE_KEY_PREFIX}:*:session" + ) + + @pytest.mark.parametrize("bad_id", ["", "has space", "has:colon", "$dollar"]) + def test_rejects_bad_id(self, bad_id): + with pytest.raises(ValueError): + SharingDomainNamespace.instance_session_key(bad_id) + with pytest.raises(ValueError): + SharingDomainNamespace.instance_sd_nodes_key(bad_id) + + def test_parse_round_trip(self): + for pid in ["inst-001", "abc.v2", "x_y_z", "ABC123"]: + key = SharingDomainNamespace.instance_session_key(pid) + parsed = SharingDomainNamespace.parse_instance_session_key(key) + assert parsed == pid + + @pytest.mark.parametrize("bad", [ + "not-a-flexkv-key", + "flexkv:instance:foo:notsession", + "flexkv:instance::session", # empty instance_id + ]) + def test_parse_rejects_bad(self, bad): + with pytest.raises(ValueError): + SharingDomainNamespace.parse_instance_session_key(bad) + + +# --------------------------------------------------------------------------- +# Equality / hashability +# --------------------------------------------------------------------------- +class TestEquality: + def test_equal_namespaces_share_hash(self): + sd = SharingDomainKey.default() + a = SharingDomainNamespace(sd) + b = SharingDomainNamespace(sd) + assert a == b + assert hash(a) == hash(b) + + def test_different_sd_keys_unequal(self): + a = SharingDomainNamespace(SharingDomainKey.default()) + b_sd = SharingDomainKey( + model_id="m", pp_node_idx=0, pp_node_count=2, + tp_node_idx=0, tp_node_count=1, + is_nsa=False, + ) + b = SharingDomainNamespace(b_sd) + assert a != b diff --git a/tests/test_single_node_match.py b/tests/test_single_node_match.py new file mode 100644 index 0000000000..910f86c6fd --- /dev/null +++ b/tests/test_single_node_match.py @@ -0,0 +1,335 @@ +"""Python reference implementation of the single-node matching constraint +(design doc §4.7.1.4, implemented in C++ at +``csrc/dist/distributed_radix_tree.cpp::match_prefix`` lines 579-686). + +This test file does **not** exercise the C++ code (that needs a GPU/CUDA +build to link). Instead it re-implements the *core invariant* of the +constraint in pure Python and exhaustively checks the algorithm's +externally-visible behaviour on synthetic radix-node sequences: + + * locks onto the FIRST block's ``node_id`` as the canonical + ``matched_node_id`` for the whole match; + * as soon as a block with a different ``node_id`` is encountered, + the match truncates — that block and every subsequent block (in + the same or later nodes) is NOT included; + * the truncation happens at **block granularity**, not node + granularity — a node whose first 3 blocks belong to the locked + node_id but whose 4th block belongs to a different one contributes + only 3 blocks to the result; + * ``matched_node_id`` stays ``-1`` iff the query didn't match any + block at all (empty result). + +These are **exactly** the guarantees the downstream Master needs in +order to issue a coordinated GET to a single Remote peer — see design +doc §4.4.1 / §4.12.2 (1). Any future C++ refactor that changes the +algorithm must keep these guarantees; this reference implementation +serves as the executable spec. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import pytest + + +# --------------------------------------------------------------------------- +# Minimal radix-node shape (just enough to drive the algorithm). +# --------------------------------------------------------------------------- +@dataclass +class _Node: + """One radix node — a run of (hash, physical_block, block_node_id) triples.""" + hashes: List[int] + pbs: List[int] + bnis: List[int] # must be len == len(hashes) + + def __post_init__(self) -> None: + assert len(self.hashes) == len(self.pbs) == len(self.bnis), ( + f"inconsistent node: {len(self.hashes)=} {len(self.pbs)=} {len(self.bnis)=}" + ) + + def size(self) -> int: + return len(self.hashes) + + +@dataclass +class _MatchResult: + prefix_blocks_num: int + physical_blocks: List[int] + matched_node_id: int # -1 if no match + + +# --------------------------------------------------------------------------- +# Python reference implementation of match_prefix. +# Mirrors csrc/dist/distributed_radix_tree.cpp:579-686 line-for-line +# (minus the lease-time / renew-queue machinery which is irrelevant here). +# --------------------------------------------------------------------------- +def match_prefix_ref( + nodes: List[_Node], + query_hashes: List[int], +) -> _MatchResult: + """Reference Python implementation of the single-node matching rule. + + ``nodes`` is the flattened path from the root down to the deepest + reachable node for the query — i.e. at this point in the algorithm + we've already resolved which *sequence* of radix nodes the query + would descend. (In the real C++ code this happens dynamically via + ``lookup_child``; for testing we feed the resolved path directly so + we can focus on the single-node constraint itself.) + """ + prefix_blocks_num = 0 + pb_out: List[int] = [] + matched_node_id = -1 # lock-on-first-block sentinel + + num_query = len(query_hashes) + + for node in nodes: + if prefix_blocks_num >= num_query: + break + node_size = node.size() + remaining = num_query - prefix_blocks_num + blocks_to_check = min(node_size, remaining) + + # Step 1: how many blocks at the head of this node match the query? + matched = 0 + for i in range(blocks_to_check): + if node.hashes[i] == query_hashes[prefix_blocks_num + i]: + matched += 1 + else: + break + + # Step 2: among those ``matched`` blocks, how many can actually be + # taken before the single-node constraint trips? + if matched > 0: + actually_copied = 0 + for i in range(matched): + block_nid = node.bnis[i] + if matched_node_id == -1: + matched_node_id = block_nid # lock the first one + elif block_nid != matched_node_id: + # Constraint trip — truncate AT this block. + matched = actually_copied + break + pb_out.append(node.pbs[i]) + actually_copied += 1 + + prefix_blocks_num += matched + + # Step 3: if we couldn't consume the entire node (either hash + # mismatch or single-node trip), stop descending. + if matched < node_size: + break + + return _MatchResult( + prefix_blocks_num=prefix_blocks_num, + physical_blocks=pb_out, + matched_node_id=matched_node_id, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +class TestSingleNodeMatchingConstraint: + def test_empty_query_yields_empty_result(self): + nodes = [_Node([10, 11], [100, 101], [1, 1])] + r = match_prefix_ref(nodes, []) + assert r.prefix_blocks_num == 0 + assert r.physical_blocks == [] + assert r.matched_node_id == -1 + + def test_no_nodes_yields_empty_result(self): + r = match_prefix_ref([], [10, 11]) + assert r.prefix_blocks_num == 0 + assert r.physical_blocks == [] + assert r.matched_node_id == -1 + + def test_full_match_single_node_all_same_nid(self): + """Classic full-hit: 3 blocks all belong to node 7.""" + nodes = [_Node( + hashes=[10, 11, 12], + pbs=[100, 101, 102], + bnis=[7, 7, 7], + )] + r = match_prefix_ref(nodes, [10, 11, 12]) + assert r.prefix_blocks_num == 3 + assert r.physical_blocks == [100, 101, 102] + assert r.matched_node_id == 7 + + def test_partial_match_hash_mismatch_mid_node(self): + """Hash mismatch at index 2: stop at index 2 regardless of nid.""" + nodes = [_Node( + hashes=[10, 11, 99, 13], + pbs=[100, 101, 102, 103], + bnis=[5, 5, 5, 5], + )] + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + assert r.prefix_blocks_num == 2 + assert r.physical_blocks == [100, 101] + assert r.matched_node_id == 5 + + def test_single_node_trip_mid_node(self): + """Same node contains blocks from 2 different nids — truncate + at the first different nid.""" + nodes = [_Node( + hashes=[10, 11, 12, 13], + pbs=[100, 101, 102, 103], + bnis=[7, 7, 8, 8], # first 2 are nid=7, next 2 are nid=8 + )] + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + # Locked on nid=7 for the first block; at block index 2 (nid=8) + # the trip fires → only first 2 blocks taken. + assert r.prefix_blocks_num == 2 + assert r.physical_blocks == [100, 101] + assert r.matched_node_id == 7 + + def test_single_node_trip_at_first_block_of_second_node(self): + """Descend to child only if the full first node was consumed — + here node 1 is fully consumed (all nid=5), node 2 starts with + nid=9 which trips the constraint.""" + nodes = [ + _Node([10, 11], [100, 101], [5, 5]), + _Node([12, 13], [102, 103], [9, 9]), + ] + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + # First node consumed (locked nid=5 for both blocks). Second + # node's first block has nid=9 != 5 → trip at block 0 of node 2. + assert r.prefix_blocks_num == 2 + assert r.physical_blocks == [100, 101] + assert r.matched_node_id == 5 + + def test_single_node_trip_happens_before_hash_mismatch_check(self): + """When both constraints would fire, single-node wins first + (determines truncation point). Design doc requires that we + NEVER copy a block with a foreign nid, even if its hash is + otherwise valid.""" + nodes = [_Node( + hashes=[10, 11, 12, 13], + pbs=[100, 101, 102, 103], + bnis=[3, 3, 4, 99], # trip at index 2 due to nid + )] + # Query matches all 4 hashes — but nid constraint stops at index 2. + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + assert r.prefix_blocks_num == 2 + assert r.physical_blocks == [100, 101] + assert r.matched_node_id == 3 + + def test_trip_at_very_first_block_does_not_happen(self): + """The first block always succeeds (it defines matched_node_id).""" + nodes = [_Node( + hashes=[10, 11], + pbs=[100, 101], + bnis=[42, 42], + )] + r = match_prefix_ref(nodes, [10, 11]) + assert r.prefix_blocks_num == 2 + assert r.matched_node_id == 42 # locked on first block's nid + + def test_query_longer_than_available_nodes(self): + """If the query extends past the last node, match stops at end.""" + nodes = [_Node([10, 11], [100, 101], [1, 1])] + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + assert r.prefix_blocks_num == 2 + assert r.physical_blocks == [100, 101] + assert r.matched_node_id == 1 + + def test_query_shorter_than_first_node(self): + """Query consumes only part of the first node — matched_node_id + still reflects the locked nid even if the node has other nids + after the query end.""" + nodes = [_Node( + hashes=[10, 11, 12, 13], + pbs=[100, 101, 102, 103], + bnis=[5, 5, 9, 9], # would trip at index 2 if we got there + )] + r = match_prefix_ref(nodes, [10, 11]) # query of length 2 + assert r.prefix_blocks_num == 2 + assert r.physical_blocks == [100, 101] + assert r.matched_node_id == 5 # locked on first block's nid + + def test_matched_node_id_persists_across_nodes(self): + """Multi-node match: matched_node_id locked on node 1's first + block must still apply to node 2's blocks (they must all be the + same nid or trip).""" + nodes = [ + _Node([10, 11], [100, 101], [3, 3]), + _Node([12, 13], [102, 103], [3, 3]), + ] + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + assert r.prefix_blocks_num == 4 + assert r.physical_blocks == [100, 101, 102, 103] + assert r.matched_node_id == 3 + + def test_matched_node_id_minus_one_iff_no_match(self): + """The only way to get matched_node_id == -1 is zero matched blocks.""" + # Case A: first hash mismatch. + nodes_a = [_Node([999], [100], [1])] + r_a = match_prefix_ref(nodes_a, [10, 11]) + assert r_a.prefix_blocks_num == 0 + assert r_a.matched_node_id == -1 + + # Case B: empty nodes list. + r_b = match_prefix_ref([], [10, 11]) + assert r_b.prefix_blocks_num == 0 + assert r_b.matched_node_id == -1 + + def test_pb_write_index_matches_prefix_blocks_num(self): + """The C++ code writes ``pb_write`` physical blocks and returns + a narrow(0, pb_write) tensor. ``len(physical_blocks)`` must + equal ``prefix_blocks_num`` in every branch.""" + # Trip mid-match. + nodes = [_Node( + hashes=[10, 11, 12, 13], + pbs=[100, 101, 102, 103], + bnis=[7, 7, 8, 8], + )] + r = match_prefix_ref(nodes, [10, 11, 12, 13]) + assert len(r.physical_blocks) == r.prefix_blocks_num + + # Full match. + nodes2 = [_Node([20, 21], [200, 201], [9, 9])] + r2 = match_prefix_ref(nodes2, [20, 21]) + assert len(r2.physical_blocks) == r2.prefix_blocks_num + + # No match. + nodes3 = [_Node([999], [100], [1])] + r3 = match_prefix_ref(nodes3, [20]) + assert len(r3.physical_blocks) == r3.prefix_blocks_num == 0 + + +class TestRegressionScenarios: + """Specific scenarios from design doc §4.7.1.4 / §4.4.1.""" + + def test_design_doc_scenario_cross_remote_leaked_into_merge(self): + """Per design doc: even if a single radix node contains blocks + from multiple peer_instances (legal after merge_root_child), the + match must truncate at the boundary so the Master only issues a + coordinated GET to ONE peer.""" + # Node had 5 blocks, first 3 from peer_instance_A (nid 100), + # last 2 merged in from peer_instance_B (nid 200). + nodes = [_Node( + hashes=[1, 2, 3, 4, 5], + pbs=[10, 20, 30, 40, 50], + bnis=[100, 100, 100, 200, 200], + )] + r = match_prefix_ref(nodes, [1, 2, 3, 4, 5]) + # Must stop at 3 blocks, all from peer A. + assert r.prefix_blocks_num == 3 + assert r.physical_blocks == [10, 20, 30] + assert r.matched_node_id == 100 + + @pytest.mark.parametrize("trip_at", [1, 2, 3, 4, 5]) + def test_trip_at_arbitrary_position_is_exact(self, trip_at: int): + """Parametrize the trip position and confirm byte-exact cut.""" + # 6-block node, trip at index ``trip_at``. + n = 6 + bnis = [7] * trip_at + [8] * (n - trip_at) + nodes = [_Node( + hashes=list(range(10, 10 + n)), + pbs=list(range(100, 100 + n)), + bnis=bnis, + )] + r = match_prefix_ref(nodes, list(range(10, 10 + n))) + assert r.prefix_blocks_num == trip_at + assert r.physical_blocks == list(range(100, 100 + trip_at)) + assert r.matched_node_id == 7 From 909e6580b5187e0d83e57e6b857e20154dae4744 Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Wed, 20 May 2026 23:02:24 +0800 Subject: [PATCH 2/2] review-fixes(dist_reuse): consolidated reviewer feedback + dead-code cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashes four prior commits (bca6ccc, c8a5a2a, e8da1b8, 0974b1c) addressing reviewer comments on the dist_reuse feature commit (22bc183), plus the follow-up dead-code/docs/test cleanups discovered during review. F1. is_nsa source: read directly from model_config.is_nsa instead of reverse-deriving from enable_nsa_prefill_context_parallel (the latter is a CP toggle, orthogonal to NSA architecture). F2. Control-plane / rank-plane separation: SharingDomainKey.from_model_config now takes an optional rank_info=RankInfo argument and reads pp_rank / tp_node_idx from it; the control plane (KVManager) only constructs self-SD via default() and enumerates peers via enumerate_peers() — no fake rank fabrication. F3. RankTopology factory dropped; reuse the existing RankInfo end-to-end. Integration adapters (vLLM v1 / TRT-LLM / SGLang) plumb the real rank_info through. F4. Shell / TransferManagerOnRemote decoupling: revert start_dist_reuse_serving.sh changes; TransferManagerOnRemote stays per-node and is tagged via set_target_sd_key on each handle. F5. Delete unused flexkv.integration.multinode_policy module and its is_multinode_tp / is_multinode_cp / is_multinode_pp helpers; CP never participates in sd_key (attention all-gather makes per-cp_rank pools bit-wise identical), and TP-cross-node is encoded in SharingDomainKey.tp_node_count directly. Verified no external references in the SGLang FlexKVConnector codebase. Dead-code sweep across dist_reuse layer (formerly c8a5a2a): - drop unused helpers / accessors that no production path references - remove the matching dead unit tests (test_iter_dunder, ...) Docs (formerly e8da1b8): - drop stale phase tag in coordination_protocol module docstring. Tests: full 18-suite dist_reuse subset (332/332) passes on both GPU executors (gpu-146.56.224.46 and gpu-129.211.162.213). --- .../benchmark_dist_reuse_smoke.py | 4 +- docs/dist_reuse/redis_schema.md | 6 +- flexkv/common/config.py | 77 ++++++-- flexkv/common/dist_reuse/__init__.py | 4 - flexkv/common/dist_reuse/aggregate_radix.py | 76 +------ .../dist_reuse/coordination_protocol.py | 15 +- .../common/dist_reuse/master_coordinator.py | 94 --------- flexkv/common/dist_reuse/remote_init.py | 9 - flexkv/common/dist_reuse/sharing_domain.py | 37 +++- .../dist_reuse/sharing_domain_namespace.py | 20 +- flexkv/integration/config.py | 30 ++- flexkv/integration/multinode_policy.py | 185 ----------------- .../tensorrt_llm/trtllm_adapter.py | 3 +- flexkv/integration/vllm/vllm_v1_adapter.py | 3 +- flexkv/kvmanager.py | 17 +- flexkv/kvtask.py | 35 +++- flexkv/server/server.py | 43 +++- .../multi-nodes/start_dist_reuse_serving.sh | 55 +++--- tests/test_aggregate_radix.py | 82 +------- tests/test_coord_protocol.py | 10 - tests/test_dist_reuse_launcher.py | 28 ++- tests/test_master_coordinator.py | 38 +--- tests/test_multinode_role_policy.py | 186 ------------------ tests/test_sharing_domain_key.py | 4 - tests/test_sharing_domain_namespace.py | 7 - 25 files changed, 261 insertions(+), 807 deletions(-) delete mode 100644 flexkv/integration/multinode_policy.py delete mode 100644 tests/test_multinode_role_policy.py diff --git a/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py b/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py index 082ad7d4a0..ac51f1b98c 100644 --- a/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py +++ b/benchmarks/dist_benchmark/benchmark_dist_reuse_smoke.py @@ -259,11 +259,11 @@ def scenario_aggregate_radix_put_hook(args) -> Dict[str, Any]: if entry is not None: results["ok"] += 1 # Refcount protection: the aggregate pins blocks when acquired. - mc.pin_blocks_for_coord_get(block_ids) + mc.acquire_blocks(block_ids) for b in block_ids: assert not mc.is_evictable(b), \ f"block {b} evictable while pinned — refcount guard broken" - mc.unpin_blocks_for_coord_get(block_ids) + mc.release_blocks(block_ids) for b in block_ids: assert mc.is_evictable(b), \ f"block {b} NOT evictable after unpin — refcount stuck" diff --git a/docs/dist_reuse/redis_schema.md b/docs/dist_reuse/redis_schema.md index f2e28ce59a..86d40c8a7e 100644 --- a/docs/dist_reuse/redis_schema.md +++ b/docs/dist_reuse/redis_schema.md @@ -132,13 +132,13 @@ c3a2f91d0bcdef01:ppn0/1:tpn1/2:nsa0 — 跨机 TP=2 第 1 节点 --- -### 2.5 `sd::aggregate:` — 跨 SD 聚合标记(预留) +### 2.5 `sd::aggregate:` — 跨 SD 聚合标记(未实现) | 属性 | 值 | |---|---| -| 类型 | **未启用**(`SharingDomainNamespace.aggregate_key(...)` 已提供构造器) | +| 类型 | **未启用**,且当前没有构造器实现 | -预留供未来把 `MasterCoordinator` 的跨 SD 聚合状态持久化到 Redis(用于 Master 重启恢复)。现阶段 `AggregateRadixTree` 只在内存。 +预留 Redis key 命名空间,供未来把 `MasterCoordinator` 的跨 SD 聚合状态持久化到 Redis(用于 Master 重启恢复)。现阶段 `AggregateRadixTree` 只在内存。如需启用,应在 `SharingDomainNamespace` 上重新实现 key 构造器。 --- diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 658ef91894..5ea44f71f2 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -104,16 +104,16 @@ def freeze(self) -> None: f"[ModelConfig] cannot derive gpus_per_node: " f"total_gpus={self.total_gpus} not divisible by nnodes={self.nnodes}" ) - if self.nnodes_per_tp_group > 2: + if self.tp_node_count > 2: raise ValueError( f"[ModelConfig] only support 2-nodes TP for now, but got " - f"nnodes_per_tp_group={self.nnodes_per_tp_group} " + f"tp_node_count={self.tp_node_count} " f"(tp_size={self.tp_size}, gpus_per_node={self.gpus_per_node})" ) - if self.tp_size % self.nnodes_per_tp_group != 0: + if self.tp_size % self.tp_node_count != 0: raise ValueError( f"[ModelConfig] tp_size={self.tp_size} not divisible by " - f"nnodes_per_tp_group={self.nnodes_per_tp_group}" + f"tp_node_count={self.tp_node_count}" ) if self.instance_num < 1: raise ValueError( @@ -160,13 +160,23 @@ def nnodes_per_pp_rank(self) -> int: @property def nnodes_per_tp_group(self) -> int: - """Number of nodes spanned by one TP group.""" - return self.nnodes_per_pp_rank + """Number of nodes spanned by one TP group. + + .. deprecated:: + Kept as a stable alias of :pyattr:`tp_node_count` for + backwards compatibility with adapter code that pre-dates + the SD-key naming convention. New code should read + ``tp_node_count`` directly — that property carries the + authoritative semantic ("the TP-axis node-count entering + ``SharingDomainKey``") and is the value tracked in the + redis schema (``docs/dist_reuse/redis_schema.md``). + """ + return self.tp_node_count @property def tp_size_per_node(self) -> int: """Number of TP ranks on this node within one TP group.""" - return self.tp_size // self.nnodes_per_tp_group + return self.tp_size // self.tp_node_count @property def attn_dp_size(self) -> int: @@ -183,7 +193,7 @@ def attn_tp_size(self) -> int: @property def attn_tp_size_per_node(self) -> int: """Attention-level TP size per node.""" - return self.attn_tp_size // self.nnodes_per_tp_group + return self.attn_tp_size // self.tp_node_count @property def attn_cp_size_per_node(self) -> int: @@ -242,6 +252,43 @@ def is_multinode_tp(self) -> bool: """ return self.tp_node_count > 1 + @property + def is_multinode_pp(self) -> bool: + """PP is the dimension that makes *this instance* cross node boundaries. + + Returns True iff: + + * ``pp_size > 1`` — PP is actually deployed, + * ``nnodes > 1`` — the instance occupies more than one node, + * ``tp_node_count == 1`` — TP **does not** cross nodes + (otherwise TP-multinode is the dominant axis and would + already drive the SD-Remote decision; classifying the same + deployment as "multinode-PP" too would double-count). + + This is the missing third axis next to :pyattr:`is_multinode_tp` + and :pyattr:`is_multinode_cp`. It exists so the connector's + runtime launch logic can stop folding "PP-only crosses nodes" + into the off-master fall-through branch. + + Worked examples: + + * ``pp=4, nnodes=2, tp=8, gpus_per_node=8`` → True + (PP=4 stages × tp=8 = 32 GPUs across 2 nodes; each node + owns 2 PP stages; TP stays inside one node). + * ``pp=1, nnodes=2, tp=16`` → False + (PP single-stage; TP is the multinode axis). + * ``pp=2, nnodes=2, tp=16`` → False + (TP already crosses; PP is *not* the dominant axis here \u2014 + we leave the multinode-TP branch to handle this). + * ``pp=2, nnodes=1`` → False + (single node; PP fits in-host). + """ + return ( + self.pp_size > 1 + and self.nnodes > 1 + and self.tp_node_count == 1 + ) + @property def is_multinode_cp(self) -> bool: """CP > 1 *and* the CP group spans more than one physical node. @@ -300,9 +347,17 @@ def num_kv_heads_per_node(self) -> int: # ------------------------------------------------------------------ @property def tp_node_count(self) -> int: - """Number of physical nodes one TP group spans (= - ``nnodes_per_tp_group``). ``1`` when TP fits on a single node.""" - return self.nnodes_per_tp_group + """Number of physical nodes one TP group spans. + + Authoritative source for the TP-axis node-count used in + :class:`SharingDomainKey` and ``docs/dist_reuse/redis_schema.md``. + ``1`` when TP fits on a single node. Deprecated alias: + :pyattr:`nnodes_per_tp_group`. + """ + # PP and TP groups share the same per-rank node assignment in the + # current topology (one TP group sits on the same set of nodes as + # one PP stage), so ``nnodes_per_tp_group == nnodes_per_pp_rank``. + return self.nnodes_per_pp_rank # NOTE: ``tp_node_idx`` is a per-rank concept and was moved to # ``RankInfo`` in PR #165 (separate-per-rank-state-into-RankInfo). diff --git a/flexkv/common/dist_reuse/__init__.py b/flexkv/common/dist_reuse/__init__.py index ba9bb98aeb..125e1d177f 100644 --- a/flexkv/common/dist_reuse/__init__.py +++ b/flexkv/common/dist_reuse/__init__.py @@ -26,14 +26,12 @@ SharingDomainNamespace, ) from .aggregate_radix import ( - AggregateMatchResult, AggregateRadixTree, BlockNotTrackedError, ReadyEntry, ) from .coordination_protocol import ( CoordMsgType, - EpochVerifyError, FailureReportMsg, RemoteReadyMsg, decode_coord_message, @@ -69,13 +67,11 @@ "SD_KEY_PREFIX", "SharingDomainNamespace", # aggregate_radix - "AggregateMatchResult", "AggregateRadixTree", "BlockNotTrackedError", "ReadyEntry", # coordination_protocol (Phase D-4: trimmed to RemoteReady + FailureReport) "CoordMsgType", - "EpochVerifyError", "FailureReportMsg", "RemoteReadyMsg", "decode_coord_message", diff --git a/flexkv/common/dist_reuse/aggregate_radix.py b/flexkv/common/dist_reuse/aggregate_radix.py index 7c71f813ec..04f49d0180 100644 --- a/flexkv/common/dist_reuse/aggregate_radix.py +++ b/flexkv/common/dist_reuse/aggregate_radix.py @@ -26,12 +26,11 @@ import threading import time from dataclasses import dataclass, field -from typing import Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple __all__ = [ "ReadyEntry", - "AggregateMatchResult", "AggregateRadixTree", "BlockNotTrackedError", ] @@ -42,9 +41,6 @@ class BlockNotTrackedError(KeyError): that the aggregate radix has never seen.""" -# --------------------------------------------------------------------------- -# Data shapes -# --------------------------------------------------------------------------- @dataclass class ReadyEntry: """Per-prefix bookkeeping. @@ -93,19 +89,6 @@ def node_id_for_sd(self, sd_key: str) -> Optional[int]: if that SD has not acked yet.""" return self.ready_sds.get(sd_key) - -@dataclass -class AggregateMatchResult: - """Result of :meth:`AggregateRadixTree.match_fully_ready`.""" - - matched_block_ids: Tuple[int, ...] - contributing_peers: FrozenSet[str] - # Always a single value (matches §4.7.1.4 single-Node match constraint - # already enforced by the C++ ``RefRadixTree``). ``None`` when the - # match length is zero. - matched_node_id: Optional[int] = None - - # --------------------------------------------------------------------------- # Refcount entry # --------------------------------------------------------------------------- @@ -113,7 +96,7 @@ class AggregateMatchResult: class _RefCountEntry: count: int = 0 # Wall-clock time (seconds since epoch) of the most recent acquire. - # Used by ``scan_leaked_refcount`` to identify stuck refcounts. + # Reserved for future leak / staleness detectors. last_acquired_at: float = 0.0 @@ -125,10 +108,9 @@ class AggregateRadixTree: Public API surface mirrors design doc §4.3 / §4.3.1: - * :meth:`mark_sd_ready` / :meth:`mark_sd_evicted` — per-SD ack tracker + * :meth:`mark_sd_ready` — per-SD ack tracker * :meth:`match_fully_ready` — query the longest fully-ready prefix * :meth:`acquire` / :meth:`release` / :meth:`is_evictable` — refcount - * :meth:`scan_leaked_refcount` — leak detector * :meth:`invalidate_by_peer_instance` / :meth:`invalidate_prefix` — reactions to failure-detector events """ @@ -165,11 +147,6 @@ def __len__(self) -> int: with self._lock: return len(self._prefixes) - def known_prefixes(self) -> List[int]: - """Snapshot of prefix hashes currently tracked (not necessarily ready).""" - with self._lock: - return list(self._prefixes.keys()) - # ------------------------------------------------------------------ # Per-SD ack tracker # ------------------------------------------------------------------ @@ -241,21 +218,6 @@ def mark_sd_ready( entry.first_ready_at = self._time_fn() return became_ready - def mark_sd_evicted(self, prefix_hash: int, sd_key: str) -> None: - """Remove ``sd_key`` from a prefix's ready-set. - - If the set becomes empty the prefix is dropped entirely. Silently - no-op if the prefix is unknown (matches the "Master single-handedly - evicts" semantics in design doc §4.3.1 — Remotes do not need to - observe an eviction).""" - with self._lock: - entry = self._prefixes.get(prefix_hash) - if entry is None: - return - entry.ready_sds.pop(sd_key, None) - if not entry.ready_sds: - self._drop_prefix_locked(entry) - # ------------------------------------------------------------------ # Match # ------------------------------------------------------------------ @@ -328,38 +290,6 @@ def is_evictable(self, block_id: int) -> bool: ent = self._refcounts.get(int(block_id)) return ent is None or ent.count <= 0 - def get_refcount(self, block_id: int) -> int: - """Helper for tests / observability. Never raises.""" - with self._lock: - ent = self._refcounts.get(int(block_id)) - return ent.count if ent is not None else 0 - - def scan_leaked_refcount(self, timeout_seconds: float) -> List[int]: - """Return all block_ids whose refcount has been > 0 longer than - ``timeout_seconds``. - - The Master is expected to call this periodically (design doc §4.3.1 - prerequisite C "refcount timeout safety net") and then forcibly - zero each leaked refcount + invalidate the owning prefix(es). - """ - if timeout_seconds < 0: - raise ValueError(f"timeout_seconds must be >= 0, got {timeout_seconds!r}") - cutoff = self._time_fn() - timeout_seconds - with self._lock: - return [b for b, ent in self._refcounts.items() if ent.last_acquired_at <= cutoff] - - def force_release(self, block_id: int) -> int: - """Hard-reset a block's refcount to zero. - - Returns the previous refcount (0 if block was untracked). Designed - to be the second half of the leak-recovery sequence — call it for - every entry returned by :meth:`scan_leaked_refcount`, then call - :meth:`invalidate_prefix` to drop the prefix from the radix. - """ - with self._lock: - ent = self._refcounts.pop(int(block_id), None) - return ent.count if ent is not None else 0 - # ------------------------------------------------------------------ # Invalidation # ------------------------------------------------------------------ diff --git a/flexkv/common/dist_reuse/coordination_protocol.py b/flexkv/common/dist_reuse/coordination_protocol.py index 26aa495dd9..df7d44ff44 100644 --- a/flexkv/common/dist_reuse/coordination_protocol.py +++ b/flexkv/common/dist_reuse/coordination_protocol.py @@ -1,6 +1,6 @@ """Wire format for Master ↔ Remote coordination of distributed KV reuse. -Phase D-4 (proposal_unify_with_graph_dispatch_2026-05-15.md): the +the ``CoordQuery*`` / ``CoordGet*`` / ``CoordPut*`` message types from the early implementation are **deleted** here. They were the dataclasses behind the old per-SD ZMQ coord protocol; the unified graph-dispatch @@ -32,7 +32,6 @@ "CoordMsgType", "RemoteReadyMsg", "FailureReportMsg", - "EpochVerifyError", "encode_coord_message", "decode_coord_message", ] @@ -57,9 +56,9 @@ class CoordMsgType(str, Enum): class _BaseCoordMsg: """Common fields embedded in every wire message. - ``epoch`` carries the *sender's* expected ``session_epoch``; receivers - cross-check it against their own and raise :class:`EpochVerifyError` - when they disagree (design doc §4.3.2 anti-replay rule). + ``epoch`` carries the *sender's* expected ``session_epoch`` and is + propagated end-to-end so future receivers can do anti-replay checks + (design doc §4.3.2); the current decode path does not enforce it. """ # Class-level discriminator; every concrete subclass overrides this in @@ -129,12 +128,6 @@ class FailureReportMsg(_BaseCoordMsg): # --------------------------------------------------------------------------- # Encoding helpers — protocol-level, not transport-level # --------------------------------------------------------------------------- -class EpochVerifyError(RuntimeError): - """Raised when a receiver detects a stale ``sender_epoch``. The caller - is expected to translate this into a ``STALE_EPOCH`` response and let - the sender invalidate its view of the affected peer.""" - - _TYPE_TO_CLASS: Dict[CoordMsgType, Type[_BaseCoordMsg]] = { CoordMsgType.REMOTE_READY: RemoteReadyMsg, CoordMsgType.FAILURE_REPORT: FailureReportMsg, diff --git a/flexkv/common/dist_reuse/master_coordinator.py b/flexkv/common/dist_reuse/master_coordinator.py index 4df1c2264a..a468480e7a 100644 --- a/flexkv/common/dist_reuse/master_coordinator.py +++ b/flexkv/common/dist_reuse/master_coordinator.py @@ -231,22 +231,10 @@ def __init__( def self_sd(self) -> SharingDomainKey: return self._self_sd - @property - def namespace(self) -> SharingDomainNamespace: - return self._namespace - - @property - def instance_id(self) -> str: - return self._instance_id - @property def session_epoch(self) -> str: return self._session_epoch - @property - def aggregate_radix(self) -> AggregateRadixTree: - return self._aggregate - @property def self_node_id(self) -> int: """Distributed node_id of the Master itself. ``-1`` until @@ -300,10 +288,6 @@ def all_remotes_ready(self) -> bool: and len(self._ready_remotes) >= self._expected_remote_count ) - def ready_remote_infos(self) -> Dict[str, RemoteReadyMsg]: - with self._lock: - return dict(self._ready_remotes) - def build_sd_to_nid_map(self, self_node_id: int) -> Dict[str, int]: """Produce the ``sd_key → node_id`` mapping for ``RedisMeta.register_instance_sd_nodes``.""" @@ -336,59 +320,6 @@ def get_sd_to_nid_map(self) -> Dict[str, int]: out[sd_key_str] = int(msg.distributed_node_id) return out - # ------------------------------------------------------------ lookup - def lookup_peer_by_node_id(self, node_id: int) -> Optional[Dict[str, str]]: - """Reverse-lookup: given a (global) distributed_node_id, return - the peer SD it belongs to plus the peer instance_id that owns - that SD. - - Returns a dict with keys ``sd_key_str`` and ``instance_id``, - or ``None`` if the node_id isn't one of our ready peers (i.e. - it's either this instance's own node, unknown, or belongs to - an instance that hasn't finished its ready handshake). - - Used by the GET main-path glue - (``GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops``, - Phase D-3) and by ``handle_failure_report`` to map an - offending node_id back to the peer SD + instance it sits on. - """ - if node_id is None: - return None - try: - nid = int(node_id) - except Exception: - return None - if nid < 0: - return None - with self._lock: - for sd_key_str, msg in self._ready_remotes.items(): - if int(getattr(msg, "distributed_node_id", -1)) == nid: - return { - "sd_key_str": sd_key_str, - "instance_id": str(getattr(msg, "sender_instance_id", "") or ""), - } - return None - - # -------------------------------------------------------- pin helpers - def pin_blocks_for_coord_get(self, block_ids: Iterable[int]) -> None: - """Refcount-pin ``block_ids`` against Master-side eviction while - an in-flight coord GET is expected to land on them. - - Thin alias for :meth:`acquire_blocks` — kept so the GET-path - glue reads as intent rather than the lower-level primitive. - Used in conjunction with the multi-SD PEERH2H fan-out - (``GlobalCacheEngine._maybe_attach_multi_sd_peerh2h_ops``, - Phase D-3). - """ - self._aggregate.acquire(block_ids) - - def unpin_blocks_for_coord_get(self, block_ids: Iterable[int]) -> None: - """Release the refcount pin set by - :meth:`pin_blocks_for_coord_get`. Must be called on both the - success and failure paths of the coord GET. - """ - self._aggregate.release(block_ids) - # ------------------------------------------------------------- discovery def register_instance_discoverables( self, @@ -493,28 +424,9 @@ def mark_sd_ready( node_id=int(node_id), ) - def mark_sd_evicted(self, prefix_hash: int, sd_key_str: str) -> None: - self._aggregate.mark_sd_evicted(prefix_hash, sd_key_str) - def match_fully_ready(self, prefix_hash: int) -> Any: return self._aggregate.match_fully_ready(prefix_hash) - def invalidate_prefix(self, prefix_hash: int) -> bool: - return self._aggregate.invalidate_prefix(prefix_hash) - - # --------------------------------------------------- periodic scans - def scan_leaked_refcount(self) -> List[int]: - """Called periodically by the KVTaskManager's background thread. - - For every block that has been in-flight too long (design doc - §4.3.1 prerequisite C), force-release its refcount and - invalidate any prefix that owns it. - """ - leaked = self._aggregate.scan_leaked_refcount(self._refcount_leak_timeout) - for block_id in leaked: - self._aggregate.force_release(block_id) - return leaked - # --------------------------------------------------- failure callbacks def set_peer_lost_hook(self, cb: Optional[Any]) -> None: """Register an **additional** callback to fire when a peer @@ -575,12 +487,6 @@ def handle_failure_report(self, report) -> None: self._instance_id, peer, ) - def peer_failure_count(self, peer_instance_id: str) -> int: - """Number of unescalated failures from ``peer_instance_id`` — - used by tests and ops dashboards.""" - with self._lock: - return self._peer_failure_counts.get(peer_instance_id, 0) - def _on_peer_lost(self, peer_instance_id: str) -> None: """Invoked by the FailureDetector on peer disappearance / epoch bump. diff --git a/flexkv/common/dist_reuse/remote_init.py b/flexkv/common/dist_reuse/remote_init.py index 4069f0dde6..7e8290f6b8 100644 --- a/flexkv/common/dist_reuse/remote_init.py +++ b/flexkv/common/dist_reuse/remote_init.py @@ -31,7 +31,6 @@ RemoteReadyMsg, SharingDomainKey, SharingDomainNamespace, - encode_coord_message, ) @@ -217,14 +216,6 @@ def bootstrap(self) -> BootstrapResult: ready_msg=ready_msg, ) - # ---- encoding -------------------------------------------------- - @staticmethod - def encode_ready(msg: RemoteReadyMsg) -> dict: - """Convenience: turn a :class:`RemoteReadyMsg` into its wire - ``dict`` form (handy when the ZMQ transport prefers JSON).""" - return encode_coord_message(msg) - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/flexkv/common/dist_reuse/sharing_domain.py b/flexkv/common/dist_reuse/sharing_domain.py index 6cc900ded9..729df3a246 100644 --- a/flexkv/common/dist_reuse/sharing_domain.py +++ b/flexkv/common/dist_reuse/sharing_domain.py @@ -64,7 +64,7 @@ import hashlib import re from dataclasses import dataclass, replace -from typing import Any, Iterator, List, Optional +from typing import Any, List, Optional __all__ = [ @@ -360,8 +360,39 @@ def from_model_config( getattr(model_config, "tp_node_idx", 0)) ) else: + # Legacy path: PR #165 moved ``pp_rank`` / ``tp_node_idx`` + # off ``ModelConfig`` onto :class:`RankInfo`. These + # ``getattr(..., 0)`` reads therefore now return ``0`` + # for any post-#165 ``ModelConfig`` instance — i.e. the + # caller silently gets the master-position SD even if it + # is actually on (pp_rank>0, tp_node_idx>0). This is + # only safe for unit-test fakes that explicitly set + # ``pp_rank`` / ``tp_node_idx`` on the stub ModelConfig + # (see ``tests/test_sharing_domain_key.py``); production + # callers should pass ``rank_info=`` explicitly. We log + # a one-time warning when the heuristic is exercised on + # a multi-rank topology so the error surfaces during + # bring-up instead of silently corrupting Redis keys. _pp_rank = int(getattr(model_config, "pp_rank", 0)) _tp_node_idx = int(getattr(model_config, "tp_node_idx", 0)) + _has_pp_rank = hasattr(model_config, "pp_rank") + _has_tp_node_idx = hasattr(model_config, "tp_node_idx") + if ( + pp_size > 1 or int(getattr(model_config, "tp_node_count", 1)) > 1 + ) and not (_has_pp_rank and _has_tp_node_idx): + # Local import to avoid pulling logger into module + # import-time graph (sharing_domain.py is imported + # very early by config.py via ``derive_model_id``). + from flexkv.common.debug import flexkv_logger + flexkv_logger.warning( + "SharingDomainKey.from_model_config: called without " + "rank_info on a multi-rank topology (pp_size=%d, " + "tp_node_count=%s); per-rank fields default to 0, " + "which only matches the master SD. Pass rank_info=" + " so the per-rank position is honoured.", + pp_size, + getattr(model_config, "tp_node_count", 1), + ) _pp_node_idx = _pp_rank // pp_per_node @@ -418,10 +449,6 @@ def enumerate_peers(self) -> List["SharingDomainKey"]: out.append(replace(self, pp_node_idx=ppn, tp_node_idx=tpn)) return out - # Iteration sugar — mostly for tests. - def __iter__(self) -> Iterator["SharingDomainKey"]: - return iter(self.enumerate_peers()) - def __str__(self) -> str: # pragma: no cover — purely cosmetic return self.serialize() diff --git a/flexkv/common/dist_reuse/sharing_domain_namespace.py b/flexkv/common/dist_reuse/sharing_domain_namespace.py index 618a3ea664..916647f736 100644 --- a/flexkv/common/dist_reuse/sharing_domain_namespace.py +++ b/flexkv/common/dist_reuse/sharing_domain_namespace.py @@ -101,32 +101,14 @@ def block_key(self, node_id: int, block_hash: int) -> str: h = int(block_hash) & 0xFFFFFFFFFFFFFFFF return f"{self._prefix}:block:{int(node_id)}:{h:x}" - def aggregate_key(self, request_prefix_hash: int) -> str: - """Aggregate-radix marker (design doc §4.7) for tracking - fully-ready prefixes across SDs in this instance.""" - h = int(request_prefix_hash) & 0xFFFFFFFFFFFFFFFF - return f"{self._prefix}:aggregate:{h:x}" - # ------------------------------------------------------------------ # SCAN-friendly patterns # ------------------------------------------------------------------ def node_key_pattern(self) -> str: return f"{self._prefix}:node:*" - def meta_key_pattern(self) -> str: - return f"{self._prefix}:meta:*" - - def buffer_key_pattern(self) -> str: - return f"{self._prefix}:buffer:*" - - def block_key_pattern(self) -> str: - """Match every block in the SD regardless of node_id. Used by the - global-SCAN optimization in design doc §4.7.1.2.""" - return f"{self._prefix}:block:*" - def block_key_pattern_for_node(self, node_id: int) -> str: - """Per-node block SCAN pattern (legacy path; the global pattern - above is preferred).""" + """Per-node block SCAN pattern.""" return f"{self._prefix}:block:{int(node_id)}:*" # ------------------------------------------------------------------ diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index c8055ea021..5b5ed23a34 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -218,8 +218,11 @@ def post_init_from_sglang_config( sglang_config: sglang.srt.configs.model_config.ModelConfig-like object server_args: sglang ServerArgs — source of tp_size, dp_size, nnodes, node_rank, enable_dp_attention, attn_cp_size, - is_nsa (read from server_args.enable_nsa_prefill_context_parallel), - kv_cache_dtype, dist_init_addr + kv_cache_dtype, dist_init_addr. ``is_nsa`` is **not** + read from server_args: see the body below — it is + derived from ``sglang_config.index_head_dim`` instead, + because NSA is a model-layout property orthogonal to + whether CP-prefill is enabled. page_size: KV block size (tokens per block) used by sglang tp_rank: physical tensor parallel rank (runtime, from process group) pp_rank: pipeline parallel rank (runtime, from process group) @@ -234,11 +237,24 @@ def post_init_from_sglang_config( node_rank = server_args.node_rank enable_dp_attention = server_args.enable_dp_attention attn_cp_size = server_args.attn_cp_size - # ``is_nsa`` (NSA model layout flag): True when the model has an - # extra indexer K cache buffer. Sourced from sglang's - # ``enable_nsa_prefill_context_parallel`` server arg, but in dist_reuse - # context the flag represents the *layout*, not whether CP is on. - is_nsa = getattr(server_args, 'enable_nsa_prefill_context_parallel', False) + # ``is_nsa`` (NSA model layout flag): True when the model itself has + # an extra indexer K cache buffer. This is a *layout* property of + # the model architecture, **independent** of whether CP is enabled + # at runtime — an NSA model with cp_size=1 still has the indexer K + # cache and must therefore be isolated from non-NSA models in the + # cross-instance reuse namespace (it lives in + # ``SharingDomainKey.serialize`` as the ``nsa<0|1>`` segment). + # + # Detection rule: an NSA/DSA model exposes a positive + # ``index_head_dim`` attribute on its sglang ModelConfig (the same + # signal already consulted ~25 lines below to size the indexer + # head buffer). Falling back to + # ``server_args.enable_nsa_prefill_context_parallel`` was incorrect + # because it conflates the *runtime CP toggle* with the *static + # model layout* — a deployment can run an NSA model with CP=1 + # (no prefill-CP) and still need NSA-isolated namespaces. + index_head_dim = getattr(sglang_config, "index_head_dim", None) + is_nsa = bool(index_head_dim) and int(index_head_dim) > 0 kv_cache_dtype = getattr(server_args, 'kv_cache_dtype', None) dp_rank = 0 if dp_rank is None else int(dp_rank) diff --git a/flexkv/integration/multinode_policy.py b/flexkv/integration/multinode_policy.py deleted file mode 100644 index ebe5bdbc88..0000000000 --- a/flexkv/integration/multinode_policy.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Multi-node role decision helpers for the sglang ↔ FlexKV connector. - -Design doc §4.5.5 splits ``is_multinode`` into **two independent axes**: - -* ``is_multinode_tp`` — one TP group spans >1 physical node. - Each such node runs a full SD-Remote (``TransferManagerOnRemote``) - with its own RedisMeta + Mooncake registration. - -* ``is_multinode_cp`` — CP > 1 and the CP group spans >1 physical node. - CP all-gather makes every ``cp_rank``'s KV pool bit-wise identical, - so non-leader CP ranks **do not** run a full SD-Remote; they only - need a GPU-registration stub (``KVTPClient``) + receive coordinated - H2D commands routed by the sync-leader rank. - -The master connector (``flexkv_connector.py``) currently conflates the -two under a ``nnodes > 1 and node_rank > 0 and local_rank == 0`` rule -of thumb. This module provides the policy functions that the -connector **should** call once we can exercise cross-node boots on a -two-machine GPU setup (tracked as §2.4 in -``docs/dist_reuse/implementation_gap_2026-05-11.md``). - -Everything here is pure Python / pure logic so it is unit-testable -without torch, CUDA, or a running sglang process. See -``FlexKV/tests/test_multinode_role_policy.py``. -""" -from __future__ import annotations - -from dataclasses import dataclass -from enum import Enum -from typing import Optional - - -class RemoteProcessRole(str, Enum): - """What, if anything, should this rank's local ``FlexKVConnector`` - spawn as its ``TransferManagerOnRemote`` process? - - The three roles correspond to the three paths in the design doc: - - * ``MASTER``: this rank is the sync-leader of an instance. It runs - the full ``KVManager`` (owns the Master coordinator, writes to - Redis, owns Mooncake TransferEngine). No ``TransferManagerOnRemote`` - process is spawned — the master IS the transfer authority. - - * ``SD_REMOTE_FULL``: this rank sits on a non-master node of a - *cross-node TP/PP* group. It must spawn a full - ``TransferManagerOnRemote`` that: - - - registers its local CPU block pool with Mooncake, - - writes its ``sd::node/block`` entries to Redis, - - replies ``RemoteReadyMsg`` so the Master can discover it, - - serves coordinated GET/PUT commands from the Master. - - * ``CP_PEER_REGISTRATION_ONLY``: this rank is a non-leader CP rank. - The design doc §4.5.5 + §4.12.2 (5) says these ranks only need - a lightweight ``TransferManagerOnRemote`` *stub* that: - - - registers its GPU blocks with ``KVTPClient`` (so the sync - leader's H2D path can write to it), AND - - listens for coordinated H2D slot-mapping commands. - - It does **NOT** touch RedisMeta or Mooncake — CP all-gather - already makes this rank's content identical to the sync leader. - - * ``NO_REMOTE``: single-node instance — no ``TransferManagerOnRemote`` - spawn at all. Legacy single-box behaviour. - """ - MASTER = "master" - SD_REMOTE_FULL = "sd_remote_full" - CP_PEER_REGISTRATION_ONLY = "cp_peer_registration_only" - NO_REMOTE = "no_remote" - - -@dataclass(frozen=True) -class RankTopology: - """Topology facts about a single rank, as seen from the connector. - - All fields are scalars so this is trivially serialisable / hashable - for tests. The connector extracts these from ``ModelConfig`` + - ``server_args``; tests construct them directly. - """ - - # Core dimensions - nnodes: int - node_rank: int - local_rank: int # 0..gpus_per_node-1 - - # Rolled-up FlexKV topology (see ModelConfig docstring) - is_multinode_tp: bool # ``tp_node_count > 1`` - is_multinode_cp: bool # CP > 1 and CP crosses node boundary - - # Optional sync-leader hint. If the caller already knows whether - # this rank is the sync leader (from ``sglang`` group metadata), - # pass ``is_sync_leader``. Otherwise the default heuristic kicks - # in: ``(local_rank == 0 and node_rank == 0)``. - is_sync_leader: Optional[bool] = None - - -def decide_remote_role(topo: RankTopology) -> RemoteProcessRole: - """Compute the role of a rank. - - Decision table (see design doc §4.5.5, simplified): - - =================== ================ ================ ================= - Single-node? is_multinode_tp is_multinode_cp Role - =================== ================ ================ ================= - yes (nnodes == 1) (ignored) (ignored) NO_REMOTE - no, rank 0 box False False NO_REMOTE - no, rank 0 box any any MASTER - no, off-master box is_multinode_tp any SD_REMOTE_FULL - no, off-master box False is_multinode_cp CP_PEER_REGISTRATION_ONLY - =================== ================ ================ ================= - - Where "rank 0 box" means ``node_rank == 0``. Across both - ``is_multinode_tp`` and ``is_multinode_cp`` axes we place the - Master on ``node_rank==0`` by convention (this matches the current - ``flexkv_connector.py`` assumption that sync leader is - ``node_rank==0``). - - Note on CP + TP combined: when BOTH ``is_multinode_tp`` and - ``is_multinode_cp`` are True for an off-master rank, it runs - ``SD_REMOTE_FULL`` — the TP-side state requires a full SD-Remote; - CP-side reduction is handled *inside* that remote's sync leader - (same as ``is_multinode_tp=True, is_multinode_cp=False``). We - never downgrade a TP-remote to a CP-peer-only stub. - """ - _validate(topo) - - # Single-node instance: nothing to spawn. - if topo.nnodes <= 1: - return RemoteProcessRole.NO_REMOTE - - # Master node — spawn nothing, the in-process KVManager IS the - # transfer authority. - if topo.node_rank == 0: - if not topo.is_multinode_tp and not topo.is_multinode_cp: - # Multi-node deployment but THIS instance spans only one - # node — e.g. DP > 1 across nodes but each DP instance is - # single-node. No remote peer exists in this instance. - return RemoteProcessRole.NO_REMOTE - return RemoteProcessRole.MASTER - - # Off-master nodes — - # TP takes priority: a TP-split SD cannot be served by a CP-only stub. - if topo.is_multinode_tp: - return RemoteProcessRole.SD_REMOTE_FULL - - if topo.is_multinode_cp: - return RemoteProcessRole.CP_PEER_REGISTRATION_ONLY - - # Off-master but neither TP nor CP is multi-node. Today's legacy - # code treats this the same as CP-peer (it spawns a - # ``TransferManagerOnRemote`` on every non-master node when - # ``nnodes > 1``). We preserve that behaviour for bug-compat - # during the migration; the ideal long-term answer is NO_REMOTE, - # but flipping it here would break the existing code path that - # the multi-node PP (``is_multinode_tp=False`` but - # ``pp_size>1`` crossing nodes) relies on. - # - # TODO(dist_reuse-§2.4): revisit once ``is_multinode_pp`` has its - # own property on ModelConfig. - return RemoteProcessRole.SD_REMOTE_FULL - - -def is_sync_leader(topo: RankTopology) -> bool: - """Heuristic used when the caller hasn't provided ``is_sync_leader``. - - Today's ``flexkv_connector.py`` infers it as ``local_rank == 0 and - node_rank == 0``. We keep that rule so drop-in replacement is - byte-for-byte equivalent; callers that know better pass - ``is_sync_leader`` explicitly on the ``RankTopology``. - """ - if topo.is_sync_leader is not None: - return bool(topo.is_sync_leader) - return topo.node_rank == 0 and topo.local_rank == 0 - - -def _validate(topo: RankTopology) -> None: - if topo.nnodes <= 0: - raise ValueError(f"nnodes must be > 0, got {topo.nnodes}") - if not (0 <= topo.node_rank < topo.nnodes): - raise ValueError( - f"node_rank out of range: {topo.node_rank} / {topo.nnodes}" - ) - if topo.local_rank < 0: - raise ValueError(f"local_rank must be >= 0, got {topo.local_rank}") diff --git a/flexkv/integration/tensorrt_llm/trtllm_adapter.py b/flexkv/integration/tensorrt_llm/trtllm_adapter.py index c530a62435..65c7e6d7b2 100644 --- a/flexkv/integration/tensorrt_llm/trtllm_adapter.py +++ b/flexkv/integration/tensorrt_llm/trtllm_adapter.py @@ -38,7 +38,8 @@ def __init__(self, config: ExecutorConfig): self.flexkv_manager = KVManager(model_config=rank_info.model_config, cache_config=flexkv_config.cache_config, dp_client_id=rank_info.dp_client_id, - server_recv_port=flexkv_config.server_recv_port) + server_recv_port=flexkv_config.server_recv_port, + rank_info=rank_info) self.flexkv_manager.start() # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 5649dbfe2c..af1733ebf5 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -175,7 +175,8 @@ def __init__( cache_config=flexkv_config.cache_config, dp_client_id=self.rank_info.dp_client_id, server_recv_port=flexkv_config.server_recv_port, - event_collector=self.collector) + event_collector=self.collector, + rank_info=self.rank_info) self.flexkv_manager.start() # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index f84ca86d18..b092f27b83 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -24,7 +24,7 @@ from flexkv.server.client import KVDPClient from flexkv.server.server import KVServer, DPClient from flexkv.kvtask import KVTaskEngine, KVResponse -from flexkv.common.config import ModelConfig, CacheConfig, GLOBAL_CONFIG_FROM_ENV, MooncakeTransferEngineConfig +from flexkv.common.config import ModelConfig, CacheConfig, RankInfo, GLOBAL_CONFIG_FROM_ENV, MooncakeTransferEngineConfig from flexkv.integration.dynamo.collector import KVEventCollector from flexkv.common.debug import flexkv_logger from flexkv.cache.redis_meta import RedisMeta @@ -37,11 +37,20 @@ def __init__(self, dp_client_id: int = 0, server_recv_port: str = "", gpu_register_port: str = "", - event_collector: Optional[KVEventCollector] = None): + event_collector: Optional[KVEventCollector] = None, + rank_info: Optional[RankInfo] = None): flexkv_logger.info(f"{model_config = }") flexkv_logger.info(f"{cache_config = }") self.model_config = model_config self.cache_config = cache_config + # Per-rank metadata: forwarded into ``KVTaskEngine`` so the + # control-plane can construct the **correct** per-rank + # ``SharingDomainKey`` (pp_node_idx / tp_node_idx). Without + # this, a non-master rank would silently use the (0, 0) + # master-position SD and write to the wrong Redis namespace. + # Kept optional for legacy single-SD / single-instance callers + # that don't construct a ``RankInfo`` (e.g. unit-test stubs). + self.rank_info = rank_info if server_recv_port != "": self.server_recv_port = server_recv_port @@ -76,7 +85,8 @@ def __init__(self, cache_config=cache_config, gpu_register_port=self.gpu_register_port, server_recv_port=self.server_recv_port, - inherit_env=False) + inherit_env=False, + rank_info=rank_info) else: self.server_handle = None @@ -109,6 +119,7 @@ def __init__(self, self.gpu_register_port, redis_meta=self.redis_meta_client, event_collector=event_collector, + rank_info=rank_info, ) def start(self) -> None: diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index cd4e16dae8..707b999a84 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -10,7 +10,7 @@ import nvtx import numpy as np -from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.common.config import CacheConfig, ModelConfig, RankInfo, GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger from flexkv.common.block import hash_token from flexkv.common.transfer import TransferOpGraph, merge_to_batch_graph, get_nvtx_default_color, CompletedOp @@ -85,10 +85,21 @@ def __init__(self, cache_config: CacheConfig, gpu_register_port: Optional[str] = None, redis_meta: RedisMeta = None, - event_collector: Optional[KVEventCollector] = None + event_collector: Optional[KVEventCollector] = None, + rank_info: Optional[RankInfo] = None, ): self.model_config = model_config self.cache_config = cache_config + # Per-rank info is required for the SD-key constructor in + # ``_setup_sharing_domain_handles`` and the GPU-clear decision + # in ``_per_handle_gpu_clear_flags`` so the master/remote SD + # reflects this rank's actual (pp_rank, tp_node_idx) instead + # of silently defaulting to the (0, 0) master position. Kept + # ``Optional`` so that legacy single-SD callers (sharing-domain + # disabled) and unit-test fakes that bypass the integration + # adapter remain valid — they just won't exercise multi-SD + # routing. + self.rank_info = rank_info flexkv_logger.info( f"[KVTaskEngine] topology: {self.model_config}" @@ -218,7 +229,9 @@ def _setup_sharing_domain_handles(self, *, gpu_register_port: Optional[str]) -> make_session_epoch, ) - self_sd = SharingDomainKey.from_model_config(self.model_config) + self_sd = SharingDomainKey.from_model_config( + self.model_config, rank_info=self.rank_info, + ) # Single-SD degenerate case (no sharing) — no extra handles needed. if self_sd.total_sd_count() <= 1: @@ -541,7 +554,9 @@ def _compute_gpu_clear_flags(self) -> List[bool]: SharingDomainKey, graph_needs_gpu_clear, ) - self_sd = SharingDomainKey.from_model_config(self.model_config) + self_sd = SharingDomainKey.from_model_config( + self.model_config, rank_info=self.rank_info, + ) flags: List[bool] = [] for h in self.transfer_handles: if not isinstance(h._handle, TransferManagerMultiNodeHandle): @@ -680,9 +695,17 @@ def __init__(self, cache_config: CacheConfig, gpu_register_port: Optional[str] = None, redis_meta: Optional[RedisMeta] = None, - event_collector: Optional[KVEventCollector] = None + event_collector: Optional[KVEventCollector] = None, + rank_info: Optional[RankInfo] = None, ): - super().__init__(model_config, cache_config, gpu_register_port, redis_meta, event_collector) + super().__init__( + model_config, + cache_config, + gpu_register_port, + redis_meta, + event_collector, + rank_info=rank_info, + ) self.tracer = FlexKVTracer() self.tracer.trace_config(model_config, cache_config, gpu_layout=None) diff --git a/flexkv/server/server.py b/flexkv/server/server.py index a4ee08821f..a5fb902707 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -14,7 +14,7 @@ import subprocess import textwrap -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.config import CacheConfig, ModelConfig, RankInfo from flexkv.common.debug import flexkv_logger from flexkv.cache.redis_meta import RedisMeta from flexkv.common.memory_handle import TensorSharedHandle @@ -149,9 +149,18 @@ def __init__( cache_config: CacheConfig, gpu_register_port: str, server_recv_port: str, + rank_info: Optional[RankInfo] = None, ): self.model_config = model_config + # Per-rank info — pickled across the spawn boundary so the in- + # subprocess ``KVTaskEngine`` can construct the correct + # ``SharingDomainKey`` (pp_node_idx / tp_node_idx). Optional + # for legacy single-instance / single-SD callers; sharing- + # domain mode requires a non-None rank_info to avoid the + # ``from_model_config`` deprecation path that defaults + # per-rank fields to 0. + self.rank_info = rank_info # Init inter-process communication self.context = zmq.Context(2) self.recv_from_client = get_zmq_socket( @@ -181,7 +190,13 @@ def __init__( # update distributed_node_id cache_config.distributed_node_id = self.redis_meta_client.get_node_id() - self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, redis_meta=self.redis_meta_client) + self.kv_task_engine = KVTaskEngine( + model_config, + cache_config, + gpu_register_port, + redis_meta=self.redis_meta_client, + rank_info=rank_info, + ) self.req_counter = 0 self._is_ready = False @@ -215,9 +230,13 @@ def start_server(self) -> None: def _server_process(model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: str, - server_recv_port: str) -> None: + server_recv_port: str, + rank_info: Optional[RankInfo] = None) -> None: - server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) + server = KVServer( + model_config, cache_config, gpu_register_port, server_recv_port, + rank_info=rank_info, + ) server.run() @classmethod @@ -227,7 +246,8 @@ def create_server(cls, gpu_register_port: str, server_recv_port: Optional[str] = None, child_env: Optional[dict] = None, - inherit_env: bool = True) -> 'KVServerHandle': + inherit_env: bool = True, + rank_info: Optional[RankInfo] = None) -> 'KVServerHandle': # Set spawn method for CUDA compatibility with contextlib.suppress(RuntimeError): @@ -258,7 +278,9 @@ def create_server(cls, env.pop('CUDA_VISIBLE_DEVICES', None) # Serialize arguments - args_data = pickle.dumps((model_config, cache_config, gpu_register_port, server_recv_port)) + args_data = pickle.dumps( + (model_config, cache_config, gpu_register_port, server_recv_port, rank_info) + ) # Start subprocess flexkv_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) @@ -269,8 +291,11 @@ def create_server(cls, from flexkv.server.server import KVServer args_data = {args_data!r} - model_config, cache_config, gpu_register_port, server_recv_port = pickle.loads(args_data) - server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) + model_config, cache_config, gpu_register_port, server_recv_port, rank_info = pickle.loads(args_data) + server = KVServer( + model_config, cache_config, gpu_register_port, server_recv_port, + rank_info=rank_info, + ) server.run() ''').strip() process = subprocess.Popen([ @@ -285,7 +310,7 @@ def create_server(cls, else: # Use multiprocessing as before process = mp.Process(target=cls._server_process, - args=(model_config, cache_config, gpu_register_port, server_recv_port)) + args=(model_config, cache_config, gpu_register_port, server_recv_port, rank_info)) process.start() flexkv_logger.info( f"KVServer process started, PID: {process.pid}, " diff --git a/scripts/multi-nodes/start_dist_reuse_serving.sh b/scripts/multi-nodes/start_dist_reuse_serving.sh index a0444c8c98..beb4e73f80 100644 --- a/scripts/multi-nodes/start_dist_reuse_serving.sh +++ b/scripts/multi-nodes/start_dist_reuse_serving.sh @@ -30,10 +30,12 @@ # --model / --redis-host # # 备注: -# * 如果该 instance 只用 CP 跨节点(CP 跨机、TP/PP 不跨机),脚本会自动 -# 把 node_rank>0 的机器放到 CP_PEER_REGISTRATION_ONLY 路径上(不启 -# TransferManagerOnRemote,sglang connector 侧按 multinode_policy -# 策略自己决定)。 +# * 是否在非主节点启动 TransferManagerOnRemote 由 sglang/vllm connector +# 在 Python 侧根据 ``rank_info`` 自行判断 —— 本脚本只负责导出环境变 +# 量并 *始终* 写入一份合法的 Mooncake 配置。早期版本里这里有一段 +# ``NEED_FULL_SD_REMOTE`` 分支 + 空 sentinel 文件的逻辑,等于把 +# connector 的启动判定复制了一份到 bash 里,违反了 +# “TransferManagerOnRemote 应该 per-node 自动兼容” 的设计意图,已删除。 # * 脚本**不直接启动** sglang/vLLM 进程;它只做环境变量和配置文件生 # 成,然后 exec 用户指定的启动命令(--launcher-cmd / 默认是 # sglang 的 router 入口)。这样脚本本身单测友好。 @@ -113,11 +115,10 @@ fi # -------------------------------------------------------- derived topology # Basic rule: MASTER = node_rank 0. Everything else is an off-master role -# whose concrete type is picked by the connector's multinode_policy -# (FlexKV/flexkv/integration/multinode_policy.py). The script only needs -# to decide whether to emit a Mooncake config (full SD-Remote) or just an -# empty config (CP-peer stub); it does that by asking whether the -# instance has any TP-level cross-node spread. +# whose concrete behaviour (full SD-Remote vs. CP-peer stub) is picked by +# the connector at runtime from ``rank_info``. The script always emits a +# real Mooncake config and lets the Python layer decide whether to +# instantiate a TransferEngine on this rank. # tp_node_count = TP_SIZE / gpus_per_node. We auto-detect gpus_per_node # via nvidia-smi (default 8 if unavailable). @@ -142,8 +143,8 @@ IS_MULTINODE_TP="false" if (( TP_NODE_COUNT > 1 )) || (( PP_SIZE > 1 )); then # PP > 1 crossing nodes always needs a full SD-Remote on each PP-peer # node too. We conservatively flag it as multinode_tp for the - # mooncake-config emission branch. The connector's policy module - # still does the fine-grained role decision at runtime. + # mooncake-config emission branch. The connector still does the + # fine-grained role decision at runtime from ``rank_info``. IS_MULTINODE_TP="true" fi @@ -164,24 +165,22 @@ CFG_DIR="${SCRIPT_DIR}/gen/dist_reuse_node${NODE_RANK}" mkdir -p "$LOG_DIR" "$CFG_DIR" # ---------------------------------------------------------- mooncake config -# Only emit a real mooncake config when we need a full SD-Remote on this -# node. CP-peer-only nodes don't touch mooncake. -NEED_FULL_SD_REMOTE="false" -if [[ "$NODE_RANK" == "0" ]]; then - # Master always needs mooncake — it owns the TransferEngine in-process. - NEED_FULL_SD_REMOTE="true" -else - if [[ "$IS_MULTINODE_TP" == "true" ]]; then - NEED_FULL_SD_REMOTE="true" - fi - # CP-only multi-node: off-master is peer-stub only, no mooncake. -fi +# Always emit a real mooncake config — the integration connector +# decides at runtime (from ``rank_info`` inside Python) whether to +# actually instantiate a Mooncake TransferEngine on this rank. +# Earlier versions of this script tried to mirror that policy in shell +# (writing an empty sentinel file when the off-master node was a +# CP-peer-only stub) but that duplicated the Python launch logic in +# bash and forced the connector to read a side-channel file existence +# check — exactly the "tight coupling" pointed out in code review. +# The new contract is simple: the shell always provides +# MOONCAKE_CONFIG_PATH, and the Python layer decides whether to read it. +NEED_FULL_SD_REMOTE="true" MOONCAKE_ENGINE_PORT=$((MOONCAKE_ENGINE_PORT_BASE + NODE_RANK)) MOONCAKE_CONFIG_FILE="${CFG_DIR}/mooncake_config.json" -if [[ "$NEED_FULL_SD_REMOTE" == "true" ]]; then - cat > "$MOONCAKE_CONFIG_FILE" < "$MOONCAKE_CONFIG_FILE" < "${MOONCAKE_CONFIG_FILE}" -fi # --------------------------------------------------------- flexkv env vars # The sglang connector reads these at import time; we export them so @@ -222,7 +216,6 @@ echo " tp / pp / cp : ${TP_SIZE} / ${PP_SIZE} / ${CP_SIZE}" echo " tp_node_count : ${TP_NODE_COUNT} (gpus_per_node=${GPUS_PER_NODE})" echo " is_multinode_tp : ${IS_MULTINODE_TP}" echo " is_multinode_cp : ${IS_MULTINODE_CP}" -echo " need_full_sd_remote: ${NEED_FULL_SD_REMOTE}" echo " instance_id : ${INSTANCE_ID}" echo " master_ip : ${MASTER_IP}" echo " dist_init : ${MASTER_IP}:${DIST_INIT_PORT}" diff --git a/tests/test_aggregate_radix.py b/tests/test_aggregate_radix.py index d9d04b1913..41db535618 100644 --- a/tests/test_aggregate_radix.py +++ b/tests/test_aggregate_radix.py @@ -11,20 +11,6 @@ ) -# --------------------------------------------------------------------------- -# Manual clock for deterministic time-based assertions -# --------------------------------------------------------------------------- -class _ManualClock: - def __init__(self, start: float = 0.0) -> None: - self.now = start - - def __call__(self) -> float: - return self.now - - def advance(self, dt: float) -> None: - self.now += dt - - # --------------------------------------------------------------------------- # mark_sd_ready / fully-ready transition # --------------------------------------------------------------------------- @@ -131,30 +117,6 @@ def test_node_id_real_value_not_overwritten_by_sentinel(self): assert result.node_id_for_sd("sd0") == 42 -# --------------------------------------------------------------------------- -# mark_sd_evicted -# --------------------------------------------------------------------------- -class TestMarkSdEvicted: - def test_evict_single_sd_drops_to_partial(self): - agg = AggregateRadixTree(total_sd_count=2) - agg.mark_sd_ready(1, "sd0", [10]) - agg.mark_sd_ready(1, "sd1", [10]) - assert agg.match_fully_ready(1) is not None - agg.mark_sd_evicted(1, "sd0") - assert agg.match_fully_ready(1) is None # no longer fully ready - - def test_evict_last_sd_drops_entry(self): - agg = AggregateRadixTree(total_sd_count=2) - agg.mark_sd_ready(1, "sd0", [10]) - agg.mark_sd_evicted(1, "sd0") - # No SDs left → entry is gone, not just "partial" - assert 1 not in agg.known_prefixes() - - def test_evict_unknown_prefix_is_noop(self): - agg = AggregateRadixTree(total_sd_count=1) - agg.mark_sd_evicted(99, "sd0") # must not raise - - # --------------------------------------------------------------------------- # Refcount lifecycle # --------------------------------------------------------------------------- @@ -164,18 +126,17 @@ def test_acquire_release_cycle(self): assert agg.is_evictable(7) is True agg.acquire([7]) assert agg.is_evictable(7) is False - assert agg.get_refcount(7) == 1 agg.release([7]) assert agg.is_evictable(7) is True - assert agg.get_refcount(7) == 0 def test_acquire_increments(self): agg = AggregateRadixTree(total_sd_count=1) agg.acquire([1, 1, 1]) - assert agg.get_refcount(1) == 3 + # Three acquires — still pinned after two releases. agg.release([1, 1]) - assert agg.get_refcount(1) == 1 assert agg.is_evictable(1) is False + agg.release([1]) + assert agg.is_evictable(1) is True def test_double_release_raises(self): agg = AggregateRadixTree(total_sd_count=1) @@ -190,43 +151,6 @@ def test_release_untracked_raises(self): agg.release([42]) -# --------------------------------------------------------------------------- -# Leak scanner -# --------------------------------------------------------------------------- -class TestLeakScanner: - def test_finds_leaked_blocks(self): - clock = _ManualClock() - agg = AggregateRadixTree(total_sd_count=1, time_fn=clock) - agg.acquire([1]) - clock.advance(40.0) - leaked = agg.scan_leaked_refcount(timeout_seconds=30.0) - assert leaked == [1] - - def test_fresh_acquires_not_leaked(self): - clock = _ManualClock() - agg = AggregateRadixTree(total_sd_count=1, time_fn=clock) - agg.acquire([1]) - leaked = agg.scan_leaked_refcount(timeout_seconds=30.0) - assert leaked == [] - - def test_force_release_clears_refcount(self): - clock = _ManualClock() - agg = AggregateRadixTree(total_sd_count=1, time_fn=clock) - agg.acquire([1, 1, 1]) - prev = agg.force_release(1) - assert prev == 3 - assert agg.is_evictable(1) is True - - def test_force_release_unknown_returns_zero(self): - agg = AggregateRadixTree(total_sd_count=1) - assert agg.force_release(999) == 0 - - def test_negative_timeout_raises(self): - agg = AggregateRadixTree(total_sd_count=1) - with pytest.raises(ValueError): - agg.scan_leaked_refcount(-1.0) - - # --------------------------------------------------------------------------- # Invalidation # --------------------------------------------------------------------------- diff --git a/tests/test_coord_protocol.py b/tests/test_coord_protocol.py index 1172fb6bd9..6cff75744a 100644 --- a/tests/test_coord_protocol.py +++ b/tests/test_coord_protocol.py @@ -16,7 +16,6 @@ from flexkv.common.dist_reuse.coordination_protocol import ( CoordMsgType, - EpochVerifyError, FailureReportMsg, RemoteReadyMsg, decode_coord_message, @@ -107,15 +106,6 @@ def test_encode_rejects_non_message(self): encode_coord_message({"type": "fake"}) # type: ignore[arg-type] -# --------------------------------------------------------------------------- -# EpochVerifyError sanity -# --------------------------------------------------------------------------- -def test_epoch_verify_error_is_runtime_error(): - assert issubclass(EpochVerifyError, RuntimeError) - with pytest.raises(EpochVerifyError): - raise EpochVerifyError("stale") - - # --------------------------------------------------------------------------- # Default values # --------------------------------------------------------------------------- diff --git a/tests/test_dist_reuse_launcher.py b/tests/test_dist_reuse_launcher.py index 043a917fe8..4affaa3ac4 100644 --- a/tests/test_dist_reuse_launcher.py +++ b/tests/test_dist_reuse_launcher.py @@ -66,7 +66,6 @@ def test_master_rank_0_emits_mooncake_config(self, tmp_path): "--dry-run", ]) assert "node_rank=0/2" in proc.stdout - assert "need_full_sd_remote: true" in proc.stdout assert "instance_id : unit-test-master" in proc.stdout # Check the generated mooncake config looks like valid JSON. @@ -78,10 +77,19 @@ def test_master_rank_0_emits_mooncake_config(self, tmp_path): assert data["device_name"] == "mlx5_0" assert data["engine_port"] == 12345 # base (+ node_rank 0) - def test_cp_only_offmaster_emits_empty_config(self): - """CP-only cross-node instance: node_rank=1 is CP-peer stub, - mooncake config must be empty (sentinel for connector's - ``CP_PEER_REGISTRATION_ONLY`` path).""" + def test_cp_only_offmaster_still_emits_full_config(self): + """CP-only cross-node instance, node_rank=1. + + Earlier revisions of this script tried to mirror the + connector's role-decision in bash by emitting an *empty* + sentinel mooncake_config.json for the CP-peer-only case so the + Python connector could side-channel on file existence. That + coupling was reverted (see commit ``review-fixes(F1+F2+F3)``) + because it duplicated launch-time logic that belongs in the + connector. The new contract is: shell *always* writes a valid + mooncake config, the Python layer decides whether to + instantiate a TransferEngine. This test pins that contract. + """ proc = _run([ "--nnodes", "2", "--node-rank", "1", @@ -96,14 +104,15 @@ def test_cp_only_offmaster_emits_empty_config(self): ]) assert "node_rank=1/2" in proc.stdout assert "is_multinode_cp : true" in proc.stdout - assert "need_full_sd_remote: false" in proc.stdout cfg = (REPO_ROOT / "scripts" / "multi-nodes" / "gen" / "dist_reuse_node1" / "mooncake_config.json") assert cfg.exists() - # Empty sentinel — bytes length must be 0 to match the - # connector's "no mooncake here" contract. - assert cfg.read_text() == "" + # No more empty-sentinel hack — every node gets a valid config + # and the connector decides what to do with it at runtime. + assert cfg.stat().st_size > 0 + data = json.loads(cfg.read_text()) + assert data["metadata_backend"] == "redis" def test_tp_cross_node_offmaster_emits_full_config(self): proc = _run([ @@ -119,7 +128,6 @@ def test_tp_cross_node_offmaster_emits_full_config(self): "--dry-run", ]) assert "is_multinode_tp : true" in proc.stdout - assert "need_full_sd_remote: true" in proc.stdout cfg = (REPO_ROOT / "scripts" / "multi-nodes" / "gen" / "dist_reuse_node1" / "mooncake_config.json") diff --git a/tests/test_master_coordinator.py b/tests/test_master_coordinator.py index a4540e600d..5f51bd9f8e 100644 --- a/tests/test_master_coordinator.py +++ b/tests/test_master_coordinator.py @@ -31,7 +31,7 @@ ) sys.path.insert(0, str(Path(__file__).parent)) -from _dist_reuse_fakes import FakeRedis, ManualClock # noqa: E402 +from _dist_reuse_fakes import FakeRedis # noqa: E402 # --------------------------------------------------------------------------- @@ -226,32 +226,6 @@ def test_mark_sd_ready_flow(self): assert ok2 is True assert mc.match_fully_ready(0xAA) is not None - def test_invalidate_prefix(self): - sd = _sd() # single SD - mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") - mc.expect_remotes(0) - mc.mark_sd_ready(0xAA, sd.serialize(), [10]) - assert mc.invalidate_prefix(0xAA) is True - assert mc.match_fully_ready(0xAA) is None - - def test_scan_leaked_refcount(self): - sd = _sd() - clock = ManualClock() - mc = MasterCoordinator( - self_sd=sd, instance_id="inst-A", - refcount_leak_timeout_seconds=10.0, - ) - # Inject our manual clock into the aggregate. - mc._aggregate._time_fn = clock # direct attribute poke for testing - mc.expect_remotes(0) - mc.acquire_blocks([7, 8]) - clock.advance(20.0) - leaked = mc.scan_leaked_refcount() - assert sorted(leaked) == [7, 8] - # Force-released: now evictable. - assert mc.is_evictable(7) - assert mc.is_evictable(8) - def test_peer_loss_invalidates(self): sd = _sd() mc = MasterCoordinator(self_sd=sd, instance_id="inst-A") @@ -393,16 +367,6 @@ def test_mooncake_without_regist_buffer_raises(self): with pytest.raises(AttributeError, match="regist_buffer"): init.bootstrap() - def test_encode_ready(self): - msg = RemoteReadyMsg( - sender_instance_id="inst", sender_epoch="e", - sd_key="m:ppn0/1:tpn0/1:nsa0", - distributed_node_id=1, - ) - out = RemoteDistReuseInitializer.encode_ready(msg) - assert out["type"] == "remote_ready" - assert out["sd_key"] == "m:ppn0/1:tpn0/1:nsa0" - # --------------------------------------------------------------------------- # SharingDomainHandleSpec diff --git a/tests/test_multinode_role_policy.py b/tests/test_multinode_role_policy.py deleted file mode 100644 index 2b4a8ce75a..0000000000 --- a/tests/test_multinode_role_policy.py +++ /dev/null @@ -1,186 +0,0 @@ -"""§2.4 — Multi-node role decision policy tests. - -The connector-side split of ``is_multinode_tp`` vs. ``is_multinode_cp`` -is currently *not* plumbed into ``flexkv_connector.py`` (sglang/) — -doing that requires a two-machine GPU setup to verify. Until then we -pin the decision table at the policy-function level so that when the -actual connector swap lands, it already has a stable, tested contract -to call into. - -These tests are torch-free by design: pure logic over ``RankTopology``. -""" -from __future__ import annotations - -import sys -from pathlib import Path - -import pytest - -REPO_ROOT = Path(__file__).resolve().parent.parent -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -from flexkv.integration.multinode_policy import ( # noqa: E402 - RankTopology, - RemoteProcessRole, - decide_remote_role, - is_sync_leader, -) - - -# --------------------------------------------------------------------------- -# Trivial validation -# --------------------------------------------------------------------------- -class TestValidation: - def test_nnodes_must_be_positive(self): - with pytest.raises(ValueError): - decide_remote_role(RankTopology( - nnodes=0, node_rank=0, local_rank=0, - is_multinode_tp=False, is_multinode_cp=False, - )) - - def test_node_rank_in_range(self): - with pytest.raises(ValueError): - decide_remote_role(RankTopology( - nnodes=2, node_rank=5, local_rank=0, - is_multinode_tp=True, is_multinode_cp=False, - )) - - def test_local_rank_non_negative(self): - with pytest.raises(ValueError): - decide_remote_role(RankTopology( - nnodes=2, node_rank=1, local_rank=-1, - is_multinode_tp=True, is_multinode_cp=False, - )) - - -# --------------------------------------------------------------------------- -# Single-node instance → NO_REMOTE regardless of any flag. -# --------------------------------------------------------------------------- -class TestSingleNode: - @pytest.mark.parametrize("tp,cp", [(False, False), (True, False), - (False, True), (True, True)]) - def test_single_node_never_spawns_remote(self, tp, cp): - topo = RankTopology( - nnodes=1, node_rank=0, local_rank=0, - is_multinode_tp=tp, is_multinode_cp=cp, - ) - assert decide_remote_role(topo) is RemoteProcessRole.NO_REMOTE - - -# --------------------------------------------------------------------------- -# Master node (node_rank == 0) never spawns a remote itself. -# --------------------------------------------------------------------------- -class TestMasterNode: - def test_master_with_multinode_tp_is_master(self): - topo = RankTopology( - nnodes=2, node_rank=0, local_rank=0, - is_multinode_tp=True, is_multinode_cp=False, - ) - assert decide_remote_role(topo) is RemoteProcessRole.MASTER - - def test_master_with_multinode_cp_is_master(self): - topo = RankTopology( - nnodes=2, node_rank=0, local_rank=0, - is_multinode_tp=False, is_multinode_cp=True, - ) - assert decide_remote_role(topo) is RemoteProcessRole.MASTER - - def test_master_with_both_flags_is_master(self): - topo = RankTopology( - nnodes=2, node_rank=0, local_rank=0, - is_multinode_tp=True, is_multinode_cp=True, - ) - assert decide_remote_role(topo) is RemoteProcessRole.MASTER - - def test_master_with_nothing_crossing_nodes_is_no_remote(self): - """Multi-node deployment but THIS instance is single-node - (only DP crosses; each DP instance stays on one node).""" - topo = RankTopology( - nnodes=2, node_rank=0, local_rank=0, - is_multinode_tp=False, is_multinode_cp=False, - ) - assert decide_remote_role(topo) is RemoteProcessRole.NO_REMOTE - - -# --------------------------------------------------------------------------- -# Off-master nodes — the interesting routing table. -# --------------------------------------------------------------------------- -class TestOffMasterRouting: - def test_multinode_tp_only_is_full_sd_remote(self): - topo = RankTopology( - nnodes=2, node_rank=1, local_rank=0, - is_multinode_tp=True, is_multinode_cp=False, - ) - assert decide_remote_role(topo) is RemoteProcessRole.SD_REMOTE_FULL - - def test_multinode_cp_only_is_cp_registration_stub(self): - topo = RankTopology( - nnodes=2, node_rank=1, local_rank=0, - is_multinode_tp=False, is_multinode_cp=True, - ) - assert decide_remote_role(topo) is RemoteProcessRole.CP_PEER_REGISTRATION_ONLY - - def test_multinode_tp_wins_over_cp(self): - """When BOTH flags are True on an off-master rank, TP takes - priority — TP-split SDs cannot be served by a CP-only stub.""" - topo = RankTopology( - nnodes=2, node_rank=1, local_rank=0, - is_multinode_tp=True, is_multinode_cp=True, - ) - assert decide_remote_role(topo) is RemoteProcessRole.SD_REMOTE_FULL - - def test_neither_tp_nor_cp_multinode_still_spawns_full_remote_today(self): - """Legacy bug-compat: today's connector treats any - ``nnodes>1 and node_rank>0 and local_rank==0`` case as - ``SD_REMOTE_FULL`` (PP crossing nodes uses this path). We - preserve that during the migration; the TODO in - multinode_policy.py tracks the eventual cleanup. - """ - topo = RankTopology( - nnodes=2, node_rank=1, local_rank=0, - is_multinode_tp=False, is_multinode_cp=False, - ) - assert decide_remote_role(topo) is RemoteProcessRole.SD_REMOTE_FULL - - -# --------------------------------------------------------------------------- -# Sync-leader helper. -# --------------------------------------------------------------------------- -class TestSyncLeader: - def test_default_rule_is_local_and_node_rank_zero(self): - topo = RankTopology( - nnodes=2, node_rank=0, local_rank=0, - is_multinode_tp=True, is_multinode_cp=False, - ) - assert is_sync_leader(topo) is True - - topo2 = RankTopology( - nnodes=2, node_rank=1, local_rank=0, - is_multinode_tp=True, is_multinode_cp=False, - ) - assert is_sync_leader(topo2) is False - - topo3 = RankTopology( - nnodes=2, node_rank=0, local_rank=1, - is_multinode_tp=True, is_multinode_cp=False, - ) - assert is_sync_leader(topo3) is False - - def test_explicit_hint_overrides_default(self): - """If the caller hands us ``is_sync_leader=True`` (e.g. coming - from sglang's own group metadata), respect it even if the - default heuristic would say otherwise.""" - topo = RankTopology( - nnodes=2, node_rank=1, local_rank=7, - is_multinode_tp=True, is_multinode_cp=False, - is_sync_leader=True, - ) - assert is_sync_leader(topo) is True - - topo2 = RankTopology( - nnodes=2, node_rank=0, local_rank=0, - is_multinode_tp=True, is_multinode_cp=False, - is_sync_leader=False, - ) - assert is_sync_leader(topo2) is False diff --git a/tests/test_sharing_domain_key.py b/tests/test_sharing_domain_key.py index e4283251ca..292b5464ee 100644 --- a/tests/test_sharing_domain_key.py +++ b/tests/test_sharing_domain_key.py @@ -369,10 +369,6 @@ def test_unique_serialization(self): masters = [p for p in peers if p.is_master()] assert len(masters) == 1 - def test_iter_dunder(self): - sd = SharingDomainKey.default() - assert list(sd) == sd.enumerate_peers() - # --------------------------------------------------------------------------- # default() diff --git a/tests/test_sharing_domain_namespace.py b/tests/test_sharing_domain_namespace.py index 0a99083237..acfa6ec600 100644 --- a/tests/test_sharing_domain_namespace.py +++ b/tests/test_sharing_domain_namespace.py @@ -57,15 +57,8 @@ def test_block_key_handles_negative_hash(ns): assert ns.block_key(0, h) == f"{ns.prefix}:block:0:ffffffffffffffff" -def test_aggregate_key(ns): - assert ns.aggregate_key(0xCAFEBABE) == f"{ns.prefix}:aggregate:cafebabe" - - def test_scan_patterns(ns): assert ns.node_key_pattern() == f"{ns.prefix}:node:*" - assert ns.meta_key_pattern() == f"{ns.prefix}:meta:*" - assert ns.buffer_key_pattern() == f"{ns.prefix}:buffer:*" - assert ns.block_key_pattern() == f"{ns.prefix}:block:*" assert ns.block_key_pattern_for_node(7) == f"{ns.prefix}:block:7:*"