diff --git a/CMakeLists.txt b/CMakeLists.txt index 749b201b8d..5040ed3d95 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,42 @@ cmake_minimum_required(VERSION 3.10) -project(MainProject VERSION 1.0) + +find_package(Git QUIET) + +if(GIT_FOUND) + execute_process( + COMMAND ${GIT_EXECUTABLE} describe --tags --long --match "v*" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GIT_DESCRIBE_OUTPUT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(GIT_DESCRIBE_OUTPUT MATCHES "^v([0-9]+\\.[0-9]+\\.[0-9]+)-([0-9]+)-g([0-9a-f]+)$") + set(GIT_VERSION "${CMAKE_MATCH_1}") + set(GIT_DISTANCE "${CMAKE_MATCH_2}") + set(GIT_HASH "${CMAKE_MATCH_3}") + if(GIT_DISTANCE STREQUAL "0") + set(DETECTED_VERSION "${GIT_VERSION}") + else() + set(DETECTED_VERSION "${GIT_VERSION}+git${GIT_HASH}") + endif() + message(STATUS "Version from git tag: ${DETECTED_VERSION}") + endif() +endif() + +if(NOT DEFINED DETECTED_VERSION OR DETECTED_VERSION STREQUAL "") + set(DETECTED_VERSION "0.0.0") + message(WARNING "Could not detect version from git tag, using fallback: ${DETECTED_VERSION}") +endif() + +# Strip +gitXXXXXXX suffix for CMake project VERSION (must be numeric X.Y.Z) +if(DETECTED_VERSION MATCHES "^([0-9]+\\.[0-9]+\\.[0-9]+)") + set(NUMERIC_VERSION "${CMAKE_MATCH_1}") +else() + set(NUMERIC_VERSION "0.0.0") +endif() + +project(MainProject VERSION ${NUMERIC_VERSION}) +message(STATUS "Project version: ${PROJECT_VERSION} (full: ${DETECTED_VERSION})") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -30,10 +67,20 @@ target_include_directories(xxhash PUBLIC install(FILES ${XXHASH_HEADERS} DESTINATION include) # ==================== prometheus-cpp Library ==================== -# Option to enable/disable monitoring (env FLEXKV_ENABLE_METRICS=0 or -DFLEXKV_ENABLE_MONITORING=OFF) -set(_FLEXKV_MONITORING_DEFAULT ON) +# Step 1: Auto-detect default from directory existence +if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/third_party/prometheus-cpp") + set(_FLEXKV_MONITORING_DEFAULT ON) +else() + set(_FLEXKV_MONITORING_DEFAULT OFF) + message(STATUS "third_party/prometheus-cpp not found, Prometheus monitoring defaults to OFF") +endif() + +# Step 2: Environment variable override (FLEXKV_ENABLE_METRICS=1/0) if(DEFINED ENV{FLEXKV_ENABLE_METRICS}) - if("$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "0" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "OFF" + if("$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "1" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "ON" + OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "YES" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "TRUE") + set(_FLEXKV_MONITORING_DEFAULT ON) + elseif("$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "0" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "OFF" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "NO" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "FALSE") set(_FLEXKV_MONITORING_DEFAULT OFF) endif() diff --git a/VERSION b/VERSION deleted file mode 100644 index 3eefcb9dd5..0000000000 --- a/VERSION +++ /dev/null @@ -1 +0,0 @@ -1.0.0 diff --git a/benchmarks/benchmark_dist_kvcache.py b/benchmarks/benchmark_dist_kvcache.py new file mode 100644 index 0000000000..26e6c84426 --- /dev/null +++ b/benchmarks/benchmark_dist_kvcache.py @@ -0,0 +1,637 @@ +""" +Benchmark for FlexKV distributed KVCache in server_client_mode. + +This script tests the put/get performance of FlexKV when running in +server_client_mode with distributed KVCache sharing enabled (enable_p2p_cpu). + +Prerequisites: + - A running Redis server (default: 127.0.0.1:6379) + - At least 1 GPU available + - FlexKV built with distributed support (FLEXKV_ENABLE_P2P=1) + +Usage: + # Basic usage with default config + python benchmarks/benchmark_dist_kvcache.py --config benchmarks/example_dist_config.yml + + # Custom parameters + python benchmarks/benchmark_dist_kvcache.py \ + --config benchmarks/example_dist_config.yml \ + --batch-size 4 \ + --sequence-length 2048 \ + --cache-ratio 0.5 \ + --num-users 10 \ + --num-turns 3 + + # Multi-turn conversation benchmark only + python benchmarks/benchmark_dist_kvcache.py \\ + --config benchmarks/example_dist_config.yml \\ + --mode multiturn \\ + --num-users 20 \\ + --num-turns 5 + + # Cross-node benchmark: Node A (PUT only) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_a.yml --seed 42 --mode put-only + + # Cross-node benchmark: Node B (GET only, same seed) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_b.yml --seed 42 --mode get-only +""" +import os +import atexit +import signal +import argparse +import json +import tempfile +import time +from multiprocessing import Process +from dataclasses import dataclass + +import torch +import numpy as np + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ( + ModelConfig, CacheConfig, UserConfig, + update_default_config_from_user_config, parse_path_list, + GLOBAL_CONFIG_FROM_ENV, +) +from flexkv.common.debug import flexkv_logger +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +from utils import generate_random_multiturn + +flexkv_logger.set_level("INFO") + + +def load_dist_config(config_path: str): + """Load config with distributed KVCache support. + + Extends the standard load_config to handle distributed-specific fields: + enable_p2p_cpu, enable_p2p_ssd, enable_3rd_remote, + redis_host, redis_port, local_ip, redis_password, + server_client_mode, etc. + """ + import yaml + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(f"Loaded config: {config}") + + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + + # Model config + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + # Cache size config + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + + # Distributed KVCache config + if "enable_p2p_cpu" in config: + user_config.enable_p2p_cpu = config["enable_p2p_cpu"] + if "enable_p2p_ssd" in config: + user_config.enable_p2p_ssd = config["enable_p2p_ssd"] + if "enable_3rd_remote" in config: + user_config.enable_3rd_remote = config["enable_3rd_remote"] + + # Redis config + if "redis_host" in config: + user_config.redis_host = config["redis_host"] + if "redis_port" in config: + user_config.redis_port = config["redis_port"] + if "local_ip" in config: + user_config.local_ip = config["local_ip"] + if "redis_password" in config: + user_config.redis_password = config["redis_password"] + + # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled + if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): + if "MOONCAKE_CONFIG_PATH" not in os.environ: + mooncake_config = { + "engine_ip": config.get("mooncake_engine_ip", config.get("local_ip", "127.0.0.1")), + "engine_port": config.get("mooncake_engine_port", 5555), + "metadata_backend": config.get("mooncake_metadata_backend", "redis"), + "metadata_server": config.get("mooncake_metadata_server", + f"redis://{config.get('redis_host', '127.0.0.1')}:{config.get('redis_port', 6379)}"), + "metadata_server_auth": config.get("mooncake_metadata_server_auth", + config.get("redis_password", "")), + "protocol": config.get("mooncake_protocol", "tcp"), + "device_name": config.get("mooncake_device_name", ""), + } + # Write to a temp file that persists until process exits + mooncake_config_fd, mooncake_config_path = tempfile.mkstemp( + suffix=".json", prefix="mooncake_config_" + ) + with os.fdopen(mooncake_config_fd, "w") as f: + json.dump(mooncake_config, f, indent=2) + os.environ["MOONCAKE_CONFIG_PATH"] = mooncake_config_path + print(f"[INFO] Auto-generated mooncake config at: {mooncake_config_path}") + print(f"[INFO] Mooncake config: {json.dumps(mooncake_config, indent=2)}") + else: + mooncake_config_path = os.environ['MOONCAKE_CONFIG_PATH'] + print(f"[INFO] Using existing MOONCAKE_CONFIG_PATH: {mooncake_config_path}") + + # Store mooncake_config_path in cache_config so it survives spawn subprocesses via pickle + cache_config.mooncake_config_path = mooncake_config_path + + update_default_config_from_user_config(model_config, cache_config, user_config) + + # Handle server_client_mode from config + if config.get("server_client_mode", False): + os.environ["FLEXKV_SERVER_CLIENT_MODE"] = "1" + GLOBAL_CONFIG_FROM_ENV.server_client_mode = True + + return model_config, cache_config + + +@dataclass +class BenchmarkConfig: + # Single batch benchmark params + batch_size: int = 1 + sequence_length: int = 1024 + cache_ratio: float = 1.0 + clear_cpu_cache: bool = False + + # Multi-turn benchmark params + num_users: int = 10 + num_turns: int = 3 + system_prompt_length: int = 100 + input_length: int = 512 + output_length: int = 64 + + # General + mode: str = "all" # "single", "multiturn", "all", "put-only", "get-only" + seed: int = None # Random seed for deterministic token generation (cross-node) + + +def run_tp_client(dp_client_id, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks): + """Run tp_client process to register GPU blocks""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(gpu_register_port, dp_client_id, device_id) + + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Keep the process running + while True: + time.sleep(1) + + +def shutdown_tp_clients(tp_client_processes): + """Terminate all tp_client processes""" + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + + +def benchmark_single_batch(kvmanager, model_config, cache_config, bench_config): + """Benchmark single batch put/get with distributed KVCache""" + print("\n" + "=" * 60) + print(" Single Batch Benchmark (Distributed KVCache)") + print("=" * 60) + + sequence_length = bench_config.sequence_length + batch_size = bench_config.batch_size + cache_length = int(sequence_length * bench_config.cache_ratio) + + print(f" batch_size={batch_size}, sequence_length={sequence_length}, " + f"cache_ratio={bench_config.cache_ratio}, cache_length={cache_length}") + if bench_config.seed is not None: + print(f" seed={bench_config.seed}") + + # Generate random sequences (use seed for deterministic cross-node benchmarks) + if bench_config.seed is not None: + torch.manual_seed(bench_config.seed) + batch_sequence_tensor = [] + batch_slot_mapping = [] + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length,), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i + 1) * sequence_length, dtype=torch.int64)) + + results = {} + skip_put = (bench_config.mode == "get-only") + skip_get = (bench_config.mode == "put-only") + + # In get-only mode, wait for remote index to be refreshed from Redis + if skip_put: + rebuild_interval_ms = int(os.environ.get("FLEXKV_REBUILD_INTERVAL_MS", "100")) + # Wait at least 3x rebuild_interval to ensure at least one full refresh cycle + wait_time_s = max(rebuild_interval_ms * 3 / 1000.0, 0.5) + print(f" Waiting {wait_time_s:.2f}s for remote index refresh " + f"(FLEXKV_REBUILD_INTERVAL_MS={rebuild_interval_ms})...") + time.sleep(wait_time_s) + + # ---- Benchmark PUT ---- + if not skip_put: + print("\n--- PUT Phase ---") + start_time = time.time() + batch_put_ids = [] + if bench_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async( + batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None, + ) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put if elapsed_time_put > 0 else 0 + print(f" PUT: {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put * 1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + results.update({ + "put_tokens": put_tokens, + "put_time_ms": elapsed_time_put * 1000, + "put_bandwidth_GBs": transfer_bandwidth_put, + }) + else: + print("\n--- PUT Phase SKIPPED (get-only mode) ---") + + if bench_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + # ---- Benchmark GET ---- + if not skip_get: + print("\n--- GET Phase ---") + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get if elapsed_time_get > 0 else 0 + print(f" GET: {cached_tokens}/{all_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time * 1000:.2f}ms, " + f"e2e time: {elapsed_time_get * 1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + results.update({ + "get_cached_tokens": cached_tokens, + "get_total_tokens": all_tokens, + "get_cache_ratio": cached_tokens / all_tokens if all_tokens > 0 else 0, + "get_match_time_ms": get_match_time * 1000, + "get_e2e_time_ms": elapsed_time_get * 1000, + "get_bandwidth_GBs": transfer_bandwidth_get, + }) + else: + print("\n--- GET Phase SKIPPED (put-only mode) ---") + + return results + + +def benchmark_multiturn(kvmanager, model_config, cache_config, bench_config): + """Benchmark multi-turn conversation with distributed KVCache""" + print("\n" + "=" * 60) + print(" Multi-Turn Conversation Benchmark (Distributed KVCache)") + print("=" * 60) + print(f" num_users={bench_config.num_users}, num_turns={bench_config.num_turns}, " + f"system_prompt_length={bench_config.system_prompt_length}, " + f"input_length={bench_config.input_length}, output_length={bench_config.output_length}") + + # Generate multi-turn requests + reqs = generate_random_multiturn( + num_user_requests=bench_config.num_users, + num_turns=bench_config.num_turns, + system_prompt_length=bench_config.system_prompt_length, + input_length=bench_config.input_length, + output_length=bench_config.output_length, + seed=bench_config.seed, + ) + + total_get_requests = 0 + total_put_requests = 0 + cache_hit_ratios = [] + total_put_time = 0 + total_get_time = 0 + total_put_tokens = 0 + total_get_cached_tokens = 0 + total_get_all_tokens = 0 + + request_id = 0 + for req in reqs: + fake_slot_mapping = torch.arange(req.token_mask.sum(), dtype=torch.int64) + + if req.request_type == "get": + total_get_requests += 1 + total_get_all_tokens += req.token_mask.sum().item() + + start_time = time.time() + task_id, _ = kvmanager.get_match( + req.token_ids, + token_mask=torch.ones_like(torch.from_numpy(req.token_ids) if isinstance(req.token_ids, np.ndarray) else req.token_ids), + ) + kvmanager.launch([task_id], [fake_slot_mapping.numpy()]) + result = kvmanager.wait([task_id]) + elapsed = time.time() - start_time + total_get_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + cached = response.return_mask.sum().item() + total_get_cached_tokens += cached + ratio = cached / req.token_mask.sum().item() + cache_hit_ratios.append(ratio) + else: + cache_hit_ratios.append(0.0) + + elif req.request_type == "put": + total_put_requests += 1 + + start_time = time.time() + task_id = kvmanager.put_async( + req.token_ids, + fake_slot_mapping.numpy(), + token_mask=None, + ) + result = kvmanager.wait([task_id], completely=True) + elapsed = time.time() - start_time + total_put_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + total_put_tokens += response.return_mask.sum().item() + + request_id += 1 + + # Print results + print(f"\n--- Results ---") + print(f" Total requests: {len(reqs)} (GET: {total_get_requests}, PUT: {total_put_requests})") + print(f" PUT: {total_put_tokens} tokens, total time: {total_put_time * 1000:.2f}ms, " + f"avg time: {total_put_time * 1000 / max(total_put_requests, 1):.2f}ms/req") + print(f" GET: {total_get_cached_tokens}/{total_get_all_tokens} tokens cached, " + f"total time: {total_get_time * 1000:.2f}ms, " + f"avg time: {total_get_time * 1000 / max(total_get_requests, 1):.2f}ms/req") + + if cache_hit_ratios: + sorted_ratios = sorted(cache_hit_ratios) + avg_ratio = sum(sorted_ratios) / len(sorted_ratios) + print(f" Cache hit ratio: avg={avg_ratio * 100:.2f}%, " + f"min={sorted_ratios[0] * 100:.2f}%, " + f"median={sorted_ratios[len(sorted_ratios) // 2] * 100:.2f}%, " + f"max={sorted_ratios[-1] * 100:.2f}%") + + return { + "total_requests": len(reqs), + "get_requests": total_get_requests, + "put_requests": total_put_requests, + "put_tokens": total_put_tokens, + "put_total_time_ms": total_put_time * 1000, + "get_cached_tokens": total_get_cached_tokens, + "get_total_tokens": total_get_all_tokens, + "get_total_time_ms": total_get_time * 1000, + "avg_cache_hit_ratio": sum(cache_hit_ratios) / len(cache_hit_ratios) if cache_hit_ratios else 0, + } + + +def main(args): + # Set FLEXKV_REBUILD_INTERVAL_MS for faster cross-node index sync + # NOTE: Must set env var AND update GLOBAL_CONFIG_FROM_ENV because + # GLOBAL_CONFIG_FROM_ENV is evaluated at module import time (before main runs). + # The env var alone is not enough since the Namespace is already frozen. + if args.rebuild_interval_ms is not None: + os.environ["FLEXKV_REBUILD_INTERVAL_MS"] = str(args.rebuild_interval_ms) + GLOBAL_CONFIG_FROM_ENV.rebuild_interval_ms = args.rebuild_interval_ms + print(f"[INFO] Set FLEXKV_REBUILD_INTERVAL_MS={args.rebuild_interval_ms}") + + # Load config + model_config, cache_config = load_dist_config(args.config) + + bench_config = BenchmarkConfig( + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache, + num_users=args.num_users, + num_turns=args.num_turns, + system_prompt_length=args.system_prompt_length, + input_length=args.input_length, + output_length=args.output_length, + mode=args.mode, + seed=args.seed, + ) + + # Pad sequence length to be divisible by tokens_per_block + bench_config.sequence_length = ( + ((bench_config.sequence_length - 1) // cache_config.tokens_per_block + 1) + * cache_config.tokens_per_block + ) + + num_gpu_blocks = bench_config.sequence_length * bench_config.batch_size // cache_config.tokens_per_block + # Ensure enough GPU blocks for multi-turn mode too + if bench_config.mode in ("multiturn", "all", "put-only", "get-only"): + max_tokens_per_user = ( + bench_config.system_prompt_length + + bench_config.num_turns * (bench_config.input_length + bench_config.output_length) + ) + multiturn_blocks = max_tokens_per_user * bench_config.num_users // cache_config.tokens_per_block + num_gpu_blocks = max(num_gpu_blocks, multiturn_blocks) + # Add some extra blocks for safety + num_gpu_blocks = int(num_gpu_blocks * 1.5) + 64 + + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError( + f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} > " + f"available GPUs {torch.cuda.device_count()}" + ) + + print("=" * 60) + print(" FlexKV Distributed KVCache Benchmark (server_client_mode)") + print("=" * 60) + print(f" model_config: {model_config}") + print(f" cache_config: {cache_config}") + print(f" enable_kv_sharing: {cache_config.enable_kv_sharing}") + print(f" enable_p2p_cpu: {cache_config.enable_p2p_cpu}") + print(f" redis: {cache_config.redis_host}:{cache_config.redis_port}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" bench_config: {bench_config}") + + # Create KVManager (this will start KVServer in server_client_mode) + kvmanager = KVManager(model_config, cache_config) + kvmanager.start() + + # Start tp_client processes to register GPU blocks + tp_client_processes = [] + + # Register cleanup handler to ensure processes are terminated on exit + def _cleanup(): + shutdown_tp_clients(tp_client_processes) + try: + kvmanager.shutdown() + except Exception: + pass + atexit.register(_cleanup) + + def _signal_handler(signum, frame): + print(f"\nReceived signal {signum}, shutting down...") + _cleanup() + # Re-raise to allow default handler + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + for tp_rank in range(model_config.tp_size): + tp_process = Process( + target=run_tp_client, + args=(0, tp_rank, kvmanager.gpu_register_port, + model_config, cache_config, num_gpu_blocks), + daemon=True, + ) + tp_process.start() + tp_client_processes.append(tp_process) + + # Wait for system to be ready + print("\nWaiting for FlexKV to be ready...") + wait_start = time.time() + while not kvmanager.is_ready(): + time.sleep(1) + elapsed = time.time() - wait_start + if elapsed > 120: + print("ERROR: Timeout waiting for FlexKV to be ready (120s)") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + return + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + print(f" Still waiting... ({int(elapsed)}s)") + print(f"FlexKV is ready! (took {time.time() - wait_start:.1f}s)") + + try: + results = {} + + if bench_config.mode in ("single", "all", "put-only", "get-only"): + results["single_batch"] = benchmark_single_batch( + kvmanager, model_config, cache_config, bench_config + ) + + if bench_config.mode in ("multiturn", "all"): + results["multiturn"] = benchmark_multiturn( + kvmanager, model_config, cache_config, bench_config + ) + + # Print summary + print("\n" + "=" * 60) + print(" Benchmark Summary") + print("=" * 60) + for name, result in results.items(): + print(f"\n [{name}]") + for k, v in result.items(): + if isinstance(v, float): + print(f" {k}: {v:.4f}") + else: + print(f" {k}: {v}") + + # In put-only mode, keep the process alive so other nodes can GET the data + if bench_config.mode == "put-only": + print("\n" + "-" * 60) + print("Data published to Redis. Press Enter to shutdown " + "(keep running for other nodes to GET)...") + print("-" * 60) + try: + input() + except EOFError: + # Handle non-interactive environments + print("Non-interactive mode detected. Sleeping indefinitely (Ctrl+C to stop)...") + while True: + time.sleep(1) + + finally: + print("\nShutting down...") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + # Unregister atexit handler since we've already cleaned up + atexit.unregister(_cleanup) + print("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark FlexKV distributed KVCache in server_client_mode" + ) + parser.add_argument("--config", type=str, default="benchmarks/example_dist_config.yml", + help="Path to config YAML file") + parser.add_argument("--mode", type=str, default="all", + choices=["single", "multiturn", "all", "put-only", "get-only"], + help="Benchmark mode: single, multiturn, all, put-only, get-only") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for deterministic token generation (for cross-node benchmarks)") + + # Single batch params + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for single batch benchmark") + parser.add_argument("--sequence-length", type=int, default=1024, help="Sequence length per request") + parser.add_argument("--cache-ratio", type=float, default=1.0, help="Ratio of tokens to cache in PUT phase") + parser.add_argument("--clear-cpu-cache", action="store_true", help="Clear CPU cache between PUT and GET") + + # Multi-turn params + parser.add_argument("--num-users", type=int, default=10, help="Number of simulated users") + parser.add_argument("--num-turns", type=int, default=3, help="Number of conversation turns per user") + parser.add_argument("--system-prompt-length", type=int, default=100, help="System prompt length in tokens") + parser.add_argument("--input-length", type=int, default=512, help="Input length per turn in tokens") + parser.add_argument("--output-length", type=int, default=64, help="Output length per turn in tokens") + + # Cross-node sync params + parser.add_argument("--rebuild-interval-ms", type=int, default=None, + help="Override FLEXKV_REBUILD_INTERVAL_MS (default: use env or 100). " + "Recommended: 20 for cross-node benchmarks") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index c629b545ea..f05a63d97e 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -139,7 +139,7 @@ def benchmark_flexkv(model_config: ModelConfig, token_mask=None) batch_get_ids.append(task_id) get_match_time = time.time() - start_time - kvmanager.launch(batch_get_ids, batch_slot_mapping) + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) get_result = kvmanager.wait(batch_get_ids) elapsed_time_get = time.time() - start_time cached_tokens = 0 diff --git a/benchmarks/dist_benchmark/benchmark_dist_direct.py b/benchmarks/dist_benchmark/benchmark_dist_direct.py new file mode 100644 index 0000000000..7e22e51e50 --- /dev/null +++ b/benchmarks/dist_benchmark/benchmark_dist_direct.py @@ -0,0 +1,660 @@ +""" +Benchmark for FlexKV distributed KVCache in direct mode (non-server_client_mode). + +In direct mode, KVManager creates KVTaskEngine directly in the main process +without going through KVServer/KVDPClient IPC. This is simpler and has lower +overhead, suitable for single-instance single-dp benchmarks. + +The key difference from server_client_mode: + - No KVServer subprocess is spawned + - KVTaskEngine runs directly in the main process + - KVTPClient still registers GPU blocks via ZMQ to KVTaskEngine + - RedisMeta is created directly in KVManager (not inside KVServer) + +Prerequisites: + - A running Redis server (default: 127.0.0.1:6379) + - At least 1 GPU available + - FlexKV built with distributed support (FLEXKV_ENABLE_P2P=1) + +Usage: + # Basic usage with default config + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config benchmarks/dist_benchmark/example_dist_direct_config.yml + + # Custom parameters + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config benchmarks/dist_benchmark/example_dist_direct_config.yml \\ + --batch-size 4 \\ + --sequence-length 2048 \\ + --cache-ratio 0.5 \\ + --num-users 10 \\ + --num-turns 3 + + # Multi-turn conversation benchmark only + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config benchmarks/dist_benchmark/example_dist_direct_config.yml \\ + --mode multiturn \\ + --num-users 20 \\ + --num-turns 5 + + # Cross-node benchmark: Node A (PUT only) + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config config_a.yml --seed 42 --mode put-only + + # Cross-node benchmark: Node B (GET only, same seed) + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config config_b.yml --seed 42 --mode get-only +""" +import os +import sys +import atexit +import signal +import argparse +import json +import tempfile +import time +from multiprocessing import Process +from dataclasses import dataclass + +import torch +import numpy as np + +# Add parent directory to path so we can import utils +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ( + ModelConfig, CacheConfig, UserConfig, + update_default_config_from_user_config, parse_path_list, + GLOBAL_CONFIG_FROM_ENV, +) +from flexkv.common.debug import flexkv_logger +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +from utils import generate_random_multiturn + +flexkv_logger.set_level("INFO") + + +def load_dist_direct_config(config_path: str): + """Load config for direct mode (non-server_client_mode) distributed KVCache. + + This is similar to load_dist_config in benchmark_dist_kvcache.py, but + ensures server_client_mode is NOT set, so KVManager uses KVTaskEngine + directly instead of going through KVServer IPC. + """ + import yaml + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(f"Loaded config: {config}") + + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + + # Model config + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + # Cache size config + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + + # Distributed KVCache config + if "enable_p2p_cpu" in config: + user_config.enable_p2p_cpu = config["enable_p2p_cpu"] + if "enable_p2p_ssd" in config: + user_config.enable_p2p_ssd = config["enable_p2p_ssd"] + if "enable_3rd_remote" in config: + user_config.enable_3rd_remote = config["enable_3rd_remote"] + + # Redis config + if "redis_host" in config: + user_config.redis_host = config["redis_host"] + if "redis_port" in config: + user_config.redis_port = config["redis_port"] + if "local_ip" in config: + user_config.local_ip = config["local_ip"] + if "redis_password" in config: + user_config.redis_password = config["redis_password"] + + # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled + if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): + if "MOONCAKE_CONFIG_PATH" not in os.environ: + mooncake_config = { + "engine_ip": config.get("mooncake_engine_ip", config.get("local_ip", "127.0.0.1")), + "engine_port": config.get("mooncake_engine_port", 5555), + "metadata_backend": config.get("mooncake_metadata_backend", "redis"), + "metadata_server": config.get("mooncake_metadata_server", + f"redis://{config.get('redis_host', '127.0.0.1')}:{config.get('redis_port', 6379)}"), + "metadata_server_auth": config.get("mooncake_metadata_server_auth", + config.get("redis_password", "")), + "protocol": config.get("mooncake_protocol", "tcp"), + "device_name": config.get("mooncake_device_name", ""), + } + # Write to a temp file that persists until process exits + mooncake_config_fd, mooncake_config_path = tempfile.mkstemp( + suffix=".json", prefix="mooncake_config_" + ) + with os.fdopen(mooncake_config_fd, "w") as f: + json.dump(mooncake_config, f, indent=2) + os.environ["MOONCAKE_CONFIG_PATH"] = mooncake_config_path + print(f"[INFO] Auto-generated mooncake config at: {mooncake_config_path}") + print(f"[INFO] Mooncake config: {json.dumps(mooncake_config, indent=2)}") + else: + mooncake_config_path = os.environ['MOONCAKE_CONFIG_PATH'] + print(f"[INFO] Using existing MOONCAKE_CONFIG_PATH: {mooncake_config_path}") + + # Store mooncake_config_path in cache_config so it survives spawn subprocesses via pickle + cache_config.mooncake_config_path = mooncake_config_path + + update_default_config_from_user_config(model_config, cache_config, user_config) + + # IMPORTANT: Ensure server_client_mode is NOT set for direct mode + # Even if config says server_client_mode: true, we override it here + if config.get("server_client_mode", False): + print("[WARN] server_client_mode is set in config but will be IGNORED in direct mode benchmark.") + os.environ.pop("FLEXKV_SERVER_CLIENT_MODE", None) + GLOBAL_CONFIG_FROM_ENV.server_client_mode = False + + return model_config, cache_config + + +@dataclass +class BenchmarkConfig: + # Single batch benchmark params + batch_size: int = 1 + sequence_length: int = 1024 + cache_ratio: float = 1.0 + clear_cpu_cache: bool = False + + # Multi-turn benchmark params + num_users: int = 10 + num_turns: int = 3 + system_prompt_length: int = 100 + input_length: int = 512 + output_length: int = 64 + + # General + mode: str = "all" # "single", "multiturn", "all", "put-only", "get-only" + seed: int = None # Random seed for deterministic token generation (cross-node) + + +def run_tp_client(dp_client_id, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks): + """Run tp_client process to register GPU blocks. + + In direct mode, KVTPClient still communicates with KVTaskEngine via ZMQ + to register GPU memory blocks. The difference is that KVTaskEngine runs + in the main process (not in a KVServer subprocess). + """ + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(gpu_register_port, dp_client_id, device_id) + + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Keep the process running + while True: + time.sleep(1) + + +def shutdown_tp_clients(tp_client_processes): + """Terminate all tp_client processes""" + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + + +def benchmark_single_batch(kvmanager, model_config, cache_config, bench_config): + """Benchmark single batch put/get with distributed KVCache (direct mode)""" + print("\n" + "=" * 60) + print(" Single Batch Benchmark (Distributed KVCache - Direct Mode)") + print("=" * 60) + + sequence_length = bench_config.sequence_length + batch_size = bench_config.batch_size + cache_length = int(sequence_length * bench_config.cache_ratio) + + print(f" batch_size={batch_size}, sequence_length={sequence_length}, " + f"cache_ratio={bench_config.cache_ratio}, cache_length={cache_length}") + if bench_config.seed is not None: + print(f" seed={bench_config.seed}") + + # Generate random sequences (use seed for deterministic cross-node benchmarks) + if bench_config.seed is not None: + torch.manual_seed(bench_config.seed) + batch_sequence_tensor = [] + batch_slot_mapping = [] + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length,), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i + 1) * sequence_length, dtype=torch.int64)) + + results = {} + skip_put = (bench_config.mode == "get-only") + skip_get = (bench_config.mode == "put-only") + + # In get-only mode, wait for remote index to be refreshed from Redis + if skip_put: + rebuild_interval_ms = int(os.environ.get("FLEXKV_REBUILD_INTERVAL_MS", "100")) + wait_time_s = max(rebuild_interval_ms * 3 / 1000.0, 0.5) + print(f" Waiting {wait_time_s:.2f}s for remote index refresh " + f"(FLEXKV_REBUILD_INTERVAL_MS={rebuild_interval_ms})...") + time.sleep(wait_time_s) + + # ---- Benchmark PUT ---- + if not skip_put: + print("\n--- PUT Phase ---") + start_time = time.time() + batch_put_ids = [] + if bench_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async( + batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None, + ) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put if elapsed_time_put > 0 else 0 + print(f" PUT: {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put * 1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + results.update({ + "put_tokens": put_tokens, + "put_time_ms": elapsed_time_put * 1000, + "put_bandwidth_GBs": transfer_bandwidth_put, + }) + else: + print("\n--- PUT Phase SKIPPED (get-only mode) ---") + + if bench_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + # ---- Benchmark GET ---- + if not skip_get: + print("\n--- GET Phase ---") + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get if elapsed_time_get > 0 else 0 + print(f" GET: {cached_tokens}/{all_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time * 1000:.2f}ms, " + f"e2e time: {elapsed_time_get * 1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + results.update({ + "get_cached_tokens": cached_tokens, + "get_total_tokens": all_tokens, + "get_cache_ratio": cached_tokens / all_tokens if all_tokens > 0 else 0, + "get_match_time_ms": get_match_time * 1000, + "get_e2e_time_ms": elapsed_time_get * 1000, + "get_bandwidth_GBs": transfer_bandwidth_get, + }) + else: + print("\n--- GET Phase SKIPPED (put-only mode) ---") + + return results + + +def benchmark_multiturn(kvmanager, model_config, cache_config, bench_config): + """Benchmark multi-turn conversation with distributed KVCache (direct mode)""" + print("\n" + "=" * 60) + print(" Multi-Turn Conversation Benchmark (Distributed KVCache - Direct Mode)") + print("=" * 60) + print(f" num_users={bench_config.num_users}, num_turns={bench_config.num_turns}, " + f"system_prompt_length={bench_config.system_prompt_length}, " + f"input_length={bench_config.input_length}, output_length={bench_config.output_length}") + + # Generate multi-turn requests + reqs = generate_random_multiturn( + num_user_requests=bench_config.num_users, + num_turns=bench_config.num_turns, + system_prompt_length=bench_config.system_prompt_length, + input_length=bench_config.input_length, + output_length=bench_config.output_length, + seed=bench_config.seed, + ) + + total_get_requests = 0 + total_put_requests = 0 + cache_hit_ratios = [] + total_put_time = 0 + total_get_time = 0 + total_put_tokens = 0 + total_get_cached_tokens = 0 + total_get_all_tokens = 0 + + request_id = 0 + for req in reqs: + fake_slot_mapping = torch.arange(req.token_mask.sum(), dtype=torch.int64) + + if req.request_type == "get": + total_get_requests += 1 + total_get_all_tokens += req.token_mask.sum().item() + + start_time = time.time() + task_id, _ = kvmanager.get_match( + req.token_ids, + token_mask=torch.ones_like(torch.from_numpy(req.token_ids) if isinstance(req.token_ids, np.ndarray) else req.token_ids), + ) + kvmanager.launch([task_id], [fake_slot_mapping.numpy()]) + result = kvmanager.wait([task_id]) + elapsed = time.time() - start_time + total_get_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + cached = response.return_mask.sum().item() + total_get_cached_tokens += cached + ratio = cached / req.token_mask.sum().item() + cache_hit_ratios.append(ratio) + else: + cache_hit_ratios.append(0.0) + + elif req.request_type == "put": + total_put_requests += 1 + + start_time = time.time() + task_id = kvmanager.put_async( + req.token_ids, + fake_slot_mapping.numpy(), + token_mask=None, + ) + result = kvmanager.wait([task_id], completely=True) + elapsed = time.time() - start_time + total_put_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + total_put_tokens += response.return_mask.sum().item() + + request_id += 1 + + # Print results + print(f"\n--- Results ---") + print(f" Total requests: {len(reqs)} (GET: {total_get_requests}, PUT: {total_put_requests})") + print(f" PUT: {total_put_tokens} tokens, total time: {total_put_time * 1000:.2f}ms, " + f"avg time: {total_put_time * 1000 / max(total_put_requests, 1):.2f}ms/req") + print(f" GET: {total_get_cached_tokens}/{total_get_all_tokens} tokens cached, " + f"total time: {total_get_time * 1000:.2f}ms, " + f"avg time: {total_get_time * 1000 / max(total_get_requests, 1):.2f}ms/req") + + if cache_hit_ratios: + sorted_ratios = sorted(cache_hit_ratios) + avg_ratio = sum(sorted_ratios) / len(sorted_ratios) + print(f" Cache hit ratio: avg={avg_ratio * 100:.2f}%, " + f"min={sorted_ratios[0] * 100:.2f}%, " + f"median={sorted_ratios[len(sorted_ratios) // 2] * 100:.2f}%, " + f"max={sorted_ratios[-1] * 100:.2f}%") + + return { + "total_requests": len(reqs), + "get_requests": total_get_requests, + "put_requests": total_put_requests, + "put_tokens": total_put_tokens, + "put_total_time_ms": total_put_time * 1000, + "get_cached_tokens": total_get_cached_tokens, + "get_total_tokens": total_get_all_tokens, + "get_total_time_ms": total_get_time * 1000, + "avg_cache_hit_ratio": sum(cache_hit_ratios) / len(cache_hit_ratios) if cache_hit_ratios else 0, + } + + +def main(args): + # Set FLEXKV_REBUILD_INTERVAL_MS for faster cross-node index sync + if args.rebuild_interval_ms is not None: + os.environ["FLEXKV_REBUILD_INTERVAL_MS"] = str(args.rebuild_interval_ms) + GLOBAL_CONFIG_FROM_ENV.rebuild_interval_ms = args.rebuild_interval_ms + print(f"[INFO] Set FLEXKV_REBUILD_INTERVAL_MS={args.rebuild_interval_ms}") + + # Load config (ensures server_client_mode is OFF) + model_config, cache_config = load_dist_direct_config(args.config) + + bench_config = BenchmarkConfig( + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache, + num_users=args.num_users, + num_turns=args.num_turns, + system_prompt_length=args.system_prompt_length, + input_length=args.input_length, + output_length=args.output_length, + mode=args.mode, + seed=args.seed, + ) + + # Pad sequence length to be divisible by tokens_per_block + bench_config.sequence_length = ( + ((bench_config.sequence_length - 1) // cache_config.tokens_per_block + 1) + * cache_config.tokens_per_block + ) + + num_gpu_blocks = bench_config.sequence_length * bench_config.batch_size // cache_config.tokens_per_block + # Ensure enough GPU blocks for multi-turn mode too + if bench_config.mode in ("multiturn", "all", "put-only", "get-only"): + max_tokens_per_user = ( + bench_config.system_prompt_length + + bench_config.num_turns * (bench_config.input_length + bench_config.output_length) + ) + multiturn_blocks = max_tokens_per_user * bench_config.num_users // cache_config.tokens_per_block + num_gpu_blocks = max(num_gpu_blocks, multiturn_blocks) + # Add some extra blocks for safety + num_gpu_blocks = int(num_gpu_blocks * 1.5) + 64 + + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError( + f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} > " + f"available GPUs {torch.cuda.device_count()}" + ) + + print("=" * 60) + print(" FlexKV Distributed KVCache Benchmark (Direct Mode)") + print("=" * 60) + print(f" model_config: {model_config}") + print(f" cache_config: {cache_config}") + print(f" enable_kv_sharing: {cache_config.enable_kv_sharing}") + print(f" enable_p2p_cpu: {cache_config.enable_p2p_cpu}") + print(f" redis: {cache_config.redis_host}:{cache_config.redis_port}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" bench_config: {bench_config}") + print(f" server_client_mode: False (direct mode)") + + # Create KVManager in direct mode + # In direct mode, KVManager creates KVTaskEngine directly (no KVServer subprocess) + # RedisMeta is also created directly in KVManager + kvmanager = KVManager(model_config, cache_config) + kvmanager.start() + + # Verify we are indeed in direct mode + assert not kvmanager.server_client_mode, \ + "Expected direct mode (server_client_mode=False), but got server_client_mode=True. " \ + "Check your config: dp_size must be 1, instance_num must be 1, and " \ + "FLEXKV_SERVER_CLIENT_MODE env var must not be set." + + # Start tp_client processes to register GPU blocks + # Even in direct mode, GPU blocks are registered via KVTPClient -> KVTaskEngine ZMQ + tp_client_processes = [] + + # Register cleanup handler to ensure processes are terminated on exit + def _cleanup(): + shutdown_tp_clients(tp_client_processes) + try: + kvmanager.shutdown() + except Exception: + pass + atexit.register(_cleanup) + + def _signal_handler(signum, frame): + print(f"\nReceived signal {signum}, shutting down...") + _cleanup() + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + for tp_rank in range(model_config.tp_size): + tp_process = Process( + target=run_tp_client, + args=(0, tp_rank, kvmanager.gpu_register_port, + model_config, cache_config, num_gpu_blocks), + daemon=True, + ) + tp_process.start() + tp_client_processes.append(tp_process) + + # Wait for system to be ready + print("\nWaiting for FlexKV to be ready (direct mode)...") + wait_start = time.time() + while not kvmanager.is_ready(): + time.sleep(1) + elapsed = time.time() - wait_start + if elapsed > 120: + print("ERROR: Timeout waiting for FlexKV to be ready (120s)") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + return + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + print(f" Still waiting... ({int(elapsed)}s)") + print(f"FlexKV is ready! (took {time.time() - wait_start:.1f}s)") + + try: + results = {} + + if bench_config.mode in ("single", "all", "put-only", "get-only"): + results["single_batch"] = benchmark_single_batch( + kvmanager, model_config, cache_config, bench_config + ) + + if bench_config.mode in ("multiturn", "all"): + results["multiturn"] = benchmark_multiturn( + kvmanager, model_config, cache_config, bench_config + ) + + # Print summary + print("\n" + "=" * 60) + print(" Benchmark Summary (Direct Mode)") + print("=" * 60) + for name, result in results.items(): + print(f"\n [{name}]") + for k, v in result.items(): + if isinstance(v, float): + print(f" {k}: {v:.4f}") + else: + print(f" {k}: {v}") + + # In put-only mode, keep the process alive so other nodes can GET the data + if bench_config.mode == "put-only": + print("\n" + "-" * 60) + print("Data published to Redis. Press Enter to shutdown " + "(keep running for other nodes to GET)...") + print("-" * 60) + try: + input() + except EOFError: + print("Non-interactive mode detected. Sleeping indefinitely (Ctrl+C to stop)...") + while True: + time.sleep(1) + + finally: + print("\nShutting down...") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + atexit.unregister(_cleanup) + print("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark FlexKV distributed KVCache in direct mode (non-server_client_mode)" + ) + parser.add_argument("--config", type=str, + default="benchmarks/dist_benchmark/example_dist_direct_config.yml", + help="Path to config YAML file") + parser.add_argument("--mode", type=str, default="all", + choices=["single", "multiturn", "all", "put-only", "get-only"], + help="Benchmark mode: single, multiturn, all, put-only, get-only") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for deterministic token generation (for cross-node benchmarks)") + + # Single batch params + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for single batch benchmark") + parser.add_argument("--sequence-length", type=int, default=1024, help="Sequence length per request") + parser.add_argument("--cache-ratio", type=float, default=1.0, help="Ratio of tokens to cache in PUT phase") + parser.add_argument("--clear-cpu-cache", action="store_true", help="Clear CPU cache between PUT and GET") + + # Multi-turn params + parser.add_argument("--num-users", type=int, default=10, help="Number of simulated users") + parser.add_argument("--num-turns", type=int, default=3, help="Number of conversation turns per user") + parser.add_argument("--system-prompt-length", type=int, default=100, help="System prompt length in tokens") + parser.add_argument("--input-length", type=int, default=512, help="Input length per turn in tokens") + parser.add_argument("--output-length", type=int, default=64, help="Output length per turn in tokens") + + # Cross-node sync params + parser.add_argument("--rebuild-interval-ms", type=int, default=None, + help="Override FLEXKV_REBUILD_INTERVAL_MS (default: use env or 100). " + "Recommended: 20 for cross-node benchmarks") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/dist_benchmark/benchmark_dist_kvcache.py b/benchmarks/dist_benchmark/benchmark_dist_kvcache.py new file mode 100644 index 0000000000..26e6c84426 --- /dev/null +++ b/benchmarks/dist_benchmark/benchmark_dist_kvcache.py @@ -0,0 +1,637 @@ +""" +Benchmark for FlexKV distributed KVCache in server_client_mode. + +This script tests the put/get performance of FlexKV when running in +server_client_mode with distributed KVCache sharing enabled (enable_p2p_cpu). + +Prerequisites: + - A running Redis server (default: 127.0.0.1:6379) + - At least 1 GPU available + - FlexKV built with distributed support (FLEXKV_ENABLE_P2P=1) + +Usage: + # Basic usage with default config + python benchmarks/benchmark_dist_kvcache.py --config benchmarks/example_dist_config.yml + + # Custom parameters + python benchmarks/benchmark_dist_kvcache.py \ + --config benchmarks/example_dist_config.yml \ + --batch-size 4 \ + --sequence-length 2048 \ + --cache-ratio 0.5 \ + --num-users 10 \ + --num-turns 3 + + # Multi-turn conversation benchmark only + python benchmarks/benchmark_dist_kvcache.py \\ + --config benchmarks/example_dist_config.yml \\ + --mode multiturn \\ + --num-users 20 \\ + --num-turns 5 + + # Cross-node benchmark: Node A (PUT only) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_a.yml --seed 42 --mode put-only + + # Cross-node benchmark: Node B (GET only, same seed) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_b.yml --seed 42 --mode get-only +""" +import os +import atexit +import signal +import argparse +import json +import tempfile +import time +from multiprocessing import Process +from dataclasses import dataclass + +import torch +import numpy as np + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ( + ModelConfig, CacheConfig, UserConfig, + update_default_config_from_user_config, parse_path_list, + GLOBAL_CONFIG_FROM_ENV, +) +from flexkv.common.debug import flexkv_logger +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +from utils import generate_random_multiturn + +flexkv_logger.set_level("INFO") + + +def load_dist_config(config_path: str): + """Load config with distributed KVCache support. + + Extends the standard load_config to handle distributed-specific fields: + enable_p2p_cpu, enable_p2p_ssd, enable_3rd_remote, + redis_host, redis_port, local_ip, redis_password, + server_client_mode, etc. + """ + import yaml + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(f"Loaded config: {config}") + + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + + # Model config + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + # Cache size config + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + + # Distributed KVCache config + if "enable_p2p_cpu" in config: + user_config.enable_p2p_cpu = config["enable_p2p_cpu"] + if "enable_p2p_ssd" in config: + user_config.enable_p2p_ssd = config["enable_p2p_ssd"] + if "enable_3rd_remote" in config: + user_config.enable_3rd_remote = config["enable_3rd_remote"] + + # Redis config + if "redis_host" in config: + user_config.redis_host = config["redis_host"] + if "redis_port" in config: + user_config.redis_port = config["redis_port"] + if "local_ip" in config: + user_config.local_ip = config["local_ip"] + if "redis_password" in config: + user_config.redis_password = config["redis_password"] + + # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled + if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): + if "MOONCAKE_CONFIG_PATH" not in os.environ: + mooncake_config = { + "engine_ip": config.get("mooncake_engine_ip", config.get("local_ip", "127.0.0.1")), + "engine_port": config.get("mooncake_engine_port", 5555), + "metadata_backend": config.get("mooncake_metadata_backend", "redis"), + "metadata_server": config.get("mooncake_metadata_server", + f"redis://{config.get('redis_host', '127.0.0.1')}:{config.get('redis_port', 6379)}"), + "metadata_server_auth": config.get("mooncake_metadata_server_auth", + config.get("redis_password", "")), + "protocol": config.get("mooncake_protocol", "tcp"), + "device_name": config.get("mooncake_device_name", ""), + } + # Write to a temp file that persists until process exits + mooncake_config_fd, mooncake_config_path = tempfile.mkstemp( + suffix=".json", prefix="mooncake_config_" + ) + with os.fdopen(mooncake_config_fd, "w") as f: + json.dump(mooncake_config, f, indent=2) + os.environ["MOONCAKE_CONFIG_PATH"] = mooncake_config_path + print(f"[INFO] Auto-generated mooncake config at: {mooncake_config_path}") + print(f"[INFO] Mooncake config: {json.dumps(mooncake_config, indent=2)}") + else: + mooncake_config_path = os.environ['MOONCAKE_CONFIG_PATH'] + print(f"[INFO] Using existing MOONCAKE_CONFIG_PATH: {mooncake_config_path}") + + # Store mooncake_config_path in cache_config so it survives spawn subprocesses via pickle + cache_config.mooncake_config_path = mooncake_config_path + + update_default_config_from_user_config(model_config, cache_config, user_config) + + # Handle server_client_mode from config + if config.get("server_client_mode", False): + os.environ["FLEXKV_SERVER_CLIENT_MODE"] = "1" + GLOBAL_CONFIG_FROM_ENV.server_client_mode = True + + return model_config, cache_config + + +@dataclass +class BenchmarkConfig: + # Single batch benchmark params + batch_size: int = 1 + sequence_length: int = 1024 + cache_ratio: float = 1.0 + clear_cpu_cache: bool = False + + # Multi-turn benchmark params + num_users: int = 10 + num_turns: int = 3 + system_prompt_length: int = 100 + input_length: int = 512 + output_length: int = 64 + + # General + mode: str = "all" # "single", "multiturn", "all", "put-only", "get-only" + seed: int = None # Random seed for deterministic token generation (cross-node) + + +def run_tp_client(dp_client_id, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks): + """Run tp_client process to register GPU blocks""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(gpu_register_port, dp_client_id, device_id) + + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Keep the process running + while True: + time.sleep(1) + + +def shutdown_tp_clients(tp_client_processes): + """Terminate all tp_client processes""" + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + + +def benchmark_single_batch(kvmanager, model_config, cache_config, bench_config): + """Benchmark single batch put/get with distributed KVCache""" + print("\n" + "=" * 60) + print(" Single Batch Benchmark (Distributed KVCache)") + print("=" * 60) + + sequence_length = bench_config.sequence_length + batch_size = bench_config.batch_size + cache_length = int(sequence_length * bench_config.cache_ratio) + + print(f" batch_size={batch_size}, sequence_length={sequence_length}, " + f"cache_ratio={bench_config.cache_ratio}, cache_length={cache_length}") + if bench_config.seed is not None: + print(f" seed={bench_config.seed}") + + # Generate random sequences (use seed for deterministic cross-node benchmarks) + if bench_config.seed is not None: + torch.manual_seed(bench_config.seed) + batch_sequence_tensor = [] + batch_slot_mapping = [] + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length,), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i + 1) * sequence_length, dtype=torch.int64)) + + results = {} + skip_put = (bench_config.mode == "get-only") + skip_get = (bench_config.mode == "put-only") + + # In get-only mode, wait for remote index to be refreshed from Redis + if skip_put: + rebuild_interval_ms = int(os.environ.get("FLEXKV_REBUILD_INTERVAL_MS", "100")) + # Wait at least 3x rebuild_interval to ensure at least one full refresh cycle + wait_time_s = max(rebuild_interval_ms * 3 / 1000.0, 0.5) + print(f" Waiting {wait_time_s:.2f}s for remote index refresh " + f"(FLEXKV_REBUILD_INTERVAL_MS={rebuild_interval_ms})...") + time.sleep(wait_time_s) + + # ---- Benchmark PUT ---- + if not skip_put: + print("\n--- PUT Phase ---") + start_time = time.time() + batch_put_ids = [] + if bench_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async( + batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None, + ) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put if elapsed_time_put > 0 else 0 + print(f" PUT: {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put * 1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + results.update({ + "put_tokens": put_tokens, + "put_time_ms": elapsed_time_put * 1000, + "put_bandwidth_GBs": transfer_bandwidth_put, + }) + else: + print("\n--- PUT Phase SKIPPED (get-only mode) ---") + + if bench_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + # ---- Benchmark GET ---- + if not skip_get: + print("\n--- GET Phase ---") + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get if elapsed_time_get > 0 else 0 + print(f" GET: {cached_tokens}/{all_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time * 1000:.2f}ms, " + f"e2e time: {elapsed_time_get * 1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + results.update({ + "get_cached_tokens": cached_tokens, + "get_total_tokens": all_tokens, + "get_cache_ratio": cached_tokens / all_tokens if all_tokens > 0 else 0, + "get_match_time_ms": get_match_time * 1000, + "get_e2e_time_ms": elapsed_time_get * 1000, + "get_bandwidth_GBs": transfer_bandwidth_get, + }) + else: + print("\n--- GET Phase SKIPPED (put-only mode) ---") + + return results + + +def benchmark_multiturn(kvmanager, model_config, cache_config, bench_config): + """Benchmark multi-turn conversation with distributed KVCache""" + print("\n" + "=" * 60) + print(" Multi-Turn Conversation Benchmark (Distributed KVCache)") + print("=" * 60) + print(f" num_users={bench_config.num_users}, num_turns={bench_config.num_turns}, " + f"system_prompt_length={bench_config.system_prompt_length}, " + f"input_length={bench_config.input_length}, output_length={bench_config.output_length}") + + # Generate multi-turn requests + reqs = generate_random_multiturn( + num_user_requests=bench_config.num_users, + num_turns=bench_config.num_turns, + system_prompt_length=bench_config.system_prompt_length, + input_length=bench_config.input_length, + output_length=bench_config.output_length, + seed=bench_config.seed, + ) + + total_get_requests = 0 + total_put_requests = 0 + cache_hit_ratios = [] + total_put_time = 0 + total_get_time = 0 + total_put_tokens = 0 + total_get_cached_tokens = 0 + total_get_all_tokens = 0 + + request_id = 0 + for req in reqs: + fake_slot_mapping = torch.arange(req.token_mask.sum(), dtype=torch.int64) + + if req.request_type == "get": + total_get_requests += 1 + total_get_all_tokens += req.token_mask.sum().item() + + start_time = time.time() + task_id, _ = kvmanager.get_match( + req.token_ids, + token_mask=torch.ones_like(torch.from_numpy(req.token_ids) if isinstance(req.token_ids, np.ndarray) else req.token_ids), + ) + kvmanager.launch([task_id], [fake_slot_mapping.numpy()]) + result = kvmanager.wait([task_id]) + elapsed = time.time() - start_time + total_get_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + cached = response.return_mask.sum().item() + total_get_cached_tokens += cached + ratio = cached / req.token_mask.sum().item() + cache_hit_ratios.append(ratio) + else: + cache_hit_ratios.append(0.0) + + elif req.request_type == "put": + total_put_requests += 1 + + start_time = time.time() + task_id = kvmanager.put_async( + req.token_ids, + fake_slot_mapping.numpy(), + token_mask=None, + ) + result = kvmanager.wait([task_id], completely=True) + elapsed = time.time() - start_time + total_put_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + total_put_tokens += response.return_mask.sum().item() + + request_id += 1 + + # Print results + print(f"\n--- Results ---") + print(f" Total requests: {len(reqs)} (GET: {total_get_requests}, PUT: {total_put_requests})") + print(f" PUT: {total_put_tokens} tokens, total time: {total_put_time * 1000:.2f}ms, " + f"avg time: {total_put_time * 1000 / max(total_put_requests, 1):.2f}ms/req") + print(f" GET: {total_get_cached_tokens}/{total_get_all_tokens} tokens cached, " + f"total time: {total_get_time * 1000:.2f}ms, " + f"avg time: {total_get_time * 1000 / max(total_get_requests, 1):.2f}ms/req") + + if cache_hit_ratios: + sorted_ratios = sorted(cache_hit_ratios) + avg_ratio = sum(sorted_ratios) / len(sorted_ratios) + print(f" Cache hit ratio: avg={avg_ratio * 100:.2f}%, " + f"min={sorted_ratios[0] * 100:.2f}%, " + f"median={sorted_ratios[len(sorted_ratios) // 2] * 100:.2f}%, " + f"max={sorted_ratios[-1] * 100:.2f}%") + + return { + "total_requests": len(reqs), + "get_requests": total_get_requests, + "put_requests": total_put_requests, + "put_tokens": total_put_tokens, + "put_total_time_ms": total_put_time * 1000, + "get_cached_tokens": total_get_cached_tokens, + "get_total_tokens": total_get_all_tokens, + "get_total_time_ms": total_get_time * 1000, + "avg_cache_hit_ratio": sum(cache_hit_ratios) / len(cache_hit_ratios) if cache_hit_ratios else 0, + } + + +def main(args): + # Set FLEXKV_REBUILD_INTERVAL_MS for faster cross-node index sync + # NOTE: Must set env var AND update GLOBAL_CONFIG_FROM_ENV because + # GLOBAL_CONFIG_FROM_ENV is evaluated at module import time (before main runs). + # The env var alone is not enough since the Namespace is already frozen. + if args.rebuild_interval_ms is not None: + os.environ["FLEXKV_REBUILD_INTERVAL_MS"] = str(args.rebuild_interval_ms) + GLOBAL_CONFIG_FROM_ENV.rebuild_interval_ms = args.rebuild_interval_ms + print(f"[INFO] Set FLEXKV_REBUILD_INTERVAL_MS={args.rebuild_interval_ms}") + + # Load config + model_config, cache_config = load_dist_config(args.config) + + bench_config = BenchmarkConfig( + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache, + num_users=args.num_users, + num_turns=args.num_turns, + system_prompt_length=args.system_prompt_length, + input_length=args.input_length, + output_length=args.output_length, + mode=args.mode, + seed=args.seed, + ) + + # Pad sequence length to be divisible by tokens_per_block + bench_config.sequence_length = ( + ((bench_config.sequence_length - 1) // cache_config.tokens_per_block + 1) + * cache_config.tokens_per_block + ) + + num_gpu_blocks = bench_config.sequence_length * bench_config.batch_size // cache_config.tokens_per_block + # Ensure enough GPU blocks for multi-turn mode too + if bench_config.mode in ("multiturn", "all", "put-only", "get-only"): + max_tokens_per_user = ( + bench_config.system_prompt_length + + bench_config.num_turns * (bench_config.input_length + bench_config.output_length) + ) + multiturn_blocks = max_tokens_per_user * bench_config.num_users // cache_config.tokens_per_block + num_gpu_blocks = max(num_gpu_blocks, multiturn_blocks) + # Add some extra blocks for safety + num_gpu_blocks = int(num_gpu_blocks * 1.5) + 64 + + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError( + f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} > " + f"available GPUs {torch.cuda.device_count()}" + ) + + print("=" * 60) + print(" FlexKV Distributed KVCache Benchmark (server_client_mode)") + print("=" * 60) + print(f" model_config: {model_config}") + print(f" cache_config: {cache_config}") + print(f" enable_kv_sharing: {cache_config.enable_kv_sharing}") + print(f" enable_p2p_cpu: {cache_config.enable_p2p_cpu}") + print(f" redis: {cache_config.redis_host}:{cache_config.redis_port}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" bench_config: {bench_config}") + + # Create KVManager (this will start KVServer in server_client_mode) + kvmanager = KVManager(model_config, cache_config) + kvmanager.start() + + # Start tp_client processes to register GPU blocks + tp_client_processes = [] + + # Register cleanup handler to ensure processes are terminated on exit + def _cleanup(): + shutdown_tp_clients(tp_client_processes) + try: + kvmanager.shutdown() + except Exception: + pass + atexit.register(_cleanup) + + def _signal_handler(signum, frame): + print(f"\nReceived signal {signum}, shutting down...") + _cleanup() + # Re-raise to allow default handler + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + for tp_rank in range(model_config.tp_size): + tp_process = Process( + target=run_tp_client, + args=(0, tp_rank, kvmanager.gpu_register_port, + model_config, cache_config, num_gpu_blocks), + daemon=True, + ) + tp_process.start() + tp_client_processes.append(tp_process) + + # Wait for system to be ready + print("\nWaiting for FlexKV to be ready...") + wait_start = time.time() + while not kvmanager.is_ready(): + time.sleep(1) + elapsed = time.time() - wait_start + if elapsed > 120: + print("ERROR: Timeout waiting for FlexKV to be ready (120s)") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + return + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + print(f" Still waiting... ({int(elapsed)}s)") + print(f"FlexKV is ready! (took {time.time() - wait_start:.1f}s)") + + try: + results = {} + + if bench_config.mode in ("single", "all", "put-only", "get-only"): + results["single_batch"] = benchmark_single_batch( + kvmanager, model_config, cache_config, bench_config + ) + + if bench_config.mode in ("multiturn", "all"): + results["multiturn"] = benchmark_multiturn( + kvmanager, model_config, cache_config, bench_config + ) + + # Print summary + print("\n" + "=" * 60) + print(" Benchmark Summary") + print("=" * 60) + for name, result in results.items(): + print(f"\n [{name}]") + for k, v in result.items(): + if isinstance(v, float): + print(f" {k}: {v:.4f}") + else: + print(f" {k}: {v}") + + # In put-only mode, keep the process alive so other nodes can GET the data + if bench_config.mode == "put-only": + print("\n" + "-" * 60) + print("Data published to Redis. Press Enter to shutdown " + "(keep running for other nodes to GET)...") + print("-" * 60) + try: + input() + except EOFError: + # Handle non-interactive environments + print("Non-interactive mode detected. Sleeping indefinitely (Ctrl+C to stop)...") + while True: + time.sleep(1) + + finally: + print("\nShutting down...") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + # Unregister atexit handler since we've already cleaned up + atexit.unregister(_cleanup) + print("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark FlexKV distributed KVCache in server_client_mode" + ) + parser.add_argument("--config", type=str, default="benchmarks/example_dist_config.yml", + help="Path to config YAML file") + parser.add_argument("--mode", type=str, default="all", + choices=["single", "multiturn", "all", "put-only", "get-only"], + help="Benchmark mode: single, multiturn, all, put-only, get-only") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for deterministic token generation (for cross-node benchmarks)") + + # Single batch params + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for single batch benchmark") + parser.add_argument("--sequence-length", type=int, default=1024, help="Sequence length per request") + parser.add_argument("--cache-ratio", type=float, default=1.0, help="Ratio of tokens to cache in PUT phase") + parser.add_argument("--clear-cpu-cache", action="store_true", help="Clear CPU cache between PUT and GET") + + # Multi-turn params + parser.add_argument("--num-users", type=int, default=10, help="Number of simulated users") + parser.add_argument("--num-turns", type=int, default=3, help="Number of conversation turns per user") + parser.add_argument("--system-prompt-length", type=int, default=100, help="System prompt length in tokens") + parser.add_argument("--input-length", type=int, default=512, help="Input length per turn in tokens") + parser.add_argument("--output-length", type=int, default=64, help="Output length per turn in tokens") + + # Cross-node sync params + parser.add_argument("--rebuild-interval-ms", type=int, default=None, + help="Override FLEXKV_REBUILD_INTERVAL_MS (default: use env or 100). " + "Recommended: 20 for cross-node benchmarks") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/dist_benchmark/example_dist_config.yml b/benchmarks/dist_benchmark/example_dist_config.yml new file mode 100644 index 0000000000..0f7a175344 --- /dev/null +++ b/benchmarks/dist_benchmark/example_dist_config.yml @@ -0,0 +1,34 @@ +# Distributed KVCache benchmark config (server_client_mode) +# Model config +num_layers: 4 +num_kv_heads: 8 +head_size: 128 +dtype: bfloat16 +use_mla: false +tp_size: 1 +dp_size: 1 +tokens_per_block: 16 + +# Cache config +cpu_cache_gb: 4 +ssd_cache_gb: 0 + +# Distributed KVCache config +enable_p2p_cpu: true + +# Redis config (for KV sharing metadata) +redis_host: "10.135.1.175" +redis_port: 6379 +redis_password: "123456" +local_ip: "10.135.1.176" + +# Mooncake Transfer Engine config (required for P2P) +mooncake_engine_ip: "10.135.1.176" +mooncake_engine_port: 5555 +mooncake_metadata_backend: "redis" +mooncake_metadata_server: "redis://10.135.1.175:6379" +mooncake_metadata_server_auth: "123456" +mooncake_protocol: "rdma" # "tcp" or "rdma" +mooncake_device_name: "mlx5_0,mlx5_1,mlx5_4,mlx5_5" # RDMA device name, e.g. "mlx5_0"; leave empty for tcp +# Force server_client_mode +server_client_mode: true diff --git a/benchmarks/dist_benchmark/example_dist_direct_config.yml b/benchmarks/dist_benchmark/example_dist_direct_config.yml new file mode 100644 index 0000000000..ae03da9855 --- /dev/null +++ b/benchmarks/dist_benchmark/example_dist_direct_config.yml @@ -0,0 +1,40 @@ +# Distributed KVCache benchmark config (direct mode, non-server_client_mode) +# In direct mode, KVManager creates KVTaskEngine directly in the main process +# without going through KVServer/KVDPClient IPC. This is simpler and has +# lower overhead, suitable for single-instance single-dp benchmarks. + +# Model config +num_layers: 4 +num_kv_heads: 8 +head_size: 128 +dtype: bfloat16 +use_mla: false +tp_size: 1 +dp_size: 1 +tokens_per_block: 16 + +# Cache config +cpu_cache_gb: 4 +ssd_cache_gb: 0 + +# Distributed KVCache config +enable_p2p_cpu: true + +# Redis config (for KV sharing metadata) +redis_host: "10.135.1.175" +redis_port: 6379 +redis_password: "123456" +local_ip: "10.135.1.176" + +# Mooncake Transfer Engine config (required for P2P) +mooncake_engine_ip: "10.135.1.176" +mooncake_engine_port: 5555 +mooncake_metadata_backend: "redis" +mooncake_metadata_server: "redis://10.135.1.175:6379" +mooncake_metadata_server_auth: "123456" +mooncake_protocol: "rdma" # "tcp" or "rdma" +mooncake_device_name: "mlx5_0,mlx5_1,mlx5_4,mlx5_5" # RDMA device name, e.g. "mlx5_0"; leave empty for tcp + +# Direct mode (non-server_client_mode) +# In direct mode, KVTaskEngine runs in the main process +server_client_mode: false diff --git a/benchmarks/dist_benchmark/redis_check.py b/benchmarks/dist_benchmark/redis_check.py new file mode 100644 index 0000000000..2c702056f6 --- /dev/null +++ b/benchmarks/dist_benchmark/redis_check.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +FlexKV Redis Data Inspector + +Check what data the put-only node has pushed to Redis. +This script inspects all FlexKV-related keys in Redis including: + - global:node_id (global node ID counter) + - node: (registered node info) + - meta: (node meta: mooncake engine addr, buffer ptrs) + - buffer::* (RDMA memory region registrations) + - CPUB:: (CPU KVCache block metadata - the actual cached data index) + - SSDB:: (SSD KVCache block metadata) + - PCFSB:: (PCFS remote KVCache block metadata) + - pcfs: (PCFS file node IDs) + - mooncake/* (Mooncake Transfer Engine metadata) + +Usage: + python benchmarks/redis_check.py [--host HOST] [--port PORT] [--password PWD] + + # With defaults from example_dist_config.yml: + python benchmarks/redis_check.py --host 10.135.1.175 --port 6379 --password 123456 +""" + +import argparse +import sys + +try: + import redis +except ImportError: + print("ERROR: redis-py is required. Install with: pip install redis") + sys.exit(1) + + +def connect_redis(host, port, password): + """Connect to Redis and verify connectivity.""" + r = redis.Redis( + host=host, port=port, + password=password if password else None, + decode_responses=True, + socket_connect_timeout=5, + ) + try: + r.ping() + print(f"✅ Connected to Redis at {host}:{port}") + except redis.ConnectionError as e: + print(f"❌ Failed to connect to Redis at {host}:{port}: {e}") + sys.exit(1) + return r + + +def scan_keys(r, pattern, count=1000): + """Scan Redis keys matching pattern (non-blocking).""" + keys = [] + cursor = 0 + while True: + cursor, batch = r.scan(cursor=cursor, match=pattern, count=count) + keys.extend(batch) + if cursor == 0: + break + return sorted(keys) + + +def check_global_node_id(r): + """Check the global node ID counter.""" + print("\n" + "=" * 60) + print(" 1. Global Node ID Counter") + print("=" * 60) + val = r.get("global:node_id") + if val is not None: + print(f" global:node_id = {val}") + print(f" → {val} node(s) have been registered in total") + else: + print(" ⚠️ global:node_id not found (no nodes registered yet)") + + +def check_registered_nodes(r): + """Check registered node information.""" + print("\n" + "=" * 60) + print(" 2. Registered Nodes (node:*)") + print("=" * 60) + keys = scan_keys(r, "node:*") + if not keys: + print(" ⚠️ No registered nodes found") + return + + print(f" Found {len(keys)} registered node(s):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + print(f" {field}: {value}") + print() + + +def check_node_meta(r): + """Check node meta information (mooncake engine addr, buffer ptrs).""" + print("\n" + "=" * 60) + print(" 3. Node Meta (meta:*)") + print("=" * 60) + keys = scan_keys(r, "meta:*") + if not keys: + print(" ⚠️ No node meta found") + print(" → This means PEER2CPUTransferWorker hasn't registered yet,") + print(" or mooncake transfer engine initialization failed.") + return + + print(f" Found {len(keys)} node meta entry(ies):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + # Format large integers (pointers) in hex for readability + if field in ("cpu_buffer_ptr", "ssd_buffer_ptr"): + try: + int_val = int(value) + print(f" {field}: {value} (0x{int_val:x})") + except (ValueError, TypeError): + print(f" {field}: {value}") + else: + print(f" {field}: {value}") + print() + + +def check_buffer_registrations(r): + """Check RDMA buffer registrations.""" + print("\n" + "=" * 60) + print(" 4. RDMA Buffer Registrations (buffer:*)") + print("=" * 60) + keys = scan_keys(r, "buffer:*") + if not keys: + print(" ⚠️ No RDMA buffer registrations found") + return + + print(f" Found {len(keys)} buffer registration(s):\n") + for key in keys: + data = r.hgetall(key) + buf_size = data.get("buffer_size", "?") + try: + size_mb = int(buf_size) / (1024 * 1024) + print(f" 📌 {key}: size={buf_size} bytes ({size_mb:.2f} MB)") + except (ValueError, TypeError): + print(f" 📌 {key}: size={buf_size}") + + +def check_block_metadata(r): + """Check KVCache block metadata - this is the core data from put operations. + + FlexKV uses different key prefixes for different device types: + - CPUB:: — CPU block metadata (P2P CPU sharing) + - SSDB:: — SSD block metadata (P2P SSD sharing) + - PCFSB:: — PCFS remote block metadata + Each key is a Redis hash with fields: ph, pb, nid, hash, lt, state. + """ + print("\n" + "=" * 60) + print(" 5. KVCache Block Metadata (CPUB/SSDB/PCFSB)") + print("=" * 60) + + # FlexKV actual block key prefixes (set in hie_cache_engine.py) + block_prefixes = { + "CPUB": "CPU", + "SSDB": "SSD", + "PCFSB": "PCFS (Remote)", + } + + grand_total = 0 + for prefix, label in block_prefixes.items(): + keys = scan_keys(r, f"{prefix}:*") + if not keys: + print(f"\n [{label}] {prefix}:* — no entries found") + continue + + grand_total += len(keys) + + # Group by node_id: key format is PREFIX:: + node_blocks = {} + for key in keys: + parts = key.split(":") + if len(parts) >= 2: + node_id = parts[1] + if node_id not in node_blocks: + node_blocks[node_id] = [] + node_blocks[node_id].append(key) + + print(f"\n [{label}] {prefix}:* — {len(keys)} block(s) across {len(node_blocks)} node(s):") + + for node_id, block_keys in sorted(node_blocks.items(), key=lambda x: int(x[0]) if x[0].isdigit() else 0): + print(f" 📌 Node {node_id}: {len(block_keys)} block(s)") + + # Show first few blocks as samples + sample_count = min(3, len(block_keys)) + for key in block_keys[:sample_count]: + data = r.hgetall(key) + if data: + # BlockMeta fields: ph (physical hash), pb (physical block), + # nid (node id), hash, lt (lease time), state + ph = data.get("ph", "?") + pb = data.get("pb", "?") + nid = data.get("nid", "?") + hash_val = data.get("hash", "?") + lt = data.get("lt", "?") + state = data.get("state", "?") + print(f" {key}: ph={ph}, pb={pb}, nid={nid}, hash={hash_val}, lt={lt}, state={state}") + else: + key_type = r.type(key) + print(f" {key}: type={key_type}, (empty hash)") + + if len(block_keys) > sample_count: + print(f" ... and {len(block_keys) - sample_count} more block(s)") + + if grand_total == 0: + print("\n ⚠️ No block metadata found in any prefix (CPUB/SSDB/PCFSB)") + print(" → This means no KVCache data has been published to Redis yet.") + print(" The put-only node may still be uploading, or the upload") + print(" interval (rebuild_interval_ms) hasn't elapsed yet.") + else: + print(f"\n ✅ Total block metadata entries: {grand_total}") + + +def check_pcfs_data(r): + """Check PCFS file node IDs.""" + print("\n" + "=" * 60) + print(" 6. PCFS File Node IDs (pcfs:*)") + print("=" * 60) + keys = scan_keys(r, "pcfs:*") + if not keys: + print(" (none found - this is normal if PCFS sharing is not used)") + return + + print(f" Found {len(keys)} PCFS entry(ies):\n") + for key in keys: + values = r.lrange(key, 0, -1) + print(f" 📌 {key}: {len(values)} file node ID(s)") + if values: + sample = values[:10] + print(f" sample: {sample}") + if len(values) > 10: + print(f" ... and {len(values) - 10} more") + + +def check_mooncake_keys(r): + """Check Mooncake Transfer Engine related keys.""" + print("\n" + "=" * 60) + print(" 7. Mooncake Transfer Engine Keys") + print("=" * 60) + # Mooncake uses Redis as metadata backend, keys may vary + # Common patterns: segment info, endpoint info + patterns = ["mooncake/*", "mooncake:*", "segment:*", "endpoint:*", "mc:*"] + found_any = False + for pattern in patterns: + keys = scan_keys(r, pattern) + if keys: + found_any = True + print(f"\n Pattern '{pattern}': {len(keys)} key(s)") + for key in keys[:10]: + key_type = r.type(key) + if key_type == "hash": + data = r.hgetall(key) + print(f" 📌 {key} (hash): {data}") + elif key_type == "string": + val = r.get(key) + if val and len(val) > 200: + print(f" 📌 {key} (string): {val[:200]}...") + else: + print(f" 📌 {key} (string): {val}") + elif key_type == "set": + members = r.smembers(key) + print(f" 📌 {key} (set): {members}") + elif key_type == "list": + vals = r.lrange(key, 0, 9) + print(f" 📌 {key} (list): {vals}") + else: + print(f" 📌 {key} (type={key_type})") + if len(keys) > 10: + print(f" ... and {len(keys) - 10} more") + + if not found_any: + print(" (no mooncake-specific keys found)") + + +def check_all_keys_summary(r): + """Show a summary of ALL keys in Redis grouped by prefix.""" + print("\n" + "=" * 60) + print(" 8. All Keys Summary") + print("=" * 60) + all_keys = scan_keys(r, "*") + if not all_keys: + print(" ⚠️ Redis is completely empty!") + return + + print(f" Total keys in Redis: {len(all_keys)}\n") + + # Group by prefix (first part before ':') + prefix_counts = {} + for key in all_keys: + prefix = key.split(":")[0] if ":" in key else key + prefix_counts[prefix] = prefix_counts.get(prefix, 0) + 1 + + print(f" {'Prefix':<30} {'Count':>8}") + print(f" {'-'*30} {'-'*8}") + for prefix, count in sorted(prefix_counts.items(), key=lambda x: -x[1]): + print(f" {prefix:<30} {count:>8}") + + +def main(): + parser = argparse.ArgumentParser( + description="FlexKV Redis Data Inspector - Check put-only node data" + ) + parser.add_argument("--host", type=str, default="10.135.1.175", + help="Redis host (default: 10.135.1.175)") + parser.add_argument("--port", type=int, default=6379, + help="Redis port (default: 6379)") + parser.add_argument("--password", type=str, default="123456", + help="Redis password (default: 123456)") + args = parser.parse_args() + + print("=" * 60) + print(" FlexKV Redis Data Inspector") + print("=" * 60) + print(f" Target: {args.host}:{args.port}") + + r = connect_redis(args.host, args.port, args.password) + + check_global_node_id(r) + check_registered_nodes(r) + check_node_meta(r) + check_buffer_registrations(r) + check_block_metadata(r) + check_pcfs_data(r) + check_mooncake_keys(r) + check_all_keys_summary(r) + + print("\n" + "=" * 60) + print(" Inspection Complete") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dist_benchmark/run_dist_benchmark.sh b/benchmarks/dist_benchmark/run_dist_benchmark.sh new file mode 100755 index 0000000000..8e64fdac22 --- /dev/null +++ b/benchmarks/dist_benchmark/run_dist_benchmark.sh @@ -0,0 +1,405 @@ +#!/bin/bash +# ============================================================================= +# FlexKV Distributed KVCache Benchmark - One-Click Launch Script +# +# This script handles: +# 1. Check and start Redis server if not running +# 2. Set up environment variables +# 3. Run the distributed KVCache benchmark +# +# Usage: +# bash benchmarks/run_dist_benchmark.sh [options] +# +# Options (passed through to benchmark_dist_kvcache.py): +# --config Config YAML file (default: benchmarks/example_dist_config.yml) +# --mode Benchmark mode: single, multiturn, or all (default: all) +# --batch-size Batch size (default: 1) +# --sequence-length Sequence length (default: 1024) +# --num-users Number of simulated users (default: 10) +# --num-turns Number of conversation turns (default: 3) +# --clean-redis Clean up FlexKV & Mooncake residual data in Redis before running benchmark +# (removes node:*, meta:*, CPUB:block:*, SSDB:block:*, PCFSB:block:*, +# mooncake/*, mooncake:*, segment:*, endpoint:*, mc:* keys) +# --clean-redis-only Clean up FlexKV & Mooncake residual data in Redis and exit (no benchmark) +# +# Examples: +# # Run with defaults +# bash benchmarks/run_dist_benchmark.sh +# +# # Custom parameters +# bash benchmarks/run_dist_benchmark.sh --batch-size 4 --sequence-length 2048 +# +# # Multi-turn only +# bash benchmarks/run_dist_benchmark.sh --mode multiturn --num-users 20 --num-turns 5 +# +# # Clean Redis residual data before benchmark +# bash benchmarks/run_dist_benchmark.sh --clean-redis +# +# # Only clean Redis residual data (no benchmark) +# bash benchmarks/run_dist_benchmark.sh --clean-redis-only +# ============================================================================= + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Default config file +CONFIG_FILE="${SCRIPT_DIR}/example_dist_config.yml" +REDIS_STARTED_BY_US=false +CLEAN_REDIS=false +CLEAN_REDIS_ONLY=false + +# Parse script-specific arguments and --config, pass the rest through to benchmark +BENCH_ARGS=() +prev_arg="" +for arg in "$@"; do + if [[ "$prev_arg" == "--config" ]]; then + CONFIG_FILE="$arg" + BENCH_ARGS+=("$arg") + prev_arg="$arg" + continue + fi + case "$arg" in + --clean-redis) + CLEAN_REDIS=true + ;; + --clean-redis-only) + CLEAN_REDIS=true + CLEAN_REDIS_ONLY=true + ;; + *) + BENCH_ARGS+=("$arg") + ;; + esac + prev_arg="$arg" +done + +# ============================================ +# Step 1: Parse Redis config from YAML +# ============================================ +info "============================================" +info "Step 1: Parsing configuration" +info "============================================" + +# Helper function to parse a YAML value using Python (handles comments, quotes, etc. correctly) +# Usage: parse_yaml_value [default] +parse_yaml_value() { + local key="$1" file="$2" default="${3:-}" + local val + val=$(python3 -c " +import yaml, sys +with open('$file') as f: + d = yaml.safe_load(f) +v = d.get('$key') +if v is None: + print('$default') +else: + print(v) +" 2>/dev/null) || val="$default" + echo "$val" +} + +# Simple YAML parser for redis config +REDIS_HOST=$(parse_yaml_value "redis_host" "$CONFIG_FILE" "127.0.0.1") +REDIS_PORT=$(parse_yaml_value "redis_port" "$CONFIG_FILE" "6379") +REDIS_PASSWORD=$(parse_yaml_value "redis_password" "$CONFIG_FILE" "") + +info "Config file: ${CONFIG_FILE}" +info "Redis: ${REDIS_HOST}:${REDIS_PORT}" + +# ============================================ +# Step 2: Check and start Redis +# ============================================ +info "============================================" +info "Step 2: Checking Redis server" +info "============================================" + +check_redis() { + local auth_args="" + if [[ -n "$REDIS_PASSWORD" ]]; then + auth_args="-a $REDIS_PASSWORD" + fi + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $auth_args ping 2>/dev/null | grep -q "PONG" +} + +# Build redis-cli auth arguments (reused across the script) +REDIS_AUTH_ARGS="" +if [[ -n "$REDIS_PASSWORD" ]]; then + REDIS_AUTH_ARGS="-a $REDIS_PASSWORD" +fi + +if check_redis; then + ok "Redis is already running at ${REDIS_HOST}:${REDIS_PORT}" +else + warn "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + + # Only try to start Redis if it's localhost + if [[ "$REDIS_HOST" == "127.0.0.1" ]] || [[ "$REDIS_HOST" == "localhost" ]]; then + if command -v redis-server &>/dev/null; then + info "Starting Redis server on port ${REDIS_PORT}..." + redis-server --port "$REDIS_PORT" --daemonize yes --save "" --appendonly no \ + --protected-mode no --loglevel warning + sleep 1 + + if check_redis; then + ok "Redis server started successfully" + REDIS_STARTED_BY_US=true + else + error "Failed to start Redis server" + error "Please install Redis: sudo apt install redis-server" + exit 1 + fi + else + error "redis-server not found. Please install Redis:" + error " sudo apt install redis-server" + exit 1 + fi + else + error "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + error "Please start Redis on the remote host first." + exit 1 + fi +fi + +# ============================================ +# Step 2.5: Clean FlexKV residual data in Redis (if requested) +# ============================================ +if [[ "$CLEAN_REDIS" == "true" ]]; then + info "============================================" + info "Cleaning FlexKV residual data in Redis" + info "============================================" + + clean_redis_keys() { + local pattern="$1" + local count=0 + local cursor=0 + while true; do + local result + result=$(redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS SCAN $cursor MATCH "$pattern" COUNT 500 2>/dev/null) + cursor=$(echo "$result" | head -1) + local keys + keys=$(echo "$result" | tail -n +2) + if [[ -n "$keys" ]]; then + local batch_keys + batch_keys=$(echo "$keys" | tr '\n' ' ') + if [[ -n "$batch_keys" ]]; then + local deleted + deleted=$(echo "$batch_keys" | xargs redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS DEL 2>/dev/null) + count=$((count + deleted)) + fi + fi + if [[ "$cursor" == "0" ]]; then + break + fi + done + echo "$count" + } + + total_deleted=0 + + # Clean node:* keys + n=$(clean_redis_keys "node:*") + [[ $n -gt 0 ]] && info "Deleted $n node:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean meta:* keys + n=$(clean_redis_keys "meta:*") + [[ $n -gt 0 ]] && info "Deleted $n meta:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean CPUB:block:* keys + n=$(clean_redis_keys "CPUB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n CPUB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean SSDB:block:* keys + n=$(clean_redis_keys "SSDB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n SSDB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean PCFSB:block:* keys + n=$(clean_redis_keys "PCFSB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n PCFSB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean Mooncake Transfer Engine residual keys + # Mooncake uses Redis as metadata backend to store segment/endpoint info + for mc_pattern in "mooncake/*" "mooncake:*" "segment:*" "endpoint:*" "mc:*"; do + n=$(clean_redis_keys "$mc_pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${mc_pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + if [[ $total_deleted -gt 0 ]]; then + ok "Cleaned $total_deleted FlexKV & Mooncake residual key(s) from Redis" + else + ok "No FlexKV residual data found in Redis" + fi + + if [[ "$CLEAN_REDIS_ONLY" == "true" ]]; then + ok "Clean-only mode, exiting." + exit 0 + fi +fi + +# ============================================ +# Step 3: Set up environment +# ============================================ +info "============================================" +info "Step 3: Setting up environment" +info "============================================" + +# Detect Python (prefer virtual env) +if [[ -n "$VIRTUAL_ENV" ]]; then + # Prefer 'which python3' to get the actual resolved path in the activated venv, + # because $VIRTUAL_ENV may point to a path that doesn't match the real filesystem + # (e.g. symlinks, home dir aliases like ~ vs /data1/home). + if command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + else + PYTHON="$VIRTUAL_ENV/bin/python3" + fi + if [[ ! -x "$PYTHON" ]]; then + error "Python3 not found at $PYTHON (VIRTUAL_ENV=$VIRTUAL_ENV)" + exit 1 + fi + info "Using virtual env Python: $PYTHON" +elif command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + info "Using system Python: $PYTHON" +else + error "Python3 not found!" + exit 1 +fi + +# Set PYTHONPATH to include project root +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# Set LD_LIBRARY_PATH for C++ libraries +if [[ -d "${PROJECT_ROOT}/build" ]]; then + export LD_LIBRARY_PATH="${PROJECT_ROOT}/build:${LD_LIBRARY_PATH:-}" +fi + +info "PYTHONPATH=${PYTHONPATH}" +info "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" + +# Generate mooncake config JSON and export MOONCAKE_CONFIG_PATH if P2P is enabled +ENABLE_P2P_CPU=$(grep -E "^enable_p2p_cpu:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") +ENABLE_P2P_SSD=$(grep -E "^enable_p2p_ssd:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") + +if [[ "$ENABLE_P2P_CPU" == "true" ]] || [[ "$ENABLE_P2P_SSD" == "true" ]]; then + if [[ -z "${MOONCAKE_CONFIG_PATH:-}" ]]; then + info "P2P enabled, generating mooncake config..." + + # Parse mooncake config from YAML using helper function + MC_ENGINE_IP=$(parse_yaml_value "mooncake_engine_ip" "$CONFIG_FILE") + MC_ENGINE_PORT=$(parse_yaml_value "mooncake_engine_port" "$CONFIG_FILE") + MC_METADATA_BACKEND=$(parse_yaml_value "mooncake_metadata_backend" "$CONFIG_FILE") + MC_METADATA_SERVER=$(parse_yaml_value "mooncake_metadata_server" "$CONFIG_FILE") + MC_METADATA_SERVER_AUTH=$(parse_yaml_value "mooncake_metadata_server_auth" "$CONFIG_FILE") + MC_PROTOCOL=$(parse_yaml_value "mooncake_protocol" "$CONFIG_FILE") + MC_DEVICE_NAME=$(parse_yaml_value "mooncake_device_name" "$CONFIG_FILE") + LOCAL_IP=$(parse_yaml_value "local_ip" "$CONFIG_FILE" "127.0.0.1") + + # Use defaults if not specified + MC_ENGINE_IP="${MC_ENGINE_IP:-$LOCAL_IP}" + MC_ENGINE_PORT="${MC_ENGINE_PORT:-5555}" + MC_METADATA_BACKEND="${MC_METADATA_BACKEND:-redis}" + MC_METADATA_SERVER="${MC_METADATA_SERVER:-redis://${REDIS_HOST}:${REDIS_PORT}}" + MC_PROTOCOL="${MC_PROTOCOL:-tcp}" + MC_DEVICE_NAME="${MC_DEVICE_NAME:-}" + + # Generate JSON config file + MOONCAKE_CONFIG_FILE=$(mktemp /tmp/mooncake_config_XXXXXX.json) + cat > "$MOONCAKE_CONFIG_FILE" </dev/null || true + ok "Redis stopped." +fi + +if [[ $BENCH_EXIT_CODE -eq 0 ]]; then + echo "" + ok "Benchmark completed successfully!" +else + echo "" + error "Benchmark failed with exit code: $BENCH_EXIT_CODE" +fi + +exit $BENCH_EXIT_CODE diff --git a/benchmarks/dist_benchmark/run_dist_direct_benchmark.sh b/benchmarks/dist_benchmark/run_dist_direct_benchmark.sh new file mode 100755 index 0000000000..6a9bd5887a --- /dev/null +++ b/benchmarks/dist_benchmark/run_dist_direct_benchmark.sh @@ -0,0 +1,372 @@ +#!/bin/bash +# ============================================================================= +# FlexKV Distributed KVCache Benchmark (Direct Mode) - One-Click Launch Script +# +# This script runs the distributed KVCache benchmark in direct mode +# (non-server_client_mode), where KVManager creates KVTaskEngine directly +# in the main process without going through KVServer/KVDPClient IPC. +# +# Usage: +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh [options] +# +# Options (passed through to benchmark_dist_direct.py): +# --config Config YAML file (default: benchmarks/dist_benchmark/example_dist_direct_config.yml) +# --mode Benchmark mode: single, multiturn, or all (default: all) +# --batch-size Batch size (default: 1) +# --sequence-length Sequence length (default: 1024) +# --num-users Number of simulated users (default: 10) +# --num-turns Number of conversation turns (default: 3) +# --clean-redis Clean up FlexKV & Mooncake residual data in Redis before running benchmark +# --clean-redis-only Clean up FlexKV & Mooncake residual data in Redis and exit (no benchmark) +# +# Examples: +# # Run with defaults (direct mode) +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh +# +# # Custom parameters +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --batch-size 4 --sequence-length 2048 +# +# # Multi-turn only +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --mode multiturn --num-users 20 --num-turns 5 +# +# # Clean Redis residual data before benchmark +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --clean-redis +# +# # Only clean Redis residual data (no benchmark) +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --clean-redis-only +# ============================================================================= + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Default config file (direct mode config) +CONFIG_FILE="${SCRIPT_DIR}/example_dist_direct_config.yml" +REDIS_STARTED_BY_US=false +CLEAN_REDIS=false +CLEAN_REDIS_ONLY=false + +# Parse script-specific arguments and --config, pass the rest through to benchmark +BENCH_ARGS=() +prev_arg="" +for arg in "$@"; do + if [[ "$prev_arg" == "--config" ]]; then + CONFIG_FILE="$arg" + BENCH_ARGS+=("$arg") + prev_arg="$arg" + continue + fi + case "$arg" in + --clean-redis) + CLEAN_REDIS=true + ;; + --clean-redis-only) + CLEAN_REDIS=true + CLEAN_REDIS_ONLY=true + ;; + *) + BENCH_ARGS+=("$arg") + ;; + esac + prev_arg="$arg" +done + +# ============================================ +# Step 1: Parse Redis config from YAML +# ============================================ +info "============================================" +info "Step 1: Parsing configuration" +info "============================================" + +parse_yaml_value() { + local key="$1" file="$2" default="${3:-}" + local val + val=$(python3 -c " +import yaml, sys +with open('$file') as f: + d = yaml.safe_load(f) +v = d.get('$key') +if v is None: + print('$default') +else: + print(v) +" 2>/dev/null) || val="$default" + echo "$val" +} + +REDIS_HOST=$(parse_yaml_value "redis_host" "$CONFIG_FILE" "127.0.0.1") +REDIS_PORT=$(parse_yaml_value "redis_port" "$CONFIG_FILE" "6379") +REDIS_PASSWORD=$(parse_yaml_value "redis_password" "$CONFIG_FILE" "") + +info "Config file: ${CONFIG_FILE}" +info "Redis: ${REDIS_HOST}:${REDIS_PORT}" +info "Mode: Direct (non-server_client_mode)" + +# ============================================ +# Step 2: Check and start Redis +# ============================================ +info "============================================" +info "Step 2: Checking Redis server" +info "============================================" + +check_redis() { + local auth_args="" + if [[ -n "$REDIS_PASSWORD" ]]; then + auth_args="-a $REDIS_PASSWORD" + fi + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $auth_args ping 2>/dev/null | grep -q "PONG" +} + +REDIS_AUTH_ARGS="" +if [[ -n "$REDIS_PASSWORD" ]]; then + REDIS_AUTH_ARGS="-a $REDIS_PASSWORD" +fi + +if check_redis; then + ok "Redis is already running at ${REDIS_HOST}:${REDIS_PORT}" +else + warn "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + + if [[ "$REDIS_HOST" == "127.0.0.1" ]] || [[ "$REDIS_HOST" == "localhost" ]]; then + if command -v redis-server &>/dev/null; then + info "Starting Redis server on port ${REDIS_PORT}..." + redis-server --port "$REDIS_PORT" --daemonize yes --save "" --appendonly no \ + --protected-mode no --loglevel warning + sleep 1 + + if check_redis; then + ok "Redis server started successfully" + REDIS_STARTED_BY_US=true + else + error "Failed to start Redis server" + error "Please install Redis: sudo apt install redis-server" + exit 1 + fi + else + error "redis-server not found. Please install Redis:" + error " sudo apt install redis-server" + exit 1 + fi + else + error "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + error "Please start Redis on the remote host first." + exit 1 + fi +fi + +# ============================================ +# Step 2.5: Clean FlexKV & Mooncake residual data in Redis (if requested) +# ============================================ +if [[ "$CLEAN_REDIS" == "true" ]]; then + info "============================================" + info "Cleaning FlexKV & Mooncake residual data in Redis" + info "============================================" + + clean_redis_keys() { + local pattern="$1" + local count=0 + local cursor=0 + while true; do + local result + result=$(redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS SCAN $cursor MATCH "$pattern" COUNT 500 2>/dev/null) + cursor=$(echo "$result" | head -1) + local keys + keys=$(echo "$result" | tail -n +2) + if [[ -n "$keys" ]]; then + local batch_keys + batch_keys=$(echo "$keys" | tr '\n' ' ') + if [[ -n "$batch_keys" ]]; then + local deleted + deleted=$(echo "$batch_keys" | xargs redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS DEL 2>/dev/null) + count=$((count + deleted)) + fi + fi + if [[ "$cursor" == "0" ]]; then + break + fi + done + echo "$count" + } + + total_deleted=0 + + # Clean FlexKV keys + for pattern in "node:*" "meta:*" "CPUB:block:*" "SSDB:block:*" "PCFSB:block:*"; do + n=$(clean_redis_keys "$pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + # Clean Mooncake Transfer Engine residual keys + for mc_pattern in "mooncake/*" "mooncake:*" "segment:*" "endpoint:*" "mc:*"; do + n=$(clean_redis_keys "$mc_pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${mc_pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + if [[ $total_deleted -gt 0 ]]; then + ok "Cleaned $total_deleted FlexKV & Mooncake residual key(s) from Redis" + else + ok "No FlexKV residual data found in Redis" + fi + + if [[ "$CLEAN_REDIS_ONLY" == "true" ]]; then + ok "Clean-only mode, exiting." + exit 0 + fi +fi + +# ============================================ +# Step 3: Set up environment +# ============================================ +info "============================================" +info "Step 3: Setting up environment" +info "============================================" + +if [[ -n "$VIRTUAL_ENV" ]]; then + # Prefer 'which python3' to get the actual resolved path in the activated venv, + # because $VIRTUAL_ENV may point to a path that doesn't match the real filesystem + # (e.g. symlinks, home dir aliases like ~ vs /data1/home). + if command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + else + PYTHON="$VIRTUAL_ENV/bin/python3" + fi + if [[ ! -x "$PYTHON" ]]; then + error "Python3 not found at $PYTHON (VIRTUAL_ENV=$VIRTUAL_ENV)" + exit 1 + fi + info "Using virtual env Python: $PYTHON" +elif command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + info "Using system Python: $PYTHON" +else + error "Python3 not found!" + exit 1 +fi + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +if [[ -d "${PROJECT_ROOT}/build" ]]; then + export LD_LIBRARY_PATH="${PROJECT_ROOT}/build:${LD_LIBRARY_PATH:-}" +fi + +# IMPORTANT: Ensure server_client_mode is NOT set for direct mode +unset FLEXKV_SERVER_CLIENT_MODE + +info "PYTHONPATH=${PYTHONPATH}" +info "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" +info "FLEXKV_SERVER_CLIENT_MODE= (direct mode)" + +# Generate mooncake config JSON if P2P is enabled +ENABLE_P2P_CPU=$(grep -E "^enable_p2p_cpu:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") +ENABLE_P2P_SSD=$(grep -E "^enable_p2p_ssd:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") + +if [[ "$ENABLE_P2P_CPU" == "true" ]] || [[ "$ENABLE_P2P_SSD" == "true" ]]; then + if [[ -z "${MOONCAKE_CONFIG_PATH:-}" ]]; then + info "P2P enabled, generating mooncake config..." + + MC_ENGINE_IP=$(parse_yaml_value "mooncake_engine_ip" "$CONFIG_FILE") + MC_ENGINE_PORT=$(parse_yaml_value "mooncake_engine_port" "$CONFIG_FILE") + MC_METADATA_BACKEND=$(parse_yaml_value "mooncake_metadata_backend" "$CONFIG_FILE") + MC_METADATA_SERVER=$(parse_yaml_value "mooncake_metadata_server" "$CONFIG_FILE") + MC_METADATA_SERVER_AUTH=$(parse_yaml_value "mooncake_metadata_server_auth" "$CONFIG_FILE") + MC_PROTOCOL=$(parse_yaml_value "mooncake_protocol" "$CONFIG_FILE") + MC_DEVICE_NAME=$(parse_yaml_value "mooncake_device_name" "$CONFIG_FILE") + LOCAL_IP=$(parse_yaml_value "local_ip" "$CONFIG_FILE" "127.0.0.1") + + MC_ENGINE_IP="${MC_ENGINE_IP:-$LOCAL_IP}" + MC_ENGINE_PORT="${MC_ENGINE_PORT:-5555}" + MC_METADATA_BACKEND="${MC_METADATA_BACKEND:-redis}" + MC_METADATA_SERVER="${MC_METADATA_SERVER:-redis://${REDIS_HOST}:${REDIS_PORT}}" + MC_PROTOCOL="${MC_PROTOCOL:-tcp}" + MC_DEVICE_NAME="${MC_DEVICE_NAME:-}" + + MOONCAKE_CONFIG_FILE=$(mktemp /tmp/mooncake_config_XXXXXX.json) + cat > "$MOONCAKE_CONFIG_FILE" </dev/null || true + ok "Redis stopped." +fi + +if [[ $BENCH_EXIT_CODE -eq 0 ]]; then + echo "" + ok "Benchmark (Direct Mode) completed successfully!" +else + echo "" + error "Benchmark failed with exit code: $BENCH_EXIT_CODE" +fi + +exit $BENCH_EXIT_CODE diff --git a/benchmarks/dist_benchmark/utils.py b/benchmarks/dist_benchmark/utils.py new file mode 100644 index 0000000000..d979a27a07 --- /dev/null +++ b/benchmarks/dist_benchmark/utils.py @@ -0,0 +1,131 @@ +import asyncio +import random +import time +from dataclasses import dataclass, field +from typing import Optional, List, Tuple, Any +import yaml + +import torch +import numpy as np +from tqdm import tqdm + +from flexkv.common.config import * +from flexkv.common.storage import KVCacheLayoutType + + +@dataclass +class KVRequest: + user_id: int + turn_id: int + request_type: str # "get" or "put" + token_ids: np.ndarray + token_mask: np.ndarray + slot_mapping: Optional[np.ndarray] = None + + request_id: int = field(init=False) + _request_id_counter: int = field(init=False, default=0) + + def __post_init__(self): + self.request_id = KVRequest._request_id_counter + KVRequest._request_id_counter += 1 + + if isinstance(self.token_ids, torch.Tensor): + self.token_ids = self.token_ids.numpy().astype(np.int64) + if isinstance(self.token_mask, torch.Tensor): + self.token_mask = self.token_mask.numpy().astype(np.int64) + if isinstance(self.slot_mapping, torch.Tensor): + self.slot_mapping = self.slot_mapping.numpy().astype(np.int64) + +def generate_random_multiturn(num_user_requests: int, + num_turns: int, + system_prompt_length: int, + input_length: int, + output_length: int, + num_turns_ratio: float = 0.5, + input_length_ratio: float = 0.5, + output_length_ratio: float = 0.5, + seed: int = None) -> List[KVRequest]: + all_requests = [] + token_id_range = 10000 + # Set seed for deterministic generation (useful for cross-node benchmarks) + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) + system_prompt = torch.randint(0, token_id_range, (system_prompt_length,)) + for i in range(num_user_requests): + user_requests = [] + user_num_turns = max(random.randint(int(num_turns_ratio * num_turns), num_turns), 1) + history = system_prompt.clone() + for j in range(user_num_turns): + turn_input_length = random.randint(int(input_length_ratio * input_length), input_length) + turn_output_length = random.randint(int(output_length_ratio * output_length), output_length) + input_tokens = torch.randint(0, token_id_range, (turn_input_length,)) + output_tokens = torch.randint(0, token_id_range, (turn_output_length,)) + request = dict( + user_id=i, + turn_id=j, + input=torch.cat([history, input_tokens], dim=0), + output=output_tokens, + ) + history = torch.cat([history, input_tokens, output_tokens], dim=0) + user_requests.append(request) + all_requests.append(user_requests) + indices = [0] * num_user_requests + kv_requests = [] + while True: + available_lists = [ + i for i in range(num_user_requests) + if indices[i] < len(all_requests[i]) + ] + if not available_lists: + break + user_id = random.choice(available_lists) + request = all_requests[user_id][indices[user_id]] + indices[user_id] += 1 + kv_requests.append(KVRequest( + user_id=request["user_id"], + turn_id=request["turn_id"], + request_type="get", + token_ids=request["input"], + token_mask=torch.ones_like(request["input"]), + )) + kv_requests.append(KVRequest( + user_id=request["user_id"], + turn_id=request["turn_id"], + request_type="put", + token_ids=torch.cat([request["input"], request["output"]], dim=0), + token_mask=torch.ones_like(torch.cat([request["input"], request["output"]], dim=0)), + )) + return kv_requests + +def load_config(config_path: str) -> Tuple[ModelConfig, CacheConfig]: + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(config) + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + update_default_config_from_user_config(model_config, cache_config, user_config) + return model_config, cache_config + +if __name__ == "__main__": + model_config, cache_config = load_config("./benchmarks/example_config.yml") + print(model_config) + print(cache_config) diff --git a/benchmarks/example_dist_config.yml b/benchmarks/example_dist_config.yml new file mode 100644 index 0000000000..0f7a175344 --- /dev/null +++ b/benchmarks/example_dist_config.yml @@ -0,0 +1,34 @@ +# Distributed KVCache benchmark config (server_client_mode) +# Model config +num_layers: 4 +num_kv_heads: 8 +head_size: 128 +dtype: bfloat16 +use_mla: false +tp_size: 1 +dp_size: 1 +tokens_per_block: 16 + +# Cache config +cpu_cache_gb: 4 +ssd_cache_gb: 0 + +# Distributed KVCache config +enable_p2p_cpu: true + +# Redis config (for KV sharing metadata) +redis_host: "10.135.1.175" +redis_port: 6379 +redis_password: "123456" +local_ip: "10.135.1.176" + +# Mooncake Transfer Engine config (required for P2P) +mooncake_engine_ip: "10.135.1.176" +mooncake_engine_port: 5555 +mooncake_metadata_backend: "redis" +mooncake_metadata_server: "redis://10.135.1.175:6379" +mooncake_metadata_server_auth: "123456" +mooncake_protocol: "rdma" # "tcp" or "rdma" +mooncake_device_name: "mlx5_0,mlx5_1,mlx5_4,mlx5_5" # RDMA device name, e.g. "mlx5_0"; leave empty for tcp +# Force server_client_mode +server_client_mode: true diff --git a/benchmarks/redis_check.py b/benchmarks/redis_check.py new file mode 100644 index 0000000000..2c702056f6 --- /dev/null +++ b/benchmarks/redis_check.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +FlexKV Redis Data Inspector + +Check what data the put-only node has pushed to Redis. +This script inspects all FlexKV-related keys in Redis including: + - global:node_id (global node ID counter) + - node: (registered node info) + - meta: (node meta: mooncake engine addr, buffer ptrs) + - buffer::* (RDMA memory region registrations) + - CPUB:: (CPU KVCache block metadata - the actual cached data index) + - SSDB:: (SSD KVCache block metadata) + - PCFSB:: (PCFS remote KVCache block metadata) + - pcfs: (PCFS file node IDs) + - mooncake/* (Mooncake Transfer Engine metadata) + +Usage: + python benchmarks/redis_check.py [--host HOST] [--port PORT] [--password PWD] + + # With defaults from example_dist_config.yml: + python benchmarks/redis_check.py --host 10.135.1.175 --port 6379 --password 123456 +""" + +import argparse +import sys + +try: + import redis +except ImportError: + print("ERROR: redis-py is required. Install with: pip install redis") + sys.exit(1) + + +def connect_redis(host, port, password): + """Connect to Redis and verify connectivity.""" + r = redis.Redis( + host=host, port=port, + password=password if password else None, + decode_responses=True, + socket_connect_timeout=5, + ) + try: + r.ping() + print(f"✅ Connected to Redis at {host}:{port}") + except redis.ConnectionError as e: + print(f"❌ Failed to connect to Redis at {host}:{port}: {e}") + sys.exit(1) + return r + + +def scan_keys(r, pattern, count=1000): + """Scan Redis keys matching pattern (non-blocking).""" + keys = [] + cursor = 0 + while True: + cursor, batch = r.scan(cursor=cursor, match=pattern, count=count) + keys.extend(batch) + if cursor == 0: + break + return sorted(keys) + + +def check_global_node_id(r): + """Check the global node ID counter.""" + print("\n" + "=" * 60) + print(" 1. Global Node ID Counter") + print("=" * 60) + val = r.get("global:node_id") + if val is not None: + print(f" global:node_id = {val}") + print(f" → {val} node(s) have been registered in total") + else: + print(" ⚠️ global:node_id not found (no nodes registered yet)") + + +def check_registered_nodes(r): + """Check registered node information.""" + print("\n" + "=" * 60) + print(" 2. Registered Nodes (node:*)") + print("=" * 60) + keys = scan_keys(r, "node:*") + if not keys: + print(" ⚠️ No registered nodes found") + return + + print(f" Found {len(keys)} registered node(s):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + print(f" {field}: {value}") + print() + + +def check_node_meta(r): + """Check node meta information (mooncake engine addr, buffer ptrs).""" + print("\n" + "=" * 60) + print(" 3. Node Meta (meta:*)") + print("=" * 60) + keys = scan_keys(r, "meta:*") + if not keys: + print(" ⚠️ No node meta found") + print(" → This means PEER2CPUTransferWorker hasn't registered yet,") + print(" or mooncake transfer engine initialization failed.") + return + + print(f" Found {len(keys)} node meta entry(ies):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + # Format large integers (pointers) in hex for readability + if field in ("cpu_buffer_ptr", "ssd_buffer_ptr"): + try: + int_val = int(value) + print(f" {field}: {value} (0x{int_val:x})") + except (ValueError, TypeError): + print(f" {field}: {value}") + else: + print(f" {field}: {value}") + print() + + +def check_buffer_registrations(r): + """Check RDMA buffer registrations.""" + print("\n" + "=" * 60) + print(" 4. RDMA Buffer Registrations (buffer:*)") + print("=" * 60) + keys = scan_keys(r, "buffer:*") + if not keys: + print(" ⚠️ No RDMA buffer registrations found") + return + + print(f" Found {len(keys)} buffer registration(s):\n") + for key in keys: + data = r.hgetall(key) + buf_size = data.get("buffer_size", "?") + try: + size_mb = int(buf_size) / (1024 * 1024) + print(f" 📌 {key}: size={buf_size} bytes ({size_mb:.2f} MB)") + except (ValueError, TypeError): + print(f" 📌 {key}: size={buf_size}") + + +def check_block_metadata(r): + """Check KVCache block metadata - this is the core data from put operations. + + FlexKV uses different key prefixes for different device types: + - CPUB:: — CPU block metadata (P2P CPU sharing) + - SSDB:: — SSD block metadata (P2P SSD sharing) + - PCFSB:: — PCFS remote block metadata + Each key is a Redis hash with fields: ph, pb, nid, hash, lt, state. + """ + print("\n" + "=" * 60) + print(" 5. KVCache Block Metadata (CPUB/SSDB/PCFSB)") + print("=" * 60) + + # FlexKV actual block key prefixes (set in hie_cache_engine.py) + block_prefixes = { + "CPUB": "CPU", + "SSDB": "SSD", + "PCFSB": "PCFS (Remote)", + } + + grand_total = 0 + for prefix, label in block_prefixes.items(): + keys = scan_keys(r, f"{prefix}:*") + if not keys: + print(f"\n [{label}] {prefix}:* — no entries found") + continue + + grand_total += len(keys) + + # Group by node_id: key format is PREFIX:: + node_blocks = {} + for key in keys: + parts = key.split(":") + if len(parts) >= 2: + node_id = parts[1] + if node_id not in node_blocks: + node_blocks[node_id] = [] + node_blocks[node_id].append(key) + + print(f"\n [{label}] {prefix}:* — {len(keys)} block(s) across {len(node_blocks)} node(s):") + + for node_id, block_keys in sorted(node_blocks.items(), key=lambda x: int(x[0]) if x[0].isdigit() else 0): + print(f" 📌 Node {node_id}: {len(block_keys)} block(s)") + + # Show first few blocks as samples + sample_count = min(3, len(block_keys)) + for key in block_keys[:sample_count]: + data = r.hgetall(key) + if data: + # BlockMeta fields: ph (physical hash), pb (physical block), + # nid (node id), hash, lt (lease time), state + ph = data.get("ph", "?") + pb = data.get("pb", "?") + nid = data.get("nid", "?") + hash_val = data.get("hash", "?") + lt = data.get("lt", "?") + state = data.get("state", "?") + print(f" {key}: ph={ph}, pb={pb}, nid={nid}, hash={hash_val}, lt={lt}, state={state}") + else: + key_type = r.type(key) + print(f" {key}: type={key_type}, (empty hash)") + + if len(block_keys) > sample_count: + print(f" ... and {len(block_keys) - sample_count} more block(s)") + + if grand_total == 0: + print("\n ⚠️ No block metadata found in any prefix (CPUB/SSDB/PCFSB)") + print(" → This means no KVCache data has been published to Redis yet.") + print(" The put-only node may still be uploading, or the upload") + print(" interval (rebuild_interval_ms) hasn't elapsed yet.") + else: + print(f"\n ✅ Total block metadata entries: {grand_total}") + + +def check_pcfs_data(r): + """Check PCFS file node IDs.""" + print("\n" + "=" * 60) + print(" 6. PCFS File Node IDs (pcfs:*)") + print("=" * 60) + keys = scan_keys(r, "pcfs:*") + if not keys: + print(" (none found - this is normal if PCFS sharing is not used)") + return + + print(f" Found {len(keys)} PCFS entry(ies):\n") + for key in keys: + values = r.lrange(key, 0, -1) + print(f" 📌 {key}: {len(values)} file node ID(s)") + if values: + sample = values[:10] + print(f" sample: {sample}") + if len(values) > 10: + print(f" ... and {len(values) - 10} more") + + +def check_mooncake_keys(r): + """Check Mooncake Transfer Engine related keys.""" + print("\n" + "=" * 60) + print(" 7. Mooncake Transfer Engine Keys") + print("=" * 60) + # Mooncake uses Redis as metadata backend, keys may vary + # Common patterns: segment info, endpoint info + patterns = ["mooncake/*", "mooncake:*", "segment:*", "endpoint:*", "mc:*"] + found_any = False + for pattern in patterns: + keys = scan_keys(r, pattern) + if keys: + found_any = True + print(f"\n Pattern '{pattern}': {len(keys)} key(s)") + for key in keys[:10]: + key_type = r.type(key) + if key_type == "hash": + data = r.hgetall(key) + print(f" 📌 {key} (hash): {data}") + elif key_type == "string": + val = r.get(key) + if val and len(val) > 200: + print(f" 📌 {key} (string): {val[:200]}...") + else: + print(f" 📌 {key} (string): {val}") + elif key_type == "set": + members = r.smembers(key) + print(f" 📌 {key} (set): {members}") + elif key_type == "list": + vals = r.lrange(key, 0, 9) + print(f" 📌 {key} (list): {vals}") + else: + print(f" 📌 {key} (type={key_type})") + if len(keys) > 10: + print(f" ... and {len(keys) - 10} more") + + if not found_any: + print(" (no mooncake-specific keys found)") + + +def check_all_keys_summary(r): + """Show a summary of ALL keys in Redis grouped by prefix.""" + print("\n" + "=" * 60) + print(" 8. All Keys Summary") + print("=" * 60) + all_keys = scan_keys(r, "*") + if not all_keys: + print(" ⚠️ Redis is completely empty!") + return + + print(f" Total keys in Redis: {len(all_keys)}\n") + + # Group by prefix (first part before ':') + prefix_counts = {} + for key in all_keys: + prefix = key.split(":")[0] if ":" in key else key + prefix_counts[prefix] = prefix_counts.get(prefix, 0) + 1 + + print(f" {'Prefix':<30} {'Count':>8}") + print(f" {'-'*30} {'-'*8}") + for prefix, count in sorted(prefix_counts.items(), key=lambda x: -x[1]): + print(f" {prefix:<30} {count:>8}") + + +def main(): + parser = argparse.ArgumentParser( + description="FlexKV Redis Data Inspector - Check put-only node data" + ) + parser.add_argument("--host", type=str, default="10.135.1.175", + help="Redis host (default: 10.135.1.175)") + parser.add_argument("--port", type=int, default=6379, + help="Redis port (default: 6379)") + parser.add_argument("--password", type=str, default="123456", + help="Redis password (default: 123456)") + args = parser.parse_args() + + print("=" * 60) + print(" FlexKV Redis Data Inspector") + print("=" * 60) + print(f" Target: {args.host}:{args.port}") + + r = connect_redis(args.host, args.port, args.password) + + check_global_node_id(r) + check_registered_nodes(r) + check_node_meta(r) + check_buffer_registrations(r) + check_block_metadata(r) + check_pcfs_data(r) + check_mooncake_keys(r) + check_all_keys_summary(r) + + print("\n" + "=" * 60) + print(" Inspection Complete") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_dist_benchmark.sh b/benchmarks/run_dist_benchmark.sh new file mode 100755 index 0000000000..8e64fdac22 --- /dev/null +++ b/benchmarks/run_dist_benchmark.sh @@ -0,0 +1,405 @@ +#!/bin/bash +# ============================================================================= +# FlexKV Distributed KVCache Benchmark - One-Click Launch Script +# +# This script handles: +# 1. Check and start Redis server if not running +# 2. Set up environment variables +# 3. Run the distributed KVCache benchmark +# +# Usage: +# bash benchmarks/run_dist_benchmark.sh [options] +# +# Options (passed through to benchmark_dist_kvcache.py): +# --config Config YAML file (default: benchmarks/example_dist_config.yml) +# --mode Benchmark mode: single, multiturn, or all (default: all) +# --batch-size Batch size (default: 1) +# --sequence-length Sequence length (default: 1024) +# --num-users Number of simulated users (default: 10) +# --num-turns Number of conversation turns (default: 3) +# --clean-redis Clean up FlexKV & Mooncake residual data in Redis before running benchmark +# (removes node:*, meta:*, CPUB:block:*, SSDB:block:*, PCFSB:block:*, +# mooncake/*, mooncake:*, segment:*, endpoint:*, mc:* keys) +# --clean-redis-only Clean up FlexKV & Mooncake residual data in Redis and exit (no benchmark) +# +# Examples: +# # Run with defaults +# bash benchmarks/run_dist_benchmark.sh +# +# # Custom parameters +# bash benchmarks/run_dist_benchmark.sh --batch-size 4 --sequence-length 2048 +# +# # Multi-turn only +# bash benchmarks/run_dist_benchmark.sh --mode multiturn --num-users 20 --num-turns 5 +# +# # Clean Redis residual data before benchmark +# bash benchmarks/run_dist_benchmark.sh --clean-redis +# +# # Only clean Redis residual data (no benchmark) +# bash benchmarks/run_dist_benchmark.sh --clean-redis-only +# ============================================================================= + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Default config file +CONFIG_FILE="${SCRIPT_DIR}/example_dist_config.yml" +REDIS_STARTED_BY_US=false +CLEAN_REDIS=false +CLEAN_REDIS_ONLY=false + +# Parse script-specific arguments and --config, pass the rest through to benchmark +BENCH_ARGS=() +prev_arg="" +for arg in "$@"; do + if [[ "$prev_arg" == "--config" ]]; then + CONFIG_FILE="$arg" + BENCH_ARGS+=("$arg") + prev_arg="$arg" + continue + fi + case "$arg" in + --clean-redis) + CLEAN_REDIS=true + ;; + --clean-redis-only) + CLEAN_REDIS=true + CLEAN_REDIS_ONLY=true + ;; + *) + BENCH_ARGS+=("$arg") + ;; + esac + prev_arg="$arg" +done + +# ============================================ +# Step 1: Parse Redis config from YAML +# ============================================ +info "============================================" +info "Step 1: Parsing configuration" +info "============================================" + +# Helper function to parse a YAML value using Python (handles comments, quotes, etc. correctly) +# Usage: parse_yaml_value [default] +parse_yaml_value() { + local key="$1" file="$2" default="${3:-}" + local val + val=$(python3 -c " +import yaml, sys +with open('$file') as f: + d = yaml.safe_load(f) +v = d.get('$key') +if v is None: + print('$default') +else: + print(v) +" 2>/dev/null) || val="$default" + echo "$val" +} + +# Simple YAML parser for redis config +REDIS_HOST=$(parse_yaml_value "redis_host" "$CONFIG_FILE" "127.0.0.1") +REDIS_PORT=$(parse_yaml_value "redis_port" "$CONFIG_FILE" "6379") +REDIS_PASSWORD=$(parse_yaml_value "redis_password" "$CONFIG_FILE" "") + +info "Config file: ${CONFIG_FILE}" +info "Redis: ${REDIS_HOST}:${REDIS_PORT}" + +# ============================================ +# Step 2: Check and start Redis +# ============================================ +info "============================================" +info "Step 2: Checking Redis server" +info "============================================" + +check_redis() { + local auth_args="" + if [[ -n "$REDIS_PASSWORD" ]]; then + auth_args="-a $REDIS_PASSWORD" + fi + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $auth_args ping 2>/dev/null | grep -q "PONG" +} + +# Build redis-cli auth arguments (reused across the script) +REDIS_AUTH_ARGS="" +if [[ -n "$REDIS_PASSWORD" ]]; then + REDIS_AUTH_ARGS="-a $REDIS_PASSWORD" +fi + +if check_redis; then + ok "Redis is already running at ${REDIS_HOST}:${REDIS_PORT}" +else + warn "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + + # Only try to start Redis if it's localhost + if [[ "$REDIS_HOST" == "127.0.0.1" ]] || [[ "$REDIS_HOST" == "localhost" ]]; then + if command -v redis-server &>/dev/null; then + info "Starting Redis server on port ${REDIS_PORT}..." + redis-server --port "$REDIS_PORT" --daemonize yes --save "" --appendonly no \ + --protected-mode no --loglevel warning + sleep 1 + + if check_redis; then + ok "Redis server started successfully" + REDIS_STARTED_BY_US=true + else + error "Failed to start Redis server" + error "Please install Redis: sudo apt install redis-server" + exit 1 + fi + else + error "redis-server not found. Please install Redis:" + error " sudo apt install redis-server" + exit 1 + fi + else + error "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + error "Please start Redis on the remote host first." + exit 1 + fi +fi + +# ============================================ +# Step 2.5: Clean FlexKV residual data in Redis (if requested) +# ============================================ +if [[ "$CLEAN_REDIS" == "true" ]]; then + info "============================================" + info "Cleaning FlexKV residual data in Redis" + info "============================================" + + clean_redis_keys() { + local pattern="$1" + local count=0 + local cursor=0 + while true; do + local result + result=$(redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS SCAN $cursor MATCH "$pattern" COUNT 500 2>/dev/null) + cursor=$(echo "$result" | head -1) + local keys + keys=$(echo "$result" | tail -n +2) + if [[ -n "$keys" ]]; then + local batch_keys + batch_keys=$(echo "$keys" | tr '\n' ' ') + if [[ -n "$batch_keys" ]]; then + local deleted + deleted=$(echo "$batch_keys" | xargs redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS DEL 2>/dev/null) + count=$((count + deleted)) + fi + fi + if [[ "$cursor" == "0" ]]; then + break + fi + done + echo "$count" + } + + total_deleted=0 + + # Clean node:* keys + n=$(clean_redis_keys "node:*") + [[ $n -gt 0 ]] && info "Deleted $n node:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean meta:* keys + n=$(clean_redis_keys "meta:*") + [[ $n -gt 0 ]] && info "Deleted $n meta:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean CPUB:block:* keys + n=$(clean_redis_keys "CPUB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n CPUB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean SSDB:block:* keys + n=$(clean_redis_keys "SSDB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n SSDB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean PCFSB:block:* keys + n=$(clean_redis_keys "PCFSB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n PCFSB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean Mooncake Transfer Engine residual keys + # Mooncake uses Redis as metadata backend to store segment/endpoint info + for mc_pattern in "mooncake/*" "mooncake:*" "segment:*" "endpoint:*" "mc:*"; do + n=$(clean_redis_keys "$mc_pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${mc_pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + if [[ $total_deleted -gt 0 ]]; then + ok "Cleaned $total_deleted FlexKV & Mooncake residual key(s) from Redis" + else + ok "No FlexKV residual data found in Redis" + fi + + if [[ "$CLEAN_REDIS_ONLY" == "true" ]]; then + ok "Clean-only mode, exiting." + exit 0 + fi +fi + +# ============================================ +# Step 3: Set up environment +# ============================================ +info "============================================" +info "Step 3: Setting up environment" +info "============================================" + +# Detect Python (prefer virtual env) +if [[ -n "$VIRTUAL_ENV" ]]; then + # Prefer 'which python3' to get the actual resolved path in the activated venv, + # because $VIRTUAL_ENV may point to a path that doesn't match the real filesystem + # (e.g. symlinks, home dir aliases like ~ vs /data1/home). + if command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + else + PYTHON="$VIRTUAL_ENV/bin/python3" + fi + if [[ ! -x "$PYTHON" ]]; then + error "Python3 not found at $PYTHON (VIRTUAL_ENV=$VIRTUAL_ENV)" + exit 1 + fi + info "Using virtual env Python: $PYTHON" +elif command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + info "Using system Python: $PYTHON" +else + error "Python3 not found!" + exit 1 +fi + +# Set PYTHONPATH to include project root +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# Set LD_LIBRARY_PATH for C++ libraries +if [[ -d "${PROJECT_ROOT}/build" ]]; then + export LD_LIBRARY_PATH="${PROJECT_ROOT}/build:${LD_LIBRARY_PATH:-}" +fi + +info "PYTHONPATH=${PYTHONPATH}" +info "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" + +# Generate mooncake config JSON and export MOONCAKE_CONFIG_PATH if P2P is enabled +ENABLE_P2P_CPU=$(grep -E "^enable_p2p_cpu:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") +ENABLE_P2P_SSD=$(grep -E "^enable_p2p_ssd:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") + +if [[ "$ENABLE_P2P_CPU" == "true" ]] || [[ "$ENABLE_P2P_SSD" == "true" ]]; then + if [[ -z "${MOONCAKE_CONFIG_PATH:-}" ]]; then + info "P2P enabled, generating mooncake config..." + + # Parse mooncake config from YAML using helper function + MC_ENGINE_IP=$(parse_yaml_value "mooncake_engine_ip" "$CONFIG_FILE") + MC_ENGINE_PORT=$(parse_yaml_value "mooncake_engine_port" "$CONFIG_FILE") + MC_METADATA_BACKEND=$(parse_yaml_value "mooncake_metadata_backend" "$CONFIG_FILE") + MC_METADATA_SERVER=$(parse_yaml_value "mooncake_metadata_server" "$CONFIG_FILE") + MC_METADATA_SERVER_AUTH=$(parse_yaml_value "mooncake_metadata_server_auth" "$CONFIG_FILE") + MC_PROTOCOL=$(parse_yaml_value "mooncake_protocol" "$CONFIG_FILE") + MC_DEVICE_NAME=$(parse_yaml_value "mooncake_device_name" "$CONFIG_FILE") + LOCAL_IP=$(parse_yaml_value "local_ip" "$CONFIG_FILE" "127.0.0.1") + + # Use defaults if not specified + MC_ENGINE_IP="${MC_ENGINE_IP:-$LOCAL_IP}" + MC_ENGINE_PORT="${MC_ENGINE_PORT:-5555}" + MC_METADATA_BACKEND="${MC_METADATA_BACKEND:-redis}" + MC_METADATA_SERVER="${MC_METADATA_SERVER:-redis://${REDIS_HOST}:${REDIS_PORT}}" + MC_PROTOCOL="${MC_PROTOCOL:-tcp}" + MC_DEVICE_NAME="${MC_DEVICE_NAME:-}" + + # Generate JSON config file + MOONCAKE_CONFIG_FILE=$(mktemp /tmp/mooncake_config_XXXXXX.json) + cat > "$MOONCAKE_CONFIG_FILE" </dev/null || true + ok "Redis stopped." +fi + +if [[ $BENCH_EXIT_CODE -eq 0 ]]; then + echo "" + ok "Benchmark completed successfully!" +else + echo "" + error "Benchmark failed with exit code: $BENCH_EXIT_CODE" +fi + +exit $BENCH_EXIT_CODE diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 1ebabc402e..d979a27a07 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -43,9 +43,14 @@ def generate_random_multiturn(num_user_requests: int, output_length: int, num_turns_ratio: float = 0.5, input_length_ratio: float = 0.5, - output_length_ratio: float = 0.5) -> List[KVRequest]: + output_length_ratio: float = 0.5, + seed: int = None) -> List[KVRequest]: all_requests = [] token_id_range = 10000 + # Set seed for deterministic generation (useful for cross-node benchmarks) + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) system_prompt = torch.randint(0, token_id_range, (system_prompt_length,)) for i in range(num_user_requests): user_requests = [] diff --git a/build.sh b/build.sh index 3e021a940c..05078f1277 100755 --- a/build.sh +++ b/build.sh @@ -15,16 +15,62 @@ for arg in "$@"; do BUILD_TYPE="release" shift ;; + --clean) + BUILD_TYPE="clean" + shift + ;; *) # Unknown option ;; esac done +# Handle clean +if [ "$BUILD_TYPE" = "clean" ]; then + echo "=== Cleaning all build artifacts ===" + + # Remove CMake build directory + if [ -d "build" ]; then + rm -rf build + echo "Removed build/" + fi + + # Remove compiled .so files in package directory + find flexkv -name "*.so" -type f -delete -print | sed 's/^/Removed /' + + # Remove copied libs directory + if [ -d "flexkv/lib" ]; then + rm -rf flexkv/lib + echo "Removed flexkv/lib/" + fi + + # Remove Python build artifacts + find . -maxdepth 2 -name "*.egg-info" -type d | while read d; do + rm -rf "$d" + echo "Removed $d" + done + # Only remove top-level dist/ (Python build output), not csrc/dist/ source directory + if [ -d "dist" ]; then + rm -rf dist + echo "Removed dist/" + fi + find . -name "__pycache__" -type d | while read d; do + rm -rf "$d" + echo "Removed $d" + done + + echo "=== Clean completed ===" + exit 0 +fi + echo "=== Building in ${BUILD_TYPE} mode ===" # Install submodules -git submodule update --init --recursive +if git rev-parse --is-inside-work-tree >/dev/null 2>&1; then + git submodule update --init --recursive +else + echo "WARNING: Not a git repository, skipping submodule update. If submodules are missing, please clone the repo instead of copying." +fi mkdir -p build cd build diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 79d68f5bbe..80aa938309 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -32,6 +32,7 @@ #include "dist/lock_free_q.h" #include "dist/redis_meta_channel.h" #endif +#include "layerwise.h" #include "monitoring/metrics_manager.h" #include @@ -45,7 +46,8 @@ void transfer_kv_blocks_binding( int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes, int start_layer_id, int num_layers, int transfer_num_cta = 4, bool is_host_to_device = true, - bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0) { + bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0, + bool sync = true) { int num_blocks = gpu_block_id_tensor.numel(); int64_t *gpu_block_ids = @@ -85,7 +87,7 @@ void transfer_kv_blocks_binding( cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_num_cta, is_host_to_device, - use_ce_transfer, is_mla); + use_ce_transfer, is_mla, sync); break; case flexkv::BackendType::TRTLLM: flexkv::transfer_kv_blocks( @@ -93,7 +95,7 @@ void transfer_kv_blocks_binding( cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_num_cta, is_host_to_device, - use_ce_transfer, is_mla); + use_ce_transfer, is_mla, sync); break; case flexkv::BackendType::SGLANG: flexkv::transfer_kv_blocks( @@ -101,7 +103,7 @@ void transfer_kv_blocks_binding( cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_num_cta, is_host_to_device, - use_ce_transfer, is_mla); + use_ce_transfer, is_mla, sync); break; } @@ -403,7 +405,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("start_layer_id"), py::arg("num_layers"), py::arg("transfer_num_cta") = 4, py::arg("is_host_to_device") = true, py::arg("use_ce_transfer") = false, py::arg("is_mla") = false, - py::arg("gpu_block_type") = 0); + py::arg("gpu_block_type") = 0, py::arg("sync") = true); m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding, "Transfer KV blocks between SSD and CPU memory", py::arg("ioctx"), py::arg("cpu_layer_id_list"), py::arg("cpu_tensor_ptr"), @@ -414,7 +416,59 @@ PYBIND11_MODULE(c_ext, m) { py::arg("is_read"), py::arg("num_blocks_per_file"), py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false); - + py::class_(m, "LayerwiseTransferGroup") + .def(py::init> &, + torch::Tensor &, std::map> &, + int, torch::Tensor &, torch::Tensor &, torch::Tensor &, + torch::Tensor &, int, int, torch::Tensor &, int, + const std::vector> &, + torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor, + std::map>>(), + py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), + py::arg("ssd_files"), py::arg("num_layers"), + py::arg("gpu_kv_strides_tensor"), + py::arg("gpu_block_strides_tensor"), + py::arg("gpu_layer_strides_tensor"), + py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), + py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"), + py::arg("tp_size"), + py::arg("indexer_gpu_blocks") = std::vector>{}, + py::arg("indexer_cpu_blocks") = torch::Tensor(), + py::arg("indexer_gpu_kv_strides_tensor") = torch::Tensor(), + py::arg("indexer_gpu_block_strides_tensor") = torch::Tensor(), + py::arg("indexer_gpu_layer_strides_tensor") = torch::Tensor(), + py::arg("indexer_gpu_chunk_sizes_tensor") = torch::Tensor(), + py::arg("indexer_ssd_files") = std::map>{}) + .def("layerwise_transfer", + &flexkv::LayerwiseTransferGroup::layerwise_transfer, + py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), + py::arg("ssd_layer_stride_in_bytes"), + py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), + py::arg("round_robin"), py::arg("num_threads_per_device"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), + py::arg("cpu_kv_stride_in_bytes"), + py::arg("cpu_layer_stride_in_bytes"), + py::arg("cpu_block_stride_in_bytes"), + py::arg("cpu_chunk_size_in_bytes"), + py::arg("h2d_cpu_kv_stride_in_bytes"), + py::arg("h2d_cpu_layer_stride_in_bytes"), + py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_cta_num"), + py::arg("use_ce_transfer"), py::arg("num_layers"), + py::arg("layer_granularity"), py::arg("is_mla"), + py::arg("counter_id") = 0, + py::arg("indexer_gpu_block_id_tensor") = torch::Tensor(), + py::arg("indexer_cpu_block_id_tensor") = torch::Tensor(), + py::arg("indexer_cpu_block_stride_in_bytes") = 0, + py::arg("indexer_cpu_layer_stride_in_bytes") = 0, + py::arg("indexer_h2d_cpu_kv_stride_in_bytes") = 0, + py::arg("indexer_h2d_cpu_layer_stride_in_bytes") = 0, + py::arg("indexer_ssd_block_ids") = torch::Tensor(), + py::arg("indexer_cpu_block_ids_d2h") = torch::Tensor(), + py::arg("indexer_ssd_layer_stride_in_bytes") = 0, + py::arg("indexer_ssd_kv_stride_in_bytes") = 0, + py::arg("indexer_cpu_chunk_size_in_bytes") = 0, + py::arg("indexer_num_blocks_per_file") = 0); #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, "Transfer KV blocks between remote and CPU memory", @@ -456,13 +510,13 @@ PYBIND11_MODULE(c_ext, m) { py::init> &, int, int, int>()); py::class_(m, "TPTransferThreadGroup") - .def(py::init &, int, int64_t, int, int, + .def(py::init &, int, int64_t, int, const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &>(), py::arg("num_gpus"), py::arg("gpu_block_ptrs_flat"), py::arg("num_tensors_per_gpu"), py::arg("cpu_blocks_ptr"), - py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("num_layers"), py::arg("gpu_kv_strides_in_bytes"), py::arg("gpu_block_strides_in_bytes"), py::arg("gpu_layer_strides_in_bytes"), @@ -475,19 +529,19 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_block_stride_in_bytes"), py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_num_cta"), py::arg("is_host_to_device"), py::arg("use_ce_transfer"), - py::arg("layer_id"), py::arg("layer_granularity"), - py::arg("is_mla")); + py::arg("layer_id"), py::arg("layer_granularity"), py::arg("is_mla") + ); #ifdef FLEXKV_ENABLE_GDS py::class_(m, "TPGDSTransferThreadGroup") .def(py::init &, int, - std::map> &, int, int, + std::map> &, int, const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &>(), py::arg("num_gpus"), py::arg("gpu_block_ptrs_flat"), py::arg("num_tensors_per_gpu"), py::arg("ssd_files"), - py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("num_layers"), py::arg("gpu_kv_strides_in_bytes"), py::arg("gpu_block_strides_in_bytes"), py::arg("gpu_layer_strides_in_bytes"), @@ -606,11 +660,18 @@ PYBIND11_MODULE(c_ext, m) { py::class_>( m, "CMatchResult") .def(py::init()) + torch::Tensor, int32_t>(), + py::arg("num_ready_matched_blocks"), + py::arg("num_matched_blocks"), + py::arg("last_node_matched_length"), + py::arg("last_ready_node"), + py::arg("last_node"), + py::arg("physical_blocks"), + py::arg("matched_node_id") = -1) .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) .def_readonly("last_node", &flexkv::CMatchResult::last_node) .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) - .def_readonly("block_node_ids", &flexkv::CMatchResult::block_node_ids) + .def_readonly("matched_node_id", &flexkv::CMatchResult::matched_node_id) .def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks) .def_readonly("num_matched_blocks", diff --git a/csrc/dist/distributed_radix_tree.cpp b/csrc/dist/distributed_radix_tree.cpp index 1831896861..04a3b38b35 100644 --- a/csrc/dist/distributed_radix_tree.cpp +++ b/csrc/dist/distributed_radix_tree.cpp @@ -354,8 +354,7 @@ std::shared_ptr DistributedRadixTree::match_prefix( if (idx == nullptr) { // Remote index not yet built - this is normal at startup auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); - auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); - return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64); } // Safely increment reference count while holding the lock @@ -563,8 +562,7 @@ std::shared_ptr RefRadixTree::match_prefix( if (root == nullptr) { auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); - auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); - return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64); } auto current_node = root; @@ -578,10 +576,10 @@ std::shared_ptr RefRadixTree::match_prefix( auto block_hashes_ptr = block_hashes.data_ptr(); HashType child_hash; - // node ids stored as int32 tensor (PyTorch lacks uint32 dtype) - auto node_ids_tensor = torch::empty({num_blocks}, torch::dtype(torch::kInt32)); - auto *ni_out = node_ids_tensor.data_ptr(); - int32_t ni_write = 0; + // Single-node matching constraint: all matched blocks must come from the + // same peer node_id. We lock the node_id on the first valid block and + // stop matching when a different node_id is encountered. + int32_t matched_node_id = -1; // -1 = not yet determined // now in ms struct timeval now_tv; gettimeofday(&now_tv, nullptr); @@ -638,9 +636,20 @@ std::shared_ptr RefRadixTree::match_prefix( if (bnis == nullptr || bnis->size() != pbs.size()) break; + // Single-node constraint: stop at the first block whose node_id + // differs from the already-locked matched_node_id. + int actually_copied = 0; for (int i = 0; i < matched; ++i) { + int32_t block_nid = static_cast((*bnis)[i]); + if (matched_node_id == -1) { + matched_node_id = block_nid; // lock the first node_id + } else if (block_nid != matched_node_id) { + // Different node_id encountered - stop matching here + matched = actually_copied; + break; + } pb_out[pb_write++] = pbs[i]; - ni_out[ni_write++] = (*bnis)[i]; + actually_copied++; } if (current_node->is_ready()) { @@ -672,10 +681,9 @@ std::shared_ptr RefRadixTree::match_prefix( } auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); - auto node_ids = node_ids_tensor.narrow(0, 0, ni_write); return std::make_shared(prefix_blocks_num, prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks, node_ids); + last_ready_node, current_node, physical_blocks, matched_node_id); } // Helper function to clean up an orphan tree (not attached to main tree) diff --git a/csrc/gds/tp_gds_transfer_thread_group.cpp b/csrc/gds/tp_gds_transfer_thread_group.cpp index f75e35bfe6..30aecae204 100644 --- a/csrc/gds/tp_gds_transfer_thread_group.cpp +++ b/csrc/gds/tp_gds_transfer_thread_group.cpp @@ -9,7 +9,6 @@ TPGDSTransferThreadGroup::TPGDSTransferThreadGroup( const std::vector &gpu_block_ptrs_flat, int num_tensors_per_gpu, std::map> &ssd_files, - int dp_group_id, int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, @@ -19,7 +18,6 @@ TPGDSTransferThreadGroup::TPGDSTransferThreadGroup( num_gpus_ = num_gpus; num_tensors_per_gpu_ = num_tensors_per_gpu; - dp_group_id_ = dp_group_id; // per-GPU layout parameters gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; diff --git a/csrc/gds/tp_gds_transfer_thread_group.h b/csrc/gds/tp_gds_transfer_thread_group.h index 616c788b30..359548477a 100644 --- a/csrc/gds/tp_gds_transfer_thread_group.h +++ b/csrc/gds/tp_gds_transfer_thread_group.h @@ -27,7 +27,6 @@ class TPGDSTransferThreadGroup { const std::vector &gpu_block_ptrs_flat, int num_tensors_per_gpu, std::map> &ssd_files, - int dp_group_id, int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, @@ -54,7 +53,6 @@ class TPGDSTransferThreadGroup { std::future enqueue_for_gpu(int gpu_idx, Task task); int num_gpus_; - int dp_group_id_; std::vector gpu_device_ids_; void **gpu_blocks_; int num_tensors_per_gpu_; diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp new file mode 100644 index 0000000000..40a217ec25 --- /dev/null +++ b/csrc/layerwise.cpp @@ -0,0 +1,646 @@ +#include "layerwise.h" +#include +#include +#include +#include +#include +#include +#include + +namespace flexkv { + +struct LayerCallbackData { + int start_layer; + int layers_this_batch; + int num_gpus; + std::atomic *counter; + // Eventfd info for notification + bool enable_eventfd; + int tp_size; + int num_layers; + int *layer_eventfds; // Pointer to eventfds array for current counter set + // NVTX range id for CPU->GPU transfer + nvtxRangeId_t *current_range_id_ptr; // Pointer to current layer's range ID + bool is_last_batch; // Whether this is the last batch + char next_range_name[64]; // Name for next layer's range (if not last batch) + nvtxRangeId_t *next_range_id_ptr; // Pointer to next layer's range ID storage +}; + +static void CUDART_CB layer_done_host_callback(void *userData) { + LayerCallbackData *data = static_cast(userData); + int completed = data->counter->fetch_add(1) + 1; + if (completed == data->num_gpus) { + // Notify via eventfd when all GPUs complete this layer batch + if (data->enable_eventfd && data->layer_eventfds != nullptr) { + // Signal each tp_rank's eventfd for completed layers + for (int layer = data->start_layer; + layer < data->start_layer + data->layers_this_batch; ++layer) { + for (int tp_rank = 0; tp_rank < data->tp_size; ++tp_rank) { + int fd = data->layer_eventfds[tp_rank * data->num_layers + layer]; + if (fd >= 0) { + // Write 2 to support both get_key_buffer and get_value_buffer waits + uint64_t val = 2; + ssize_t ret = write(fd, &val, sizeof(val)); + } + } + } + } + // End current NVTX range when all GPUs complete + if (data->current_range_id_ptr != nullptr && *data->current_range_id_ptr != 0) { + nvtxRangeEnd(*data->current_range_id_ptr); + } + // Start next layer's NVTX range (so it begins right after current layer ends) + if (!data->is_last_batch && data->next_range_id_ptr != nullptr) { + *data->next_range_id_ptr = nvtxRangeStartA(data->next_range_name); + } + delete data->counter; + } + delete data; +} + +LayerwiseTransferGroup::LayerwiseTransferGroup( + int num_gpus, const std::vector> &gpu_blocks, + torch::Tensor &cpu_blocks, + std::map> &ssd_files, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size, + const std::vector> &indexer_gpu_blocks, + torch::Tensor indexer_cpu_blocks, + torch::Tensor indexer_gpu_kv_strides_tensor, + torch::Tensor indexer_gpu_block_strides_tensor, + torch::Tensor indexer_gpu_layer_strides_tensor, + torch::Tensor indexer_gpu_chunk_sizes_tensor, + std::map> indexer_ssd_files) { + + num_gpus_ = num_gpus; + num_layers_ = num_layers; + tp_size_ = tp_size; + current_counter_id_ = 0; + + // Initialize eventfds + enable_eventfd_ = (layer_eventfds_tensor.numel() > 0); + if (enable_eventfd_) { + // layer_eventfds_tensor layout: [num_counters, tp_size, num_layers] + // Index formula: counter_id * tp_size * num_layers + tp_rank * num_layers + layer + int total_fds = layer_eventfds_tensor.numel(); + num_counters_ = total_fds / (tp_size * num_layers); + + int32_t *fds_ptr = layer_eventfds_tensor.data_ptr(); + layer_eventfds_.assign(fds_ptr, fds_ptr + total_fds); + + printf("[LayerwiseTransferGroup] Initialized with eventfds: " + "tp_size=%d, num_counters=%d, num_layers=%d, total_fds=%d\n", + tp_size_, num_counters_, num_layers_, total_fds); + } else { + num_counters_ = 0; + printf("[LayerwiseTransferGroup] Initialized without eventfds\n"); + } + + gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_layer_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t *kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); + int64_t *block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t *layer_strides_ptr = gpu_layer_strides_tensor.data_ptr(); + int64_t *chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; + gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; + gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + gpu_layer_strides_in_bytes_[i] = layer_strides_ptr[i]; + } + + num_tensors_per_gpu_ = gpu_blocks[0].size(); + cudaMallocHost((void **)&gpu_blocks_, + num_gpus_ * num_tensors_per_gpu_ * sizeof(void *)); + for (int i = 0; i < num_gpus_; ++i) { + for (int j = 0; j < num_tensors_per_gpu_; ++j) { + gpu_blocks_[i * num_tensors_per_gpu_ + j] = gpu_blocks[i][j].data_ptr(); + } + } + + if (num_tensors_per_gpu_ == 1) { + backend_type_ = BackendType::TRTLLM; + } else if (num_tensors_per_gpu_ == num_layers) { + backend_type_ = BackendType::VLLM; + } else if (num_tensors_per_gpu_ == num_layers * 2) { + backend_type_ = BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported GPU block type: " + + std::to_string(num_tensors_per_gpu_)); + } + + gpu_tensor_handlers_.reserve(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { + int64_t **gpu_blocks_ptr = + reinterpret_cast(gpu_blocks_ + i * num_tensors_per_gpu_); + gpu_tensor_handlers_.emplace_back( + backend_type_, gpu_blocks_ptr, num_layers, gpu_kv_strides_in_bytes_[i], + gpu_block_strides_in_bytes_[i], gpu_layer_strides_in_bytes_[i]); + } + + cpu_blocks_ = cpu_blocks.data_ptr(); + + // Get GPU device IDs from tensors (like tp_transfer_thread_group.cpp) + gpu_device_ids_.resize(num_gpus_); + for (int i = 0; i < num_gpus_; ++i) { + gpu_device_ids_[i] = gpu_blocks[i][0].device().index(); + } + + // Create CUDA streams for each GPU + streams_.resize(num_gpus_); + events_.resize(num_gpus_); + + // Get highest priority (lowest value) + int leastPriority, greatestPriority; + cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority); + + for (int i = 0; i < num_gpus_; i++) { + cudaSetDevice(gpu_device_ids_[i]); + cudaStreamCreateWithPriority(&streams_[i], cudaStreamNonBlocking, greatestPriority); + cudaEventCreate(&events_[i]); + } + + // Initialize SSD IO context if ssd_files is not empty + enable_ssd_ = !ssd_files.empty(); + if (enable_ssd_) { + ioctx_ = std::make_unique(ssd_files, ssd_files.size(), + iouring_entries, iouring_flags); + } + + // Initialize indexer fuse support + enable_indexer_ = !indexer_gpu_blocks.empty(); + if (enable_indexer_) { + indexer_num_tensors_per_gpu_ = indexer_gpu_blocks[0].size(); + cudaMallocHost((void **)&indexer_gpu_blocks_, + num_gpus_ * indexer_num_tensors_per_gpu_ * sizeof(void *)); + for (int i = 0; i < num_gpus_; ++i) { + for (int j = 0; j < indexer_num_tensors_per_gpu_; ++j) { + indexer_gpu_blocks_[i * indexer_num_tensors_per_gpu_ + j] = + indexer_gpu_blocks[i][j].data_ptr(); + } + } + + indexer_cpu_blocks_ = indexer_cpu_blocks.data_ptr(); + + indexer_gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + indexer_gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + indexer_gpu_layer_strides_in_bytes_ = new int64_t[num_gpus]; + indexer_gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t *idx_kv_strides_ptr = indexer_gpu_kv_strides_tensor.data_ptr(); + int64_t *idx_block_strides_ptr = indexer_gpu_block_strides_tensor.data_ptr(); + int64_t *idx_layer_strides_ptr = indexer_gpu_layer_strides_tensor.data_ptr(); + int64_t *idx_chunk_sizes_ptr = indexer_gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + indexer_gpu_kv_strides_in_bytes_[i] = idx_kv_strides_ptr[i]; + indexer_gpu_block_strides_in_bytes_[i] = idx_block_strides_ptr[i]; + indexer_gpu_layer_strides_in_bytes_[i] = idx_layer_strides_ptr[i]; + indexer_gpu_chunk_sizes_in_bytes_[i] = idx_chunk_sizes_ptr[i]; + } + + // Determine indexer backend type from tensor count (symmetric with main KV) + if (indexer_num_tensors_per_gpu_ == 1) { + indexer_backend_type_ = BackendType::TRTLLM; + } else if (indexer_num_tensors_per_gpu_ == num_layers) { + indexer_backend_type_ = BackendType::VLLM; + } else if (indexer_num_tensors_per_gpu_ == num_layers * 2) { + indexer_backend_type_ = BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported indexer GPU block type: " + + std::to_string(indexer_num_tensors_per_gpu_)); + } + + // Build GTensorHandlers for indexer (symmetric with main KV) + indexer_gpu_tensor_handlers_.reserve(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { + int64_t **idx_gpu_blocks_ptr = reinterpret_cast( + indexer_gpu_blocks_ + i * indexer_num_tensors_per_gpu_); + indexer_gpu_tensor_handlers_.emplace_back( + indexer_backend_type_, idx_gpu_blocks_ptr, num_layers, + indexer_gpu_kv_strides_in_bytes_[i], + indexer_gpu_block_strides_in_bytes_[i], + indexer_gpu_layer_strides_in_bytes_[i]); + } + + fprintf(stderr, "[LayerwiseTransferGroup] Indexer fuse: enabled=true, " + "num_tensors_per_gpu=%d, chunk_size=%ld bytes, backend=%s\n", + indexer_num_tensors_per_gpu_, indexer_gpu_chunk_sizes_in_bytes_[0], + indexer_backend_type_ == BackendType::SGLANG ? "SGLANG" : + indexer_backend_type_ == BackendType::VLLM ? "VLLM" : "TRTLLM"); + } + + // Initialize indexer SSD IO context if indexer_ssd_files is not empty + enable_indexer_ssd_ = !indexer_ssd_files.empty(); + if (enable_indexer_ssd_) { + indexer_ioctx_ = std::make_unique( + indexer_ssd_files, indexer_ssd_files.size(), + iouring_entries, iouring_flags); + } +} + +LayerwiseTransferGroup::~LayerwiseTransferGroup() { + for (int i = 0; i < num_gpus_; i++) { + cudaSetDevice(gpu_device_ids_[i]); + cudaStreamDestroy(streams_[i]); + cudaEventDestroy(events_[i]); + } + + cudaFreeHost(gpu_blocks_); + + gpu_tensor_handlers_.clear(); + delete[] gpu_kv_strides_in_bytes_; + delete[] gpu_block_strides_in_bytes_; + delete[] gpu_layer_strides_in_bytes_; + delete[] gpu_chunk_sizes_in_bytes_; + + // Clean up indexer resources + if (enable_indexer_) { + cudaFreeHost(indexer_gpu_blocks_); + indexer_gpu_tensor_handlers_.clear(); + delete[] indexer_gpu_kv_strides_in_bytes_; + delete[] indexer_gpu_block_strides_in_bytes_; + delete[] indexer_gpu_layer_strides_in_bytes_; + delete[] indexer_gpu_chunk_sizes_in_bytes_; + } +} + +void LayerwiseTransferGroup::layer_done_callback(int start_layer, + int layers_this_batch, + nvtxRangeId_t *current_range_id_ptr, + bool is_last_batch, + const char *next_range_name, + nvtxRangeId_t *next_range_id_ptr) { + std::atomic *counter = new std::atomic(0); + + // Get eventfd pointer for current counter set + int *eventfds_ptr = nullptr; + if (enable_eventfd_ && num_counters_ > 0) { + // Offset into layer_eventfds_ for current counter set + int offset = current_counter_id_ * tp_size_ * num_layers_; + eventfds_ptr = layer_eventfds_.data() + offset; + } + + for (int i = 0; i < num_gpus_; ++i) { + LayerCallbackData *data = new LayerCallbackData{ + start_layer, layers_this_batch, num_gpus_, counter, + enable_eventfd_, tp_size_, num_layers_, eventfds_ptr, + current_range_id_ptr, is_last_batch, {0}, next_range_id_ptr}; + // Copy next range name + if (next_range_name != nullptr) { + snprintf(data->next_range_name, sizeof(data->next_range_name), "%s", next_range_name); + } + cudaLaunchHostFunc(streams_[i], layer_done_host_callback, data); + } +} + +void LayerwiseTransferGroup::layerwise_transfer( + const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids_d2h, + const int64_t ssd_layer_stride_in_bytes, + const int64_t ssd_kv_stride_in_bytes, const int num_blocks_per_file, + const int round_robin, const int num_threads_per_device, + const torch::Tensor &gpu_block_id_tensor, + const torch::Tensor &cpu_block_id_tensor, + const int64_t cpu_kv_stride_in_bytes, + const int64_t cpu_layer_stride_in_bytes, + const int64_t cpu_block_stride_in_bytes, + const int64_t cpu_chunk_size_in_bytes, + const int64_t h2d_cpu_kv_stride_in_bytes, + const int64_t h2d_cpu_layer_stride_in_bytes, + const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, + const bool use_ce_transfer, const int num_layers, + const int layer_granularity, const bool is_mla, + const int counter_id, + const torch::Tensor &indexer_gpu_block_id_tensor, + const torch::Tensor &indexer_cpu_block_id_tensor, + const int64_t indexer_cpu_block_stride_in_bytes, + const int64_t indexer_cpu_layer_stride_in_bytes, + const int64_t indexer_h2d_cpu_kv_stride_in_bytes, + const int64_t indexer_h2d_cpu_layer_stride_in_bytes, + const torch::Tensor &indexer_ssd_block_ids, + const torch::Tensor &indexer_cpu_block_ids_d2h, + const int64_t indexer_ssd_layer_stride_in_bytes, + const int64_t indexer_ssd_kv_stride_in_bytes, + const int64_t indexer_cpu_chunk_size_in_bytes, + const int indexer_num_blocks_per_file) { + + // Set current counter ID for eventfd notification + current_counter_id_ = counter_id; + + int num_blocks = gpu_block_id_tensor.numel(); + int64_t *gpu_block_ids = + static_cast(gpu_block_id_tensor.data_ptr()); + int64_t *cpu_block_ids = + static_cast(cpu_block_id_tensor.data_ptr()); + void *cpu_ptr = cpu_blocks_; + + // Indexer block ids (may be empty if indexer is not enabled or not provided) + bool do_indexer_transfer = enable_indexer_ && + indexer_gpu_block_id_tensor.defined() && + indexer_gpu_block_id_tensor.numel() > 0; + int num_indexer_blocks = 0; + int64_t *indexer_gpu_block_ids = nullptr; + int64_t *indexer_cpu_block_ids = nullptr; + if (do_indexer_transfer) { + num_indexer_blocks = indexer_gpu_block_id_tensor.numel(); + indexer_gpu_block_ids = + static_cast(indexer_gpu_block_id_tensor.data_ptr()); + indexer_cpu_block_ids = + static_cast(indexer_cpu_block_id_tensor.data_ptr()); + } + + // Create CUDA events for timing each layer batch (on GPU 0) + int num_batches = (num_layers + layer_granularity - 1) / layer_granularity; + std::vector timing_events(num_batches + 1); // +1 for start event + std::vector batch_start_layers(num_batches); + std::vector batch_layers_count(num_batches); + + cudaSetDevice(gpu_device_ids_[0]); + for (int i = 0; i <= num_batches; ++i) { + cudaEventCreate(&timing_events[i]); + } + + // Record start event + cudaEventRecord(timing_events[0], streams_[0]); + + // Allocate storage for NVTX range IDs (one per batch) + std::vector h2d_range_ids(num_batches, 0); + // Pre-generate all range names with data size info + std::vector h2d_range_names(num_batches); + for (int b = 0; b < num_batches; ++b) { + int sl = b * layer_granularity; + int ltb = std::min(layer_granularity, num_layers - sl); + // Calculate data size for this batch: chunk_size * 2 (K+V) * layers * num_blocks + int64_t bytes_this_batch = 0; + for (int g = 0; g < num_gpus_; ++g) { + bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * ltb * num_blocks; + } + // Add indexer bytes if applicable + int64_t indexer_bytes_this_batch = 0; + if (do_indexer_transfer) { + for (int g = 0; g < num_gpus_; ++g) { + indexer_bytes_this_batch += indexer_gpu_chunk_sizes_in_bytes_[g] * ltb * num_indexer_blocks; + } + } + double mb_this_batch = (bytes_this_batch + indexer_bytes_this_batch) / (1024.0 * 1024.0); + char name[256]; + if (do_indexer_transfer) { + snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) KV:%.2fMB+Idx:%.2fMB", + sl, sl + ltb, bytes_this_batch / (1024.0 * 1024.0), + indexer_bytes_this_batch / (1024.0 * 1024.0)); + } else { + snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) %.2fMB", sl, sl + ltb, + bytes_this_batch / (1024.0 * 1024.0)); + } + h2d_range_names[b] = name; + } + + // Start the first batch's NVTX range in main thread + if (num_batches > 0) { + h2d_range_ids[0] = nvtxRangeStartA(h2d_range_names[0].c_str()); + } + + // Step 0: SSD -> CPU transfer for ALL layers at once (before layerwise loop). + // This is required because the CPU memory uses TP-divided layout where each rank's + // data occupies a contiguous region [rank*tp_stride, (rank+1)*tp_stride). Per-layer-batch + // SSD reads with full strides would land at wrong CPU positions for TP > 1. + if (enable_ssd_ && ssd_block_ids.numel() > 0) { + int num_ssd_blocks = ssd_block_ids.numel(); + int64_t ssd_bytes = cpu_chunk_size_in_bytes * 2 * num_layers * num_ssd_blocks; + double ssd_mb = ssd_bytes / (1024.0 * 1024.0); + char ssd_range_name[128]; + snprintf(ssd_range_name, sizeof(ssd_range_name), + "SSD->CPU AllLayers[0,%d) %.2fMB", num_layers, ssd_mb); + nvtxRangePushA(ssd_range_name); + + torch::Tensor all_layer_ids = + torch::arange(0, num_layers, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *ioctx_, all_layer_ids, reinterpret_cast(cpu_blocks_), + ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, + cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, + ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, + cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + num_blocks_per_file, round_robin, num_threads_per_device, is_mla); + + nvtxRangePop(); + } + + // Indexer SSD -> CPU transfer for ALL layers at once. + if (enable_indexer_ssd_ && indexer_ssd_block_ids.defined() && + indexer_ssd_block_ids.numel() > 0) { + int num_indexer_ssd_blocks = indexer_ssd_block_ids.numel(); + int64_t indexer_ssd_bytes = indexer_cpu_chunk_size_in_bytes * num_layers * num_indexer_ssd_blocks; + double indexer_ssd_mb = indexer_ssd_bytes / (1024.0 * 1024.0); + char idx_ssd_range_name[128]; + snprintf(idx_ssd_range_name, sizeof(idx_ssd_range_name), + "Indexer SSD->CPU AllLayers[0,%d) %.2fMB", num_layers, indexer_ssd_mb); + nvtxRangePushA(idx_ssd_range_name); + + torch::Tensor all_layer_ids = + torch::arange(0, num_layers, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *indexer_ioctx_, all_layer_ids, + reinterpret_cast(indexer_cpu_blocks_), + indexer_ssd_block_ids, indexer_cpu_block_ids_d2h, + indexer_cpu_layer_stride_in_bytes, + indexer_ssd_kv_stride_in_bytes, + indexer_ssd_layer_stride_in_bytes, + indexer_ssd_kv_stride_in_bytes, + indexer_cpu_chunk_size_in_bytes, + indexer_cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + indexer_num_blocks_per_file, round_robin, num_threads_per_device, + true /* is_mla: indexer always MLA */); + + nvtxRangePop(); + } + + int batch_idx = 0; + for (int start_layer = 0; start_layer < num_layers; + start_layer += layer_granularity) { + int layers_this_batch = + std::min(layer_granularity, num_layers - start_layer); + + batch_start_layers[batch_idx] = start_layer; + batch_layers_count[batch_idx] = layers_this_batch; + + // Step 1: CPU -> GPU transfer + // NVTX range for this batch was already started (by main thread for first batch, + // or by previous batch's callback for subsequent batches) + + for (int i = 0; i < num_gpus_; ++i) { + cudaSetDevice(gpu_device_ids_[i]); + int64_t cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; + if (is_mla) { + cpu_startoff_inside_chunks = 0; + } + int64_t gpu_startoff_inside_chunks = 0; + int64_t chunk_size = gpu_chunk_sizes_in_bytes_[i]; + + switch (backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, h2d_cpu_kv_stride_in_bytes, h2d_cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, h2d_cpu_kv_stride_in_bytes, h2d_cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, h2d_cpu_kv_stride_in_bytes, h2d_cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); + break; + } + + // Fused indexer CPU -> GPU transfer on the same stream + // Uses transfer_kv_blocks (symmetric with main KV Step 2) instead of + // hand-written cudaMemcpyAsync loops for backend-agnostic support. + // Note: indexer uses ReplicatedLinear weights with 1 head (is_mla=true), + // so all TP ranks hold identical data. No TP head-partitioning needed, + // cpu_startoff is always 0 (unlike main KV which may offset by tp_stride). + if (do_indexer_transfer) { + int64_t idx_chunk_size = indexer_gpu_chunk_sizes_in_bytes_[i]; + // idx_cpu_startoff = 0: indexer data is not partitioned across TP ranks + int64_t idx_cpu_startoff = 0; + + switch (indexer_backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_indexer_blocks, start_layer, layers_this_batch, + indexer_gpu_block_ids, indexer_gpu_tensor_handlers_[i], + 0 /* gpu_startoff */, indexer_cpu_block_ids, + indexer_cpu_blocks_, + indexer_h2d_cpu_kv_stride_in_bytes, + indexer_h2d_cpu_layer_stride_in_bytes, + indexer_cpu_block_stride_in_bytes, + idx_cpu_startoff, idx_chunk_size, + streams_[i], transfer_cta_num, true /* h2d */, + use_ce_transfer, true /* is_mla */, false /* sync */); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_indexer_blocks, start_layer, layers_this_batch, + indexer_gpu_block_ids, indexer_gpu_tensor_handlers_[i], + 0 /* gpu_startoff */, indexer_cpu_block_ids, + indexer_cpu_blocks_, + indexer_h2d_cpu_kv_stride_in_bytes, + indexer_h2d_cpu_layer_stride_in_bytes, + indexer_cpu_block_stride_in_bytes, + idx_cpu_startoff, idx_chunk_size, + streams_[i], transfer_cta_num, true /* h2d */, + use_ce_transfer, true /* is_mla */, false /* sync */); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_indexer_blocks, start_layer, layers_this_batch, + indexer_gpu_block_ids, indexer_gpu_tensor_handlers_[i], + 0 /* gpu_startoff */, indexer_cpu_block_ids, + indexer_cpu_blocks_, + indexer_h2d_cpu_kv_stride_in_bytes, + indexer_h2d_cpu_layer_stride_in_bytes, + indexer_cpu_block_stride_in_bytes, + idx_cpu_startoff, idx_chunk_size, + streams_[i], transfer_cta_num, true /* h2d */, + use_ce_transfer, true /* is_mla */, false /* sync */); + break; + } + } + } + + // Record event after this batch on GPU 0 + cudaSetDevice(gpu_device_ids_[0]); + cudaEventRecord(timing_events[batch_idx + 1], streams_[0]); + + // NVTX: current range ends in callback, next range starts in callback + bool is_last_batch = (batch_idx == num_batches - 1); + const char *next_name = is_last_batch ? nullptr : h2d_range_names[batch_idx + 1].c_str(); + nvtxRangeId_t *next_id_ptr = is_last_batch ? nullptr : &h2d_range_ids[batch_idx + 1]; + + layer_done_callback(start_layer, layers_this_batch, + &h2d_range_ids[batch_idx], is_last_batch, + next_name, next_id_ptr); + batch_idx++; + } + for (int i = 0; i < num_gpus_; ++i) { + cudaError_t err = cudaStreamSynchronize(streams_[i]); + if (err != cudaSuccess) { + throw std::runtime_error("layerwise_transfer failed on GPU " + + std::to_string(i) + ": " + + cudaGetErrorString(err)); + } + } + + // Calculate and print timing for each layer batch + // chunk_size per GPU * num_gpus * 2 (K+V) * layers_this_batch * num_blocks + // fprintf(stderr, "\n[LayerwiseTransfer] CPU->GPU Transfer Timing (num_blocks=%d):\n", num_blocks); + float total_time_ms = 0.0f; + int64_t total_bytes = 0; + + for (int i = 0; i < num_batches; ++i) { + float elapsed_ms = 0.0f; + cudaEventElapsedTime(&elapsed_ms, timing_events[i], timing_events[i + 1]); + + // Calculate bytes transferred for this batch + // For each GPU: chunk_size * 2 (K+V) * layers * num_blocks + int64_t bytes_this_batch = 0; + for (int g = 0; g < num_gpus_; ++g) { + bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * batch_layers_count[i] * num_blocks; + } + // Include indexer bytes + int64_t indexer_bytes_batch = 0; + if (do_indexer_transfer) { + for (int g = 0; g < num_gpus_; ++g) { + indexer_bytes_batch += indexer_gpu_chunk_sizes_in_bytes_[g] * batch_layers_count[i] * num_indexer_blocks; + } + bytes_this_batch += indexer_bytes_batch; + } + + double bandwidth_gbps = (bytes_this_batch / (1024.0 * 1024.0 * 1024.0)) / (elapsed_ms / 1000.0); + + // fprintf(stderr, " Layers [%d, %d): time=%.3f ms, size=%.2f MB, bandwidth=%.2f GB/s\n", + // batch_start_layers[i], + // batch_start_layers[i] + batch_layers_count[i], + // elapsed_ms, + // bytes_this_batch / (1024.0 * 1024.0), + // bandwidth_gbps); + + total_time_ms += elapsed_ms; + total_bytes += bytes_this_batch; + } + + double total_bandwidth_gbps = (total_bytes / (1024.0 * 1024.0 * 1024.0)) / (total_time_ms / 1000.0); + // fprintf(stderr, " Total: time=%.3f ms, size=%.2f MB, avg_bandwidth=%.2f GB/s\n\n", + // total_time_ms, total_bytes / (1024.0 * 1024.0), total_bandwidth_gbps); + // fflush(stderr); + + // Cleanup timing events + cudaSetDevice(gpu_device_ids_[0]); + for (int i = 0; i <= num_batches; ++i) { + cudaEventDestroy(timing_events[i]); + } +} + +} // namespace flexkv diff --git a/csrc/layerwise.h b/csrc/layerwise.h new file mode 100644 index 0000000000..2de0a23550 --- /dev/null +++ b/csrc/layerwise.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtensor_handler.cuh" +#include "transfer.cuh" +#include "transfer_ssd.h" + +namespace flexkv { + +class LayerwiseTransferGroup { +public: + LayerwiseTransferGroup( + int num_gpus, const std::vector> &gpu_blocks, + torch::Tensor &cpu_blocks, + std::map> &ssd_files, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size, + const std::vector> &indexer_gpu_blocks = {}, + torch::Tensor indexer_cpu_blocks = torch::Tensor(), + torch::Tensor indexer_gpu_kv_strides_tensor = torch::Tensor(), + torch::Tensor indexer_gpu_block_strides_tensor = torch::Tensor(), + torch::Tensor indexer_gpu_layer_strides_tensor = torch::Tensor(), + torch::Tensor indexer_gpu_chunk_sizes_tensor = torch::Tensor(), + std::map> indexer_ssd_files = {}); + + ~LayerwiseTransferGroup(); + + // Layerwise transfer: SSD->CPU + CPU->GPU + void layerwise_transfer( + const torch::Tensor + &ssd_block_ids, // SSD source block ids (for disk2host) + const torch::Tensor + &cpu_block_ids_d2h, // CPU dest block ids (for disk2host) + const int64_t ssd_layer_stride_in_bytes, + const int64_t ssd_kv_stride_in_bytes, const int num_blocks_per_file, + const int round_robin, const int num_threads_per_device, + const torch::Tensor + &gpu_block_id_tensor, // GPU dest block ids (for host2device) + const torch::Tensor + &cpu_block_id_tensor, // CPU source block ids (for host2device) + const int64_t cpu_kv_stride_in_bytes, + const int64_t cpu_layer_stride_in_bytes, + const int64_t cpu_block_stride_in_bytes, + const int64_t cpu_chunk_size_in_bytes, + const int64_t h2d_cpu_kv_stride_in_bytes, + const int64_t h2d_cpu_layer_stride_in_bytes, + const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, + const bool use_ce_transfer, const int num_layers, + const int layer_granularity, const bool is_mla, + const int counter_id = 0, + const torch::Tensor &indexer_gpu_block_id_tensor = torch::Tensor(), + const torch::Tensor &indexer_cpu_block_id_tensor = torch::Tensor(), + const int64_t indexer_cpu_block_stride_in_bytes = 0, + const int64_t indexer_cpu_layer_stride_in_bytes = 0, + const int64_t indexer_h2d_cpu_kv_stride_in_bytes = 0, + const int64_t indexer_h2d_cpu_layer_stride_in_bytes = 0, + const torch::Tensor &indexer_ssd_block_ids = torch::Tensor(), + const torch::Tensor &indexer_cpu_block_ids_d2h = torch::Tensor(), + const int64_t indexer_ssd_layer_stride_in_bytes = 0, + const int64_t indexer_ssd_kv_stride_in_bytes = 0, + const int64_t indexer_cpu_chunk_size_in_bytes = 0, + const int indexer_num_blocks_per_file = 0); + +private: + int num_gpus_; + int dp_group_id_; + void **gpu_blocks_; + void *cpu_blocks_; + int num_tensors_per_gpu_; + int64_t *gpu_kv_strides_in_bytes_; + int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_layer_strides_in_bytes_; + int64_t *gpu_chunk_sizes_in_bytes_; + + BackendType backend_type_; + std::vector gpu_tensor_handlers_; + + std::vector gpu_device_ids_; + std::vector streams_; + std::vector events_; + + // SSD IO context + bool enable_ssd_; + std::unique_ptr ioctx_; + + // Indexer fuse support + bool enable_indexer_ = false; + void **indexer_gpu_blocks_ = nullptr; + void *indexer_cpu_blocks_ = nullptr; + int indexer_num_tensors_per_gpu_ = 0; + int64_t *indexer_gpu_kv_strides_in_bytes_ = nullptr; + int64_t *indexer_gpu_block_strides_in_bytes_ = nullptr; + int64_t *indexer_gpu_layer_strides_in_bytes_ = nullptr; + int64_t *indexer_gpu_chunk_sizes_in_bytes_ = nullptr; + BackendType indexer_backend_type_ = BackendType::SGLANG; + std::vector indexer_gpu_tensor_handlers_; + + // Indexer SSD IO context + bool enable_indexer_ssd_ = false; + std::unique_ptr indexer_ioctx_; + + // Layer eventfds for notification + // Shape: [tp_size, num_counters, num_layers] + bool enable_eventfd_; + int tp_size_; + int num_counters_; + int num_layers_; + std::vector layer_eventfds_; // Flat array + int current_counter_id_; // Current counter set index for this transfer + + void layer_done_callback(int start_layer, int layers_this_batch, + nvtxRangeId_t *current_range_id_ptr, + bool is_last_batch, + const char *next_range_name, + nvtxRangeId_t *next_range_id_ptr); +}; + +} // namespace flexkv diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp index 04d3429f6b..b27ac43920 100644 --- a/csrc/radix_tree.cpp +++ b/csrc/radix_tree.cpp @@ -520,9 +520,8 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks, } auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); - auto empty_uint32 = torch::Tensor(); return std::make_shared(ready_prefix_blocks_num, prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks, empty_uint32); + last_ready_node, current_node, physical_blocks); } } // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h index 65bad5dc12..4c7c3c5d86 100644 --- a/csrc/radix_tree.h +++ b/csrc/radix_tree.h @@ -227,17 +227,18 @@ class CMatchResult { CRadixNode *last_ready_node; CRadixNode *last_node; torch::Tensor physical_blocks; - torch::Tensor block_node_ids; + int32_t matched_node_id; // single node_id for all matched blocks (-1 = no match) CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, CRadixNode *_last_ready_node, CRadixNode *_last_node, torch::Tensor blocks, - torch::Tensor block_node_ids = torch::Tensor()) + int32_t matched_node_id = -1) : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), last_node(_last_node), - physical_blocks(blocks), block_node_ids(block_node_ids) {} + physical_blocks(blocks), + matched_node_id(matched_node_id) {} ~CMatchResult() {} }; diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 349cc62cc8..d0fa757244 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -22,7 +22,7 @@ namespace flexkv { TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector &gpu_block_ptrs_flat, - int num_tensors_per_gpu, int64_t cpu_blocks_ptr, int dp_group_id, + int num_tensors_per_gpu, int64_t cpu_blocks_ptr, int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, const std::vector &gpu_layer_strides_in_bytes, @@ -30,7 +30,6 @@ TPTransferThreadGroup::TPTransferThreadGroup( const std::vector &gpu_device_ids) { num_gpus_ = num_gpus; num_tensors_per_gpu_ = num_tensors_per_gpu; - dp_group_id_ = dp_group_id; gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; @@ -167,6 +166,8 @@ void TPTransferThreadGroup::tp_group_transfer( std::vector> futures; futures.reserve(num_gpus_); + bool enable_sharded_d2h = is_mla && !is_host_to_device; + for (int i = 0; i < num_gpus_; ++i) { futures.emplace_back(enqueue_for_gpu(i, [&, i]() { try { @@ -177,23 +178,20 @@ void TPTransferThreadGroup::tp_group_transfer( int64_t *cpu_block_ids = static_cast(cpu_block_id_tensor.data_ptr()); void *cpu_ptr = cpu_blocks_; - int64_t cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; - if (is_mla && !is_host_to_device) { + int64_t cpu_startoff_inside_chunks = 0; + if (enable_sharded_d2h) cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_; - } else if (is_mla && is_host_to_device) { - cpu_startoff_inside_chunks = 0; - } + else if (!is_mla) + cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; int64_t gpu_startoff_inside_chunks = - is_mla && !is_host_to_device - ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ - : 0; + enable_sharded_d2h ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ + : 0; // we assume that the chunk size is the same for all gpus, // even if they have different number of gpu_blocks - int64_t chunk_size = is_mla && !is_host_to_device + int64_t chunk_size = enable_sharded_d2h ? gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i]; - // Dispatch to the appropriate template based on backend type switch (backend_type_) { case BackendType::VLLM: diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 0aceaf2a68..4a4aacd373 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -38,7 +38,7 @@ class TPTransferThreadGroup { TPTransferThreadGroup(int num_gpus, const std::vector &gpu_block_ptrs_flat, int num_tensors_per_gpu, int64_t cpu_blocks_ptr, - int dp_group_id, int num_layers, + int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, const std::vector &gpu_layer_strides_in_bytes, @@ -63,7 +63,6 @@ class TPTransferThreadGroup { std::future enqueue_for_gpu(int gpu_idx, Task task); int num_gpus_; - int dp_group_id_; std::vector gpu_device_ids_; void **gpu_blocks_; void *cpu_blocks_; diff --git a/csrc/transfer.cu b/csrc/transfer.cu index 60ac276857..f46412e406 100644 --- a/csrc/transfer.cu +++ b/csrc/transfer.cu @@ -87,7 +87,7 @@ void transfer_kv_blocks( int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_num_cta, bool is_host_to_device, - bool use_ce_transfer, bool is_mla) { + bool use_ce_transfer, bool is_mla, bool sync) { int block_size = 1024; @@ -120,8 +120,8 @@ void transfer_kv_blocks( j * cpu_kv_stride_int64 + cpu_block_idx * cpu_block_stride_int64 + cpu_startoff_inside_chunks_int64; - int64_t *gpu_ptr = - ptr_at(gpu_tensor_handler, i, j, gpu_block_idx); + int64_t *gpu_ptr = ptr_at(gpu_tensor_handler, + i + start_layer_id, j, gpu_block_idx); int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks_int64; @@ -167,7 +167,9 @@ void transfer_kv_blocks( actual_chunk_bytes * static_cast(num_layers) * static_cast(kv_dim) * static_cast(num_blocks)); } - cudaStreamSynchronize(stream); + if (sync) { + cudaStreamSynchronize(stream); + } } // Explicit template instantiations @@ -176,16 +178,16 @@ template void transfer_kv_blocks(int, int, int, int64_t *, int64_t *, void *, int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, - bool, bool, bool); + bool, bool, bool, bool); template void transfer_kv_blocks( int, int, int, int64_t *, GTensorHandler, int64_t, int64_t *, void *, int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, - bool); + bool, bool); template void transfer_kv_blocks( int, int, int, int64_t *, GTensorHandler, int64_t, int64_t *, void *, int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, - bool); + bool, bool); } // namespace flexkv diff --git a/csrc/transfer.cuh b/csrc/transfer.cuh index 7436b2e887..5aab0af1d8 100644 --- a/csrc/transfer.cuh +++ b/csrc/transfer.cuh @@ -30,6 +30,7 @@ void transfer_kv_blocks( int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_num_cta, - bool is_host_to_device, bool use_ce_transfer, bool is_mla); + bool is_host_to_device, bool use_ce_transfer, bool is_mla, + bool sync = true); } // namespace flexkv diff --git a/docs/flexkv_config_reference/README_zh.md b/docs/flexkv_config_reference/README_zh.md index 562f1d8a38..e24f055784 100644 --- a/docs/flexkv_config_reference/README_zh.md +++ b/docs/flexkv_config_reference/README_zh.md @@ -47,6 +47,9 @@ enable_gds: false | `FLEXKV_SSD_CACHE_GB` | int | 0 | SSD 缓存层容量,单位为 GB。建议设置大于 `FLEXKV_CPU_CACHE_GB`并为`FLEXKV_MAX_FILE_SIZE_GB`的整数倍,若仅用CPU缓存则设为 0(此时不启用 SSD 缓存) | | `FLEXKV_SSD_CACHE_DIR` | str | "./flexkv_ssd" | SSD 缓存数据的存放目录。若有多块 SSD,可通过分号 `;` 分隔多个挂载路径。例如 `"/data0/flexkv_ssd/;/data1/flexkv_ssd/"`,以提升带宽 | | `FLEXKV_ENABLE_GDS` | bool | 0 | 是否启用 GPU Direct Storage(GDS)。如硬件和驱动支持,开启后可提升 SSD 到 GPU 的数据吞吐能力。默认关闭,开启请设为 1 | +| `FLEXKV_USE_HUGEPAGE_CPU_BUFFER` | bool | 0 | 是否为通用 CPU KV cache 启用 HugePage。默认关闭,开启请设为 1 | +| `FLEXKV_USE_HUGEPAGE_TMP_BUFFER` | bool | 0 | 是否为 `enable_p2p_ssd` 场景下的 tmp CPU staging buffer 启用 HugePage。默认关闭,开启请设为 1 | +| `FLEXKV_HUGEPAGE_SIZE_BYTES` | int | 2097152 | HugePage 大小,默认 2 MiB。如果宿主机准备的是 1 GiB HugePage,可设为 `1073741824` | --- diff --git a/docs/hugepage/README_en.md b/docs/hugepage/README_en.md new file mode 100644 index 0000000000..f097b1d98b --- /dev/null +++ b/docs/hugepage/README_en.md @@ -0,0 +1,299 @@ +# FlexKV HugePage User Guide + +## 1. Overview + +FlexKV currently exposes three HugePage-related configuration fields: + +- `use_hugepage_cpu_buffer` + Controls whether the main CPU KV cache should be allocated from HugePages. +- `use_hugepage_tmp_buffer` + Controls whether the temporary CPU staging buffer in the `enable_p2p_ssd=true` path should be allocated from HugePages. +- `hugepage_size_bytes` + Controls the HugePage size used by both allocation paths. + +The two HugePage switches are independent. They can be enabled separately or together. + +There is one important implementation constraint today: + +- `use_hugepage_cpu_buffer` for the main CPU KV cache must use a `hugetlbfs`-backed file so the same HugePage mapping can be reopened inside `spawn` workers. +- Anonymous `MAP_HUGETLB` is not sufficient for the main CPU cache path because once that tensor is sent into `TransferEngine` workers, PyTorch serializes ordinary CPU tensors into new shared-memory storage, which breaks the original HugePage backing. +- `use_hugepage_tmp_buffer` does not have that cross-process sharing requirement, so it may still succeed through either anonymous HugePages or `hugetlbfs`. + +--- + +## 2. Recommended Use Cases + +HugePage is recommended in the following situations: + +- The CPU KV cache is large and CPU-side page table or TLB overhead is non-trivial. +- `enable_p2p_ssd=true` is enabled and you want to optimize the temporary staging buffer in the `SSD -> CPU -> GPU` path. +- The host has already been prepared with reserved HugePages and a working hugetlbfs mount. + +HugePage is not recommended when the host has no HugePage reservation or when the target workload does not materially benefit from CPU cache or p2p SSD path optimization. + +--- + +## 3. Prerequisites + +### 3.1 HugePages Must Be Reserved on the Host + +Check the current HugePage status: + +```bash +grep -E 'HugePages_|Hugepagesize' /proc/meminfo +``` + +For 2 MiB HugePages, the following command reserves 4096 pages, which is about 8 GiB: + +```bash +sudo sysctl -w vm.nr_hugepages=4096 +``` + +If you plan to use 1 GiB HugePages, they usually need to be reserved through kernel boot parameters, for example: + +```text +default_hugepagesz=1G hugepagesz=1G hugepages=N +``` + +### 3.2 hugetlbfs Must Be Mounted + +FlexKV uses `/mnt/hugepages` as the default hugetlbfs mount point. + +For `use_hugepage_cpu_buffer`, this is a hard requirement rather than a recommendation. If `FLEXKV_HUGETLBFS_DIR` points to a normal filesystem, FlexKV now rejects it and falls back instead of silently treating regular 4 KiB pages as HugePages. + +Check the mount status: + +```bash +mount | grep hugetlbfs +ls -ld /mnt/hugepages +``` + +If hugetlbfs is not mounted yet: + +```bash +sudo mkdir -p /mnt/hugepages +sudo mount -t hugetlbfs none /mnt/hugepages +``` + +If your actual hugetlbfs mount point is different, set it explicitly: + +```bash +export FLEXKV_HUGETLBFS_DIR=/path/to/hugetlbfs +``` + +### 3.3 CUDA Runtime Is Required for the tmp Buffer Path + +The `use_hugepage_tmp_buffer` path performs `cudaHostRegister` after HugePage allocation succeeds. This path therefore requires: + +- A working CUDA runtime +- `libcudart.so` to be discoverable +- A sufficiently large `memlock` limit on the host or in the container + +Basic check: + +```bash +python3 - <<'PY' +import torch +print(torch.cuda.is_available()) +PY +``` + +Note: `use_hugepage_cpu_buffer` does not depend on `cudaHostRegister`. + +Additional note: although `use_hugepage_cpu_buffer` does not require CUDA runtime, it does require a writable `hugetlbfs` mount because the main CPU KV cache must be reopened from the same HugePage-backed file inside spawned workers. + +--- + +## 4. Configuration + +HugePage is now a formal user-facing configuration surface in FlexKV. It can be configured through either configuration files or environment variables. + +### 4.1 Configuration File + +YAML example: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +JSON example: + +```json +{ + "cpu_cache_gb": 32, + "ssd_cache_gb": 1024, + "ssd_cache_dir": "/data/flexkv_ssd/", + "enable_p2p_ssd": true, + "use_hugepage_cpu_buffer": true, + "use_hugepage_tmp_buffer": true, + "hugepage_size_bytes": 2097152 +} +``` + +### 4.2 Environment Variables + +```bash +export FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1 +export FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1 +export FLEXKV_HUGEPAGE_SIZE_BYTES=2097152 +``` + +Meaning: + +- `FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1` + Enables HugePage allocation for the main CPU KV cache. +- `FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1` + Enables HugePage allocation for the temporary CPU staging buffer in the p2p SSD path. +- `FLEXKV_HUGEPAGE_SIZE_BYTES=2097152` + Uses 2 MiB HugePages. + +If the host is configured for 1 GiB HugePages: + +```bash +export FLEXKV_HUGEPAGE_SIZE_BYTES=1073741824 +``` + +### 4.3 How to Choose Between the Two Switches + +- If you only want to optimize the main CPU KV cache, enable `use_hugepage_cpu_buffer`. +- If you only want to optimize the temporary staging buffer in the p2p SSD path, enable `use_hugepage_tmp_buffer` and make sure `enable_p2p_ssd=true` is set. +- If both paths matter, enable both switches. + +--- + +## 5. Recommended Enablement Order + +For a first rollout, the recommended order is: + +1. Prepare 2 MiB HugePages on the host and confirm that hugetlbfs is mounted correctly. +2. Enable `use_hugepage_cpu_buffer=true` first and verify that the main CPU KV cache works correctly. +3. If you also need to validate the p2p SSD path, then enable `use_hugepage_tmp_buffer=true`. +4. Only after the feature is stable should you evaluate switching to 1 GiB HugePages. + +For initial validation, 2 MiB HugePages are recommended because host setup is simpler and troubleshooting is more straightforward. + +--- + +## 6. How to Verify It Is Working + +### 6.1 Check the Logs + +If the tmp staging buffer successfully uses HugePages, logs will contain a message similar to: + +```text +[PEER2CPUTransferWorker] tmp_cpu_buffer allocated on HugePages: 2.000 GB +``` + +If the main CPU KV cache successfully uses HugePages, you will typically also see a log similar to: + +```text +HugePage allocate total_size: ... GB (page_size=2MiB) +``` + +If the HugePage path for the tmp staging buffer fails and falls back, logs will contain a message similar to: + +```text +[PEER2CPUTransferWorker] HugePage allocation for tmp_cpu_buffer failed (...); falling back to torch.empty(pin_memory=True). +``` + +If `use_hugepage_cpu_buffer=true` but the hugetlbfs mount is invalid, logs will contain a message similar to: + +```text +HugePage allocation failed (HugePage: /path is not a hugetlbfs mount ...); falling back to regular CPU memory. +``` + +### 6.2 Check HugePage Counters + +Before and after the service starts, run: + +```bash +grep -E 'HugePages_Total|HugePages_Free|Hugepagesize' /proc/meminfo +``` + +If HugePage allocation is active, you will typically observe: + +- `HugePages_Total` unchanged +- `HugePages_Free` decreased + +After the service exits and releases resources, `HugePages_Free` should return close to its original value. + +### 6.3 Run the Test Suite + +If the machine already satisfies the HugePage and CUDA requirements, run: + +```bash +PYTHONDONTWRITEBYTECODE=1 python3 -m pytest -q tests/hugepage -rs +``` + +This test suite validates: + +- HugePage allocation and release +- HugePage configuration flow for the CPU KV cache +- HugePage configuration flow for the tmp staging buffer +- Fallback behavior when HugePage allocation cannot be used + +--- + +## 7. Common Configuration Errors + +### 7.1 `use_hugepage_tmp_buffer` Is Enabled but Does Not Take Effect + +Check the following items in order: + +- `enable_p2p_ssd=true` is actually enabled +- The host has enough HugePages reserved +- hugetlbfs is mounted +- `FLEXKV_HUGETLBFS_DIR` points to the correct mount +- CUDA runtime is available +- `memlock` is not too small + +### 7.2 HugePage Is Enabled but There Is No Error and No Performance Gain + +This usually means the HugePage path has already fallen back to regular memory. + +FlexKV treats HugePage as a best-effort optimization with automatic fallback. Service startup success alone is not sufficient evidence that HugePage is active. You must confirm through logs and `/proc/meminfo`. + +Also distinguish the two common cases: + +- For `use_hugepage_cpu_buffer`, a writable `hugetlbfs` mount is mandatory even if the host has reserved HugePages. +- For `use_hugepage_tmp_buffer`, anonymous HugePages may still work even without `hugetlbfs`. + +### 7.3 1 GiB HugePages Do Not Work After Configuration + +The most common reason is that the host does not actually have a 1 GiB HugePage pool available. Confirm the following: + +- Kernel boot parameters are set correctly +- The host has a real 1 GiB HugePage reservation +- `hugepage_size_bytes` matches the HugePage type actually available on the machine + +For initial rollout, it is better to validate functionality with 2 MiB HugePages first and move to 1 GiB only afterward. + +--- + +## 8. Minimal Working Examples + +If you want to validate both the main CPU KV cache and the p2p SSD tmp buffer HugePage paths, the following is a minimal example: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +If you only want to validate the main CPU KV cache path: + +```yaml +cpu_cache_gb: 32 +use_hugepage_cpu_buffer: true +hugepage_size_bytes: 2097152 +``` diff --git a/docs/hugepage/README_zh.md b/docs/hugepage/README_zh.md new file mode 100644 index 0000000000..b4799f7898 --- /dev/null +++ b/docs/hugepage/README_zh.md @@ -0,0 +1,299 @@ +# FlexKV HugePage 使用指南 + +## 一、功能概述 + +FlexKV 当前支持两类 HugePage 配置项: + +- `use_hugepage_cpu_buffer` + 控制通用 CPU KV Cache 是否优先使用 HugePage 分配。 +- `use_hugepage_tmp_buffer` + 控制 `enable_p2p_ssd=true` 场景下临时 CPU staging buffer 是否优先使用 HugePage 分配。 +- `hugepage_size_bytes` + 控制上述两类内存申请时使用的 HugePage 大小。 + +两类开关可以独立启用,也可以同时启用。 + +当前实现上有一个重要限制: + +- `use_hugepage_cpu_buffer` 对主 CPU KV cache 的生效路径,必须依赖 `hugetlbfs` 挂载文件来保证在 `spawn` worker 场景下仍然保持 HugePage backing。 +- 纯匿名 `MAP_HUGETLB` 只用于单进程或不可共享场景;一旦主 CPU cache 被传给 `TransferEngine` 的子进程,PyTorch 会把普通 CPU tensor 序列化成新的 shared-memory storage,匿名映射无法继续作为跨进程共享后端,因此不能满足主 CPU cache 的目标语义。 +- `use_hugepage_tmp_buffer` 仍然可以走匿名 HugePage 或 hugetlbfs,两者都不会经过上述主 cache 的跨进程共享问题。 + +--- + +## 二、适用场景 + +建议在以下场景启用 HugePage: + +- CPU KV Cache 容量较大,希望降低页表和 TLB 开销。 +- 已启用 `enable_p2p_ssd=true`,并希望优化 `SSD -> CPU -> GPU` 数据路径中的临时 staging buffer。 +- 已完成宿主机 HugePage 预留和 hugetlbfs 挂载,具备稳定的系统运行条件。 + +如果机器没有预留 HugePage,或者当前并不依赖 CPU KV Cache / p2p SSD 路径上的性能收益,不建议启用。 + +--- + +## 三、前置条件 + +### 3.1 宿主机已预留 HugePage + +先检查系统状态: + +```bash +grep -E 'HugePages_|Hugepagesize' /proc/meminfo +``` + +以 2 MiB HugePage 为例,预留 4096 个页,即约 8 GiB: + +```bash +sudo sysctl -w vm.nr_hugepages=4096 +``` + +如果使用 1 GiB HugePage,通常需要在内核启动参数中预留,例如: + +```text +default_hugepagesz=1G hugepagesz=1G hugepages=N +``` + +### 3.2 宿主机已挂载 hugetlbfs + +FlexKV 默认使用 `/mnt/hugepages` 作为 hugetlbfs 挂载点。 + +说明:对于 `use_hugepage_cpu_buffer`,这一步不是“建议”,而是必须条件。如果 `FLEXKV_HUGETLBFS_DIR` 指向普通文件系统,FlexKV 现在会直接判定失败并回退,不再把普通 4 KiB 页误判成 HugePage 成功。 + +检查挂载状态: + +```bash +mount | grep hugetlbfs +ls -ld /mnt/hugepages +``` + +如果尚未挂载,可执行: + +```bash +sudo mkdir -p /mnt/hugepages +sudo mount -t hugetlbfs none /mnt/hugepages +``` + +如果实际挂载点不是 `/mnt/hugepages`,需要显式设置: + +```bash +export FLEXKV_HUGETLBFS_DIR=/path/to/hugetlbfs +``` + +### 3.3 tmp buffer 场景需要 CUDA 运行时 + +`use_hugepage_tmp_buffer` 对应的 staging buffer 在 HugePage 分配成功后还会执行 `cudaHostRegister`。因此这一路径要求: + +- CUDA runtime 可用 +- `libcudart.so` 可正常加载 +- 容器或宿主机的 `memlock` 限制不要过小 + +可先做基础检查: + +```bash +python3 - <<'PY' +import torch +print(torch.cuda.is_available()) +PY +``` + +说明:`use_hugepage_cpu_buffer` 不依赖 `cudaHostRegister`。 + +补充说明:`use_hugepage_cpu_buffer` 虽然不依赖 CUDA runtime,但依赖可写的 hugetlbfs 挂载点,因为主 CPU KV cache 需要通过该文件在 `spawn` worker 间重新打开同一块 HugePage-backed 映射。 + +--- + +## 四、配置方式 + +FlexKV 已将 HugePage 作为正式用户配置项,支持配置文件和环境变量两种方式。 + +### 4.1 配置文件 + +YAML 示例: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +JSON 示例: + +```json +{ + "cpu_cache_gb": 32, + "ssd_cache_gb": 1024, + "ssd_cache_dir": "/data/flexkv_ssd/", + "enable_p2p_ssd": true, + "use_hugepage_cpu_buffer": true, + "use_hugepage_tmp_buffer": true, + "hugepage_size_bytes": 2097152 +} +``` + +### 4.2 环境变量 + +```bash +export FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1 +export FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1 +export FLEXKV_HUGEPAGE_SIZE_BYTES=2097152 +``` + +说明: + +- `FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1` + 为通用 CPU KV Cache 启用 HugePage。 +- `FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1` + 为 p2p SSD 场景下的临时 CPU staging buffer 启用 HugePage。 +- `FLEXKV_HUGEPAGE_SIZE_BYTES=2097152` + 表示使用 2 MiB HugePage。 + +如果宿主机准备的是 1 GiB HugePage,可设置为: + +```bash +export FLEXKV_HUGEPAGE_SIZE_BYTES=1073741824 +``` + +### 4.3 两个开关的选择原则 + +- 只需要优化通用 CPU KV Cache:开启 `use_hugepage_cpu_buffer`。 +- 只需要优化 p2p SSD 的临时 staging buffer:开启 `use_hugepage_tmp_buffer`,同时确保 `enable_p2p_ssd=true`。 +- 两条路径都需要:两个开关同时开启。 + +--- + +## 五、推荐启用顺序 + +首次接入建议按以下顺序进行: + +1. 先在宿主机准备 2 MiB HugePage,并确认 hugetlbfs 挂载正常。 +2. 先只启用 `use_hugepage_cpu_buffer=true`,验证通用 CPU KV Cache 可正常工作。 +3. 如果还需要验证 p2p SSD 路径,再启用 `use_hugepage_tmp_buffer=true`。 +4. 在确认功能稳定后,再根据机器环境评估是否切换到 1 GiB HugePage。 + +推荐第一轮验证优先使用 2 MiB HugePage。它的系统准备成本更低,排障也更直接。 + +--- + +## 六、如何确认已经生效 + +### 6.1 检查日志 + +如果 tmp staging buffer 成功使用 HugePage,日志会出现类似信息: + +```text +[PEER2CPUTransferWorker] tmp_cpu_buffer allocated on HugePages: 2.000 GB +``` + +如果主 CPU KV cache 成功使用 HugePage,通常会先看到类似日志: + +```text +HugePage allocate total_size: ... GB (page_size=2MiB) +``` + +如果 tmp staging buffer 的 HugePage 路径失败并回退,日志会出现类似信息: + +```text +[PEER2CPUTransferWorker] HugePage allocation for tmp_cpu_buffer failed (...); falling back to torch.empty(pin_memory=True). +``` + +如果 `use_hugepage_cpu_buffer=true` 但 hugetlbfs 挂载不正确,日志会出现类似信息: + +```text +HugePage allocation failed (HugePage: /path is not a hugetlbfs mount ...); falling back to regular CPU memory. +``` + +### 6.2 检查 HugePage 计数 + +在服务启动前后分别执行: + +```bash +grep -E 'HugePages_Total|HugePages_Free|Hugepagesize' /proc/meminfo +``` + +如果 HugePage 分配生效,通常可以观察到: + +- `HugePages_Total` 不变 +- `HugePages_Free` 下降 + +服务退出并释放资源后,`HugePages_Free` 应恢复到接近启动前的水平。 + +### 6.3 运行测试 + +如果机器已具备 HugePage 和 CUDA 条件,可执行: + +```bash +PYTHONDONTWRITEBYTECODE=1 python3 -m pytest -q tests/hugepage -rs +``` + +该测试集可用于验证: + +- HugePage 分配与释放 +- CPU KV Cache 的 HugePage 配置路径 +- tmp staging buffer 的 HugePage 配置路径 +- HugePage 失败后的回退行为 + +--- + +## 七、常见配置错误 + +### 7.1 开启了 `use_hugepage_tmp_buffer`,但实际上没有生效 + +请依次检查: + +- 是否同时设置了 `enable_p2p_ssd=true` +- 宿主机是否预留了足够的 HugePage +- hugetlbfs 是否已挂载 +- `FLEXKV_HUGETLBFS_DIR` 是否指向正确挂载点 +- CUDA runtime 是否可用 +- `memlock` 限制是否过小 + +### 7.2 开启了 HugePage,但服务没有报错也没有性能收益 + +这通常意味着 HugePage 路径已经回退到普通内存分配。 + +FlexKV 对 HugePage 采用的是“失败自动回退”策略,因此不能仅以服务是否启动成功来判断功能是否生效,必须结合日志和 `/proc/meminfo` 一起确认。 + +另外需要区分两类原因: + +- `use_hugepage_cpu_buffer` 场景下,如果没有可写 hugetlbfs 挂载点,即使系统里预留了 HugePage,也不会被视为可用配置。 +- `use_hugepage_tmp_buffer` 场景下,如果匿名 HugePage 成功,可能不依赖 hugetlbfs 挂载。 + +### 7.3 使用 1 GiB HugePage 后启动失败或无法生效 + +最常见原因是宿主机并未真正准备 1 GiB HugePage 池。请确认: + +- 内核启动参数已正确设置 +- 宿主机已实际预留 1 GiB HugePage +- `hugepage_size_bytes` 与系统中实际可用的 HugePage 类型一致 + +如果是首次接入,建议先回到 2 MiB HugePage 完成功能验证,再切换到 1 GiB。 + +--- + +## 八、最小可用配置示例 + +如果你的目标是同时验证通用 CPU KV Cache 和 p2p SSD tmp buffer 的 HugePage 功能,可以使用以下最小配置: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +如果你当前只想验证通用 CPU KV Cache,则可以只保留: + +```yaml +cpu_cache_gb: 32 +use_hugepage_cpu_buffer: true +hugepage_size_bytes: 2097152 +``` diff --git a/flexkv/__init__.py b/flexkv/__init__.py index 936f51cb3f..97a6943e67 100644 --- a/flexkv/__init__.py +++ b/flexkv/__init__.py @@ -1,15 +1,24 @@ +import ctypes +import glob import os import sys # Add package lib directory to system library path def _setup_library_path() -> None: - """Setup library path to find shared libraries in the package""" + """Setup library path to find shared libraries in the package. + + Note: Modifying LD_LIBRARY_PATH at runtime does NOT affect the current + process's dynamic linker (ld.so reads it only at startup). We still set it + for child processes, but for the current process we must pre-load required + shared libraries via ctypes.CDLL with RTLD_GLOBAL so that subsequent + dlopen() calls (e.g. when importing c_ext) can resolve them. + """ package_dir = os.path.dirname(os.path.abspath(__file__)) lib_dir = os.path.join(package_dir, "lib") if os.path.exists(lib_dir): - # Add to LD_LIBRARY_PATH for Linux + # Set LD_LIBRARY_PATH for child processes if sys.platform.startswith('linux'): current_ld_path = os.environ.get('LD_LIBRARY_PATH', '') if lib_dir not in current_ld_path: @@ -18,6 +27,14 @@ def _setup_library_path() -> None: else: os.environ['LD_LIBRARY_PATH'] = lib_dir + # Pre-load shared libraries into the current process so that + # c_ext (loaded via dlopen) can find them. + for so_file in sorted(glob.glob(os.path.join(lib_dir, "*.so*"))): + try: + ctypes.CDLL(so_file, mode=ctypes.RTLD_GLOBAL) + except OSError: + pass # non-critical: library may not be needed + # Add to sys.path for loading if lib_dir not in sys.path: sys.path.insert(0, lib_dir) @@ -25,3 +42,12 @@ def _setup_library_path() -> None: # Call the setup function when the package is imported _setup_library_path() + +# ``flexkv.c_ext`` is a PyTorch C++ extension and dynamically links against +# ``libc10.so`` / ``libtorch*.so`` from the installed ``torch`` package. Those +# libraries live under ``/torch/lib`` and are NOT on the system +# linker search path. Importing ``torch`` here causes Python to ``dlopen`` them +# (with RTLD_GLOBAL), so any subsequent ``import flexkv.c_ext`` can resolve +# them without requiring the caller to ``import torch`` first or to set +# ``LD_LIBRARY_PATH``. +import torch # noqa: E402,F401 (side-effect import: load libtorch/libc10) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index e4a198ccb9..1b78d39272 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -76,7 +76,7 @@ def __init__(self, self.num_total_blocks = num_total_blocks self.evict_ratio = evict_ratio self.evict_start_threshold = evict_start_threshold - + self.event_collector = event_collector self._metrics_collector = metrics_collector @@ -90,15 +90,13 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: sequence_meta.num_blocks, True) # physical blocks (torch.Tensor -> numpy, zero-copy on CPU) phys = match_result.physical_blocks.cpu().numpy() - # optional block_node_ids - try: - bnis = getattr(match_result, "block_node_ids", None) - if isinstance(bnis, torch.Tensor) and bnis.numel() > 0: - bnids_np = bnis.cpu().numpy() - else: - bnids_np = None - except Exception: - bnids_np = None + # Extract single matched_node_id (single-node constraint) + raw_nid = getattr(match_result, "matched_node_id", -1) + single_node_id = int(raw_nid) if raw_nid is not None and raw_nid >= 0 else None + # Broadcast matched_node_id to per-block array for downstream compat + bnids_np = None + if single_node_id is not None and len(phys) > 0: + bnids_np = np.full(len(phys), single_node_id, dtype=np.uint32) return MatchResultAccel( num_ready_matched_blocks=match_result.num_ready_matched_blocks, num_matched_blocks=match_result.num_matched_blocks, @@ -106,6 +104,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: last_node=match_result.last_node, last_node_matched_length=match_result.last_node_matched_length, physical_blocks=phys, + matched_node_id=single_node_id, block_node_ids=bnids_np, matched_pos="remote" if self.device_type == DeviceType.REMOTE else "local", ) @@ -156,25 +155,25 @@ def take(self, strict: bool = True) -> torch.Tensor: # Calculate current utilization utilization = (self.mempool.num_total_blocks - self.mempool.num_free_blocks) / self.mempool.num_total_blocks if self.mempool.num_total_blocks > 0 else 0 - + # Proactive eviction: trigger when utilization exceeds threshold OR when blocks are needed should_evict = (utilization >= self.evict_start_threshold) or (num_required_blocks > self.mempool.num_free_blocks) - + if should_evict: if protected_node is not None: self.index.lock(protected_node) - + # Calculate how many blocks to evict # Goal: maintain free blocks above (1 - evict_start_threshold) ratio target_free_blocks = int(self.mempool.num_total_blocks * (1.0 - self.evict_start_threshold)) evict_to_reach_target = max(0, target_free_blocks - self.mempool.num_free_blocks) - + evict_block_num = max( num_required_blocks - self.mempool.num_free_blocks, # At least meet current demand evict_to_reach_target, # Or reach target free ratio int(self.mempool.num_total_blocks * self.evict_ratio) if self.evict_ratio > 0 else 0 # Or minimum evict_ratio ) - + if evict_block_num > 0: target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) evicted_block_hashes = torch.zeros(evict_block_num, dtype=torch.int64) @@ -196,18 +195,18 @@ def take(self, ) if protected_node is not None: self.index.unlock(protected_node) - + if strict and num_required_blocks > self.mempool.num_free_blocks: raise RuntimeError(f"Not enough free blocks to take, " f"required: {num_required_blocks}, " f"available: {self.mempool.num_free_blocks}") num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) allocated_blocks = self.mempool.allocate_blocks(num_allocated_blocks) - + # Record allocation metrics if self._metrics_collector is not None and num_allocated_blocks > 0: self._metrics_collector.record_allocation(DEVICE_TYPE[self.device_type].lower(), num_allocated_blocks) - + return allocated_blocks def recycle(self, physical_blocks: np.ndarray) -> None: @@ -290,19 +289,19 @@ def take(self, strict: bool = True) -> np.ndarray: # Calculate current utilization utilization = (self.mempool.num_total_blocks - self.mempool.num_free_blocks) / self.mempool.num_total_blocks if self.mempool.num_total_blocks > 0 else 0 - + # Proactive eviction: trigger when utilization exceeds threshold OR when blocks are needed should_evict = (utilization >= self.evict_start_threshold) or (num_required_blocks > self.mempool.num_free_blocks) - + if should_evict: if protected_node is not None: self.index.lock(protected_node) - + # Calculate how many blocks to evict # Goal: maintain free blocks above (1 - evict_start_threshold) ratio target_free_blocks = int(self.mempool.num_total_blocks * (1.0 - self.evict_start_threshold)) evict_to_reach_target = max(0, target_free_blocks - self.mempool.num_free_blocks) - + evict_block_num = max( num_required_blocks - self.mempool.num_free_blocks, # At least meet current demand evict_to_reach_target, # Or reach target free ratio @@ -311,28 +310,28 @@ def take(self, if evict_block_num > 0: evicted_blocks, evicted_block_hashes = self.index.evict(evict_block_num) self.mempool.recycle_blocks(evicted_blocks) - + # Record eviction metrics if self._metrics_collector is not None and len(evicted_blocks) > 0: self._metrics_collector.record_eviction(DEVICE_TYPE[self.device_type].lower(), len(evicted_blocks)) - + if self.event_collector is not None: self.event_collector.publish_removed(block_hashes=evicted_block_hashes, medium=DEVICE_TYPE[self.device_type]) if protected_node is not None: self.index.unlock(protected_node) - + if strict and num_required_blocks > self.mempool.num_free_blocks: raise RuntimeError("Not enough free blocks to take, ", f"required: {num_required_blocks}, " f"available: {self.mempool.num_free_blocks}") num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) allocated_blocks = self.mempool.allocate_blocks(num_allocated_blocks) - + # Record allocation metrics if self._metrics_collector is not None and num_allocated_blocks > 0: self._metrics_collector.record_allocation(DEVICE_TYPE[self.device_type].lower(), num_allocated_blocks) - + return allocated_blocks def recycle(self, physical_blocks: np.ndarray) -> None: @@ -351,6 +350,8 @@ class CacheStrategy: DEFAULT_CACHE_STRATEGY = CacheStrategy() +CPUONLY_CACHE_STRATEGY = CacheStrategy(ignore_gpu=False, ignore_ssd=True, ignore_remote=True, ignore_gds=True) + class GlobalCacheEngine: def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_meta: RedisMeta = None, event_collector: Optional[KVEventCollector] = None): @@ -396,7 +397,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m if cache_config.enable_cpu: if cache_config.enable_p2p_cpu: - self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta) #TODO + self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta) elif self.index_accel: self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, @@ -420,7 +421,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: if cache_config.enable_p2p_ssd: - self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta) #TODO + self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta) elif self.index_accel: self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, cache_config.num_ssd_blocks, @@ -475,7 +476,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0) self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]] = \ lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0, 0) - + # Update initial mempool stats self._update_mempool_metrics() @@ -507,10 +508,10 @@ def _update_mempool_metrics(self) -> None: engine.mempool.num_total_blocks, engine.mempool.num_free_blocks ) - + def _record_transfer_ops(self, transfer_graph: TransferOpGraph, operation: str) -> None: """Record metrics for all transfer operations in the graph. - + Args: transfer_graph: The transfer operation graph operation: Operation type ("get" or "put") @@ -530,14 +531,15 @@ def get(self, slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, temp_cache_strategy: CacheStrategy = DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None) \ -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: - layer_num = self.model_config.num_layers + layer_num = self.model_config.num_layers_per_pp_stage if layer_granularity == -1: layer_granularity = layer_num @@ -557,6 +559,13 @@ def get(self, aligned_token_ids = token_ids[:aligned_length] token_mask[aligned_length:] = False + if aligned_length == 0 or not token_mask.any(): + transfer_graph = TransferOpGraph.create_empty_graph() + transfer_graph.bind_to_worker(dp_rank, pp_rank) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) + callback = partial(self._transfer_callback, node_to_unlock={}, buffer_to_free={}) + return transfer_graph, return_mask, callback, {}, -1 + block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, @@ -607,7 +616,7 @@ def get(self, # finished_ops_ids=finished_ops_ids, # layer_num=layer_num, # layer_granularity=layer_granularity) - transfer_graph.bind_to_dp_group(dp_id) + transfer_graph.bind_to_worker(dp_rank, pp_rank) for device_type in node_to_unlock: self.cache_engines[device_type].lock_node(node_to_unlock[device_type][0]) @@ -622,12 +631,12 @@ def get(self, device_type=op_node_to_ready[op_id][0], node_to_ready=op_node_to_ready[op_id][1], ready_length=op_node_to_ready[op_id][2]) - + # Record metrics for GET operation if self._metrics_collector is not None: self._record_transfer_ops(transfer_graph, "get") self._update_mempool_metrics() - + return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id def _get_impl_global(self, @@ -965,7 +974,7 @@ def _get_impl_local(self, transfer_graph.add_transfer_op(op_peerh2h) #TODO here we dont combine peer cpu or local cpu match results, so we can safely add remote results to local cpu #TODO here assume all matched blocks are ready blocks for peer cpu - if (cpu_matched_result.insert_to_local_cpu_index and + if (cpu_matched_result.insert_to_local_cpu_index and cpu_matched_result.num_ready_matched_blocks >= block_mask_start and cpu_matched_result.num_ready_matched_blocks == cpu_matched_result.num_matched_blocks): cpu_node_to_unlock = self.cpu_cache_engine.insert(sequence_meta, @@ -1064,14 +1073,15 @@ def put(self, token_mask: np.ndarray, slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, temp_cache_strategy: CacheStrategy = DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None) \ -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: - layer_num = self.model_config.num_layers + layer_num = self.model_config.num_layers_per_pp_stage # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] @@ -1122,7 +1132,7 @@ def put(self, return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[(block_start_idx + skipped_gpu_blocks)* self.tokens_per_block: (block_start_idx + skipped_gpu_blocks + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True - transfer_graph.bind_to_dp_group(dp_id) + transfer_graph.bind_to_worker(dp_rank, pp_rank) for device_type in node_to_unlock: self.cache_engines[device_type].lock_node(node_to_unlock[device_type][0]) @@ -1359,11 +1369,11 @@ def _put_impl_local(self, :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ :ssd_matched_result.num_matched_blocks][block_mask_start:block_mask_end] - + #if len(cpu_matched_blocks) > len(ssd_matched_blocks): # print(f"[PUT_LOCAL] CPU matched blocks are greater than SSD matched blocks, skipping") # return self._empty_put_return(request_id) - + num_skipped_blocks = len(cpu_matched_blocks) fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks diff --git a/flexkv/cache/hie_cache_engine.py b/flexkv/cache/hie_cache_engine.py index 52c97b3a2b..fabe741cd9 100644 --- a/flexkv/cache/hie_cache_engine.py +++ b/flexkv/cache/hie_cache_engine.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple, TYPE_CHECKING, List, Dict +import time import numpy as np import torch @@ -31,7 +32,7 @@ def __init__(self, remote_max_num_blocks: int = 4000000, redis_node_id: int = 0, remote_refresh_batch_size: int = 1000, - remote_rebuild_interval_ms: int = 10000, + remote_rebuild_interval_ms: int = 100, remote_idle_sleep_ms: int = 10, local_safety_ttl_ms: int = 100, evict_start_threshold: float = 1.0, @@ -102,17 +103,18 @@ def start(self) -> None: if self._meta is None: raise ValueError("RedisMeta is not provided; ensure from_cache_config stores it or pass it to start().") #TODO can we use like this to distinguish the different tree pairs? + # Determine base block key prefix by device type if self.device_type == DeviceType.REMOTE: - local_ch_block_key = "PCFSB" - remote_ch_block_key = "PCFSB" + base_key = "PCFSB" elif self.device_type == DeviceType.CPU: - local_ch_block_key = "CPUB" - remote_ch_block_key = "CPUB" + base_key = "CPUB" elif self.device_type == DeviceType.SSD: - local_ch_block_key = "SSDB" - remote_ch_block_key = "SSDB" + base_key = "SSDB" else: raise ValueError(f"Invalid device type: {self.device_type}") + + local_ch_block_key = base_key + remote_ch_block_key = base_key self.remote_ch = self._meta.get_redis_meta_channel(remote_ch_block_key) self.local_ch = self._meta.get_redis_meta_channel(local_ch_block_key) # Load and store mapping of node_id -> file_nodeids from Redis @@ -152,7 +154,6 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> num_blocks = sequence_meta.num_blocks # Query both local and remote - import time t0 = time.perf_counter() mr_local = self.local_index.match_prefix(block_hashes_t, int(num_blocks), True) t1 = time.perf_counter() @@ -200,37 +201,37 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> # physical blocks bnids_np = None + single_node_id = None if chosen is mr_remote: - #try to use DistributedRadixTree's block_node_ids - #if check fails, use LocalRadixTree's match result - nids = chosen.block_node_ids - nps = chosen.physical_blocks - # Convert tensors to numpy views (CPU) if present - if isinstance(nids, torch.Tensor) and nids.numel() > 0: - # For P2P mode (CPU/SSD), no PCFS conversion is needed - # Only convert to PCFS file_nodeids if device_type is REMOTE - if self.device_type == DeviceType.REMOTE: - bnids_np = self.nodeids_to_file_nodeids(nids.cpu().numpy(), nps.cpu().numpy()) - if bnids_np is None: - chosen = mr_local - matched_pos = "local" # Update matched_pos after fallback - else: - # For P2P mode, use node_ids directly - bnids_np = nids.cpu().numpy().astype(np.uint32) - #print(f"[REMOTE_MATCH {self.device_type.name}] Using remote data: block_ids={nps.cpu().numpy()[:min(4, len(nps))]}, node_ids={bnids_np[:min(4, len(bnids_np))]}") + # Extract single matched_node_id from CMatchResult (single-node constraint) + raw_node_id = getattr(chosen, "matched_node_id", -1) + if raw_node_id is not None and raw_node_id >= 0: + single_node_id = int(raw_node_id) + nps = chosen.physical_blocks + num_blocks = nps.shape[0] if isinstance(nps, torch.Tensor) else len(nps) + if num_blocks > 0: + # Broadcast single node_id to per-block array for downstream compat + raw_nids = np.full(num_blocks, single_node_id, dtype=np.uint32) + if self.device_type == DeviceType.REMOTE: + bnids_np = self.nodeids_to_file_nodeids(raw_nids, nps.cpu().numpy()) + if bnids_np is None: + chosen = mr_local + matched_pos = "local" + single_node_id = None + else: + bnids_np = raw_nids else: - bnids_np = None + # No valid matched_node_id → fall back to local if mr_remote.num_matched_blocks > 0: - #print(f"[REMOTE_MATCH {self.device_type.name}] Warning: remote matched but block_node_ids is empty, falling back to local") chosen = mr_local - matched_pos = "local" # Update matched_pos after fallback + matched_pos = "local" + single_node_id = None phys_np = chosen.physical_blocks.cpu().numpy() #maybe we should always not insert if self.device_type == DeviceType.CPU and matched_pos == "remote" and mr_local.num_matched_blocks > 0: insert_to_local_cpu_index = False else: insert_to_local_cpu_index = True - #TODO A big question is how to get the node id for peer_cpu and peer_ssd? return MatchResultAccel( num_ready_matched_blocks=int(chosen.num_ready_matched_blocks), num_matched_blocks=int(chosen.num_matched_blocks), @@ -238,9 +239,10 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> last_node=chosen.last_node, last_node_matched_length=int(chosen.last_node_matched_length), physical_blocks=phys_np, + matched_node_id=single_node_id, block_node_ids=bnids_np, matched_pos=matched_pos, - matched_node_ids=bnids_np, # Set matched_node_ids for P2P transfer + matched_node_ids=bnids_np, # deprecated: kept for backward compat insert_to_local_cpu_index=insert_to_local_cpu_index, ) @@ -565,4 +567,3 @@ def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_typ meta=meta, ) raise ValueError("Invalid device type: {cache_config.device_type}") - diff --git a/flexkv/cache/mempool.py b/flexkv/cache/mempool.py index 4decae1b8a..7d3df3797d 100644 --- a/flexkv/cache/mempool.py +++ b/flexkv/cache/mempool.py @@ -43,17 +43,11 @@ def recycle_blocks(self, block_ids: np.ndarray) -> None: if block_ids.ndim != 1 or block_ids.dtype != np.int64: raise ValueError("block_ids must be a 1D tensor of int64") - # Remove duplicates first (same block ID appearing multiple times) block_ids = np.unique(block_ids) - - # Filter out already-free blocks to avoid double-free errors - # This can happen due to race conditions or eviction edge cases + already_free = self._free_mask[block_ids] if already_free.any(): - # Only recycle blocks that are actually in use - block_ids = block_ids[~already_free] - if len(block_ids) == 0: - return # Nothing to recycle + raise ValueError(f"block_ids {block_ids[already_free]} are already free") self._free_mask[block_ids] = True self._num_free += len(block_ids) diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index b2888bf838..1721f80f85 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Iterable, List, Tuple, Optional, Union, Dict from dataclasses import dataclass from enum import IntEnum @@ -136,8 +137,12 @@ def delete_blockmeta_batch(self, node_id: int, hashes: Iterable[int], batch_size class RedisNodeInfo: """Redis node information management class implemented in Python""" + + # Default TTL for node: key in seconds. Active nodes renew before expiry. + # If a process crashes (kill -9), the key auto-expires after this period. + DEFAULT_NODE_TTL_SECONDS: int = 30 - def __init__(self, host: str, port: int, local_ip: str, password: str = "") -> None: + def __init__(self, host: str, port: int, local_ip: str, password: str = "", node_ttl_seconds: int = 0) -> None: if _redis is None: raise ImportError("redis-py is required: pip install redis") self.host = host @@ -145,12 +150,17 @@ def __init__(self, host: str, port: int, local_ip: str, password: str = "") -> N self.local_ip = str(local_ip) self.password = str(password) self.uuid = str(uuid1()) + # Use provided TTL or fall back to default + self.node_ttl_seconds: int = node_ttl_seconds if node_ttl_seconds > 0 else self.DEFAULT_NODE_TTL_SECONDS + # Heartbeat interval – renew TTL at roughly 1/3 of the TTL period + self.heartbeat_interval_seconds: float = max(1.0, self.node_ttl_seconds / 3.0) self._node_id: Optional[int] = None self._running = False self._listener_thread: Optional[threading.Thread] = None + self._heartbeat_thread: Optional[threading.Thread] = None self.current_node_id_set: set = set() - self._client: Optional[_redis.Redis] = None - self._sub_client: Optional[_redis.Redis] = None + self._client: Optional["_redis.Redis"] = None + self._sub_client: Optional["_redis.Redis"] = None self._cleanup_done = False # register cleanup function on exit @@ -166,7 +176,7 @@ def __del__(self) -> None: # ignore exceptions in destructor, avoid affecting program exit pass - def _get_client(self) -> _redis.Redis: + def _get_client(self) -> "_redis.Redis": """Get Redis client with connection settings""" return _redis.Redis( host=self.host, @@ -178,7 +188,7 @@ def _get_client(self) -> _redis.Redis: ) def connect(self) -> bool: - """Connect to Redis and start listener thread""" + """Connect to Redis and start listener + heartbeat threads""" try: self._client = self._get_client() # Test connection @@ -192,17 +202,29 @@ def connect(self) -> bool: daemon=True ) self._listener_thread.start() + + # Start heartbeat thread for TTL renewal + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_worker, + name="redis-node-heartbeat", + daemon=True + ) + self._heartbeat_thread.start() return True except Exception: return False def disconnect(self) -> None: - """Disconnect from Redis and stop listener thread""" + """Disconnect from Redis and stop listener + heartbeat threads""" self._running = False if self._listener_thread and self._listener_thread.is_alive(): self._listener_thread.join(timeout=2.0) self._listener_thread = None + + if self._heartbeat_thread and self._heartbeat_thread.is_alive(): + self._heartbeat_thread.join(timeout=2.0) + self._heartbeat_thread = None if self._client: self._client.close() @@ -240,11 +262,14 @@ def _cleanup(self) -> None: pass def register_node(self) -> Optional[int]: - """Register a new node and get node_id""" + """Register a new node and get node_id, with TTL for automatic expiry on crash""" if not self._client: return None try: + # Clean up stale nodes from the same IP before registering + self._cleanup_stale_nodes_by_ip() + # Atomically increment global:node_id to get new node_id node_id = self._client.incr("global:node_id") self._node_id = node_id @@ -257,8 +282,13 @@ def register_node(self) -> Optional[int]: "local_ip": self.local_ip, # Keep for backward compatibility "uuid": self.uuid, "status": "active", - "timestamp": str(int(time.time())) + "timestamp": str(int(time.time())), + "pp_rank": str(getattr(self, 'pp_rank', 0)), + "pp_size": str(getattr(self, 'pp_size', 1)), }) + + # Set TTL so the key auto-expires if the process crashes + self._client.expire(node_key, self.node_ttl_seconds) # Publish node update event self._client.publish("flexkv_node_id_updated", str(node_id)) @@ -268,17 +298,22 @@ def register_node(self) -> Optional[int]: return None def unregister_node(self) -> bool: - """Unregister current node""" + """Unregister current node and clean up associated meta/block data""" if not self._client or self._node_id is None: return False try: + node_id = self._node_id + # Delete node:node_id key - node_key = f"node:{self._node_id}" + node_key = f"node:{node_id}" self._client.delete(node_key) + + # Also clean up meta: to prevent stale RDMA addresses + self._cleanup_node_data(node_id) # Publish node update event - self._client.publish("flexkv_node_id_updated", str(self._node_id)) + self._client.publish("flexkv_node_id_updated", str(node_id)) self._node_id = None return True @@ -302,6 +337,48 @@ def is_node_active(self, node_id: int) -> bool: """Check if a node_id is active - lock-free RCU check""" return node_id in self.current_node_id_set + def _heartbeat_worker(self) -> None: + """Background thread that periodically renews the TTL of node: key. + + This ensures that if the process is alive, the node key never expires. + If the process crashes (kill -9), the TTL will not be renewed and the + key will auto-expire after NODE_TTL_SECONDS, allowing other nodes to + detect the crash and stop using stale meta/block data. + """ + heartbeat_client: Optional["_redis.Redis"] = None + while self._running: + try: + if heartbeat_client is None: + heartbeat_client = self._get_client() + + if self._node_id is not None: + node_key = f"node:{self._node_id}" + # Renew TTL + heartbeat_client.expire(node_key, self.node_ttl_seconds) + # Also update the timestamp field + heartbeat_client.hset(node_key, "timestamp", str(int(time.time()))) + + except Exception: + # Connection lost, reset client so it reconnects next iteration + if heartbeat_client: + try: + heartbeat_client.close() + except Exception: + pass + heartbeat_client = None + + # Sleep in small increments so we can exit quickly when _running becomes False + for _ in range(int(self.heartbeat_interval_seconds * 10)): + if not self._running: + break + time.sleep(0.1) + + if heartbeat_client: + try: + heartbeat_client.close() + except Exception: + pass + def _listener_worker(self) -> None: """Background thread that listens for node updates""" backoff = 0.5 @@ -343,6 +420,10 @@ def scan_active_nodes(self) -> None: This method can be called externally to manually refresh the active nodes list. It uses SCAN to avoid blocking Redis server. + + Because node: keys now have a TTL (heartbeat), expired keys are + automatically removed by Redis. SCAN will only return keys that are + still alive, so stale/crashed nodes are naturally excluded. """ if not self._client: return @@ -366,6 +447,15 @@ def scan_active_nodes(self) -> None: if cursor == 0: break + # Detect nodes that disappeared (TTL expired or unregistered) + disappeared = self.current_node_id_set - new_active_nodes + if disappeared: + # Clean up meta and block data for disappeared nodes + for stale_nid in disappeared: + if stale_nid == self._node_id: + continue # Don't clean up ourselves + self._cleanup_node_data(stale_nid) + # lock-free RCU switch: atomic assignment self.current_node_id_set = new_active_nodes @@ -373,10 +463,97 @@ def scan_active_nodes(self) -> None: # If scan fails, continue with current active nodes pass + def _cleanup_stale_nodes_by_ip(self) -> None: + """Clean up stale node registrations from the same IP. + + On startup, scan all node:* keys and remove those that have the same + local_ip but a different UUID (i.e. leftover from a previous crashed process). + """ + if not self._client: + return + + try: + cursor = 0 + stale_node_ids = [] + + while True: + cursor, keys = self._client.scan(cursor=cursor, match="node:*", count=100) + for key in keys: + if not key.startswith("node:"): + continue + try: + nid = int(key[5:]) + except (ValueError, IndexError): + continue + + data = self._client.hgetall(key) + node_ip = data.get("ip", "") or data.get("local_ip", "") + node_uuid = data.get("uuid", "") + + # Same IP but different UUID → stale node from a previous process + if node_ip == self.local_ip and node_uuid != self.uuid: + stale_node_ids.append(nid) + + if cursor == 0: + break + + for stale_nid in stale_node_ids: + print(f"[RedisNodeInfo] Cleaning up stale node:{stale_nid} (same IP={self.local_ip}, different UUID)") + self._client.delete(f"node:{stale_nid}") + self._cleanup_node_data(stale_nid) + + if stale_node_ids: + # Notify other nodes about the cleanup + self._client.publish("flexkv_node_id_updated", "cleanup") + + except Exception: + pass + + def _cleanup_node_data(self, node_id: int) -> None: + """Clean up meta: and CPUB/SSDB/PCFSB block keys for a given node. + + This is called when: + 1. A node is unregistered (graceful shutdown) + 2. A stale node is detected (TTL expired / startup cleanup) + """ + if not self._client: + return + + try: + # Delete meta: (and meta::pp* for pipeline parallel) + cursor = 0 + meta_keys = [] + while True: + cursor, keys = self._client.scan(cursor=cursor, match=f"meta:{node_id}*", count=100) + meta_keys.extend(keys) + if cursor == 0: + break + if meta_keys: + self._client.delete(*meta_keys) + print(f"[RedisNodeInfo] Deleted {len(meta_keys)} meta key(s) for node {node_id}") + + # Delete CPUB:block::* / SSDB:block::* / PCFSB:block::* keys + for prefix in ("CPUB", "SSDB", "PCFSB"): + cursor = 0 + block_keys = [] + while True: + cursor, keys = self._client.scan(cursor=cursor, match=f"{prefix}:block:{node_id}:*", count=500) + block_keys.extend(keys) + if cursor == 0: + break + if block_keys: + # Delete in batches to avoid blocking Redis + batch_size = 500 + for i in range(0, len(block_keys), batch_size): + self._client.delete(*block_keys[i:i + batch_size]) + print(f"[RedisNodeInfo] Deleted {len(block_keys)} {prefix}:block key(s) for node {node_id}") + + except Exception as e: + print(f"[RedisNodeInfo] Warning: failed to clean up data for node {node_id}: {e}") class RedisMeta: - def __init__(self, host: str, port: int, password: Optional[str] = None, local_ip: str = "127.0.0.1", decode_responses: bool = True) -> None: + def __init__(self, host: str, port: int, password: Optional[str] = None, local_ip: str = "127.0.0.1", decode_responses: bool = True, node_ttl_seconds: int = 0) -> None: if _redis is None: # pragma: no cover raise ImportError("redis-py is required: pip install redis") self.host = host @@ -393,7 +570,7 @@ def __init__(self, host: str, port: int, password: Optional[str] = None, local_i self._init_error: Optional[Exception] = None # create RedisNodeInfo object - self.nodeinfo = RedisNodeInfo(host, port, local_ip, password or "") + self.nodeinfo = RedisNodeInfo(host, port, local_ip, password or "", node_ttl_seconds=node_ttl_seconds) # get UUID via nodeinfo self._uuid = self.nodeinfo.get_uuid() diff --git a/flexkv/cache/transfer_pattern.py b/flexkv/cache/transfer_pattern.py index fcf207408b..6290e3e69a 100644 --- a/flexkv/cache/transfer_pattern.py +++ b/flexkv/cache/transfer_pattern.py @@ -61,7 +61,8 @@ def convert_read_graph_to_layer_wise_graph( layer_id=i * layer_granularity, layer_granularity=layer_granularity, # Inherit these fields directly - dp_id=op.dp_id, + dp_rank=op.dp_rank, + pp_rank=op.pp_rank, ) new_graph.add_transfer_op(new_op) split_op_ids.append(new_op.op_id) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index df792f5f8b..8b92e92c72 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -1,6 +1,7 @@ import os import json -from dataclasses import dataclass +import yaml +from dataclasses import dataclass, field, fields from enum import Enum from typing import Optional, List, Union, Dict, Any from argparse import Namespace @@ -11,6 +12,16 @@ from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.debug import flexkv_logger + +@dataclass +class IndexerCacheConfig: + """Indexer-specific cache configuration, embedded inside CacheConfig.""" + # Indexer head layout + head_size: int = 0 # qk_rope_head_dim for DSA/NSA models + num_kv_heads: int = 1 # typically 1 for MLA-style indexer + dtype: torch.dtype = torch.uint8 # indexer storage dtype (fp8 quantized) + + @dataclass class ModelConfig: num_layers: int = 1 @@ -19,14 +30,188 @@ class ModelConfig: use_mla: bool = False dtype: torch.dtype = torch.bfloat16 - # parallel configs + # ------------------------------------------------------------------ + # Parallel configs + # ------------------------------------------------------------------ tp_size: int = 1 + tp_rank: int = 0 + + pp_size: int = 1 + pp_rank: int = 0 + dp_size: int = 1 + dp_rank: int = 0 + + # pp_start_layer / pp_end_layer: [start, end) layer indices for this PP stage. + pp_start_layer: int = 0 + pp_end_layer: int = -1 # -1 → lazily resolved to num_layers + + # ------------------------------------------------------------------ + # Attention-level parallel configs + # ------------------------------------------------------------------ + # enable_dp_attention: whether DP-attention is enabled (sglang + # ``--enable-dp-attention`` or TRT-LLM ``enable_attention_dp``). + # When True, the physical TP group is split into + # attn_tp × attn_cp × attn_dp. + enable_dp_attention: bool = False + + # attn_cp_size / attn_cp_rank: context-parallel size/rank. + attn_cp_size: int = 1 + attn_cp_rank: int = 0 + + # ------------------------------------------------------------------ + # Topology configs + # ------------------------------------------------------------------ + # nnodes: number of physical machines spanned by one replica + # node_rank: index of this machine within ``nnodes`` + nnodes: int = 1 + node_rank: int = 0 + + # Multi-node bootstrap: master node's IP for TransferManager rendezvous. + # ``None`` falls back to ``FLEXKV_MASTER_HOST`` env var (default + # ``"localhost"``) inside ``resolve_master_host_and_ports``. Set this + # from the framework's own launch config (e.g. sglang's + # ``--dist-init-addr``) to avoid exposing an extra env knob. + master_host: Optional[str] = None + + # NSA context parallelism: when True, every rank in the CP group holds + # identical (full) KV cache. FlexKV treats CP ranks the same as TP ranks + # in MLA mode. + is_nsa_cp: bool = False + + # ------------------------------------------------------------------ + # Freeze mechanism: after post_init, ModelConfig must not be mutated + # ------------------------------------------------------------------ + _frozen: bool = field(default=False, init=False, repr=False) + + def freeze(self) -> None: + """Lock the config so that any subsequent __setattr__ raises an error. + """ + object.__setattr__(self, '_frozen', True) + flexkv_logger.info( + f"[FlexKV] ModelConfig FROZEN — primitive vars are now immutable. " + f"Derived: attn_tp_size={self.attn_tp_size}, attn_tp_rank={self.attn_tp_rank}, " + f"tp_size_per_node={self.tp_size_per_node}, " + f"nnodes_per_pp_rank={self.nnodes_per_pp_rank}, " + f"nnodes_per_tp_group={self.nnodes_per_tp_group}, " + f"total_gpus={self.total_gpus}, " + f"gpus_per_node={self.gpus_per_node}, " + f"num_kv_heads_per_node={self.num_kv_heads_per_node}, " + f"tp_rank_per_node={self.tp_rank_per_node}, " + f"local_rank={self.local_rank}" + ) + + def __setattr__(self, name: str, value) -> None: + if name == '_frozen': + return object.__setattr__(self, name, value) + if getattr(self, '_frozen', False): + raise AttributeError( + f"ModelConfig is frozen — cannot set '{name}'. " + f"All primitive fields must be set during post_init_from_*(), " + f"after which freeze() is called. Derived fields (attn_tp_size, " + f"attn_tp_rank) are @property " + f"and cannot be set at all." + ) + object.__setattr__(self, name, value) + + # ------------------------------------------------------------------ + # Derived topology properties + # ------------------------------------------------------------------ + @property + def total_gpus(self) -> int: + """Total GPUs across all nodes for one FlexKV instance.""" + return self.dp_size * self.tp_size * self.pp_size + + @property + def gpus_per_node(self) -> int: + """Total GPUs on this node (across all DP, PP stages and TP groups).""" + return self.total_gpus // self.nnodes + + @property + def nnodes_per_pp_rank(self) -> int: + """Number of nodes spanned by one PP stage.""" + return max(self.nnodes // self.pp_size, 1) + + @property + def nnodes_per_tp_group(self) -> int: + """Number of nodes spanned by one TP group.""" + return self.nnodes_per_pp_rank + + @property + def tp_size_per_node(self) -> int: + """Number of TP ranks on this node within one TP group.""" + return self.tp_size // self.nnodes_per_tp_group + + @property + def tp_rank_per_node(self) -> int: + """TP rank index within the local node (within one TP group).""" + return self.tp_rank % self.tp_size_per_node + + @property + def local_rank(self) -> int: + """Local GPU device index within the node (a.k.a. ``LOCAL_RANK`` in + PyTorch distributed / sglang / vllm). + + Matches the standard Megatron-style rank layout: + global_rank = dp_rank * pp_size * tp_size + pp_rank * tp_size + tp_rank + + When DP-attention is enabled, DP replicas share the same physical + GPUs, so the DP dimension is not reflected in the device index. + The formula then reduces to: + local_rank = pp_rank_per_node * tp_size_per_node + tp_rank_per_node + + When DP-attention is disabled, each DP replica has its own GPUs: + local_rank = (dp_rank_per_node * pp_size_per_node + pp_rank_per_node) + * tp_size_per_node + tp_rank_per_node + where the ``_per_node`` values are derived inline from the global ranks + and topology. + """ + pp_size_per_node = max(self.pp_size // self.nnodes, 1) + pp_rank_per_node = self.pp_rank % pp_size_per_node + if self.enable_dp_attention: + return pp_rank_per_node * self.tp_size_per_node + self.tp_rank_per_node + dp_size_per_node = self.gpus_per_node // (pp_size_per_node * self.tp_size_per_node) + dp_rank_per_node = self.dp_rank % dp_size_per_node + return (dp_rank_per_node * pp_size_per_node + pp_rank_per_node) * self.tp_size_per_node + self.tp_rank_per_node + + @property + def attn_tp_size(self) -> int: + """Attention-level TP size derived from tp / attn_dp / attn_cp.""" + attn_dp = max(1, self.dp_size) if self.enable_dp_attention else 1 + cp = max(1, self.attn_cp_size) + return max(1, max(1, self.tp_size) // (attn_dp * cp)) + + @property + def attn_tp_rank(self) -> int: + """Attention-level TP rank derived from tp_rank / attn_tp_size.""" + return self.tp_rank % max(1, self.attn_tp_size) + + @property + def num_kv_heads_per_node(self) -> int: + """Number of KV heads visible to a single node.""" + if self.use_mla: + return self.num_kv_heads + return self.num_kv_heads * self.tp_size_per_node // max(1, self.attn_tp_size) + + @property + def kv_dim(self) -> int: + """KV dimension: 1 for MLA (no head split), 2 for standard (head split).""" + return 1 if self.use_mla else 2 + + @property + def num_layers_per_pp_stage(self) -> int: + """Number of layers managed by this PP stage.""" + end = self.pp_end_layer if self.pp_end_layer >= 0 else self.num_layers + return end - self.pp_start_layer @property def token_size_in_bytes(self) -> int: - kv_dim = 1 if self.use_mla else 2 - return self.num_layers * self.num_kv_heads * self.head_size * kv_dim * self.dtype.itemsize + return self.num_layers * self.num_kv_heads * self.head_size * self.kv_dim * self.dtype.itemsize + + @property + def token_size_in_bytes_per_pp_stage(self) -> int: + """Token size in bytes for one PP stage (used by data plane).""" + return self.num_layers_per_pp_stage * self.num_kv_heads * self.head_size * self.kv_dim * self.dtype.itemsize @dataclass class CacheConfig: @@ -44,7 +229,21 @@ class CacheConfig: distributed_node_id: int = -1 # only used when distributed cpu/ssd and only can be set when redis_meta_client initialized num_tmp_cpu_blocks: int = 500 # only used when distributed ssd p2p, it controls the number blocks of temp cpu buffer which used for copy data from ssd to cpu - + # When True, the main CPU KV cache is allocated from Linux HugePages via + # ``mmap(MAP_HUGETLB)`` instead of regular CPU memory. Requires pre-reserved + # huge pages on the host (see ``/proc/sys/vm/nr_hugepages``). Falls back + # silently if allocation fails. + use_hugepage_cpu_buffer: bool = False + # When True, the temporary SSD->CPU staging buffer (used by PEER2CPUTransferWorker + # under enable_p2p_ssd) is allocated from Linux HugePages via ``mmap(MAP_HUGETLB)`` + # instead of a pinned ``torch.empty``. Requires pre-reserved huge pages on the host + # (see ``/proc/sys/vm/nr_hugepages``). Falls back silently if allocation fails. + use_hugepage_tmp_buffer: bool = False + hugepage_size_bytes: int = 2 * 1024 * 1024 # 2 MiB by default; set to 1<<30 for 1GiB + + + # Indexer configuration + indexer: Optional[IndexerCacheConfig] = None # mempool capacity configs num_cpu_blocks: int = 1000000 @@ -72,6 +271,12 @@ class CacheConfig: redis_port: int = 6379 local_ip: str = "127.0.0.1" redis_password: Optional[str] = None + # TTL (seconds) for node: key in Redis. Active nodes renew via heartbeat. + # If a process crashes, the key auto-expires after this period. + node_ttl_seconds: int = 30 + + # Mooncake transfer engine config path (serialized via pickle to survive spawn subprocesses) + mooncake_config_path: Optional[str] = None def __post_init__(self): self.enable_kv_sharing = self.enable_p2p_cpu or \ @@ -101,6 +306,8 @@ def __post_init__(self): remote_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_REMOTE_LAYOUT', 'BLOCKFIRST').upper()), gds_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_GDS_LAYOUT', 'BLOCKFIRST').upper()), + enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 0))), + use_ce_transfer_h2d=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_H2D', 0))), use_ce_transfer_d2h=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_D2H', 0))), transfer_num_cta_h2d=int(os.getenv('FLEXKV_TRANSFER_NUM_CTA_H2D', 4)), @@ -126,7 +333,7 @@ def __post_init__(self): lt_pool_initial_capacity=int(os.getenv('FLEXKV_LT_POOL_INITIAL_CAPACITY', 10000000)), refresh_batch_size=int(os.getenv('FLEXKV_REFRESH_BATCH_SIZE', 256)), - rebuild_interval_ms=int(os.getenv('FLEXKV_REBUILD_INTERVAL_MS', 10000)), + rebuild_interval_ms=int(os.getenv('FLEXKV_REBUILD_INTERVAL_MS', 100)), idle_sleep_ms=int(os.getenv('FLEXKV_IDLE_SLEEP_MS', 10)), lease_ttl_ms=int(os.getenv('FLEXKV_LEASE_TTL_MS', 30000)), safety_ttl_ms=int(os.getenv('FLEXKV_SAFETY_TTL_MS', 100)), @@ -139,6 +346,9 @@ class UserConfig: ssd_cache_gb: int = 0 # 0 means disable ssd ssd_cache_dir: Union[str, List[str]] = "./ssd_cache" enable_gds: bool = False + use_hugepage_cpu_buffer: bool = False + use_hugepage_tmp_buffer: bool = False + hugepage_size_bytes: int = 2 * 1024 * 1024 enable_p2p_cpu: bool = False enable_p2p_ssd: bool = False enable_3rd_remote: bool = False @@ -151,6 +361,7 @@ class UserConfig: redis_port: Optional[int] = None local_ip: Optional[str] = None redis_password: Optional[str] = None + node_ttl_seconds: Optional[int] = None kv_cache_dtype: Optional[str] = None # Override kv_cache_dtype when TRT config uses "auto". Supported values: "fp8", "float8", "e4m3", "fp16", "float16", "bf16", "bfloat16", "fp32", "float32" def __post_init__(self): @@ -167,10 +378,6 @@ def parse_path_list(path_str: str) -> List[str]: return paths def load_user_config_from_file(config_file: str) -> UserConfig: - import json - import yaml - from dataclasses import fields - # read json config file or yaml config file if config_file.endswith('.json'): with open(config_file) as f: @@ -201,6 +408,9 @@ def load_user_config_from_env() -> UserConfig: ssd_cache_gb=int(os.getenv('FLEXKV_SSD_CACHE_GB', 0)), ssd_cache_dir=parse_path_list(os.getenv('FLEXKV_SSD_CACHE_DIR', "./flexkv_ssd")), enable_gds=bool(int(os.getenv('FLEXKV_ENABLE_GDS', 0))), + use_hugepage_cpu_buffer=bool(int(os.getenv('FLEXKV_USE_HUGEPAGE_CPU_BUFFER', 0))), + use_hugepage_tmp_buffer=bool(int(os.getenv('FLEXKV_USE_HUGEPAGE_TMP_BUFFER', 0))), + hugepage_size_bytes=int(os.getenv('FLEXKV_HUGEPAGE_SIZE_BYTES', 2 * 1024 * 1024)), kv_cache_dtype=os.getenv('FLEXKV_KV_CACHE_DTYPE', None), ) @@ -210,7 +420,7 @@ def convert_to_block_num(size_in_GB: float, block_size_in_bytes: int) -> int: def update_default_config_from_user_config(model_config: ModelConfig, cache_config: CacheConfig, user_config: UserConfig) -> None: - block_size_in_bytes = model_config.token_size_in_bytes * cache_config.tokens_per_block + block_size_in_bytes = model_config.token_size_in_bytes_per_pp_stage * cache_config.tokens_per_block assert user_config.cpu_cache_gb > 0 assert user_config.ssd_cache_gb >= 0 @@ -221,6 +431,9 @@ def update_default_config_from_user_config(model_config: ModelConfig, cache_config.ssd_cache_dir = user_config.ssd_cache_dir cache_config.enable_ssd = user_config.ssd_cache_gb > 0 cache_config.enable_gds = user_config.enable_gds + cache_config.use_hugepage_cpu_buffer = user_config.use_hugepage_cpu_buffer + cache_config.use_hugepage_tmp_buffer = user_config.use_hugepage_tmp_buffer + cache_config.hugepage_size_bytes = user_config.hugepage_size_bytes cache_config.enable_p2p_cpu = user_config.enable_p2p_cpu cache_config.enable_p2p_ssd = user_config.enable_p2p_ssd cache_config.enable_3rd_remote = user_config.enable_3rd_remote @@ -250,6 +463,8 @@ def update_default_config_from_user_config(model_config: ModelConfig, cache_config.local_ip = user_config.local_ip if user_config.redis_password is not None: cache_config.redis_password = user_config.redis_password + if user_config.node_ttl_seconds is not None: + cache_config.node_ttl_seconds = user_config.node_ttl_seconds global_config_attrs = set(vars(GLOBAL_CONFIG_FROM_ENV).keys()) for attr_name in dir(user_config): diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index 9838c1ba6c..e5fff581fc 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -170,10 +170,10 @@ def _init_from_ipc_handle( self.offset = offset - flexkv_logger.info( - f"TensorSharedHandle constructed from external IPC handle {self.ipc_handle.hex()} on device {self.device} \ - with shape {self.tensor_shape} and dtype {self.tensor_dtype}, ptr offset={offset}" - ) + # flexkv_logger.info( + # f"TensorSharedHandle constructed from external IPC handle {self.ipc_handle.hex()} on device {self.device} \ + # with shape {self.tensor_shape} and dtype {self.tensor_dtype}, ptr offset={offset}" + # ) @staticmethod def _ensure_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: diff --git a/flexkv/common/storage.py b/flexkv/common/storage.py index ffd1d2cd46..bc89cb2aed 100644 --- a/flexkv/common/storage.py +++ b/flexkv/common/storage.py @@ -172,6 +172,7 @@ class StorageHandle: num_blocks_per_file: Optional[int] = None gpu_device_id: Optional[int] = None remote_config_custom: Optional[Dict[str, Any]] = None + worker_data: Optional[Any] = None def get_tensor_list(self) -> List[torch.Tensor]: assert isinstance(self.data, list) and \ @@ -195,6 +196,11 @@ def get_tensor(self) -> torch.Tensor: else: raise ValueError(f"Invalid handle type: {self.handle_type}, expected TENSOR") + def get_worker_tensor(self) -> Any: + if self.worker_data is not None: + return self.worker_data + return self.get_tensor() + def get_file_list(self) -> Union[List[str], Dict[int, List[str]]]: if self.handle_type == AccessHandleType.FILE: return self.data # type: ignore diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index cd4553526b..47b463a2a5 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -195,7 +195,8 @@ def trace_request(self, slot_mapping: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, **kwargs): """Record a request operation""" if not self.enabled: @@ -211,7 +212,8 @@ def trace_request(self, "slot_mapping": self._convert_tensor_to_list(slot_mapping), "token_mask": self._convert_tensor_to_list(token_mask) if token_mask is not None else None, "layer_granularity": layer_granularity, - "dp_id": dp_id, + "dp_rank": dp_rank, + "pp_rank": pp_rank, "token_ids_shape": list(token_ids.shape), "slot_mapping_shape": list(slot_mapping.shape), "token_mask_shape": list(token_mask.shape) if token_mask is not None else None, @@ -329,6 +331,7 @@ def flush(self): def __del__(self): """Ensure all records are flushed when tracer is destroyed""" - from contextlib import suppress - with suppress(Exception): + try: self.flush() + except Exception: + pass diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 83f3c38518..cca69bd756 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -5,11 +5,27 @@ import numpy as np +from flexkv.common.debug import flexkv_logger + + +@dataclass(frozen=True) +class WorkerKey: + """Immutable, hashable key that uniquely identifies a worker by (dp_rank, pp_rank). + + Used as dict keys in TransferEngine's worker maps and TransferManager's + GPU grouping instead of raw ``Tuple[int, int]`` to avoid ambiguity. + """ + dp_rank: int + pp_rank: int + @dataclass(frozen=True) class CompletedOp: graph_id: int op_id: int + transfer_type: Optional[str] = None + num_blocks: int = 0 + num_bytes: int = 0 def is_graph_completed(self) -> bool: return self.op_id == -1 @@ -53,6 +69,7 @@ class TransferType(Enum): # so that the op 3 will not be executed actually, but can indicate the completion of # a group of transfer ops VIRTUAL = "Virtual" + LAYERWISE = "LAYERWISE" # class DistType(Enum): # DISTH = "DISTH" @@ -85,7 +102,8 @@ class TransferOp: # this will keep the full info successors: Set[int] = field(default_factory=set) status: TransferOpStatus = TransferOpStatus.PENDING - dp_id: int = 0 + dp_rank: int = 0 + pp_rank: int = 0 # used for get block ids inner worker process src_slot_id: int = -1 dst_slot_id: int = -1 @@ -93,6 +111,10 @@ class TransferOp: remote_node_ids: Optional[np.ndarray] = None # used for distributed cpu and ssd src_block_node_ids: Optional[np.ndarray] = None + # pending_count tracks how many workers (main KV + indexer) have not yet completed this op. + # Initialized to 1; incremented before submitting to indexer worker. + # _scheduler_loop decrements it on each worker completion; finalization happens only when it reaches 0. + pending_count: int = 1 def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ @@ -107,6 +129,68 @@ def __post_init__(self) -> None: assert self.dst_block_ids.dtype == np.int64 self.valid_block_num = self.src_block_ids.size +@dataclass +class LayerwiseTransferOp(TransferOp): + + src_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + dst_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + src_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + dst_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + counter_id: int = 0 # Counter set index for triple buffering eventfd notification + # Indexer block_ids for fused indexer transfer (1:1 with main KV block_ids) + indexer_src_block_ids: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + indexer_dst_block_ids: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + + def __init__(self, + graph_id: int, + src_block_ids_h2d: np.ndarray, + dst_block_ids_h2d: np.ndarray, + src_block_ids_disk2h: np.ndarray, + dst_block_ids_disk2h: np.ndarray, + layer_id: int = 0, + layer_granularity: int = 1, + dp_rank: int = 0, + pp_rank: int = 0, + counter_id: int = 0, + indexer_src_block_ids: Optional[np.ndarray] = None, + indexer_dst_block_ids: Optional[np.ndarray] = None) -> None: + self.src_block_ids_h2d = src_block_ids_h2d + self.dst_block_ids_h2d = dst_block_ids_h2d + self.src_block_ids_disk2h = src_block_ids_disk2h + self.dst_block_ids_disk2h = dst_block_ids_disk2h + self.counter_id = counter_id + self.indexer_src_block_ids = indexer_src_block_ids if indexer_src_block_ids is not None \ + else np.array([], dtype=np.int64) + self.indexer_dst_block_ids = indexer_dst_block_ids if indexer_dst_block_ids is not None \ + else np.array([], dtype=np.int64) + + super().__init__( + graph_id=graph_id, + transfer_type=TransferType.LAYERWISE, + src_block_ids=np.array([], dtype=np.int64), + dst_block_ids=np.array([], dtype=np.int64), + layer_id=layer_id, + layer_granularity=layer_granularity, + dp_rank=dp_rank, + pp_rank=pp_rank, + ) + + def __post_init__(self) -> None: + super().__post_init__() + + if self.layer_granularity == -1: + flexkv_logger.warning("layer_granularity is not set, using default value 1") + self.layer_granularity = 1 + assert self.src_block_ids_h2d.size == self.dst_block_ids_h2d.size + assert self.src_block_ids_disk2h.size == self.dst_block_ids_disk2h.size + assert self.indexer_src_block_ids.size == self.indexer_dst_block_ids.size + + assert self.src_block_ids_h2d.dtype == np.int64 + assert self.dst_block_ids_h2d.dtype == np.int64 + assert self.src_block_ids_disk2h.dtype == np.int64 + assert self.dst_block_ids_disk2h.dtype == np.int64 + assert self.indexer_src_block_ids.dtype == np.int64 + assert self.indexer_dst_block_ids.dtype == np.int64 class TransferOpGraph: _next_graph_id = 0 @@ -219,13 +303,30 @@ def set_gpu_blocks(self, gpu_blocks: np.ndarray) -> None: assert op.src_block_ids.size == op.dst_block_ids.size, \ f"src_block_ids.size={op.src_block_ids.size}, dst_block_ids.size={op.dst_block_ids.size}" + def clear_gpu_blocks(self) -> None: + """Clear GPU block_ids from the graph. + """ + for op_id in self._gpu_transfer_op_id: + op = self._op_map[op_id] + # Replace with empty arrays; set_gpu_blocks() will fill them later + if op.src_block_ids.size > 0: + op.src_block_ids = np.array([], dtype=op.src_block_ids.dtype) + if op.dst_block_ids.size > 0: + op.dst_block_ids = np.array([], dtype=op.dst_block_ids.dtype) + @property def num_ops(self) -> int: return len(self._op_map) - def bind_to_dp_group(self, dp_id: int) -> None: + def bind_to_worker(self, dp_rank: int, pp_rank: int) -> None: + """Bind all ops in this graph to the specified DP group and PP stage. + + Both fields are always set together because they jointly determine + which worker (GPU) handles the transfer. + """ for op in self._op_map.values(): - op.dp_id = dp_id + op.dp_rank = dp_rank + op.pp_rank = pp_rank def visualize(self) -> str: """ @@ -281,7 +382,7 @@ def format_blocks(block_ids, max_show=4): dst_str = format_blocks(op.dst_block_ids) lines.append(f"║ ├─ src_blocks: {src_str}".ljust(71) + "║") lines.append(f"║ ├─ dst_blocks: {dst_str}".ljust(71) + "║") - lines.append(f"║ └─ layer_id={op.layer_id}, dp_id={op.dp_id}".ljust(71) + "║") + lines.append(f"║ └─ layer_id={op.layer_id}, dp_rank={op.dp_rank}, pp_rank={op.pp_rank}".ljust(71) + "║") else: lines.append("║ └─ (VIRTUAL - no blocks)".ljust(71) + "║") @@ -325,9 +426,9 @@ def _merge_ops(ops: List[TransferOp], transfer_type: TransferType, dst_block_ids=dst_blocks, layer_id=ops[0].layer_id, layer_granularity=ops[0].layer_granularity, - dp_id=ops[0].dp_id, + dp_rank=ops[0].dp_rank, + pp_rank=ops[0].pp_rank, ) - graph.add_transfer_op(merged_op) if callbacks: if len(callbacks) == 1: op_callback_dict[merged_op.op_id] = callbacks[0] @@ -336,8 +437,12 @@ def _merge_ops(ops: List[TransferOp], transfer_type: TransferType, return merged_op -def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], task_end_op_ids: List[int], - op_callback_dict: Dict[int, Callable]) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: +def merge_to_batch_graph(batch_id: int, + transfer_graphs: List[TransferOpGraph], + task_end_op_ids: List[int], + op_callback_dict: Dict[int, Callable], + layerwise_transfer: bool = False, + counter_id: int = 0) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: """ Merge multiple TransferOpGraphs into a single batch graph. @@ -351,6 +456,7 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], transfer_graphs: List of graphs to merge task_end_op_ids: List of end op IDs for each task (one per graph) op_callback_dict: Dict mapping old op_id -> callback + layerwise_transfer: Whether to merge the graphs into a layerwise transfer op Returns: (merged_graph, batch_end_op_id, new_op_callback_dict) @@ -392,28 +498,63 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], merged_graph, callbacks_by_type[TransferType.DISK2H], new_op_callback_dict) merged_h2d_op = _merge_ops(ops_by_type[TransferType.H2D], TransferType.H2D, merged_graph, callbacks_by_type[TransferType.H2D], new_op_callback_dict) - if merged_disk2h_op is not None and merged_h2d_op is not None: - merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) - - # PUT path: D2H -> H2DISK - merged_d2h_op = _merge_ops(ops_by_type[TransferType.D2H], TransferType.D2H, - merged_graph, callbacks_by_type[TransferType.D2H], new_op_callback_dict) - merged_h2disk_op = _merge_ops(ops_by_type[TransferType.H2DISK], TransferType.H2DISK, - merged_graph, callbacks_by_type[TransferType.H2DISK], new_op_callback_dict) - if merged_d2h_op is not None and merged_h2disk_op is not None: - merged_graph.add_dependency(merged_h2disk_op.op_id, merged_d2h_op.op_id) - - # batch_end_op_id: GET: H2D > DISK2H; PUT: H2DISK > D2H - if merged_h2d_op is not None: - batch_end_op_id = merged_h2d_op.op_id - elif merged_disk2h_op is not None: - batch_end_op_id = merged_disk2h_op.op_id - elif merged_h2disk_op is not None: - batch_end_op_id = merged_h2disk_op.op_id - elif merged_d2h_op is not None: - batch_end_op_id = merged_d2h_op.op_id - else: + + if layerwise_transfer: + if merged_h2d_op is not None: + layerwise_transfer_op = LayerwiseTransferOp( + graph_id=merged_graph.graph_id, + src_block_ids_h2d=merged_h2d_op.src_block_ids, + dst_block_ids_h2d=merged_h2d_op.dst_block_ids, + src_block_ids_disk2h=merged_disk2h_op.src_block_ids \ + if merged_disk2h_op is not None \ + else np.array([], dtype=np.int64), + dst_block_ids_disk2h=merged_disk2h_op.dst_block_ids \ + if merged_disk2h_op is not None \ + else np.array([], dtype=np.int64), + layer_id=0, + layer_granularity=1, + dp_rank=ops_by_type[TransferType.H2D][0].dp_rank, + pp_rank=ops_by_type[TransferType.H2D][0].pp_rank, + counter_id=counter_id, + # Indexer maps 1:1 with main KV blocks, use same block_ids + # CPU side (src) and GPU side (dst) for H2D direction + indexer_src_block_ids=merged_h2d_op.src_block_ids.copy(), + indexer_dst_block_ids=merged_h2d_op.dst_block_ids.copy(), + ) + merged_graph.add_transfer_op(layerwise_transfer_op) batch_end_op_id = -1 + new_op_callback_dict.clear() + else: + if merged_disk2h_op is not None: + merged_graph.add_transfer_op(merged_disk2h_op) + if merged_h2d_op is not None: + merged_graph.add_transfer_op(merged_h2d_op) + if merged_disk2h_op is not None and merged_h2d_op is not None: + merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) + + # PUT path: D2H -> H2DISK + merged_d2h_op = _merge_ops(ops_by_type[TransferType.D2H], TransferType.D2H, + merged_graph, callbacks_by_type[TransferType.D2H], new_op_callback_dict) + merged_h2disk_op = _merge_ops(ops_by_type[TransferType.H2DISK], TransferType.H2DISK, + merged_graph, callbacks_by_type[TransferType.H2DISK], new_op_callback_dict) + if merged_d2h_op is not None: + merged_graph.add_transfer_op(merged_d2h_op) + if merged_h2disk_op is not None: + merged_graph.add_transfer_op(merged_h2disk_op) + if merged_d2h_op is not None and merged_h2disk_op is not None: + merged_graph.add_dependency(merged_h2disk_op.op_id, merged_d2h_op.op_id) + + # batch_end_op_id: GET: H2D > DISK2H; PUT: H2DISK > D2H + if merged_h2d_op is not None: + batch_end_op_id = merged_h2d_op.op_id + elif merged_disk2h_op is not None: + batch_end_op_id = merged_disk2h_op.op_id + elif merged_h2disk_op is not None: + batch_end_op_id = merged_h2disk_op.op_id + elif merged_d2h_op is not None: + batch_end_op_id = merged_d2h_op.op_id + else: + batch_end_op_id = -1 return merged_graph, batch_end_op_id, new_op_callback_dict diff --git a/flexkv/common/type.py b/flexkv/common/type.py index 8b893f2eb5..25f17d2d93 100644 --- a/flexkv/common/type.py +++ b/flexkv/common/type.py @@ -11,9 +11,13 @@ class MatchResultAccel: last_node: Optional['CRadixNode'] = None last_node_matched_length: int = 0 physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + # Single node_id for all matched blocks (single-node matching constraint). + # -1 means no remote match. Preferred over the deprecated per-block arrays. + matched_node_id: Optional[int] = None + # deprecated: kept for backward compat; prefer matched_node_id block_node_ids: Optional[np.ndarray] = None matched_pos: Optional[str] = None - matched_node_ids: Optional[np.ndarray] = None #TODO id or ids? should we allow one req match results on multiple nodes? + matched_node_ids: Optional[np.ndarray] = None # deprecated: prefer matched_node_id insert_to_local_cpu_index: bool = True def __post_init__(self) -> None: diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 82bf15634c..875c7d9a4d 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -3,7 +3,7 @@ import os import torch import tempfile -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from dataclasses import dataclass, field from flexkv.common.debug import flexkv_logger @@ -16,6 +16,28 @@ logger = flexkv_logger + +def _parse_dtype_str(dtype_str: str) -> torch.dtype: + """Convert a dtype string (e.g. 'fp8', 'bfloat16', 'fp8_e4m3') to torch.dtype. + + Shared by sglang / vllm / TRT-LLM integration adapters so that dtype + parsing logic is defined in exactly one place. + """ + dtype_map = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float8": torch.float8_e4m3fn, + "e4m3": torch.float8_e4m3fn, + "fp8_e4m3": torch.float8_e4m3fn, + } + return dtype_map.get(dtype_str.lower(), torch.bfloat16) + + @dataclass class FlexKVConfig: enable_flexkv: bool = True @@ -40,6 +62,42 @@ def __post_init__(self): if self.gpu_register_port == "": self.gpu_register_port = self.server_recv_port + "_gpu_register" + def _detect_indexer_config_from_hf(self, hf_config, source: str = "") -> None: + if hf_config is None: + return + + try: + qk_rope_head_dim = getattr(hf_config, 'qk_rope_head_dim', None) + if qk_rope_head_dim is None or qk_rope_head_dim <= 0: + return + + index_head_dim = getattr(hf_config, 'index_head_dim', None) + if index_head_dim is not None and index_head_dim > 0: + quant_block_size = 128 + head_size = self.cache_config.tokens_per_block * ( + index_head_dim + index_head_dim // quant_block_size * 4 + ) + else: + head_size = qk_rope_head_dim + + # tokens_per_block is already set to sglang page_size before this + # call, so each FlexKV block = 1 sglang page. The indexer maps + # 1:1 with blocks — no extra page_size grouping is needed. For + # NSA/DSA models, head_size stores the packed per-page buffer width + # so the CPU layout matches the GPU indexer tensor shape. + self.cache_config.indexer = IndexerCacheConfig( + head_size=head_size, + num_kv_heads=1, + dtype=torch.uint8, + ) + source_label = f" ({source})" if source else "" + logger.info( + f"Detected sparse attention indexer config{source_label}: " + f"head_size={head_size}, dtype=uint8, " + f"tokens_per_block={self.cache_config.tokens_per_block}") + except Exception as e: + logger.debug(f"Could not detect indexer config ({source}): {e}") + @classmethod def from_env(cls) -> 'FlexKVConfig': enable_flexkv = bool(int(os.getenv('ENABLE_FLEXKV', 1))) @@ -68,6 +126,19 @@ def post_init_from_vllm_config( self.model_config.use_mla = vllm_config.model_config.is_deepseek_mla self.model_config.tp_size = vllm_config.parallel_config.tensor_parallel_size self.model_config.dp_size = vllm_config.parallel_config.data_parallel_size + self.model_config.pp_size = vllm_config.parallel_config.pipeline_parallel_size + self.model_config.pp_rank = getattr(vllm_config.parallel_config, 'pipeline_parallel_rank', 0) + + if self.model_config.pp_size > 1: + from vllm.distributed.utils import get_pp_indices as vllm_get_pp_indices + start_layer, end_layer = vllm_get_pp_indices( + self.model_config.num_layers, self.model_config.pp_rank, self.model_config.pp_size + ) + self.model_config.pp_start_layer = start_layer + self.model_config.pp_end_layer = end_layer + else: + self.model_config.pp_start_layer = 0 + self.model_config.pp_end_layer = self.model_config.num_layers if self.model_config.use_mla: self.model_config.num_kv_heads = 1 else: @@ -76,48 +147,201 @@ def post_init_from_vllm_config( self.server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port self.gpu_register_port = self.server_recv_port + "_gpu_register" + hf_config = getattr(vllm_config.model_config, 'hf_config', None) + self._detect_indexer_config_from_hf(hf_config, source="vllm") + + logger.info( + f"[FlexKV vllm] Primitive vars set: tp_size={self.model_config.tp_size}, " + f"dp_size={self.model_config.dp_size}, dp_rank={self.model_config.dp_rank}, " + f"pp_size={self.model_config.pp_size}, pp_rank={self.model_config.pp_rank}, " + f"enable_dp_attention={self.model_config.enable_dp_attention}, " + f"attn_cp_size={self.model_config.attn_cp_size}, " + f"attn_cp_rank={self.model_config.attn_cp_rank}" + ) + logger.info( + f"[FlexKV vllm] Derived vars: attn_tp_size={self.model_config.attn_tp_size}, " + f"attn_tp_rank={self.model_config.attn_tp_rank}, " + f"local_rank={self.model_config.local_rank}" + ) + + # Freeze model_config — no further mutations allowed + self.model_config.freeze() def post_init_from_sglang_config( self, sglang_config, - tp_size: int, - page_size: int, + server_args, + page_size: int = 64, + tp_rank: Optional[int] = 0, + pp_rank: Optional[int] = 0, + dp_rank: Optional[int] = 0, + attn_cp_rank: Optional[int] = 0, ): """ Initialize FlexKVConfig fields from sglang config. Args: sglang_config: sglang.srt.configs.model_config.ModelConfig-like object - tp_size: tensor parallel size used by sglang + server_args: sglang ServerArgs — source of tp_size, dp_size, + nnodes, node_rank, enable_dp_attention, attn_cp_size, + is_nsa_cp, kv_cache_dtype, dist_init_addr page_size: KV block size (tokens per block) used by sglang + tp_rank: physical tensor parallel rank (runtime, from process group) + pp_rank: pipeline parallel rank (runtime, from process group) + dp_rank: data parallel rank (runtime, from process group) + attn_cp_rank: attention-level context parallel rank (runtime) """ - # cache config - self.cache_config.tokens_per_block = int(page_size) + # Extract parallelism params from server_args + tp_size = server_args.tp_size + pp_size = server_args.pp_size + dp_size = server_args.dp_size + nnodes = server_args.nnodes + node_rank = server_args.node_rank + enable_dp_attention = server_args.enable_dp_attention + attn_cp_size = server_args.attn_cp_size + is_nsa_cp = getattr(server_args, 'enable_nsa_prefill_context_parallel', False) + kv_cache_dtype = getattr(server_args, 'kv_cache_dtype', None) + + # cache config: use page_size as tokens_per_block so that FlexKV's + # CPU radix tree manages blocks at page granularity, ensuring that + # hash generation, matching, insertion and eviction are all page-aligned. + self.cache_config.tokens_per_block = page_size self.model_config.num_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) - if hasattr(sglang_config, "get_num_kv_heads"): - try: - self.model_config.num_kv_heads = int(sglang_config.get_num_kv_heads(tp_size)) - except Exception: + from sglang.srt.configs.model_config import AttentionArch + use_mla = getattr(sglang_config, "attention_arch", None) == AttentionArch.MLA + + if use_mla: + kv_lora_rank = int(getattr(sglang_config, "kv_lora_rank", 0)) + qk_rope_head_dim = int(getattr(sglang_config, "qk_rope_head_dim", 0)) + mla_head_size = kv_lora_rank + qk_rope_head_dim + self.model_config.num_kv_heads = 1 + self.model_config.head_size = int(mla_head_size) + else: + if hasattr(sglang_config, "get_total_num_kv_heads"): + try: + self.model_config.num_kv_heads = int(sglang_config.get_total_num_kv_heads()) + except Exception: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + elif hasattr(sglang_config, "get_num_kv_heads"): + try: + per_rank = int(sglang_config.get_num_kv_heads(tp_size)) + self.model_config.num_kv_heads = per_rank * tp_size + except Exception: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + else: self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0)) + + # Determine KV cache dtype: prioritize user_config.kv_cache_dtype (from + # flexkv_config.yaml or FLEXKV_KV_CACHE_DTYPE env var), then fall back + # to the sglang model dtype. sglang's ModelConfig.dtype is the *model + # weight* dtype (e.g. bfloat16), which may differ from the KV cache + # dtype (e.g. fp8_e4m3 when --kv-cache-dtype fp8_e4m3 is used). + user_dtype_str = self.user_config.kv_cache_dtype + if user_dtype_str is not None: + self.model_config.dtype = _parse_dtype_str(user_dtype_str) + logger.info( + f"[FlexKV] Using kv_cache_dtype from user_config: " + f"'{user_dtype_str}' -> {self.model_config.dtype}" + ) + elif kv_cache_dtype is not None and kv_cache_dtype != "auto": + # Use the kv_cache_dtype from sglang server_args (e.g. "fp8_e4m3") + self.model_config.dtype = _parse_dtype_str(kv_cache_dtype) + logger.info( + f"[FlexKV] Using kv_cache_dtype from sglang server_args: " + f"'{kv_cache_dtype}' -> {self.model_config.dtype}" + ) else: - self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) - self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0)) + self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16) + logger.warning( + f"[FlexKV] No kv_cache_dtype in user_config or server_args, falling back to sglang " + f"model dtype: {self.model_config.dtype}. If your KV cache uses a " + f"different dtype (e.g. fp8), add 'kv_cache_dtype: fp8' to your " + f"flexkv_config.yaml or set FLEXKV_KV_CACHE_DTYPE=fp8 environment variable." + ) - self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16) + if use_mla and getattr(sglang_config, "index_head_dim", None) is not None: + kv_lora_rank = int(getattr(sglang_config, "kv_lora_rank", 0)) + qk_rope_head_dim = int(getattr(sglang_config, "qk_rope_head_dim", 0)) + if self.model_config.dtype == torch.float8_e4m3fn: + assert kv_lora_rank % 128 == 0, ( + f"kv_lora_rank {kv_lora_rank} must be multiple of 128 " + "for NSA FP8 KV cache layout" + ) + self.model_config.head_size = int( + kv_lora_rank + + kv_lora_rank // 128 * 4 + + qk_rope_head_dim * torch.bfloat16.itemsize + ) - attn_arch = getattr(sglang_config, "attention_arch", None) - use_mla = False - if hasattr(attn_arch, "name"): - use_mla = (attn_arch.name.upper() == "MLA") - elif isinstance(attn_arch, str): - use_mla = (attn_arch.upper() == "MLA") self.model_config.use_mla = use_mla self.model_config.tp_size = int(tp_size) - self.model_config.dp_size = int(getattr(sglang_config, "dp_size", 1)) + self.model_config.tp_rank = int(tp_rank) + self.model_config.dp_size = int(dp_size if dp_size is not None else 1) + self.model_config.dp_rank = int(dp_rank if dp_rank is not None else 0) + self.model_config.pp_size = int(pp_size) + self.model_config.pp_rank = int(pp_rank) + + if pp_size > 1: + from sglang.srt.distributed.utils import get_pp_indices as sglang_get_pp_indices + start_layer, end_layer = sglang_get_pp_indices( + self.model_config.num_layers, self.model_config.pp_rank, self.model_config.pp_size + ) + self.model_config.pp_start_layer = start_layer + self.model_config.pp_end_layer = end_layer + else: + self.model_config.pp_start_layer = 0 + self.model_config.pp_end_layer = self.model_config.num_layers + self.model_config.enable_dp_attention = bool(enable_dp_attention) + self.model_config.attn_cp_size = int(attn_cp_size) + self.model_config.attn_cp_rank = int(attn_cp_rank) + self.model_config.is_nsa_cp = is_nsa_cp + self.model_config.nnodes = max(1, int(nnodes)) + self.model_config.node_rank = int(node_rank) + # Multi-node bootstrap: master host (derived from sglang --dist-init-addr). + # ``None`` here falls back to FLEXKV_MASTER_HOST env var downstream. + _dist_init_addr = getattr(server_args, 'dist_init_addr', None) + master_host = _dist_init_addr.split(":")[0] if _dist_init_addr and int(nnodes) > 1 else None + self.model_config.master_host = master_host update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) + hf_config = getattr(sglang_config, 'hf_config', None) + self._detect_indexer_config_from_hf(hf_config, source="sglang") + + if self.cache_config.indexer is not None: + logger.info( + f"[FlexKV] Complete indexer config (sglang): " + f"head_size={self.cache_config.indexer.head_size}, " + f"dtype={self.cache_config.indexer.dtype}, " + f"num_layers={self.model_config.num_layers}, " + f"tokens_per_block={self.cache_config.tokens_per_block}" + ) + + # Log primitive and derived variables for verification + logger.info( + f"[FlexKV sglang] Primitive vars set: tp_size={self.model_config.tp_size}, " + f"tp_rank={self.model_config.tp_rank}, dp_size={self.model_config.dp_size}, " + f"dp_rank={self.model_config.dp_rank}, pp_size={self.model_config.pp_size}, " + f"pp_rank={self.model_config.pp_rank}, " + f"enable_dp_attention={self.model_config.enable_dp_attention}, " + f"attn_cp_size={self.model_config.attn_cp_size}, " + f"attn_cp_rank={self.model_config.attn_cp_rank}, " + f"is_nsa_cp={self.model_config.is_nsa_cp}, " + f"nnodes={self.model_config.nnodes}, node_rank={self.model_config.node_rank}" + ) + logger.info( + f"[FlexKV sglang] Derived vars: attn_tp_size={self.model_config.attn_tp_size}, " + f"attn_tp_rank={self.model_config.attn_tp_rank}, " + f"tp_rank_per_node={self.model_config.tp_rank_per_node}, " + f"tp_size_per_node={self.model_config.tp_size_per_node}, " + f"local_rank={self.model_config.local_rank}" + ) + + # Freeze model_config — no further mutations allowed + self.model_config.freeze() + def post_init_from_trt_config( self, config, @@ -126,22 +350,7 @@ def post_init_from_trt_config( # Convert dtype string to torch.dtype dtype_str = config.pytorch_backend_config.kv_cache_dtype flexkv_logger.info(f"[FlexKVConfig] dtype_str from TRT config: {dtype_str}") - - # Helper function to convert dtype string to torch.dtype - def _parse_dtype_str(dtype_str: str) -> torch.dtype: - dtype_map = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "float8": torch.float8_e4m3fn, - "e4m3": torch.float8_e4m3fn, - } - return dtype_map.get(dtype_str.lower(), torch.bfloat16) - + if dtype_str == "auto": # When dtype_str is "auto", try to get kv_cache_dtype from user_config first # This allows users to specify kv_cache_dtype in flexkv_config.json or via environment variable @@ -164,7 +373,7 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.dtype = _parse_dtype_str(dtype_str) else: self.model_config.dtype = dtype_str - + # Set model config (parallel configs part) if config.mapping.enable_attention_dp: self.model_config.tp_size = 1 @@ -172,19 +381,29 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: else: self.model_config.tp_size = config.mapping.tp_size self.model_config.dp_size = 1 - + self.model_config.pp_size = getattr(config.mapping, 'pp_size', 1) + self.model_config.pp_rank = getattr(config.mapping, 'pp_rank', 0) + + if self.model_config.pp_size > 1: + layers_range = config.mapping.pp_layers(self.model_config.num_layers) + self.model_config.pp_start_layer = layers_range[0] + self.model_config.pp_end_layer = layers_range[-1] + 1 + else: + self.model_config.pp_start_layer = 0 + self.model_config.pp_end_layer = self.model_config.num_layers + # self.model_config (model configs part) try: model_path = getattr(config, 'hf_model_dir', None) from transformers import AutoConfig as HFAutoConfig hf_config = HFAutoConfig.from_pretrained( - str(model_path), + str(model_path), trust_remote_code=True ) self.model_config.num_layers = hf_config.num_hidden_layers - self.model_config.use_mla = (hasattr(hf_config, 'kv_lora_rank') and + self.model_config.use_mla = (hasattr(hf_config, 'kv_lora_rank') and hf_config.kv_lora_rank is not None and - hasattr(hf_config, 'qk_rope_head_dim') and + hasattr(hf_config, 'qk_rope_head_dim') and hf_config.qk_rope_head_dim is not None) if self.model_config.use_mla: self.model_config.head_size = hf_config.kv_lora_rank + hf_config.qk_rope_head_dim @@ -197,8 +416,29 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: else: self.model_config.head_size = hf_config.hidden_size // hf_config.num_attention_heads self.model_config.num_kv_heads = hf_config.num_attention_heads - + + self._detect_indexer_config_from_hf(hf_config, source="TRT-LLM") except Exception as e: flexkv_logger.error(f"Failed to load config from {model_path}: {e}") # Update cache config with user config after model config is initialized update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) + + # Log primitive and derived variables for verification + flexkv_logger.info( + f"[FlexKV TRT-LLM] Primitive vars set: tp_size={self.model_config.tp_size}, " + f"tp_rank={self.model_config.tp_rank}, dp_size={self.model_config.dp_size}, " + f"dp_rank={self.model_config.dp_rank}, pp_size={self.model_config.pp_size}, " + f"pp_rank={self.model_config.pp_rank}, " + f"enable_dp_attention={self.model_config.enable_dp_attention}, " + f"attn_cp_size={self.model_config.attn_cp_size}, " + f"attn_cp_rank={self.model_config.attn_cp_rank}, " + f"nnodes={self.model_config.nnodes}, node_rank={self.model_config.node_rank}" + ) + flexkv_logger.info( + f"[FlexKV TRT-LLM] Derived vars: attn_tp_size={self.model_config.attn_tp_size}, " + f"attn_tp_rank={self.model_config.attn_tp_rank}, " + f"local_rank={self.model_config.local_rank}" + ) + + # Freeze model_config — no further mutations allowed + self.model_config.freeze() diff --git a/flexkv/integration/tensorrt_llm/trtllm_adapter.py b/flexkv/integration/tensorrt_llm/trtllm_adapter.py index 58853da657..db925fe6b7 100644 --- a/flexkv/integration/tensorrt_llm/trtllm_adapter.py +++ b/flexkv/integration/tensorrt_llm/trtllm_adapter.py @@ -45,7 +45,7 @@ def __init__(self, config: ExecutorConfig): self.flexkv_manager = KVManager(model_config=self.model_config, cache_config=self.cache_config, server_recv_port=flexkv_config.server_recv_port, - dp_client_id=self.dp_rank) + dp_client_id=self.model_config.dp_rank) self.flexkv_manager.start() # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) @@ -209,6 +209,8 @@ def _get_match( task_id, matched_mask = self.flexkv_manager.get_match( token_ids=np_token_ids, token_mask=np_token_mask, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) num_new_matched_tokens = matched_mask.sum().item() @@ -366,6 +368,8 @@ def _put_match( task_id, unmatched_mask = self.flexkv_manager.put_match( token_ids=np_token_ids, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) @@ -538,26 +542,30 @@ def __init__(self, config: ExecutorConfig): self.remote_process = TransferManagerOnRemote.create_process() flexkv_logger.info(f"TransferManagerOnRemote process created, PID: {self.remote_process.pid}") - flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, dp_client_id: {dp_client_id}") - self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_client_id, current_device_id) + flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, dp_rank: {dp_rank}") + self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_rank=dp_rank, pp_rank=flexkv_config.model_config.pp_rank, device_id=current_device_id) flexkv_logger.info("Finish init FlexKVWorkerConnector") def _need_to_create_remote_process(self) -> bool: """Check if need to create TransferManagerOnRemote process. Returns True when all of the following conditions are met: - - Multi-node TP is detected (tp_size > gpus_per_node) + - Multi-node TP is detected (nnodes_per_tp_group > 1) - Current node is not master node (node_rank > 0) - - Current worker is worker0 in TP group (tp_rank == 0) + - Current worker is worker0 in the local TP group (tp_rank_per_node == 0) Returns: bool: True if need to create TransferManagerOnRemote process, False otherwise. """ try: is_master_node = self.node_rank == 0 - is_first_worker = self.tp_rank % 8 == 0 - is_multinode_tp = self.flexkv_config.model_config.tp_size > torch.cuda.device_count() - flexkv_logger.info(f"{is_master_node=}, {is_first_worker=}, {is_multinode_tp=}") + is_first_worker = self.flexkv_config.model_config.tp_rank_per_node == 0 + is_multinode_tp = self.flexkv_config.model_config.nnodes_per_tp_group > 1 + flexkv_logger.info( + f"{is_master_node=}, {is_first_worker=}, {is_multinode_tp=}, " + f"nnodes_per_tp_group={self.flexkv_config.model_config.nnodes_per_tp_group}, " + f"tp_rank_per_node={self.flexkv_config.model_config.tp_rank_per_node}" + ) return is_multinode_tp and not is_master_node and is_first_worker except Exception as e: diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index b88e17b55b..74cebf5f78 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -336,6 +336,8 @@ def _get_match( task_id, matched_mask = self.flexkv_manager.get_match( token_ids=np_token_ids, token_mask=np_token_mask, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) num_new_matched_tokens = matched_mask.sum().item() @@ -484,6 +486,8 @@ def _put_match( namespace = self._extract_namespace(request) task_id, unmatched_mask = self.flexkv_manager.put_match( token_ids=np_token_ids, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) @@ -704,13 +708,27 @@ def __init__( logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, " f"server_client_mode={server_client_mode}, dp_client_id={dp_client_id}, " f"client_id={client_id}, device_id={device_id}") - self.tp_client = KVTPClient(flexkv_config.gpu_register_port, client_id, device_id) + self.tp_client = KVTPClient(flexkv_config.gpu_register_port, + dp_rank=client_id, + pp_rank=self.flexkv_config.model_config.pp_rank, + device_id=device_id) logger.info("Finish init FlexKVWorkerConnector") def register_to_server(self, kv_caches: dict[str, torch.Tensor]): logger.info("Start register kv_caches") - gpu_blocks = list(kv_caches.values()) - num_layer = len(kv_caches) + + # Separate main KV caches from indexer caches by layer name. + main_kv_caches: dict[str, torch.Tensor] = {} + indexer_kv_caches: dict[str, torch.Tensor] = {} + for layer_name, tensor in kv_caches.items(): + if ".k_cache" in layer_name: + indexer_kv_caches[layer_name] = tensor + else: + main_kv_caches[layer_name] = tensor + + # Build main KV cache layout + gpu_blocks = list(main_kv_caches.values()) + num_layer = len(main_kv_caches) if self.flexkv_config.model_config.use_mla: assert gpu_blocks[0].ndim == 3, ( f"expect kv cached tensor has 3 dim but get shape={gpu_blocks[0].shape}.") @@ -734,7 +752,32 @@ def register_to_server(self, kv_caches: dict[str, torch.Tensor]): head_size=head_size, is_mla=self.flexkv_config.model_config.use_mla, ) - self.tp_client.register_to_server(gpu_blocks, gpu_layout) + + # Build indexer layout if indexer caches are present + indexer_buffers = None + indexer_layout = None + if indexer_kv_caches: + indexer_buffers = list(indexer_kv_caches.values()) + first_indexer_buffer = indexer_buffers[0] + assert first_indexer_buffer.ndim == 3, ( + f"expect indexer cache tensor has 3 dim but get shape={first_indexer_buffer.shape}.") + indexer_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=len(indexer_buffers), + num_block=first_indexer_buffer.shape[0], + tokens_per_block=first_indexer_buffer.shape[1], + num_head=1, + head_size=first_indexer_buffer.shape[2], + is_mla=True, + ) + + self.tp_client.register_to_server( + kv_caches=gpu_blocks, + kv_layout=gpu_layout, + indexer_buffers=indexer_buffers, + indexer_layout=indexer_layout, + ) + logger.info("Finish register kv_caches") def __del__(self): diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 8e61a15136..20dce81e01 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -55,6 +55,11 @@ def __init__(self, else: self.gpu_register_port = self.server_recv_port + "_gpu_register" + flexkv_logger.info( + f"[KVManager] IPC ports: server_recv_port={self.server_recv_port}, " + f"gpu_register_port={self.gpu_register_port}" + ) + # Multi-instance mode also requires server_client_mode self.server_client_mode = (model_config.dp_size > 1 or self.instance_num > 1 or @@ -67,24 +72,10 @@ def __init__(self, flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") self.redis_meta_client = None - if self.cache_config.enable_kv_sharing: - flexkv_logger.info(f"[kv manager] initializing RedisMeta and connection to \ - {self.cache_config.redis_host}:{self.cache_config.redis_port}") - # initialize redis Meta obj - self.redis_meta_client = RedisMeta( - self.cache_config.redis_host, - self.cache_config.redis_port, - self.cache_config.redis_password, - self.cache_config.local_ip, - ) - self.redis_meta_client.init_meta() - # update distributed_node_id - self.cache_config.distributed_node_id = self.redis_meta_client.get_node_id() # update distributed_node_id of current node - - self.enable_mps = GLOBAL_CONFIG_FROM_ENV.enable_mps if self.server_client_mode: + # In server_client_mode, RedisMeta is created and initialized inside KVServer # Server should only be created once across all instances and dp ranks if self.instance_id == 0 and dp_client_id == 0: total_clients = self.instance_num * model_config.dp_size @@ -99,6 +90,21 @@ def __init__(self, self.server_handle = None self.dp_client = KVDPClient(self.server_recv_port, self.model_config, self.global_client_id) else: + # In non-server_client_mode, create RedisMeta here and pass to KVTaskEngine + if self.cache_config.enable_kv_sharing: + flexkv_logger.info(f"[kv manager] initializing RedisMeta and connection to " + f"{self.cache_config.redis_host}:{self.cache_config.redis_port}") + self.redis_meta_client = RedisMeta( + self.cache_config.redis_host, + self.cache_config.redis_port, + self.cache_config.redis_password, + self.cache_config.local_ip, + node_ttl_seconds=self.cache_config.node_ttl_seconds, + ) + self.redis_meta_client.init_meta() + # update distributed_node_id + self.cache_config.distributed_node_id = self.redis_meta_client.get_node_id() + self.server_handle = None self.kv_task_engine = KVTaskEngine(self.model_config, self.cache_config, self.gpu_register_port, redis_meta=self.redis_meta_client, event_collector=event_collector) @@ -127,6 +133,10 @@ def is_ready(self) -> bool: def shutdown(self) -> None: if self.server_client_mode: self.dp_client.shutdown() + # Wait for the server process to exit after sending shutdown request + if self.server_handle is not None: + self.server_handle.shutdown() + self.server_handle = None else: self.kv_task_engine.shutdown() @@ -141,7 +151,8 @@ def get_async(self, slot_mapping: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: if isinstance(token_ids, torch.Tensor): @@ -155,13 +166,15 @@ def get_async(self, slot_mapping, token_mask, layer_granularity, + pp_rank=pp_rank, namespace=namespace) else: task_id, _ = self.kv_task_engine.get_async(token_ids, slot_mapping, token_mask, layer_granularity, - dp_id, + dp_rank, + pp_rank=pp_rank, namespace=namespace) return task_id @@ -169,7 +182,9 @@ def get_match(self, token_ids: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, + cpu_only: bool = False, namespace: Optional[List[str]] = None, ) -> Tuple[int, np.ndarray]: if isinstance(token_ids, torch.Tensor): @@ -180,12 +195,16 @@ def get_match(self, task_id, mask = self.dp_client.get_match(token_ids, token_mask, layer_granularity, + pp_rank=pp_rank, + cpu_only=cpu_only, namespace=namespace) else: task_id, mask = self.kv_task_engine.get_match(token_ids, token_mask, layer_granularity, - dp_id, + dp_rank, + pp_rank=pp_rank, + cpu_only=cpu_only, namespace=namespace) return task_id, mask @@ -193,7 +212,8 @@ def put_async(self, token_ids: Union[torch.Tensor, np.ndarray], slot_mapping: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: if isinstance(token_ids, torch.Tensor): @@ -203,15 +223,16 @@ def put_async(self, if isinstance(token_mask, torch.Tensor): token_mask = token_mask.numpy() if self.server_client_mode: - task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask, namespace=namespace) + task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask, pp_rank=pp_rank, namespace=namespace) else: - task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_id, namespace=namespace) + task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_rank, pp_rank=pp_rank, namespace=namespace) return task_id def put_match(self, token_ids: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> Tuple[int, np.ndarray]: if isinstance(token_ids, torch.Tensor): @@ -219,27 +240,30 @@ def put_match(self, if isinstance(token_mask, torch.Tensor): token_mask = token_mask.numpy() if self.server_client_mode: - task_id, mask = self.dp_client.put_match(token_ids, token_mask, namespace=namespace) + task_id, mask = self.dp_client.put_match(token_ids, token_mask, pp_rank=pp_rank, namespace=namespace) else: - task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id, namespace=namespace) + task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_rank, pp_rank=pp_rank, namespace=namespace) return task_id, mask def prefetch_async(self, token_ids: np.ndarray, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None) -> int: if isinstance(token_ids, torch.Tensor): token_ids = token_ids.numpy() if self.server_client_mode: - task_id = self.dp_client.prefetch_async(token_ids, namespace=namespace) + task_id = self.dp_client.prefetch_async(token_ids, pp_rank=pp_rank, namespace=namespace) else: - task_id = self.kv_task_engine.prefetch_async(token_ids, dp_id=dp_id, namespace=namespace) + task_id = self.kv_task_engine.prefetch_async(token_ids, dp_rank=dp_rank, pp_rank=pp_rank, namespace=namespace) return task_id def launch(self, task_ids: Union[int, List[int]], slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]], - as_batch: bool = False) -> List[int]: + as_batch: bool = False, + layerwise_transfer: bool = False, + counter_id: int = 0) -> List[int]: if isinstance(task_ids, int): task_ids = [task_ids] if not isinstance(slot_mappings, List): @@ -247,9 +271,15 @@ def launch(self, if isinstance(slot_mappings[0], torch.Tensor): slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] if self.server_client_mode: - return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch) + return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer, counter_id) else: - return self.kv_task_engine.launch_tasks(task_ids, slot_mappings, as_batch) + return self.kv_task_engine.launch_tasks( + task_ids, + slot_mappings, + as_batch=as_batch, + layerwise_transfer=layerwise_transfer, + counter_id=counter_id + ) def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index fa77bd0c1f..99e85b5b3c 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -9,23 +9,23 @@ import os from expiring_dict import ExpiringDict import nvtx -import torch import numpy as np -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger from flexkv.common.block import hash_token from flexkv.common.transfer import TransferOpGraph, merge_to_batch_graph, get_nvtx_default_color, CompletedOp from flexkv.common.tracer import FlexKVTracer -from flexkv.cache.cache_engine import GlobalCacheEngine, DEFAULT_CACHE_STRATEGY +from flexkv.cache.cache_engine import GlobalCacheEngine, DEFAULT_CACHE_STRATEGY, CPUONLY_CACHE_STRATEGY from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.transfer_manager import ( - get_master_host_and_ports_from_env, + resolve_master_host_and_ports, get_trtllm_subprocess_host_and_ports_from_env ) from flexkv.cache.redis_meta import RedisMeta from flexkv.integration.dynamo.collector import KVEventCollector +from flexkv.transfer_manager import TransferManagerMultiNodeHandle class TaskStatus(Enum): # slot mapping is not ready @@ -61,7 +61,6 @@ class KVTask: token_ids: np.ndarray slot_mapping: np.ndarray token_mask: Optional[np.ndarray] - dp_id: int # cache engine return graph: TransferOpGraph @@ -69,6 +68,9 @@ class KVTask: callback: Optional[Union[Callable, List[Callable]]] op_callback_dict: Dict[int, Callable] + dp_rank: int = 0 + pp_rank: int = 0 + # batch: points to the batch task id if this task was merged into a batch batch_task_id: Optional[int] = None @@ -107,27 +109,55 @@ def __init__(self, self.model_config = model_config self._check_config(model_config, cache_config) - self.is_multinode_tp = False - self.tp_node_count = 1 - if self.model_config.tp_size > torch.cuda.device_count(): - if self.model_config.tp_size != torch.cuda.device_count() * 2: - raise ValueError("Only support 2 nodes TP for now") - assert self.model_config.dp_size == 1 - self.tp_node_count = self.model_config.tp_size // torch.cuda.device_count() - self.is_multinode_tp = True + # ---- Multi-node topology ---- + nnodes = self.model_config.nnodes + pp_size = self.model_config.pp_size + tp_size = self.model_config.tp_size + + total_gpus = tp_size * pp_size + if total_gpus % nnodes != 0: + raise ValueError( + f"[KVTaskEngine] cannot derive gpus_per_node: " + f"tp*pp={total_gpus} not divisible by nnodes={nnodes}" + ) + gpus_per_node = total_gpus // nnodes + + self.nnodes_per_tp_group = max( + (tp_size + gpus_per_node - 1) // gpus_per_node, 1 + ) + if self.nnodes_per_tp_group > 2: + raise ValueError( + f"Only support 2-nodes TP for now, but got " + f"nnodes_per_tp_group={self.nnodes_per_tp_group} " + f"(tp_size={tp_size}, gpus_per_node={gpus_per_node})" + ) + + if tp_size % self.nnodes_per_tp_group != 0: + raise ValueError( + f"[KVTaskEngine] tp_size={tp_size} not divisible by " + f"nnodes_per_tp_group={self.nnodes_per_tp_group}" + ) + tp_size_per_node = tp_size // self.nnodes_per_tp_group + + flexkv_logger.info( + f"[KVTaskEngine] topology: " + f"nnodes={nnodes}, " + f"node_rank={self.model_config.node_rank}, " + f"gpus_per_node={gpus_per_node}, " + f"tp_size={tp_size}, " + f"pp_size={pp_size}, " + f"dp_size={self.model_config.dp_size}, " + f"nnodes_per_tp_group={self.nnodes_per_tp_group}, " + f"tp_size_per_node={tp_size_per_node}, " + f"master_host={self.model_config.master_host!r}" + ) self.cache_engine = GlobalCacheEngine(cache_config, model_config, redis_meta, event_collector) - model_config_for_transfer = copy.deepcopy(self.model_config) - if self.is_multinode_tp: - model_config_for_transfer.tp_size //= self.tp_node_count - if not self.model_config.use_mla: - model_config_for_transfer.num_kv_heads //= self.tp_node_count - combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1" if not combine_with_trtllm: self.transfer_handles = [TransferManagerHandle( - model_config_for_transfer, + self.model_config, self.cache_config, mode="process", gpu_register_port=gpu_register_port @@ -140,7 +170,7 @@ def __init__(self, self.remote_process = TransferManagerOnRemote.create_process(mode="TrtllmSubprocess") self.transfer_handles = [ TransferManagerHandle( - model_config_for_transfer, + self.model_config, self.cache_config, mode="remote", gpu_register_port=gpu_register_port, @@ -150,10 +180,12 @@ def __init__(self, ] self.transfer_handles[0]._handle.send_config_to_remotes() - if self.is_multinode_tp: - master_host, master_ports = get_master_host_and_ports_from_env() + if self.model_config.nnodes > 1: + master_host, master_ports = resolve_master_host_and_ports( + master_host=self.model_config.master_host + ) self.transfer_handles.append(TransferManagerHandle( - model_config_for_transfer, + self.model_config, self.cache_config, mode="remote", gpu_register_port=gpu_register_port, @@ -206,8 +238,10 @@ def create_get_task(self, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, is_fake_slot_mapping: bool = False, + temp_cache_strategy=DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None, ) -> None: if task_id in self.tasks: @@ -217,9 +251,11 @@ def create_get_task(self, token_ids=token_ids, token_mask=token_mask, slot_mapping=slot_mapping, - layer_num=self.model_config.num_layers, + layer_num=self.model_config.num_layers_per_pp_stage, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, + temp_cache_strategy=temp_cache_strategy, namespace=namespace) self.tasks[task_id] = KVTask( task_id=task_id, @@ -230,7 +266,8 @@ def create_get_task(self, token_ids=token_ids, slot_mapping=slot_mapping, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, graph=graph, return_mask=return_mask, callback=callback, @@ -243,7 +280,8 @@ def create_put_task(self, token_ids: np.ndarray, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, is_fake_slot_mapping: bool = False, namespace: Optional[List[str]] = None, ) -> None: @@ -254,8 +292,9 @@ def create_put_task(self, token_ids=token_ids, token_mask=token_mask, slot_mapping=slot_mapping, - layer_num=self.model_config.num_layers, - dp_id=dp_id, + layer_num=self.model_config.num_layers_per_pp_stage, + dp_rank=dp_rank, + pp_rank=pp_rank, namespace=namespace) self.tasks[task_id] = KVTask( task_id=task_id, @@ -266,7 +305,8 @@ def create_put_task(self, token_ids=token_ids, slot_mapping=slot_mapping, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, graph=graph, return_mask=return_mask, callback=callback, @@ -276,6 +316,7 @@ def create_put_task(self, def create_prefetch_task(self, task_id: int, token_ids: np.ndarray, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> None: if task_id in self.tasks: @@ -290,7 +331,9 @@ def create_prefetch_task(self, token_ids=token_ids, token_mask=fake_token_mask, slot_mapping=fake_slot_mapping, - layer_num=self.model_config.num_layers, + layer_num=self.model_config.num_layers_per_pp_stage, + dp_rank=0, # dp_rank irrelevant: prefetch only uploads to CPU (ignore_gpu=True) + pp_rank=pp_rank, temp_cache_strategy=temp_cache_strategy, namespace=namespace) self.tasks[task_id] = KVTask( @@ -302,7 +345,8 @@ def create_prefetch_task(self, token_ids=token_ids, slot_mapping=fake_slot_mapping, # ignore slot_mapping for prefetch token_mask=fake_token_mask, # ignore token_mask for prefetch - dp_id=0, # ignore dp_id for prefetch + dp_rank=0, # ignore dp_rank for prefetch + pp_rank=pp_rank, graph=graph, return_mask=return_mask, callback=callback, @@ -319,7 +363,21 @@ def _launch_task(self, task_id: int) -> None: nvtx.mark(f"launch task: task_id={task_id}, graph_id={transfer_graph.graph_id}") if transfer_graph.num_ops > 0: for transfer_handle in self.transfer_handles: - transfer_handle.submit(transfer_graph) + # For remote handles: deepcopy graph and clear GPU blocks when + # it's a cross-machine PP handle (different PP stages have + # different GPU block_ids). Cross-machine TP handles share + # the same slot_mapping, so no clear is needed. + if isinstance(transfer_handle._handle, TransferManagerMultiNodeHandle): + if self.model_config.nnodes > 1 and self.model_config.pp_size > 1: + # Cross-machine PP: each PP rank has different GPU blocks + graph_copy = copy.deepcopy(transfer_graph) + graph_copy.clear_gpu_blocks() + transfer_handle.submit(graph_copy, task_end_op_id=self.tasks[task_id].task_end_op_id) + else: + # Cross-machine TP: same slot_mapping across TP ranks + transfer_handle.submit(transfer_graph, task_end_op_id=self.tasks[task_id].task_end_op_id) + else: + transfer_handle.submit(transfer_graph, task_end_op_id=self.tasks[task_id].task_end_op_id) def _update_tasks(self, timeout: float = 0.001) -> None: completed_ops = self._get_completed_ops(timeout) @@ -456,7 +514,7 @@ def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> raise ValueError("remote_file_size must not None if use file_size model") if model_config.use_mla: kv_size = ( - model_config.num_layers + model_config.num_layers_per_pp_stage * cache_config.tokens_per_block * model_config.num_kv_heads * model_config.head_size @@ -464,7 +522,7 @@ def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> ) else: kv_size = ( - model_config.num_layers + model_config.num_layers_per_pp_stage * 2 * cache_config.tokens_per_block * model_config.num_kv_heads @@ -493,16 +551,18 @@ def get_async(self, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: - self._sync_prefetch(token_ids, namespace) + # self._sync_prefetch(token_ids, namespace) task_id, return_mask = self._get_match_impl(token_ids, slot_mapping, is_fake_slot_mapping=False, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, task_id=task_id, namespace=namespace) # trace get request @@ -513,7 +573,8 @@ def get_async(self, slot_mapping=slot_mapping, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) self._launch_task(task_id) return task_id, return_mask @@ -522,14 +583,16 @@ def put_async(self, token_ids: np.ndarray, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: task_id, return_mask = self._put_match_impl(token_ids, slot_mapping, is_fake_slot_mapping=False, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, task_id=task_id, namespace=namespace) # trace put request @@ -540,7 +603,8 @@ def put_async(self, slot_mapping=slot_mapping, token_mask=token_mask, layer_granularity=-1, # put has no layer_granularity parameter - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) self._launch_task(task_id) return task_id, return_mask @@ -645,11 +709,13 @@ def get_match(self, token_ids: np.ndarray, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, + cpu_only: bool = False, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) - self._sync_prefetch(token_ids, namespace) + # self._sync_prefetch(token_ids, namespace) if token_mask is None: token_mask = np.ones_like(token_ids, dtype=bool) fake_slot_mapping = np.zeros_like(token_ids[token_mask]) @@ -658,7 +724,9 @@ def get_match(self, is_fake_slot_mapping=True, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, + cpu_only=cpu_only, task_id=task_id, namespace=namespace) # trace get match request @@ -669,7 +737,8 @@ def get_match(self, slot_mapping=fake_slot_mapping, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) nvtx.pop_range() return result_task_id, return_mask @@ -680,23 +749,30 @@ def _get_match_impl(self, is_fake_slot_mapping: bool = False, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, + cpu_only: bool = False, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: if token_mask is None: token_mask = np.ones_like(token_ids) if layer_granularity == -1: - layer_granularity = self.model_config.num_layers + layer_granularity = self.model_config.num_layers_per_pp_stage if task_id == -1: task_id = self._gen_task_id() + temp_cache_strategy = DEFAULT_CACHE_STRATEGY + if cpu_only: + temp_cache_strategy = CPUONLY_CACHE_STRATEGY nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) self.create_get_task(task_id, token_ids, slot_mapping, token_mask, layer_granularity, - dp_id, + dp_rank, + pp_rank=pp_rank, is_fake_slot_mapping=is_fake_slot_mapping, + temp_cache_strategy=temp_cache_strategy, namespace=namespace) self._process_empty_graph(task_id) nvtx.pop_range() @@ -705,7 +781,8 @@ def _get_match_impl(self, def put_match(self, token_ids: np.ndarray, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: fake_slot_mapping = np.zeros_like(token_ids) @@ -713,7 +790,8 @@ def put_match(self, fake_slot_mapping, is_fake_slot_mapping=True, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, task_id=task_id, namespace=namespace) # trace put match request @@ -724,7 +802,8 @@ def put_match(self, slot_mapping=fake_slot_mapping, token_mask=token_mask, layer_granularity=-1, # put has no layer_granularity parameter - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) return result_task_id, return_mask @@ -733,7 +812,8 @@ def _put_match_impl(self, slot_mapping: np.ndarray, is_fake_slot_mapping: bool = False, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: if token_mask is None: @@ -745,7 +825,8 @@ def _put_match_impl(self, token_ids, slot_mapping, token_mask, - dp_id, + dp_rank, + pp_rank=pp_rank, is_fake_slot_mapping=is_fake_slot_mapping, namespace=namespace) self._process_empty_graph(task_id) @@ -754,13 +835,14 @@ def _put_match_impl(self, def prefetch_async(self, token_ids: np.ndarray, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> int: if task_id == -1: task_id = self._gen_task_id() nvtx.push_range(f"prefetch match: task_id={task_id}", color=get_nvtx_default_color()) - self.create_prefetch_task(task_id, token_ids, namespace=namespace) + self.create_prefetch_task(task_id, token_ids, pp_rank=pp_rank, namespace=namespace) self._process_empty_graph(task_id) nvtx.pop_range() # trace prefetch async request @@ -771,15 +853,20 @@ def prefetch_async(self, slot_mapping=np.zeros_like(token_ids), token_mask=np.ones_like(token_ids), layer_granularity=-1, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) self._launch_task(task_id) return task_id def merge_to_batch_kvtask(self, + batch_id: int, + task_ids: List[int], - batch_task_type: TaskType) -> TransferOpGraph: + batch_task_type: TaskType, + layerwise_transfer: bool = False, + counter_id: int = 0) -> TransferOpGraph: op_callback_dict = {} task_end_op_ids = [] callbacks = [] @@ -796,11 +883,12 @@ def merge_to_batch_kvtask(self, task_end_op_ids.append(self.tasks[task_id].task_end_op_id) callbacks.append(self.tasks[task_id].callback) return_masks.append(self.tasks[task_id].return_mask) - batch_task_graph, task_end_op_id, op_callback_dict = merge_to_batch_graph(batch_id, transfer_graphs, task_end_op_ids, - op_callback_dict) + op_callback_dict, + layerwise_transfer, + counter_id) self.tasks[batch_id] = KVTask( task_id=batch_id, token_ids=np.concatenate([self.tasks[task_id].token_ids for task_id in task_ids]), @@ -810,7 +898,8 @@ def merge_to_batch_kvtask(self, task_end_op_id=task_end_op_id, task_end_op_finished=False, status=TaskStatus.READY, - dp_id=self.tasks[task_ids[0]].dp_id, + dp_rank=self.tasks[task_ids[0]].dp_rank, + pp_rank=self.tasks[task_ids[0]].pp_rank, graph=batch_task_graph, return_mask=return_masks, callback=callbacks, @@ -826,7 +915,9 @@ def launch_tasks(self, task_ids: List[int], slot_mappings: List[np.ndarray], as_batch: bool = False, - batch_id: int = -1) -> List[int]: + batch_id: int = -1, + layerwise_transfer: bool = False, + counter_id: int = 0) -> List[int]: assert isinstance(slot_mappings[0], np.ndarray) # trace launch tasks self.tracer.trace_launch_tasks(task_ids, slot_mappings, as_batch) @@ -837,11 +928,21 @@ def launch_tasks(self, all_get = all(self.tasks[tid].task_type == TaskType.GET for tid in task_ids) all_put = all(self.tasks[tid].task_type == TaskType.PUT for tid in task_ids) - if len(task_ids) > 1 and as_batch and (all_get or all_put): + if (len(task_ids) > 1 or layerwise_transfer) and as_batch and (all_get or all_put): if batch_id == -1: batch_id = self._gen_task_id() + if layerwise_transfer: + if not GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + flexkv_logger.warning("layerwise transfer is not enabled") + layerwise_transfer = False + elif not all_get: + flexkv_logger.warning("only support layerwise get") + layerwise_transfer = False batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT - transfer_graphs = [self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type)] + batch_task_graph = self.merge_to_batch_kvtask( + batch_id, task_ids, batch_task_type, layerwise_transfer, counter_id + ) + transfer_graphs = [batch_task_graph] self.tasks[batch_id].status = TaskStatus.RUNNING task_ids = [batch_id] else: diff --git a/flexkv/mooncakeEngineWrapper.py b/flexkv/mooncakeEngineWrapper.py index bcb080f44b..b0f24f3eb1 100644 --- a/flexkv/mooncakeEngineWrapper.py +++ b/flexkv/mooncakeEngineWrapper.py @@ -31,13 +31,18 @@ def __init__( ) if config is None: - mooncake_config_path = os.environ["MOONCAKE_CONFIG_PATH"] + mooncake_config_path = os.environ.get("MOONCAKE_CONFIG_PATH") + if mooncake_config_path is None: + raise RuntimeError( + "MOONCAKE_CONFIG_PATH is not set. Please set the MOONCAKE_CONFIG_PATH " + "environment variable or pass a MooncakeTransferEngineConfig object." + ) self.config = MooncakeTransferEngineConfig.from_file(mooncake_config_path) else: self.config = config - self.engine_ip = config.engine_ip - self.engien_port = config.engine_port - self.mooncake_addr = f"{self.engine_ip}:{self.engien_port}" + self.engine_ip = self.config.engine_ip + self.engine_port = self.config.engine_port + self.mooncake_addr = f"{self.engine_ip}:{self.engine_port}" flexkv_logger.info(f"Mooncake listen on: {self.mooncake_addr}") supported_backend = ["redis"] @@ -51,13 +56,18 @@ def __init__( # transfer engine initialize self.engine = TransferEngine() + # Set Redis auth env vars for mooncake engine (it reads MC_REDIS_PASSWORD internally) + if self.config.metadata_server_auth: + os.environ["MC_REDIS_PASSWORD"] = self.config.metadata_server_auth + flexkv_logger.info("Set MC_REDIS_PASSWORD environment variable for mooncake Redis authentication") + self.engine.initialize_ext( self.mooncake_addr, self.config.metadata_server, self.config.protocol, self.config.device_name, self.metadata_backend, - ) + ) # mooncake operations def regist_buffer(self, buffer_ptr: int, buffer_size: int) -> int: @@ -65,7 +75,7 @@ def regist_buffer(self, buffer_ptr: int, buffer_size: int) -> int: ret = self.engine.register_memory(buffer_ptr, buffer_size) return ret if ret == 0 else -1 - def unregist_buffer(self, buffer_ptr: int) -> None: + def unregist_buffer(self, buffer_ptr: int) -> int: """Unregister the buffer to the mooncake engine.""" ret = self.engine.unregister_memory(buffer_ptr) return ret if ret == 0 else -1 @@ -92,7 +102,7 @@ def batch_transfer_sync_write(self, peer_engine_addr: str, src_ptr_list: List[in ret = self.engine.batch_transfer_sync_write(peer_engine_addr, src_ptr_list, dst_ptr_list, data_size_list) return ret if ret == 0 else -1 - def transfer_sync_write_with_notify(self, peer_engine_addr: str, src_ptr: int, dst_ptr: int, data_size: int, notify_name: str, msg : NotifyMsg): + def transfer_sync_write_with_notify(self, peer_engine_addr: str, src_ptr: int, dst_ptr: int, data_size: int, notify_name: str, msg : NotifyMsg) -> int: if not MOONCAKE_AVAILABLE: raise RuntimeError("Mooncake engine is not available") notify = engine.TransferNotify(notify_name, msg.to_string()) diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 9947fa2cf7..cc0ac1d267 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -94,6 +94,7 @@ def put_async( token_ids: np.ndarray, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray], + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: req = PutRequest(self.dp_client_id, @@ -101,6 +102,7 @@ def put_async( slot_mapping, token_mask if token_mask is not None else None, self._get_task_id(), + pp_rank, namespace) self.send_to_server.send_pyobj(req) return req.task_id @@ -109,12 +111,14 @@ def put_match( self, token_ids: np.ndarray, token_mask: Optional[np.ndarray], + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> Optional[Tuple[int, np.ndarray]]: req = PutMatchRequest(self.dp_client_id, token_ids, token_mask if token_mask is not None else None, self._get_task_id(), + pp_rank, namespace) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() @@ -127,9 +131,10 @@ def put_match( def prefetch_async( self, token_ids: np.ndarray, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: - req = PrefetchRequest(self.dp_client_id, token_ids, self._get_task_id(), namespace) + req = PrefetchRequest(self.dp_client_id, token_ids, self._get_task_id(), pp_rank, namespace) self.send_to_server.send_pyobj(req) return req.task_id @@ -139,6 +144,7 @@ def get_async( slot_mapping: np.ndarray, token_mask: Optional[np.ndarray], layer_granularity: int, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: req = GetRequest(self.dp_client_id, @@ -147,6 +153,7 @@ def get_async( token_mask if token_mask is not None else None, self._get_task_id(), layer_granularity, + pp_rank, namespace) self.send_to_server.send_pyobj(req) return req.task_id @@ -156,13 +163,17 @@ def get_match( token_ids: np.ndarray, token_mask: Optional[np.ndarray], layer_granularity: int, + pp_rank: int = 0, + cpu_only: bool = False, namespace: Optional[List[str]] = None, ) -> Optional[Tuple[int, np.ndarray]]: req = GetMatchRequest(self.dp_client_id, token_ids, token_mask if token_mask is not None else None, layer_granularity, + cpu_only, self._get_task_id(), + pp_rank, namespace) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() @@ -177,11 +188,13 @@ def launch_tasks( task_ids: List[int], slot_mappings: List[np.ndarray], as_batch: bool = False, + layerwise_transfer: bool = False, + counter_id: int = 0, ) -> List[int]: batch_id = -1 if as_batch: batch_id = self._get_task_id() - req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id) + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id, layerwise_transfer, counter_id) self.send_to_server.send_pyobj(req) return [batch_id] if as_batch else task_ids @@ -235,7 +248,8 @@ class KVTPClient: def __init__( self, gpu_register_port: str, - dp_client_id: int, + dp_rank: int, + pp_rank: int, device_id: int, ): # Init inter-process communication @@ -244,16 +258,47 @@ def __init__( context, zmq.SocketType.PUSH, gpu_register_port, False ) - self.dp_client_id = dp_client_id + self.dp_rank = dp_rank + self.pp_rank = pp_rank self.device_id = device_id - flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized!") + flexkv_logger.info(f"KVTPClient {device_id} of DP {self.dp_rank} Initialized! " + f"(gpu_register_port={gpu_register_port}, dp_rank={dp_rank}, pp_rank={pp_rank})") + + def set_slot_mapping(self, task_id: int, slot_mapping: np.ndarray) -> None: + """Send set_slot_mapping message to TransferManagerOnRemote via existing ZMQ channel. + + Reuses the same PUSH socket (send_to_server) that connects to + TransferManagerOnRemote's command_socket — no separate IPC socket needed. + """ + message = { + 'type': 'set_slot_mapping', + 'task_id': task_id, + 'slot_mapping': slot_mapping, + } + try: + self.send_to_server.send_pyobj(message, flags=zmq.NOBLOCK) + flexkv_logger.debug( + f"KVTPClient {self.device_id}: set_slot_mapping sent for task_id={task_id}" + ) + except zmq.Again: + flexkv_logger.warning( + f"KVTPClient {self.device_id}: zmq.Again when sending set_slot_mapping, " + f"retrying with blocking send..." + ) + self.send_to_server.send_pyobj(message) + flexkv_logger.info( + f"KVTPClient {self.device_id}: set_slot_mapping sent (blocking retry) " + f"for task_id={task_id}" + ) def register_to_server( self, kv_caches: List[torch.Tensor], kv_layout: KVCacheLayout, override_device_id: Optional[int] = None, + indexer_buffers: Optional[List[torch.Tensor]] = None, + indexer_layout: Optional[KVCacheLayout] = None, ) -> None: if not kv_caches or not kv_caches[0].is_cuda: raise ValueError("GPU blocks must be CUDA tensors") @@ -266,14 +311,34 @@ def register_to_server( handle = TensorSharedHandle(tensor, device_id) handles.append(handle) + # Build optional indexer handles + indexer_handles = None + if indexer_buffers is not None and len(indexer_buffers) > 0: + indexer_handles = [] + for tensor in indexer_buffers: + indexer_handles.append(TensorSharedHandle(tensor, device_id)) + register_req = RegisterTPClientRequest( - self.dp_client_id, + self.dp_rank, + self.pp_rank, device_id, handles, - kv_layout + kv_layout, + indexer_handles=indexer_handles, + indexer_gpu_layout=indexer_layout, ) - self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) + try: + self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) + flexkv_logger.info( + f"KVTPClient {device_id}: registration message sent " + f"(dp_rank={self.dp_rank}, num_kv_caches={len(kv_caches)})") + except zmq.Again: + flexkv_logger.error( + f"KVTPClient {device_id}: zmq.Again when sending registration " + f"(send buffer full or no connection). Retrying with blocking send...") + self.send_to_server.send_pyobj(register_req) + flexkv_logger.info(f"KVTPClient {device_id}: registration message sent (blocking retry)") if __name__ == "__main__": diff --git a/flexkv/server/request.py b/flexkv/server/request.py index e540f495cb..f1e22fdf62 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -18,10 +18,14 @@ class RegisterDPClientRequest: @dataclass class RegisterTPClientRequest: - dp_client_id: int + dp_rank: int + pp_rank: int device_id: int handles: List[TensorSharedHandle] gpu_layout: KVCacheLayout + # --- Indexer shadow transfer fields --- + indexer_handles: Optional[List[TensorSharedHandle]] = None + indexer_gpu_layout: Optional[KVCacheLayout] = None @dataclass class IsReadyRequest: @@ -34,6 +38,7 @@ class PutRequest: slot_mapping: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @@ -45,6 +50,7 @@ class GetRequest: token_mask: Optional[np.ndarray] task_id: int = -1 layer_granularity: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -52,6 +58,7 @@ class PrefetchRequest: dp_client_id: int token_ids: np.ndarray task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -60,6 +67,7 @@ class PutMatchRequest: token_ids: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -68,7 +76,9 @@ class GetMatchRequest: token_ids: np.ndarray token_mask: Optional[np.ndarray] layer_granularity: int + cpu_only: bool = False task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -78,6 +88,8 @@ class LaunchTaskRequest: slot_mappings: List[np.ndarray] as_batch: bool = False batch_id: int = -1 + layerwise_transfer: bool = False + counter_id: int = 0 # Counter set index for triple buffering eventfd notification @dataclass class CancelTaskRequest: @@ -125,3 +137,4 @@ class ShutdownRequest: @dataclass class CheckRunningRequest: dp_client_id: int + diff --git a/flexkv/server/server.py b/flexkv/server/server.py index dbefb260e2..208259c40e 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -16,6 +16,7 @@ from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger +from flexkv.cache.redis_meta import RedisMeta from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.kvtask import KVTaskEngine @@ -111,15 +112,31 @@ class KVServerHandle: def __init__(self, process: Union[mp.Process, 'subprocess.Popen']): self.process = process + def _is_alive(self) -> bool: + """Check if the process is still running (compatible with both Process and Popen).""" + if isinstance(self.process, subprocess.Popen): + return self.process.poll() is None + return self.process.is_alive() + + def _join(self, timeout: float = None) -> None: + """Wait for the process to finish (compatible with both Process and Popen).""" + if isinstance(self.process, subprocess.Popen): + try: + self.process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + pass + else: + self.process.join(timeout=timeout) + def shutdown(self) -> None: - self.process.join(timeout=5) - if self.process.is_alive(): + self._join(timeout=5) + if self._is_alive(): flexkv_logger.info("force terminate the server process") self.process.terminate() - self.process.join() + self._join() def __del__(self) -> None: - if self.process.is_alive(): + if self._is_alive(): self.shutdown() class KVServer: @@ -137,11 +154,32 @@ def __init__( self.context = zmq.Context(2) self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, server_recv_port, True) + flexkv_logger.info( + f"[KVServer] IPC ports bound: server_recv_port={server_recv_port}, " + f"gpu_register_port={gpu_register_port}" + ) # Use total_clients if provided (multi-instance mode), otherwise use dp_size max_clients = total_clients if total_clients > 0 else model_config.dp_size self.client_manager = ClientManager(max_num_dp_client=max_clients) - self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port) + + # Initialize RedisMeta in KVServer for server_client_mode + self.redis_meta_client = None + if cache_config.enable_kv_sharing: + flexkv_logger.info(f"[kv server] initializing RedisMeta and connection to " + f"{cache_config.redis_host}:{cache_config.redis_port}") + self.redis_meta_client = RedisMeta( + cache_config.redis_host, + cache_config.redis_port, + cache_config.redis_password, + cache_config.local_ip, + node_ttl_seconds=cache_config.node_ttl_seconds, + ) + self.redis_meta_client.init_meta() + # update distributed_node_id + cache_config.distributed_node_id = self.redis_meta_client.get_node_id() + + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, redis_meta=self.redis_meta_client) self.req_counter = 0 self._is_ready = False @@ -209,6 +247,12 @@ def create_server(cls, env.update(child_env) else: env = child_env or {} + # Always propagate FLEXKV_* env vars to child process so that + # runtime config overrides (e.g. FLEXKV_REBUILD_INTERVAL_MS) + # are visible when config.py is re-imported in the subprocess. + for key, val in os.environ.items(): + if key.startswith("FLEXKV_") and key not in env: + env[key] = val # Remove CUDA_VISIBLE_DEVICES so server can see all GPUs env.pop('CUDA_VISIBLE_DEVICES', None) @@ -285,17 +329,19 @@ def run(self) -> None: # Cleanup after shutdown flexkv_logger.info("Server shutting down, cleaning up...") - if hasattr(self, 'kvmanager'): - self.kvmanager.shutdown() + if hasattr(self, 'kv_task_engine'): + self.kv_task_engine.shutdown() flexkv_logger.info("Server shutdown complete") def _verify_model_config(self, model_config: ModelConfig) -> None: """Verify that client's model config matches server's config.""" + skip_fields = {"dp_rank"} for field in fields(ModelConfig): + if field.name in skip_fields: + continue client_val = getattr(model_config, field.name) server_val = getattr(self.model_config, field.name) - print(f"ModelConfig.{field.name} mismatch: client={client_val}, server={server_val}") assert client_val == server_val, \ f"ModelConfig.{field.name} mismatch: client={client_val}, server={server_val}" @@ -330,7 +376,8 @@ def _handle_get_request(self, req: GetRequest) -> None: slot_mapping=req.slot_mapping, token_mask=req.token_mask, layer_granularity=req.layer_granularity, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, namespace=req.namespace, ) @@ -340,7 +387,8 @@ def _handle_put_request(self, req: PutRequest) -> None: token_ids=req.token_ids, slot_mapping=req.slot_mapping, token_mask=req.token_mask, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, task_id=req.task_id, namespace=req.namespace, ) @@ -351,7 +399,9 @@ def _handle_get_match_request(self, req: GetMatchRequest) -> None: token_ids=req.token_ids, token_mask=req.token_mask, layer_granularity=req.layer_granularity, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, + cpu_only=req.cpu_only, task_id=req.task_id, namespace=req.namespace, ) @@ -364,7 +414,8 @@ def _handle_put_match_request(self, req: PutMatchRequest) -> None: req_id, mask = self.kv_task_engine.put_match( token_ids=req.token_ids, token_mask=req.token_mask, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, task_id=req.task_id, namespace=req.namespace, ) @@ -376,14 +427,20 @@ def _handle_prefetch_request(self, req: PrefetchRequest) -> None: """Handle Prefetch request""" task_id = self.kv_task_engine.prefetch_async( token_ids=req.token_ids, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, task_id=req.task_id, namespace=req.namespace, ) def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: """Handle LaunchTask request""" - self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings, req.as_batch, req.batch_id) + self.kv_task_engine.launch_tasks(req.task_ids, + req.slot_mappings, + req.as_batch, + req.batch_id, + req.layerwise_transfer, + req.counter_id) def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index ccb99e5b32..8e60841541 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -1,6 +1,10 @@ +import ctypes +import mmap import os +import weakref +from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Union, Dict, Any, BinaryIO +from typing import Tuple, List, Union, Dict, Any, BinaryIO try: from flexkv.c_ext import Pcfs except ImportError: @@ -128,6 +132,463 @@ def from_raw_data(cls, dtype=dtype, ) +# --------------------------------------------------------------------------- +# HugePage helpers (standalone, reusable outside of BaseStorageAllocator) +# --------------------------------------------------------------------------- +DEFAULT_HUGE_PAGE_SIZE = 2 * 1024 * 1024 # 2 MiB +DEFAULT_HUGETLBFS_DIR = "/mnt/hugepages" + +_MAP_SHARED = 0x01 +_MAP_PRIVATE = 0x02 +_MAP_ANONYMOUS = 0x20 +_MAP_HUGETLB = 0x40000 +_MAP_HUGE_SHIFT = 26 +_PROT_READ = 0x1 +_PROT_WRITE = 0x2 +_MAP_FAILED = ctypes.c_void_p(-1).value # (void*)-1 +_HUGETLBFS_MAGIC = 0x958458F6 + +_libc = ctypes.CDLL("libc.so.6", use_errno=True) +_libc.mmap.restype = ctypes.c_void_p +_libc.mmap.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_long, +] +_libc.munmap.restype = ctypes.c_int +_libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] +_libc.ftruncate.restype = ctypes.c_int +_libc.ftruncate.argtypes = [ctypes.c_int, ctypes.c_long] +_libc.close.restype = ctypes.c_int +_libc.close.argtypes = [ctypes.c_int] + + +class _StatFS(ctypes.Structure): + _fields_ = [ + ("f_type", ctypes.c_long), + ("f_bsize", ctypes.c_long), + ("f_blocks", ctypes.c_ulong), + ("f_bfree", ctypes.c_ulong), + ("f_bavail", ctypes.c_ulong), + ("f_files", ctypes.c_ulong), + ("f_ffree", ctypes.c_ulong), + ("f_fsid", ctypes.c_int * 2), + ("f_namelen", ctypes.c_long), + ("f_frsize", ctypes.c_long), + ("f_flags", ctypes.c_long), + ("f_spare", ctypes.c_long * 4), + ] + + +_libc.statfs.restype = ctypes.c_int +_libc.statfs.argtypes = [ctypes.c_char_p, ctypes.POINTER(_StatFS)] + +@dataclass +class _HugePageMapping: + finalizer: Any + aligned: int + path: str | None = None + + +_live_hugepage_mappings: "Dict[int, _HugePageMapping]" = {} + + +@dataclass(frozen=True) +class HugePageTensorHandle: + path: str + num_elements: int + dtype: torch.dtype + aligned: int + + def get_tensor(self) -> torch.Tensor: + return _materialize_shareable_hugepage_tensor( + path=self.path, + num_elements=self.num_elements, + dtype=self.dtype, + aligned=self.aligned, + ) + + +def _cleanup_hugepage_mapping(addr: int, aligned: int, fd: int, + path: str | None, data_ptr: int) -> None: + _munmap_huge(addr, aligned) + if fd >= 0: + _libc.close(fd) + if path is not None: + try: + os.unlink(path) + except FileNotFoundError: + pass + _live_hugepage_mappings.pop(data_ptr, None) + + +def _cleanup_hugepage_mmap(mm: mmap.mmap, path: str | None, data_ptr: int) -> None: + try: + mm.close() + finally: + if path is not None: + try: + os.unlink(path) + except FileNotFoundError: + pass + _live_hugepage_mappings.pop(data_ptr, None) + + +def _statfs_type(path: str) -> int: + statfs_buf = _StatFS() + if _libc.statfs(path.encode(), ctypes.byref(statfs_buf)) != 0: + err = ctypes.get_errno() + raise RuntimeError( + f"HugePage: statfs({path}) failed: {os.strerror(err)} (errno={err})" + ) + return int(statfs_buf.f_type) + + +def _ensure_hugetlbfs_mount(mnt_dir: str) -> None: + if not os.path.isdir(mnt_dir): + raise RuntimeError(f"HugePage: hugetlbfs directory does not exist: {mnt_dir}") + fs_type = _statfs_type(mnt_dir) + if fs_type != _HUGETLBFS_MAGIC: + raise RuntimeError( + f"HugePage: {mnt_dir} is not a hugetlbfs mount " + f"(f_type=0x{fs_type:x}, expected=0x{_HUGETLBFS_MAGIC:x})" + ) + + +def _create_hugetlbfs_file(aligned: int) -> tuple[str, int]: + mnt_dir = os.environ.get("FLEXKV_HUGETLBFS_DIR", DEFAULT_HUGETLBFS_DIR) + _ensure_hugetlbfs_mount(mnt_dir) + path = os.path.join(mnt_dir, f"flexkv_hugepage_{os.getpid()}_{id(object()):x}") + fd = os.open(path, os.O_CREAT | os.O_RDWR | os.O_EXCL, 0o600) + try: + ctypes.set_errno(0) + if _libc.ftruncate(fd, aligned) != 0: + err = ctypes.get_errno() + raise RuntimeError( + f"HugePage: ftruncate({path}, {aligned}) failed: " + f"{os.strerror(err)} (errno={err})" + ) + except Exception: + os.close(fd) + try: + os.unlink(path) + except FileNotFoundError: + pass + raise + return path, fd + + +def _wrap_mmap_tensor(mm: mmap.mmap, + aligned: int, + num_elements: int, + dtype: torch.dtype, + cleanup_path: str | None) -> torch.Tensor: + num_bytes = num_elements * dtype.itemsize + tensor = torch.frombuffer(mm, dtype=torch.uint8, count=num_bytes).view(dtype)[:num_elements] + ptr = tensor.data_ptr() + finalizer = weakref.finalize(tensor, _cleanup_hugepage_mmap, mm, cleanup_path, ptr) + _live_hugepage_mappings[ptr] = _HugePageMapping( + finalizer=finalizer, + aligned=aligned, + path=cleanup_path, + ) + return tensor + + +def _materialize_shareable_hugepage_tensor(path: str, + num_elements: int, + dtype: torch.dtype, + aligned: int) -> torch.Tensor: + fd = os.open(path, os.O_RDWR) + try: + mm = mmap.mmap( + fd, + aligned, + flags=mmap.MAP_SHARED, + prot=mmap.PROT_READ | mmap.PROT_WRITE, + ) + finally: + os.close(fd) + return _wrap_mmap_tensor(mm, aligned, num_elements, dtype, cleanup_path=None) + + +def materialize_worker_tensor(data: Union[torch.Tensor, HugePageTensorHandle]) -> torch.Tensor: + if isinstance(data, torch.Tensor): + return data + if isinstance(data, HugePageTensorHandle): + return data.get_tensor() + raise TypeError(f"Unsupported worker tensor type: {type(data)}") + + +def get_worker_hugepage_handle(tensor: torch.Tensor, + num_elements: int, + dtype: torch.dtype) -> HugePageTensorHandle | None: + mapping = _live_hugepage_mappings.get(tensor.data_ptr()) + if mapping is None or mapping.path is None: + return None + return HugePageTensorHandle( + path=mapping.path, + num_elements=num_elements, + dtype=dtype, + aligned=mapping.aligned, + ) + + +def _align_to_page(num_bytes: int, page_size_bytes: int) -> int: + """Round *num_bytes* up to the next multiple of *page_size_bytes*.""" + return (num_bytes + page_size_bytes - 1) & ~(page_size_bytes - 1) + + +def _read_hugepages_free(page_size_bytes: int) -> int: + """Return the number of free huge pages for *page_size_bytes*.""" + try: + size_kb = page_size_bytes // 1024 + cur_kb = 0 + free = 0 + with open("/proc/meminfo") as f: + for line in f: + if line.startswith("Hugepagesize:"): + cur_kb = int(line.split()[1]) + elif line.startswith("HugePages_Free:"): + free = int(line.split()[1]) + if cur_kb != size_kb: + return 0 # wrong page size pool + return free + except Exception: + return 0 + + +def _mmap_huge(num_bytes: int, page_size_bytes: int) -> Tuple[int, int, int]: + if num_bytes <= 0: + raise ValueError(f"HugePage: num_bytes must be > 0, got {num_bytes}") + if page_size_bytes <= 0 or (page_size_bytes & (page_size_bytes - 1)) != 0: + raise ValueError( + f"HugePage: page_size_bytes must be a power of two, got {page_size_bytes}" + ) + + aligned = _align_to_page(num_bytes, page_size_bytes) + page_shift = page_size_bytes.bit_length() - 1 + + # 1) Anonymous MAP_HUGETLB — no hugetlbfs mount needed. + free_pages = _read_hugepages_free(page_size_bytes) + if free_pages and aligned > free_pages * page_size_bytes: + flexkv_logger.warning( + f"HugePage: requested {aligned // page_size_bytes} pages " + f"({aligned / (1024**3):.3f} GiB) but only {free_pages} free " + f"(page_size={page_size_bytes // (1024*1024)} MiB). " + f"The kernel may fall back to regular pages or overcommit." + ) + + ctypes.set_errno(0) + huge_flags = _MAP_PRIVATE | _MAP_ANONYMOUS | _MAP_HUGETLB | (page_shift << _MAP_HUGE_SHIFT) + ret = _libc.mmap(None, aligned, _PROT_READ | _PROT_WRITE, huge_flags, -1, 0) + if ret is not None and ret != _MAP_FAILED: + return int(ret), aligned, -1 + + # 2) Fallback: file-backed hugetlbfs. Reject non-hugetlbfs mounts so we + # never silently succeed on regular 4 KiB pages. + fd = -1 + try: + path, fd = _create_hugetlbfs_file(aligned) + try: + os.unlink(path) + except OSError: + pass + + ctypes.set_errno(0) + ret = _libc.mmap( + None, + aligned, + _PROT_READ | _PROT_WRITE, + _MAP_SHARED, + fd, + 0, + ) + if ret is None or ret == _MAP_FAILED: + err = ctypes.get_errno() + raise RuntimeError( + f"HugePage: mmap({path}, {aligned}) failed: " + f"{os.strerror(err)} (errno={err})" + ) + return int(ret), aligned, fd + except Exception: # noqa: BLE001 + if fd >= 0: + os.close(fd) + raise + + +def _munmap_huge(addr: int, length: int) -> None: + if _libc.munmap(ctypes.c_void_p(addr), length) != 0: + err = ctypes.get_errno() + flexkv_logger.warning( + f"HugePage: munmap({hex(addr)}, {length}) failed: " + f"{os.strerror(err)} (errno={err})" + ) + + +def alloc_hugepage_tensor(num_elements: int, + dtype: torch.dtype, + page_size_bytes: int = DEFAULT_HUGE_PAGE_SIZE, + shareable: bool = False) -> torch.Tensor: + """Allocate ``num_elements`` values of ``dtype`` on HugePage-backed memory. + + Returns a 1-D ``torch.Tensor`` that zero-copy wraps the mmap'd region. + The tensor's ``data_ptr()`` can be passed to ``cudaHostRegister`` or to + other RDMA-style registration APIs. + + Use ``free_hugepage_tensor(tensor)`` to explicitly release the mapping; + otherwise it will be released when the tensor (and all references to it) + are garbage-collected. + + Raises: + RuntimeError: if the mmap fails or if the resulting VMA is not backed + by huge pages of the requested size (i.e. no silent fallback). + """ + num_bytes = num_elements * dtype.itemsize + + if shareable: + aligned = _align_to_page(num_bytes, page_size_bytes) + path, fd = _create_hugetlbfs_file(aligned) + try: + mm = mmap.mmap( + fd, + aligned, + flags=mmap.MAP_SHARED, + prot=mmap.PROT_READ | mmap.PROT_WRITE, + ) + finally: + os.close(fd) + return _wrap_mmap_tensor(mm, aligned, num_elements, dtype, cleanup_path=path) + + addr, aligned, fd = _mmap_huge(num_bytes, page_size_bytes) + + # Zero-copy wrap: build a numpy uint8 array pointing at the raw memory, + # then view it as the requested dtype via ``torch.frombuffer``. The numpy + # array keeps a reference (``_base_keepalive``) so Python's GC cannot free + # the underlying bytes while the tensor is still live. + buf_type = (ctypes.c_uint8 * aligned) + raw = buf_type.from_address(addr) + np_arr = np.frombuffer(raw, dtype=np.uint8, count=num_bytes) + tensor = torch.frombuffer(np_arr, dtype=torch.uint8, count=num_bytes) \ + .view(dtype)[:num_elements] + + ptr = tensor.data_ptr() + finalizer = weakref.finalize(tensor, _cleanup_hugepage_mapping, + addr, aligned, fd, None, ptr) + _live_hugepage_mappings[ptr] = _HugePageMapping( + finalizer=finalizer, + aligned=aligned, + path=None, + ) + return tensor + + +def free_hugepage_tensor(tensor: torch.Tensor) -> None: + """Release the HugePage mapping previously created by :func:`alloc_hugepage_tensor`. + + No-op if ``tensor`` is not known to be HugePage-backed. + The caller must ensure no other references to the tensor's memory remain + in active use (e.g. ``cudaHostUnregister`` should be called first, and + any Python reference to ``tensor`` should be dropped after this call). + """ + if not isinstance(tensor, torch.Tensor): + return + ptr = tensor.data_ptr() + mapping = _live_hugepage_mappings.pop(ptr, None) + if mapping is None: + return + mapping.finalizer() + + +class HugePageAllocator(BaseStorageAllocator): + """CPU KV-cache allocator backed by hugetlbfs HugePages. + + Unlike :class:`CPUAllocator` (which relies on ``torch.empty`` on top of 4KiB + pages), this allocator maps a hugetlbfs file and wraps the resulting buffer + into a 1-D ``torch.Tensor`` (zero-copy). + + Benefits: + * Reduced TLB pressure for large KV caches (2MiB / 1GiB pages). + * The returned tensor's ``data_ptr()`` can still be passed to + ``cudaHostRegister`` for pinned H2D/D2H transfers. + + Prerequisites: + * The kernel must have huge pages reserved, e.g. for 2MiB pages:: + + echo N > /proc/sys/vm/nr_hugepages + # or, per-size on recent kernels: + echo N > /sys/kernel/mm/hugepages/hugepages-2048kB/nr_hugepages + + For 1GiB pages the kernel usually needs ``hugepagesz=1G`` at boot + and a corresponding ``hugepages=N`` reservation. + + kwargs: + page_size_bytes (int): Huge page size in bytes. Supported values: + ``2 * 1024 * 1024`` (default) or ``1024 * 1024 * 1024``. + """ + + @classmethod + def allocate(cls, + layout: KVCacheLayout, + dtype: torch.dtype, + **kwargs: Any) -> StorageHandle: + page_size_bytes = int(kwargs.get("page_size_bytes", DEFAULT_HUGE_PAGE_SIZE)) + total_elements = layout.get_total_elements() + element_size = dtype.itemsize + + flexkv_logger.info( + f"HugePage allocate total_size: " + f"{total_elements * element_size / 1024 / 1024 / 1024:.4f} GB " + f"(page_size={page_size_bytes // (1024 * 1024)}MiB)" + ) + try: + physical_tensor = alloc_hugepage_tensor( + total_elements, + dtype, + page_size_bytes, + shareable=True, + ) + except Exception as e: # noqa: BLE001 + flexkv_logger.warning( + f"HugePage allocation failed ({e}); falling back to regular CPU memory." + ) + return CPUAllocator.allocate(layout, dtype, **kwargs) + worker_data = get_worker_hugepage_handle(physical_tensor, total_elements, dtype) + return StorageHandle( + handle_type=AccessHandleType.TENSOR, + data=physical_tensor, + kv_layout=layout, + dtype=dtype, + worker_data=worker_data, + ) + + @classmethod + def free(cls, accessible_handle: StorageHandle) -> None: + if accessible_handle.handle_type != AccessHandleType.TENSOR: + return + tensor = accessible_handle.data + if isinstance(tensor, torch.Tensor): + free_hugepage_tensor(tensor) + + @classmethod + def from_raw_data(cls, + data: torch.Tensor, # type: ignore + layout: KVCacheLayout, + dtype: torch.dtype, + **kwargs: Any) -> StorageHandle: + # We assume the caller already backs ``data`` with huge pages (or does + # not care). We do not take ownership of any mmap here. + return StorageHandle( + handle_type=AccessHandleType.TENSOR, + data=data, + kv_layout=layout, + dtype=dtype, + ) + + class SSDAllocator(BaseStorageAllocator): @classmethod def allocate(cls, @@ -138,7 +599,7 @@ def allocate(cls, file_prefix = kwargs.get("file_prefix", "flexkv_ssd_cache") cfg_max_file_size_gb = kwargs.get("max_file_size_gb", -1) cfg_max_blocks_per_file = int(1e9) - + if cache_dir is None: raise ValueError("cache_dir is required for SSD allocator") if isinstance(cache_dir, str): @@ -172,6 +633,12 @@ def allocate(cls, real_file_size = num_blocks_per_file * block_size ssd_files: Dict[int, List[str]] = {} + total_num_files = num_files_per_device * num_ssd_devices + real_total_size = total_num_files * real_file_size + flexkv_logger.info(f"SSD allocator creating {total_num_files} files in {cache_dir}, " + f"each file {real_file_size/1024/1024/1024:.2f} GB, " + f"total {real_total_size/1024/1024/1024:.2f} GB") + file_count = 0 for i in range(num_ssd_devices): ssd_files[i] = [] for j in range(num_files_per_device): @@ -179,9 +646,13 @@ def allocate(cls, with open(file_path, "wb+", buffering=0) as file: cls._create_file(file, real_file_size) ssd_files[i].append(file_path) - total_num_files = num_files_per_device * num_ssd_devices - real_total_size = total_num_files * real_file_size - flexkv_logger.info(f"SSD allocator create total {total_num_files} files in {cache_dir}, " + file_count += 1 + if file_count % max(1, total_num_files // 10) == 0 or file_count == total_num_files: + flexkv_logger.info( + f"SSD allocator progress: {file_count}/{total_num_files} files created " + f"({file_count * 100 // total_num_files}%)" + ) + flexkv_logger.info(f"SSD allocator done: {total_num_files} files in {cache_dir}, " f"each file has {real_file_size/1024/1024/1024:.2f} GB, total size {real_total_size/1024/1024/1024:.2f} GB") return StorageHandle( handle_type=AccessHandleType.FILE, diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 50fe069cc5..46eb8824ec 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -6,27 +6,42 @@ import hashlib from flexkv.common.config import ModelConfig, CacheConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import StorageHandle, KVCacheLayout, KVCacheLayoutType from flexkv.common.transfer import DeviceType -from flexkv.storage.allocator import CPUAllocator, GPUAllocator, SSDAllocator, RemoteAllocator +from flexkv.storage.allocator import ( + CPUAllocator, + GPUAllocator, + HugePageAllocator, + RemoteAllocator, + SSDAllocator, +) class StorageEngine: + def _cpu_allocator(self) -> type[CPUAllocator] | type[HugePageAllocator]: + if self._cache_config.use_hugepage_cpu_buffer: + return HugePageAllocator + return CPUAllocator + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): """Initialize storage engine""" self._storage_handles: Dict[Tuple[DeviceType, int], StorageHandle] = {} + self._indexer_storage_handles: Dict[Tuple[DeviceType, int], StorageHandle] = {} self._model_config = model_config self._cache_config = cache_config + self._indexer_config = cache_config.indexer + if self._cache_config.enable_cpu: self._cpu_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_cpu_blocks, tokens_per_block=self._cache_config.tokens_per_block, - num_head=self._model_config.num_kv_heads, + num_head=self._model_config.num_kv_heads_per_node, head_size=self._model_config.head_size, is_mla=self._model_config.use_mla ) @@ -35,15 +50,35 @@ def __init__(self, layout=self._cpu_layout, dtype=self._model_config.dtype, ) + if self._indexer_config is not None: + # Indexer maps 1:1 with main KV blocks (each block = 1 page), + # so indexer num_blocks equals main KV num_blocks and + # tokens_per_block is 1 (one indexer entry per page). + indexer_cpu_layout = KVCacheLayout( + type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, + num_layer=self._model_config.num_layers_per_pp_stage, + num_block=self._cache_config.num_cpu_blocks, + tokens_per_block=1, + num_head=self._indexer_config.num_kv_heads, + head_size=self._indexer_config.head_size, + is_mla=True + ) + self.allocate( + device_type=DeviceType.CPU, + layout=indexer_cpu_layout, + dtype=self._indexer_config.dtype, + is_indexer=True, + ) + if self._cache_config.enable_ssd: if not GLOBAL_CONFIG_FROM_ENV.ssd_layout_type == self._cpu_layout.type: raise ValueError(f"SSD layout type must be the same as CPU layout type: {self._cpu_layout.type}") self._ssd_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.ssd_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_ssd_blocks, tokens_per_block=self._cache_config.tokens_per_block, - num_head=self._model_config.num_kv_heads, + num_head=self._model_config.num_kv_heads_per_node, head_size=self._model_config.head_size, is_mla=self._model_config.use_mla ) @@ -54,15 +89,34 @@ def __init__(self, cache_dir=self._cache_config.ssd_cache_dir, max_file_size_gb=GLOBAL_CONFIG_FROM_ENV.max_file_size_gb ) + if self._indexer_config is not None: + indexer_ssd_layout = KVCacheLayout( + type=GLOBAL_CONFIG_FROM_ENV.ssd_layout_type, + num_layer=self._model_config.num_layers_per_pp_stage, + num_block=self._cache_config.num_ssd_blocks, + tokens_per_block=1, + num_head=self._indexer_config.num_kv_heads, + head_size=self._indexer_config.head_size, + is_mla=True + ) + self.allocate( + device_type=DeviceType.SSD, + layout=indexer_ssd_layout, + dtype=self._indexer_config.dtype, + cache_dir=self._cache_config.ssd_cache_dir, + max_file_size_gb=GLOBAL_CONFIG_FROM_ENV.max_file_size_gb, + is_indexer=True, + ) + if self._cache_config.enable_remote: if not GLOBAL_CONFIG_FROM_ENV.remote_layout_type == self._cpu_layout.type: raise ValueError(f"Remote layout type must be the same as CPU layout type: {self._cpu_layout.type}") self._remote_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.remote_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_remote_blocks, tokens_per_block=self._cache_config.tokens_per_block, - num_head=self._model_config.num_kv_heads, + num_head=self._model_config.num_kv_heads_per_node, head_size=self._model_config.head_size, is_mla=self._model_config.use_mla ) @@ -73,12 +127,43 @@ def __init__(self, file_path=self._cache_config.remote_cache_path, remote_config_custom = self._cache_config.remote_config_custom ) + if self._indexer_config is not None: + indexer_remote_layout = KVCacheLayout( + type=GLOBAL_CONFIG_FROM_ENV.remote_layout_type, + num_layer=self._model_config.num_layers_per_pp_stage, + num_block=self._cache_config.num_remote_blocks, + tokens_per_block=1, + num_head=self._indexer_config.num_kv_heads, + head_size=self._indexer_config.head_size, + is_mla=True + ) + indexer_remote_path = self._cache_config.remote_cache_path + if isinstance(indexer_remote_path, str): + indexer_remote_path = indexer_remote_path + "_indexer" + elif isinstance(indexer_remote_path, list): + indexer_remote_path = [p + "_indexer" for p in indexer_remote_path] + self.allocate( + device_type=DeviceType.REMOTE, + layout=indexer_remote_layout, + dtype=self._indexer_config.dtype, + file_path=indexer_remote_path, + remote_config_custom=self._cache_config.remote_config_custom, + is_indexer=True, + ) + + @property + def _has_indexer(self) -> bool: + """True when indexer is configured and CPU buffer is allocated.""" + return (DeviceType.CPU, 0) in self._indexer_storage_handles def register_gpu_blocks(self, gpu_blocks: List[TensorSharedHandle], gpu_layout: KVCacheLayout, device_id: int = 0, - dtype: torch.dtype = torch.float16) -> None: + dtype: torch.dtype = torch.float16, + indexer_gpu_blocks: Optional[List[TensorSharedHandle]] = None, + indexer_gpu_layout: Optional[KVCacheLayout] = None, + indexer_dtype: Optional[torch.dtype] = None) -> None: self.allocate( device_type=DeviceType.GPU, layout=gpu_layout, @@ -86,6 +171,29 @@ def register_gpu_blocks(self, device_id=device_id, raw_data=gpu_blocks ) + if indexer_gpu_blocks is not None: + # Indexer maps 1:1 with main KV blocks; validate consistency. + flexkv_logger.info( + f"[StorageEngine] Registering indexer GPU buffer: " + f"num_block={indexer_gpu_layout.num_block}, " + f"head_size={indexer_gpu_layout.head_size}, " + f"num_head={indexer_gpu_layout.num_head}, " + f"dtype={indexer_dtype}" + ) + if indexer_gpu_layout.num_block != gpu_layout.num_block: + flexkv_logger.warning( + f"[StorageEngine] Indexer GPU num_block mismatch: " + f"indexer_num_block={indexer_gpu_layout.num_block}, " + f"expected={gpu_layout.num_block} (1:1 with main KV blocks)" + ) + self.allocate( + device_type=DeviceType.GPU, + layout=indexer_gpu_layout, + dtype=indexer_dtype if indexer_dtype is not None else dtype, + device_id=device_id, + raw_data=indexer_gpu_blocks, + is_indexer=True, + ) def allocate(self, device_type: DeviceType, @@ -93,43 +201,64 @@ def allocate(self, dtype: torch.dtype, device_id: int = 0, raw_data: Optional[Union[List[TensorSharedHandle], List[str], str]] = None, + is_indexer: bool = False, **kwargs: Any) -> bool: """ - Create and add an allocator for specified device + Create and add an allocator for specified device. Args: - device_type: Type of the device (CPU, GPU, etc.) - layout: Layout of kv cache - dtype: Data type of tensors - device_id: Device ID (default 0) - raw_data: Optional raw data to be used for initialization + device_type: Type of the device (CPU, GPU, SSD, REMOTE). + layout: Layout of kv cache. + dtype: Data type of tensors. + device_id: Device ID (default 0). + raw_data: Optional raw data to be used for initialization. + The expected type depends on ``device_type``: + + * ``DeviceType.CPU`` – ``torch.Tensor`` + * ``DeviceType.GPU`` – ``List[TensorSharedHandle]`` or + ``List[torch.Tensor]`` + * ``DeviceType.SSD`` – ``str`` or ``List[str]`` + (file path(s) to existing SSD cache files) + * ``DeviceType.REMOTE`` – ``str`` or ``List[str]`` + (remote file path(s)) + is_indexer: Whether this allocation is for indexer storage. + When True, SSD file_prefix uses 'indexer_' tag + (e.g. ``flexkv_indexer_ssdcache_``). **kwargs: Additional arguments for specific allocator types - (e.g., pin_memory for CPU, file_path for Disk) + (e.g., pin_memory for CPU, file_path for Disk). Returns: - bool: True if allocator created successfully, False otherwise + bool: True if allocator created successfully, False if already exists. """ + storage_handles = self._indexer_storage_handles if is_indexer else self._storage_handles key = (device_type, device_id) - if key in self._storage_handles: + if key in storage_handles: return False storage_handle: StorageHandle if device_type == DeviceType.CPU: + cpu_allocator = self._cpu_allocator() pin_memory = kwargs.get('pin_memory', False) + page_size_bytes = kwargs.get( + 'page_size_bytes', + self._cache_config.hugepage_size_bytes, + ) if raw_data is not None: assert isinstance(raw_data, torch.Tensor), \ "raw_data for CPUAllocator must be Tensor" - storage_handle = CPUAllocator.from_raw_data( + storage_handle = cpu_allocator.from_raw_data( data=raw_data, # type: ignore layout=layout, dtype=dtype, - pin_memory=pin_memory + pin_memory=pin_memory, + page_size_bytes=page_size_bytes, ) else: - storage_handle = CPUAllocator.allocate( + storage_handle = cpu_allocator.allocate( layout=layout, dtype=dtype, - pin_memory=pin_memory + pin_memory=pin_memory, + page_size_bytes=page_size_bytes, ) elif device_type == DeviceType.GPU: num_chunks = kwargs.get('num_chunks', 1) @@ -137,7 +266,7 @@ def allocate(self, assert isinstance(raw_data, list) and \ (all(isinstance(x, TensorSharedHandle) for x in raw_data) or \ all(isinstance(x, torch.Tensor) for x in raw_data)), \ - "raw_data for GPUAllocator must be List[TensorWrapper] or List[Tensor]" + "raw_data for GPUAllocator must be List[TensorSharedHandle] or List[Tensor]" storage_handle = GPUAllocator.from_raw_data( data=raw_data, # type: ignore layout=layout, @@ -169,7 +298,8 @@ def allocate(self, server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port hash_value = hashlib.md5(server_recv_port.encode()).hexdigest() rand_suffix = f"{hash_value[:6]}" - file_prefix = f"flexkv_ssdcache_{rand_suffix}" + ssd_prefix_tag = "indexer_" if is_indexer else "" + file_prefix = f"flexkv_{ssd_prefix_tag}ssdcache_{rand_suffix}" storage_handle = SSDAllocator.allocate( layout=layout, dtype=dtype, @@ -206,28 +336,41 @@ def allocate(self, ) else: raise ValueError(f"Unsupported device type: {device_type}") - self._storage_handles[key] = storage_handle + storage_handles[key] = storage_handle return True def get_storage_handle(self, device_type: DeviceType, - device_id: int = 0) -> StorageHandle: + device_id: int = 0, + is_indexer: bool = False) -> StorageHandle: """ - Get accessible handle for specified blocks + Get accessible handle for specified blocks. Args: - device_type: Type of the device to get handle from - device_id: Device ID + device_type: Type of the device to get handle from. + device_id: Device ID. + is_indexer: Whether to get indexer storage handle. """ + storage_handles = self._indexer_storage_handles if is_indexer else self._storage_handles key = (device_type, device_id) - if key not in self._storage_handles: - raise ValueError(f"Storage handle not found for device type: {device_type}, device id: {device_id}") - - storage_handle = self._storage_handles[key] - return storage_handle + if key not in storage_handles: + raise ValueError( + f"Storage handle not found for device type: {device_type}, " + f"device id: {device_id}, is_indexer: {is_indexer}" + ) + return storage_handles[key] def has_storage_handle(self, device_type: DeviceType, - device_id: int = 0) -> bool: - """Check if storage handle exists for given device type and id""" - return (device_type, device_id) in self._storage_handles + device_id: int = 0, + is_indexer: bool = False) -> bool: + """ + Check if storage handle exists for given device type and id. + + Args: + device_type: Type of the device. + device_id: Device ID. + is_indexer: Whether to check indexer storage handle. + """ + storage_handles = self._indexer_storage_handles if is_indexer else self._storage_handles + return (device_type, device_id) in storage_handles diff --git a/flexkv/transfer/host_buffer.py b/flexkv/transfer/host_buffer.py new file mode 100644 index 0000000000..bc2b9ea44d --- /dev/null +++ b/flexkv/transfer/host_buffer.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import ctypes +from dataclasses import dataclass +from typing import Optional + +import torch + +from flexkv.common.debug import flexkv_logger +from flexkv.storage.allocator import alloc_hugepage_tensor, free_hugepage_tensor + +_cudart = None +_cudart_load_error: Optional[OSError] = None + + +def _get_cudart(): + global _cudart + global _cudart_load_error + + if _cudart is None and _cudart_load_error is None: + try: + _cudart = ctypes.CDLL("libcudart.so") + except OSError as e: + _cudart_load_error = e + + if _cudart is None: + raise RuntimeError(f"libcudart.so is unavailable: {_cudart_load_error}") + return _cudart + + +def cuda_host_registration_available() -> bool: + try: + _get_cudart() + except RuntimeError: + return False + return True + + +def cudaHostRegister(tensor: torch.Tensor) -> None: + cudart = _get_cudart() + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + ret = cudart.cudaHostRegister(ctypes.c_void_p(ptr), ctypes.c_size_t(size), 1) + if ret != 0: + raise RuntimeError(f"cudaHostRegister failed with error code {ret}") + + +def cudaHostUnregister(tensor: torch.Tensor) -> None: + cudart = _get_cudart() + ptr = tensor.data_ptr() + ret = cudart.cudaHostUnregister(ctypes.c_void_p(ptr)) + if ret != 0: + raise RuntimeError(f"cudaHostUnregister failed with error code {ret}") + + +@dataclass +class HostBufferHandle: + tensor: torch.Tensor + is_hugepage: bool = False + is_cuda_registered: bool = False + + def __post_init__(self) -> None: + if self.is_cuda_registered and not self.is_hugepage: + raise ValueError("CUDA-registered host buffer must be HugePage-backed") + + @classmethod + def pinned(cls, tensor: torch.Tensor) -> HostBufferHandle: + return cls(tensor=tensor) + + @classmethod + def hugepage(cls, tensor: torch.Tensor) -> HostBufferHandle: + return cls(tensor=tensor, is_hugepage=True, is_cuda_registered=True) + + def release(self) -> None: + if not self.is_hugepage: + return + + if self.is_cuda_registered: + try: + cudaHostUnregister(self.tensor) + except Exception as e: + flexkv_logger.warning( + f"[host_buffer] release hugepage host buffer: cuda unregister failed ({e})" + ) + self.is_cuda_registered = False + + free_hugepage_tensor(self.tensor) + flexkv_logger.info("[host_buffer] release hugepage host buffer") + self.is_hugepage = False + + +def _allocate_pinned_cpu_tensor(num_elements: int, dtype: torch.dtype) -> HostBufferHandle: + return HostBufferHandle.pinned( + torch.empty( + num_elements, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + ) + + +def _fallback_to_pinned( + num_elements: int, + dtype: torch.dtype, + reason: Exception, +) -> HostBufferHandle: + flexkv_logger.warning( + f"[host_buffer] fallback to pinned host buffer ({reason})" + ) + return _allocate_pinned_cpu_tensor(num_elements, dtype) + + +def allocate_host_buffer( + num_elements: int, + dtype: torch.dtype, + use_hugepage: bool, + hugepage_size_bytes: int, +) -> HostBufferHandle: + if not use_hugepage: + return _allocate_pinned_cpu_tensor(num_elements, dtype) + + flexkv_logger.info("[host_buffer] attempt hugepage host buffer") + + hugepage_buf = None + try: + hugepage_buf = alloc_hugepage_tensor( + num_elements=num_elements, + dtype=dtype, + page_size_bytes=hugepage_size_bytes, + ) + cudaHostRegister(hugepage_buf) + except Exception as e: + if hugepage_buf is not None: + free_hugepage_tensor(hugepage_buf) + return _fallback_to_pinned(num_elements, dtype, e) + + flexkv_logger.info( + f"[host_buffer] hugepage host buffer ready: " + f"{hugepage_buf.numel() * hugepage_buf.element_size() / (1024 ** 3):.3f} GB" + ) + return HostBufferHandle.hugepage(hugepage_buf) diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py new file mode 100644 index 0000000000..e3161d9cf1 --- /dev/null +++ b/flexkv/transfer/layerwise.py @@ -0,0 +1,589 @@ +import time +import os +import socket +import struct +from torch.multiprocessing import Queue as MPQueue +from multiprocessing.connection import Connection +from typing import List, Any, Dict, Union, Optional, Tuple + +import torch + +from flexkv.c_ext import LayerwiseTransferGroup +from flexkv.common.debug import flexkv_logger +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ModelConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.storage.allocator import HugePageTensorHandle, materialize_worker_tensor + +from flexkv.transfer.worker_op import WorkerLayerwiseTransferOp +from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister + + +def build_layerwise_eventfd_socket_path( + pp_rank: int, + dp_rank: int, + pp_size: int = 1, + dp_size: int = 1, +) -> str: + """Construct the LayerwiseWorker's UDS socket path. + + Disambiguated by ``(pp_rank, dp_rank)`` so multiple PP stages and DP + replicas on the same host each get their own endpoint. + """ + base = os.environ.get( + 'FLEXKV_LAYERWISE_EVENTFD_SOCKET', + '/tmp/flexkv_layerwise_eventfd.sock', + ) + suffix = "" + if pp_size > 1: + suffix += f"_pp{pp_rank}" + if dp_size > 1: + suffix += f"_dp{dp_rank}" + if not suffix: + return base + root, ext = os.path.splitext(base) + return f"{root}{suffix}{ext}" + + +def _recv_fds(sock: socket.socket, num_fds: int) -> Tuple[List[int], bytes]: + """Receive multiple fds + extra_data via Unix domain socket (SCM_RIGHTS).""" + data_buf = bytearray(256) + anc_buf_size = socket.CMSG_SPACE(num_fds * struct.calcsize("i")) + + nbytes, ancdata, flags, addr = sock.recvmsg_into([data_buf], anc_buf_size, 0) + data = bytes(data_buf[:nbytes]) + + fds = [] + for level, ctype, cdata in ancdata: + if level == socket.SOL_SOCKET and ctype == socket.SCM_RIGHTS: + num_received = len(cdata) // struct.calcsize("i") + fds = list(struct.unpack(f"{num_received}i", cdata[:num_received * struct.calcsize("i")])) + break + if not fds: + raise RuntimeError("did not receive fds via SCM_RIGHTS") + return fds, data + +class LayerwiseTransferWorker(TransferWorkerBase): + def __init__(self, + worker_id: int, + transfer_conn: Connection, + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, + gpu_blocks: List[List[TensorSharedHandle]], + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], + ssd_files: Dict[int, List[str]], + gpu_kv_layouts: List[KVCacheLayout], + cpu_kv_layout: KVCacheLayout, + ssd_kv_layout: KVCacheLayout, + dtype: torch.dtype, + tp_group_size: int, + layerwise_eventfd_socket: str, + num_blocks_per_file: int, + use_ce_transfer_h2d: bool = False, + use_ce_transfer_d2h: bool = False, + h2d_cta_num: int = 4, + d2h_cta_num: int = 4, + enable_eventfd: bool = True, + indexer_gpu_blocks: Optional[List[List[TensorSharedHandle]]] = None, + indexer_cpu_blocks: Optional[Union[torch.Tensor, HugePageTensorHandle]] = None, + indexer_gpu_kv_layouts: Optional[List[KVCacheLayout]] = None, + indexer_cpu_kv_layout: Optional[KVCacheLayout] = None, + indexer_dtype: Optional[torch.dtype] = None, + indexer_ssd_files: Optional[Dict[int, List[str]]] = None, + indexer_ssd_kv_layout: Optional[KVCacheLayout] = None, + indexer_num_blocks_per_file: int = 0) -> None: + flexkv_logger.debug( + f"[LayerwiseWorker] __init__ started: worker_id={worker_id}, " + f"tp_group_size={tp_group_size}, " + f"enable_eventfd={enable_eventfd}, " + f"num_gpu_blocks={[len(b) for b in gpu_blocks]}") + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" + cpu_blocks = materialize_worker_tensor(cpu_blocks) + imported_gpu_blocks = [] + for handles_in_one_gpu in gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_gpu_blocks.append(blocks_in_one_gpu) + self.gpu_blocks = imported_gpu_blocks + self.dtype = dtype # note this should be quantized data type + self.is_mla = gpu_kv_layouts[0].is_mla + + self.num_gpus = len(self.gpu_blocks) + self.tp_group_size = tp_group_size + # Pre-computed UDS socket path. Both ends (this worker and the + # sglang connector) derive the path from the same ModelConfig + # fields (pp_rank / dp_rank / node_rank / is_multinode_tp), so no + # env-var plumbing between processes is required. + self.layerwise_eventfd_socket = layerwise_eventfd_socket + + # initialize GPU storage + self.num_layers = gpu_kv_layouts[0].num_layer + # here the chunk size doesn't include the layer info + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_block_strides_in_bytes = [gpu_kv_layout.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_layer_strides_in_bytes = [gpu_kv_layout.get_layer_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + + num_blocks_first_gpu = len(imported_gpu_blocks[0]) if imported_gpu_blocks else 0 + if num_blocks_first_gpu == 1: + self.gpu_block_type_ = 1 # TRTLLM + elif num_blocks_first_gpu == self.num_layers: + self.gpu_block_type_ = 0 # VLLM + elif num_blocks_first_gpu == self.num_layers * 2: + self.gpu_block_type_ = 2 # SGLANG + else: + raise ValueError(f"Invalid GPU block type: {num_blocks_first_gpu}") + + flexkv_logger.debug(f"[LayerwiseWorker] About to receive eventfds, enable_eventfd={enable_eventfd}") + if enable_eventfd: + layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) + else: + layer_eventfds_tensor = torch.empty(0, dtype=torch.int32) + flexkv_logger.debug(f"[LayerwiseWorker] Eventfds received, tensor shape={layer_eventfds_tensor.shape}") + + # initialize CPU storage + flexkv_logger.info(f"[LayerwiseWorker] Pinning CPU Memory: " + f"{cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") + cudaHostRegister(cpu_blocks) + flexkv_logger.debug("[LayerwiseWorker] CPU memory pinned successfully") + self.cpu_blocks = cpu_blocks + + self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize + self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize + # Full CPU strides (for SSD->CPU, which transfers all TP ranks' data) + self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize + self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize + # TP-divided CPU strides (for CPU->GPU, each rank reads its own portion) + if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: + cpu_kv_layout_tp = cpu_kv_layout.div_head(self.tp_group_size) + else: + cpu_kv_layout_tp = cpu_kv_layout + self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size + self.h2d_cpu_kv_stride_in_bytes = cpu_kv_layout_tp.get_kv_stride() * self.dtype.itemsize + self.h2d_cpu_layer_stride_in_bytes = cpu_kv_layout_tp.get_layer_stride() * self.dtype.itemsize + + self.use_ce_transfer_h2d = use_ce_transfer_h2d + self.use_ce_transfer_d2h = use_ce_transfer_d2h + self.h2d_cta_num = h2d_cta_num + self.d2h_cta_num = d2h_cta_num + + # initialize SSD storage + self.enable_ssd = len(ssd_files) > 0 + self.ssd_files = ssd_files + if self.enable_ssd: + self.num_blocks_per_file = num_blocks_per_file + self.num_files = sum(len(file_list) for file_list in ssd_files.values()) + self.round_robin = 1 + + ssd_kv_layout_per_file = ssd_kv_layout.div_block(self.num_files, padding=True) + self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize + self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize + self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize + else: + self.num_blocks_per_file = 0 + self.round_robin = 1 + self.ssd_kv_stride_in_bytes = 0 + self.ssd_layer_stride_in_bytes = 0 + self.ssd_block_stride_in_bytes = 0 + + gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) + gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) + gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) + gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) + + # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers + flexkv_logger.debug("[LayerwiseWorker] Creating LayerwiseTransferGroup...") + + # Initialize indexer fuse support + self.enable_indexer = (indexer_gpu_blocks is not None and indexer_cpu_blocks is not None) + indexer_constructor_kwargs = {} + if self.enable_indexer: + assert indexer_gpu_kv_layouts is not None + assert indexer_cpu_kv_layout is not None + assert indexer_dtype is not None + indexer_cpu_blocks = materialize_worker_tensor(indexer_cpu_blocks) + + # Import indexer GPU tensor handles + imported_indexer_gpu_blocks = [] + for handles_in_one_gpu in indexer_gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_indexer_gpu_blocks.append(blocks_in_one_gpu) + + # Pin indexer CPU memory + flexkv_logger.info( + f"[LayerwiseWorker] Pinning indexer CPU Memory: " + f"{indexer_cpu_blocks.numel() * indexer_cpu_blocks.element_size() / (1024 ** 3):.4f} GB") + cudaHostRegister(indexer_cpu_blocks) + + # Compute indexer GPU stride tensors + indexer_gpu_kv_strides = [layout.get_kv_stride() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + indexer_gpu_block_strides = [layout.get_block_stride() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + indexer_gpu_layer_strides = [layout.get_layer_stride() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + indexer_gpu_chunk_sizes = [layout.get_chunk_size() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + + # Compute indexer CPU strides. + # Indexer is always is_mla=True (1 head, ReplicatedLinear weights), + # so all TP ranks hold identical data and no head-partitioning is needed. + # Therefore indexer has no tp_stride — cpu_startoff is always 0. + self.indexer_cpu_block_stride_in_bytes = indexer_cpu_kv_layout.get_block_stride() * indexer_dtype.itemsize + self.indexer_cpu_layer_stride_in_bytes = indexer_cpu_kv_layout.get_layer_stride() * indexer_dtype.itemsize + self.indexer_h2d_cpu_kv_stride_in_bytes = indexer_cpu_kv_layout.get_kv_stride() * indexer_dtype.itemsize + self.indexer_h2d_cpu_layer_stride_in_bytes = indexer_cpu_kv_layout.get_layer_stride() * indexer_dtype.itemsize + + self.indexer_gpu_blocks = imported_indexer_gpu_blocks + self.indexer_cpu_blocks = indexer_cpu_blocks + self.indexer_gpu_kv_strides_tensor = torch.tensor(indexer_gpu_kv_strides, dtype=torch.int64) + self.indexer_gpu_block_strides_tensor = torch.tensor(indexer_gpu_block_strides, dtype=torch.int64) + self.indexer_gpu_layer_strides_tensor = torch.tensor(indexer_gpu_layer_strides, dtype=torch.int64) + self.indexer_gpu_chunk_sizes_tensor = torch.tensor(indexer_gpu_chunk_sizes, dtype=torch.int64) + + flexkv_logger.info( + f"[LayerwiseWorker] Indexer fuse enabled: " + f"gpu_blocks={len(imported_indexer_gpu_blocks)}, " + f"cpu_size={indexer_cpu_blocks.numel() * indexer_cpu_blocks.element_size() / (1024 ** 2):.2f} MB, " + f"chunk_size={indexer_gpu_chunk_sizes[0]} bytes, " + f"cpu_block_stride={self.indexer_cpu_block_stride_in_bytes} bytes, " + f"cpu_layer_stride={self.indexer_cpu_layer_stride_in_bytes} bytes") + else: + self.indexer_cpu_block_stride_in_bytes = 0 + self.indexer_cpu_layer_stride_in_bytes = 0 + self.indexer_h2d_cpu_kv_stride_in_bytes = 0 + self.indexer_h2d_cpu_layer_stride_in_bytes = 0 + self.indexer_gpu_blocks = [] + self.indexer_cpu_blocks = torch.Tensor() + self.indexer_gpu_kv_strides_tensor = torch.empty(0, dtype=torch.int64) + self.indexer_gpu_block_strides_tensor = torch.empty(0, dtype=torch.int64) + self.indexer_gpu_layer_strides_tensor = torch.empty(0, dtype=torch.int64) + self.indexer_gpu_chunk_sizes_tensor = torch.empty(0, dtype=torch.int64) + + # Initialize indexer SSD support + self.enable_indexer_ssd = ( + self.enable_indexer and + indexer_ssd_files is not None and len(indexer_ssd_files) > 0 and + indexer_ssd_kv_layout is not None + ) + if self.enable_indexer_ssd: + assert indexer_dtype is not None + self.indexer_ssd_files = indexer_ssd_files + self.indexer_num_blocks_per_file = indexer_num_blocks_per_file + + indexer_ssd_kv_layout_per_file = indexer_ssd_kv_layout.div_block( + sum(len(fl) for fl in indexer_ssd_files.values()), padding=True) + self.indexer_ssd_kv_stride_in_bytes = indexer_ssd_kv_layout_per_file.get_kv_stride() * indexer_dtype.itemsize + self.indexer_ssd_layer_stride_in_bytes = indexer_ssd_kv_layout_per_file.get_layer_stride() * indexer_dtype.itemsize + self.indexer_cpu_chunk_size_in_bytes = indexer_cpu_kv_layout.get_chunk_size() * indexer_dtype.itemsize + + flexkv_logger.info( + f"[LayerwiseWorker] Indexer SSD fuse enabled: " + f"num_files={sum(len(fl) for fl in indexer_ssd_files.values())}, " + f"num_blocks_per_file={indexer_num_blocks_per_file}, " + f"ssd_kv_stride={self.indexer_ssd_kv_stride_in_bytes}, " + f"ssd_layer_stride={self.indexer_ssd_layer_stride_in_bytes}, " + f"cpu_chunk_size={self.indexer_cpu_chunk_size_in_bytes}") + else: + self.indexer_ssd_files = {} + self.indexer_num_blocks_per_file = 0 + self.indexer_ssd_kv_stride_in_bytes = 0 + self.indexer_ssd_layer_stride_in_bytes = 0 + self.indexer_cpu_chunk_size_in_bytes = 0 + + self.layerwise_transfer_group = LayerwiseTransferGroup( + self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, + self.num_layers, + gpu_kv_strides_tensor, gpu_block_strides_tensor, + gpu_layer_strides_tensor, gpu_chunk_sizes_tensor, + GLOBAL_CONFIG_FROM_ENV.iouring_entries, + GLOBAL_CONFIG_FROM_ENV.iouring_flags, + layer_eventfds_tensor, tp_group_size, + self.indexer_gpu_blocks, self.indexer_cpu_blocks, + self.indexer_gpu_kv_strides_tensor, self.indexer_gpu_block_strides_tensor, + self.indexer_gpu_layer_strides_tensor, self.indexer_gpu_chunk_sizes_tensor, + self.indexer_ssd_files) + flexkv_logger.info(f"[LayerwiseWorker] __init__ completed successfully, worker_id={worker_id}") + + def _receive_eventfds_from_sglang(self, tp_group_size: int, + max_retries: int = 180, + retry_interval: float = 1.0) -> torch.Tensor: + """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" + socket_path = self.layerwise_eventfd_socket + + def cleanup_socket(): + try: + if os.path.exists(socket_path): + os.unlink(socket_path) + except OSError: + pass + + cleanup_socket() + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + try: + server_sock.bind(socket_path) + # Use a larger backlog to accommodate client retries on failed connections + server_sock.listen(tp_group_size * 3) + os.chmod(socket_path, 0o777) + flexkv_logger.info( + f"[LayerwiseWorker] Eventfd server created: " + f"socket={socket_path}, waiting for {tp_group_size} connection(s)") + except Exception as e: + flexkv_logger.error( + f"[LayerwiseWorker] Failed to bind/listen on {socket_path}: {e}") + server_sock.close() + return torch.empty(0, dtype=torch.int32) + + # Use a per-connection timeout instead of a global one so that + # failed connections can be retried by the client without the server + # giving up too early. The total deadline is still bounded. + per_conn_timeout = 30 # seconds per accept() call + total_deadline = time.time() + max_retries * retry_interval + server_sock.settimeout(per_conn_timeout) + all_rank_eventfds: Dict[int, Dict[int, List[int]]] = {} + num_layers, num_counters = self.num_layers, 3 + conn_idx = 0 + + try: + # Keep accepting until we have eventfds from all ranks or deadline. + while len(all_rank_eventfds) < tp_group_size: + if time.time() > total_deadline: + flexkv_logger.error( + f"[LayerwiseWorker] Deadline exceeded on {socket_path}, " + f"received {len(all_rank_eventfds)}/{tp_group_size} ranks") + break + + remaining = total_deadline - time.time() + server_sock.settimeout(min(per_conn_timeout, max(remaining, 1))) + + try: + conn, _ = server_sock.accept() + conn_idx += 1 + flexkv_logger.info( + f"[LayerwiseWorker] Accepted connection " + f"{conn_idx} (registered {len(all_rank_eventfds)}/{tp_group_size}) " + f"on {socket_path}") + except socket.timeout: + flexkv_logger.warning( + f"[LayerwiseWorker] Timeout waiting for connection on {socket_path}, " + f"registered {len(all_rank_eventfds)}/{tp_group_size}, retrying...") + continue + + try: + with conn: + # Receive 16-byte metadata: tp_rank_per_node, tp_size_per_node, + # num_layers, num_counters + metadata = conn.recv(16) + if len(metadata) < 16: + flexkv_logger.error( + f"[LayerwiseWorker] Incomplete metadata on {socket_path}: " + f"expected 16 bytes, got {len(metadata)}") + continue + + rank_key, tp_size_per_node_recv, recv_num_layers, recv_num_counters = \ + struct.unpack("iiii", metadata[:16]) + + if not all_rank_eventfds: + num_layers, num_counters = recv_num_layers, recv_num_counters + + flexkv_logger.debug( + f"[LayerwiseWorker] Connection {conn_idx}: " + f"tp_rank_per_node={rank_key}, " + f"tp_size_per_node={tp_size_per_node_recv}, " + f"num_layers={recv_num_layers}, " + f"num_counters={recv_num_counters}") + + rank_eventfds = {} + for _ in range(recv_num_counters): + fds, extra_data = _recv_fds(conn, recv_num_layers) + counter_id = struct.unpack("i", extra_data[:4])[0] + rank_eventfds[counter_id] = fds + flexkv_logger.debug( + f"[LayerwiseWorker] Received counter_id={counter_id}, " + f"num_fds={len(fds)} from tp_rank_per_node={rank_key}") + + all_rank_eventfds[rank_key] = rank_eventfds + # Send ACK to client so it knows the fds were received + try: + conn.sendall(b"\x01") + except Exception: + pass + flexkv_logger.info( + f"[LayerwiseWorker] Received all eventfds from tp_rank_per_node={rank_key} " + f"on {socket_path}") + except Exception as e: + # Send NACK so client knows to retry + try: + conn.sendall(b"\x00") + except Exception: + pass + flexkv_logger.warning( + f"[LayerwiseWorker] Failed to receive eventfds from connection {conn_idx} " + f"on {socket_path}: {e}. " + f"Client will retry, continuing accept loop...") + continue + except Exception as e: + flexkv_logger.error( + f"[LayerwiseWorker] Fatal error in accept loop on {socket_path}: {e}") + finally: + server_sock.close() + cleanup_socket() + + if not all_rank_eventfds: + flexkv_logger.warning( + f"[LayerwiseWorker] No connections received on {socket_path}") + return torch.empty(0, dtype=torch.int32) + + # Build tensor: [num_counters, tp_size, num_layers] + eventfds_list = [] + for counter_id in range(num_counters): + for tp_rank in range(tp_group_size): + fds = all_rank_eventfds.get(tp_rank, {}).get(counter_id, [-1] * num_layers) + eventfds_list.extend(fds) + + tensor = torch.tensor(eventfds_list, dtype=torch.int32) + flexkv_logger.info( + f"[LayerwiseWorker] Eventfd setup complete: " + f"socket={socket_path}, tensor_shape={tensor.shape}, " + f"counters={num_counters}, tp_size_per_rank={tp_group_size}, layers={num_layers}" + ) + return tensor + + def _transfer_impl(self, + src_block_ids_h2d: torch.Tensor, + dst_block_ids_h2d: torch.Tensor, + src_block_ids_disk2h: Optional[torch.Tensor], + dst_block_ids_disk2h: Optional[torch.Tensor], + layer_granularity: int, + counter_id: int = 0, + indexer_src_block_ids: Optional[torch.Tensor] = None, + indexer_dst_block_ids: Optional[torch.Tensor] = None, + **kwargs: Any) -> None: + assert src_block_ids_h2d.dtype == torch.int64 + assert dst_block_ids_h2d.dtype == torch.int64 + assert len(src_block_ids_h2d) == len(dst_block_ids_h2d) + if src_block_ids_disk2h is not None: + assert src_block_ids_disk2h.dtype == torch.int64 + assert dst_block_ids_disk2h.dtype == torch.int64 + assert len(src_block_ids_disk2h) == len(dst_block_ids_disk2h) + + # Use unified layerwise transfer C++ interface + ssd_block_ids = src_block_ids_disk2h if src_block_ids_disk2h is not None else torch.empty(0, dtype=torch.int64) + cpu_block_ids_d2h = dst_block_ids_disk2h if dst_block_ids_disk2h is not None \ + else torch.empty(0, dtype=torch.int64) + + # Prepare indexer block_ids for fused transfer + indexer_gpu_block_id_tensor = torch.Tensor() + indexer_cpu_block_id_tensor = torch.Tensor() + if self.enable_indexer and indexer_dst_block_ids is not None and len(indexer_dst_block_ids) > 0: + indexer_gpu_block_id_tensor = indexer_dst_block_ids + indexer_cpu_block_id_tensor = indexer_src_block_ids + + # Prepare indexer SSD block_ids for fused DISK2H transfer + indexer_ssd_block_ids_tensor = torch.Tensor() + indexer_cpu_block_ids_d2h_tensor = torch.Tensor() + if self.enable_indexer_ssd and src_block_ids_disk2h is not None: + # Indexer SSD block_ids mirror main KV's DISK2H block_ids (1:1 mapping) + indexer_ssd_block_ids_tensor = ssd_block_ids + indexer_cpu_block_ids_d2h_tensor = cpu_block_ids_d2h + + self.layerwise_transfer_group.layerwise_transfer( + ssd_block_ids, + cpu_block_ids_d2h, + self.ssd_layer_stride_in_bytes, + self.ssd_kv_stride_in_bytes, + self.num_blocks_per_file, + self.round_robin, + 32, # num_threads_per_device + dst_block_ids_h2d, + src_block_ids_h2d, + self.cpu_kv_stride_in_bytes, + self.cpu_layer_stride_in_bytes, + self.cpu_block_stride_in_bytes, + self.cpu_chunk_size_in_bytes, + self.h2d_cpu_kv_stride_in_bytes, + self.h2d_cpu_layer_stride_in_bytes, + self.cpu_tp_stride_in_bytes, + self.h2d_cta_num, + self.use_ce_transfer_h2d, + self.num_layers, + layer_granularity, + self.is_mla, + counter_id, + indexer_gpu_block_id_tensor, + indexer_cpu_block_id_tensor, + self.indexer_cpu_block_stride_in_bytes, + self.indexer_cpu_layer_stride_in_bytes, + self.indexer_h2d_cpu_kv_stride_in_bytes, + self.indexer_h2d_cpu_layer_stride_in_bytes, + indexer_ssd_block_ids_tensor, + indexer_cpu_block_ids_d2h_tensor, + self.indexer_ssd_layer_stride_in_bytes, + self.indexer_ssd_kv_stride_in_bytes, + self.indexer_cpu_chunk_size_in_bytes, + self.indexer_num_blocks_per_file, + ) + + def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: + layer_granularity = transfer_op.layer_granularity + if layer_granularity == -1: + layer_granularity = self.num_layers + + src_block_ids_h2d = torch.from_numpy(transfer_op.src_block_ids_h2d).to(dtype=torch.int64).pin_memory() + dst_block_ids_h2d = torch.from_numpy(transfer_op.dst_block_ids_h2d).to(dtype=torch.int64).pin_memory() + + if transfer_op.src_block_ids_disk2h.size > 0: + src_block_ids_disk2h = torch.from_numpy(transfer_op.src_block_ids_disk2h).to(dtype=torch.int64) + dst_block_ids_disk2h = torch.from_numpy(transfer_op.dst_block_ids_disk2h).to(dtype=torch.int64) + else: + src_block_ids_disk2h = None + dst_block_ids_disk2h = None + + # Extract indexer block_ids if available + indexer_src_block_ids = None + indexer_dst_block_ids = None + if self.enable_indexer and transfer_op.indexer_src_block_ids.size > 0: + indexer_src_block_ids = torch.from_numpy( + transfer_op.indexer_src_block_ids).to(dtype=torch.int64).pin_memory() + indexer_dst_block_ids = torch.from_numpy( + transfer_op.indexer_dst_block_ids).to(dtype=torch.int64).pin_memory() + + num_h2d_blocks = len(src_block_ids_h2d) + + start_time = time.time() + self._transfer_impl( + src_block_ids_h2d, + dst_block_ids_h2d, + src_block_ids_disk2h, + dst_block_ids_disk2h, + layer_granularity, + transfer_op.counter_id, + indexer_src_block_ids=indexer_src_block_ids, + indexer_dst_block_ids=indexer_dst_block_ids, + ) + end_time = time.time() + + kv_dim = 2 if not self.is_mla else 1 + transfer_size = self.cpu_chunk_size_in_bytes * self.num_layers * num_h2d_blocks * kv_dim + + if self.is_mla: + transfer_size *= self.tp_group_size + + self._log_transfer_performance( + transfer_op, + transfer_size, + start_time, + end_time, + ) + + return True diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index c3b9539f41..092baa214d 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -18,18 +18,17 @@ import multiprocessing as mp import selectors import os -from queue import Queue from typing import Dict, List, Optional, Tuple, Union import contextlib import nvtx +import numpy as np import torch from flexkv.common.debug import flexkv_logger from flexkv.common.storage import StorageHandle -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp +from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp, WorkerKey from flexkv.common.transfer import get_nvtx_range_color -from flexkv.common.storage import KVCacheLayoutType from flexkv.transfer.scheduler import TransferScheduler from flexkv.transfer.worker import ( WorkerHandle, @@ -41,6 +40,10 @@ tpGDSTransferWorker, PEER2CPUTransferWorker, ) +from flexkv.transfer.layerwise import ( + LayerwiseTransferWorker, + build_layerwise_eventfd_socket_path, +) from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.ring_buffer import SharedOpPool @@ -52,6 +55,8 @@ def register_op_to_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: Device type prefixes prevent hash collisions when different device types use the same block ID values (e.g., CPU block 0 vs SSD block 0). """ + if op.transfer_type == TransferType.LAYERWISE: + return # Map TransferType to (src_device_type, dst_device_type) for hash prefix # This prevents hash collisions when different devices use the same block IDs transfer_type_to_devices = { @@ -82,17 +87,21 @@ def free_op_from_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: class TransferEngine: def __init__(self, - gpu_handles: Dict[int, List[StorageHandle]], + gpu_handles: Dict[WorkerKey, List[StorageHandle]], model_config: ModelConfig, cache_config: CacheConfig, cpu_handle: Optional[StorageHandle] = None, ssd_handle: Optional[StorageHandle] = None, - remote_handle: Optional[StorageHandle] = None): + remote_handle: Optional[StorageHandle] = None, + indexer_gpu_handles: Optional[Dict[WorkerKey, List[StorageHandle]]] = None, + indexer_cpu_handle: Optional[StorageHandle] = None, + indexer_ssd_handle: Optional[StorageHandle] = None, + indexer_remote_handle: Optional[StorageHandle] = None): """ Initialize transfer engine Args: - gpu_handles: Dict mapping dp_client_id -> list of GPU handles for that TP group + gpu_handles: Dict mapping WorkerKey(dp_rank, pp_rank) -> list of GPU handles for that TP group cpu_handle: CPU handle ssd_handle: Optional SSD handle remote_handle: Optional remote handle @@ -114,56 +123,99 @@ def __init__(self, # Create shutdown pipe for zero-latency selector self.shutdown_read_fd, self.shutdown_write_fd = os.pipe() - self.gpu_handles = gpu_handles + self.gpu_handle_groups = gpu_handles # WorkerKey -> list of GPU handles for that TP group self._cpu_handle = cpu_handle self._ssd_handle = ssd_handle self._remote_handle = remote_handle self._cache_config = cache_config - self._enable_pcfs_sharing = GLOBAL_CONFIG_FROM_ENV.index_accel and cache_config.enable_kv_sharing # TODO: is this correct? + # TODO: is this correct? + self._enable_pcfs_sharing = ( + GLOBAL_CONFIG_FROM_ENV.index_accel and cache_config.enable_kv_sharing + ) + + self._indexer_gpu_handles = indexer_gpu_handles + self._indexer_cpu_handle = indexer_cpu_handle + self._indexer_ssd_handle = indexer_ssd_handle + self._indexer_remote_handle = indexer_remote_handle self.pin_buffer = SharedOpPool(2048, self.cache_config.num_cpu_blocks) self.op_id_to_nvtx_range: Dict[int, str] = {} - self.dp_size = model_config.dp_size - self.tp_size = model_config.tp_size - self.num_gpu_groups = len(self.gpu_handles) + # self.dp_size = model_config.dp_size + self.tp_size_per_node = model_config.tp_size_per_node + self.num_gpu_groups = len(self.gpu_handle_groups) self._running = False + self._has_indexer = False + + self._indexer_op_to_parent_op: Dict[int, int] = {} + self._indexer_op_map: Dict[int, TransferOp] = {} + + # Same-node PP layerwise fan-out: replica op_id → parent op_id + self._pp_replica_to_parent_op: Dict[int, int] = {} + self._pp_replica_op_map: Dict[int, TransferOp] = {} def _init_workers(self) -> None: if self._running: return - self._worker_map: Dict[TransferType, Union[WorkerHandle, List[WorkerHandle]]] = {} + self._worker_map: Dict[TransferType, Union[WorkerHandle, Dict[WorkerKey, WorkerHandle]]] = {} assert self._cpu_handle is not None + _enable_layerwise = GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer # Use num_gpu_groups to support multi-instance mode # Use gpu_device_id from StorageHandle for correct CUDA device selection - if self.tp_size == 1: - self.h2d_workers: List[WorkerHandle] = [ - GPUCPUTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self.finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=gpu_handles[0].get_tensor_handle_list(), - cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layout=gpu_handles[0].kv_layout, - cpu_kv_layout=self._cpu_handle.kv_layout, - dtype=gpu_handles[0].dtype, - gpu_device_id=gpu_handles[0].gpu_device_id, - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, - ) - for _, gpu_handles in self.gpu_handles.items() - ] - self.d2h_workers: List[WorkerHandle] = [ - GPUCPUTransferWorker.create_worker( + + # H2D worker + if not _enable_layerwise: + if self.tp_size_per_node == 1: + self.h2d_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=gpu_handles[0].get_tensor_handle_list(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), + gpu_kv_layout=gpu_handles[0].kv_layout, + cpu_kv_layout=self._cpu_handle.kv_layout, + dtype=gpu_handles[0].dtype, + gpu_device_id=gpu_handles[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } + else: + self.h2d_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], + cpu_blocks=self._cpu_handle.get_worker_tensor(), + gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], + cpu_kv_layout=self._cpu_handle.kv_layout, + dtype=gpu_handles[0].dtype, + tp_group_size=self.tp_size_per_node, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } + self._worker_map[TransferType.H2D] = self.h2d_workers + + # D2H worker + if self.tp_size_per_node == 1: + self.d2h_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=gpu_handles[0].get_tensor_handle_list(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), gpu_kv_layout=gpu_handles[0].kv_layout, cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, @@ -173,68 +225,52 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, gpu_handles in self.gpu_handles.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } else: - self.h2d_workers = [ - tpGPUCPUTransferWorker.create_worker( + self.d2h_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, + tp_group_size=self.tp_size_per_node, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, gpu_handles in self.gpu_handles.items() - ] - self.d2h_workers = [ - tpGPUCPUTransferWorker.create_worker( + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } + self._worker_map[TransferType.D2H] = self.d2h_workers + + if self._ssd_handle is not None and self._cpu_handle is not None: + # DISK2H worker + if not _enable_layerwise: + self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], - cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], + op_buffer_tensor = self.pin_buffer.get_buffer(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), + ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, - dtype=gpu_handles[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ssd_kv_layout=self._ssd_handle.kv_layout, + dtype=self._cpu_handle.dtype, + num_blocks_per_file=self._ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, ) - for dp_client_id, gpu_handles in self.gpu_handles.items() - ] - self._worker_map[TransferType.H2D] = self.h2d_workers - self._worker_map[TransferType.D2H] = self.d2h_workers + self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker - if self._ssd_handle is not None and self._cpu_handle is not None: - self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self.finished_ops_queue, - op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), - ssd_files=self._ssd_handle.get_file_list(), - cpu_kv_layout=self._cpu_handle.kv_layout, - ssd_kv_layout=self._ssd_handle.kv_layout, - dtype=self._cpu_handle.dtype, - num_blocks_per_file=self._ssd_handle.num_blocks_per_file, - cache_config=self._cache_config, - ) + # H2DISK worker self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, ssd_kv_layout=self._ssd_handle.kv_layout, @@ -243,13 +279,12 @@ def _init_workers(self) -> None: cache_config=self._cache_config, ) self._worker_map[TransferType.H2DISK] = self.cpussd_write_worker - self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, remote_kv_layout=self._remote_handle.kv_layout, @@ -261,7 +296,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, remote_kv_layout=self._remote_handle.kv_layout, @@ -271,9 +306,9 @@ def _init_workers(self) -> None: self._worker_map[TransferType.H2REMOTE] = self.remotecpu_write_worker self._worker_map[TransferType.REMOTE2H] = self.remotecpu_read_worker if self.cache_config.enable_gds: - if self.tp_size == 1: - self.gds_workers = [ - GDSTransferWorker.create_worker( + if self.tp_size_per_node == 1: + self.gds_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -285,11 +320,11 @@ def _init_workers(self) -> None: dtype=self._ssd_handle.dtype, gpu_device_id=gpu_handles[0].gpu_device_id, ) - for _, gpu_handles in self.gpu_handles.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } else: - self.gds_workers = [ - tpGDSTransferWorker.create_worker( + self.gds_workers = { + worker_key: tpGDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -299,13 +334,71 @@ def _init_workers(self) -> None: gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], ssd_kv_layout=self._ssd_handle.kv_layout, dtype=self._ssd_handle.dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, + tp_group_size=self.tp_size_per_node, ) - for dp_client_id, gpu_handles in self.gpu_handles.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } self._worker_map[TransferType.DISK2D] = self.gds_workers self._worker_map[TransferType.D2DISK] = self.gds_workers + if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + ssd_files = {} if self._ssd_handle is None else self._ssd_handle.get_file_list() + ssd_kv_layout = None if self._ssd_handle is None else self._ssd_handle.kv_layout + num_blocks_per_file = 0 if self._ssd_handle is None else self._ssd_handle.num_blocks_per_file + + # Prepare indexer handles for fused layerwise transfer + has_indexer_for_layerwise = ( + self._indexer_gpu_handles is not None and + self._indexer_cpu_handle is not None + ) + + self.layerwise_workers: Dict[WorkerKey, WorkerHandle] = {} + for worker_key, gpu_handles in self.gpu_handle_groups.items(): + _layerwise_eventfd_socket = build_layerwise_eventfd_socket_path( + pp_rank=worker_key.pp_rank, + dp_rank=worker_key.dp_rank, + pp_size=self.model_config.pp_size, + dp_size=self.model_config.dp_size, + ) + # Resolve indexer handles for this WorkerKey + idx_handles = None + if has_indexer_for_layerwise: + idx_handles = self._indexer_gpu_handles.get(worker_key) + + worker = LayerwiseTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[handle.get_tensor_handle_list() for handle in gpu_handles], + cpu_blocks=self._cpu_handle.get_worker_tensor(), + ssd_files=ssd_files, + gpu_kv_layouts=[handle.kv_layout for handle in gpu_handles], + cpu_kv_layout=self._cpu_handle.kv_layout, + ssd_kv_layout=ssd_kv_layout, + dtype=gpu_handles[0].dtype, + tp_group_size=self.tp_size_per_node, + layerwise_eventfd_socket=_layerwise_eventfd_socket, + num_blocks_per_file=num_blocks_per_file, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + indexer_gpu_blocks=[h.get_tensor_handle_list() for h in idx_handles] if idx_handles else None, + indexer_cpu_blocks=self._indexer_cpu_handle.get_worker_tensor() if idx_handles else None, + indexer_gpu_kv_layouts=[h.kv_layout for h in idx_handles] if idx_handles else None, + indexer_cpu_kv_layout=self._indexer_cpu_handle.kv_layout if idx_handles else None, + indexer_dtype=idx_handles[0].dtype if idx_handles else None, + indexer_ssd_files=self._indexer_ssd_handle.get_file_list() if (idx_handles and self._indexer_ssd_handle) else None, + indexer_ssd_kv_layout=self._indexer_ssd_handle.kv_layout if (idx_handles and self._indexer_ssd_handle) else None, + indexer_num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file if (idx_handles and self._indexer_ssd_handle) else 0, + ) + self.layerwise_workers[worker_key] = worker + + flexkv_logger.debug( + f"[TransferEngine] Created layerwise worker for {worker_key}: " + f"tp_size_per_node={self.tp_size_per_node}, has_indexer={idx_handles is not None}, " + f"has_ssd={len(ssd_files) > 0}") + + self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers if self.cache_config.enable_kv_sharing and self._cpu_handle is not None and (self.cache_config.enable_p2p_cpu \ or (self._ssd_handle and self.cache_config.enable_p2p_ssd)): @@ -317,7 +410,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), cpu_kv_layout=self._cpu_handle.kv_layout, # TODO: get remote kv_layout, now we can assume that remote kv layout is same as current node remote_kv_layout=self._cpu_handle.kv_layout, @@ -325,7 +418,8 @@ def _init_workers(self) -> None: cache_config = self.cache_config, ssd_kv_layout = self._ssd_handle.kv_layout if self._ssd_handle else None, ssd_files = self._ssd_handle.get_file_list() if self._ssd_handle else None, - num_blocks_per_file = self._ssd_handle.num_blocks_per_file if self._ssd_handle else None + num_blocks_per_file = self._ssd_handle.num_blocks_per_file if self._ssd_handle else None, + mooncake_config_path = getattr(self.cache_config, 'mooncake_config_path', None) or os.environ.get("MOONCAKE_CONFIG_PATH"), ) # NOTE: now peerH2H and peerSSD2H op use the same worker if self.cache_config.enable_p2p_cpu: @@ -333,20 +427,252 @@ def _init_workers(self) -> None: if self.cache_config.enable_p2p_ssd: self._worker_map[TransferType.PEERSSD2H] = self.cpu_remote_cpu_worker + # Initialize indexer workers + if (self._indexer_gpu_handles is not None + and self._indexer_cpu_handle is not None): + self._indexer_finished_ops_queue = self.mp_ctx.Queue() + self._indexer_worker_map: Dict[TransferType, Union[WorkerHandle, Dict[WorkerKey, WorkerHandle]]] = {} + # H2D indexer worker + if not _enable_layerwise: + if self.tp_size_per_node == 1: + self._indexer_h2d_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } + else: + self._indexer_h2d_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + tp_group_size=self.tp_size_per_node, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } + self._indexer_worker_map[TransferType.H2D] = self._indexer_h2d_workers + + # D2H indexer worker + if self.tp_size_per_node == 1: + self._indexer_d2h_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } + else: + self._indexer_d2h_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + tp_group_size=self.tp_size_per_node, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } + self._indexer_worker_map[TransferType.D2H] = self._indexer_d2h_workers + if self._indexer_ssd_handle is not None and self._indexer_cpu_handle is not None: + # H2DISK indexer worker + self._indexer_h2disk_worker = CPUSSDDiskTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, + ) + self._indexer_worker_map[TransferType.H2DISK] = self._indexer_h2disk_worker + # DISK2H indexer worker + if not _enable_layerwise: + self._indexer_disk2h_worker = CPUSSDDiskTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, + ) + self._indexer_worker_map[TransferType.DISK2H] = self._indexer_disk2h_worker + flexkv_logger.info("TransferEngine: indexer SSD workers initialized") + if self._indexer_remote_handle is not None and self._indexer_cpu_handle is not None: + self._indexer_h2remote_worker = CPURemoteTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + remote_file=self._indexer_remote_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + remote_kv_layout=self._indexer_remote_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + remote_config_custom=self._indexer_remote_handle.remote_config_custom, + enable_pcfs_sharing=self._enable_pcfs_sharing, + ) + self._indexer_remote2h_worker = CPURemoteTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + remote_file=self._indexer_remote_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + remote_kv_layout=self._indexer_remote_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + remote_config_custom=self._indexer_remote_handle.remote_config_custom, + ) + self._indexer_worker_map[TransferType.H2REMOTE] = self._indexer_h2remote_worker + self._indexer_worker_map[TransferType.REMOTE2H] = self._indexer_remote2h_worker + flexkv_logger.info("TransferEngine: indexer Remote workers initialized") + if self.cache_config.enable_gds and self._indexer_ssd_handle is not None: + if self.tp_size_per_node == 1: + self._indexer_gds_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GDSTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_ssd_handle.dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + ) + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } + else: + self._indexer_gds_workers = { + worker_key: tpGDSTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + ssd_files=self._indexer_ssd_handle.get_file_list(), + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_ssd_handle.dtype, + tp_group_size=self.tp_size_per_node, + ) + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } + self._indexer_worker_map[TransferType.DISK2D] = self._indexer_gds_workers + self._indexer_worker_map[TransferType.D2DISK] = self._indexer_gds_workers + flexkv_logger.info("TransferEngine: indexer GDS workers initialized") + if self.cache_config.enable_kv_sharing and self._indexer_cpu_handle is not None and ( + self.cache_config.enable_p2p_cpu + or (self._indexer_ssd_handle and self.cache_config.enable_p2p_ssd)): + flexkv_logger.info("[transfer_engine] initializing the indexer PEER2CPUTransferWorker!") + self._indexer_cpu_remote_cpu_worker: WorkerHandle = PEER2CPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + remote_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + cache_config=self._cache_config, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout if self._indexer_ssd_handle else None, + ssd_files=self._indexer_ssd_handle.get_file_list() if self._indexer_ssd_handle else None, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file if self._indexer_ssd_handle else None, + ) + if self.cache_config.enable_p2p_cpu: + self._indexer_worker_map[TransferType.PEERH2H] = self._indexer_cpu_remote_cpu_worker + if self.cache_config.enable_p2p_ssd: + self._indexer_worker_map[TransferType.PEERSSD2H] = self._indexer_cpu_remote_cpu_worker + flexkv_logger.info("TransferEngine: indexer P2P workers initialized") + self._has_indexer = True + if not _enable_layerwise: + flexkv_logger.info( + f"TransferEngine: indexer inline workers initialized " + f"({len(self._indexer_h2d_workers)} H2D + {len(self._indexer_d2h_workers)} D2H)") + else: + flexkv_logger.info( + f"TransferEngine: indexer inline workers initialized " + f"(H2D fused into layerwise, {len(self._indexer_d2h_workers)} D2H)") if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") - # Wait for all workers to ready + # Wait for all main KV workers to ready for transfer_type, worker in self._worker_map.items(): - if isinstance(worker, List): - for w in worker: - flexkv_logger.info(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") + if isinstance(worker, dict): + for w in worker.values(): + flexkv_logger.debug(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") w.ready_event.wait() - flexkv_logger.info(f"{transfer_type.name} worker {w.worker_id} is ready") + flexkv_logger.debug(f"{transfer_type.name} worker {w.worker_id} is ready") else: - flexkv_logger.info(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") + flexkv_logger.debug(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") worker.ready_event.wait() - flexkv_logger.info(f"{transfer_type.name} worker {worker.worker_id} is ready") + flexkv_logger.debug(f"{transfer_type.name} worker {worker.worker_id} is ready") + # Wait for all indexer workers to ready + if self._has_indexer: + for transfer_type, worker in self._indexer_worker_map.items(): + if isinstance(worker, dict): + for w in worker.values(): + flexkv_logger.debug(f"waiting for indexer {transfer_type.name} worker {w.worker_id} to ready") + w.ready_event.wait() + flexkv_logger.debug(f"indexer {transfer_type.name} worker {w.worker_id} is ready") + else: + flexkv_logger.debug(f"waiting for indexer {transfer_type.name} worker {worker.worker_id} to ready") + worker.ready_event.wait() + flexkv_logger.debug(f"indexer {transfer_type.name} worker {worker.worker_id} is ready") + # Startup assertions: verify layerwise mode worker map consistency + if _enable_layerwise: + assert TransferType.H2D not in self._worker_map, \ + "H2D worker should not exist in layerwise mode (fused into layerwise worker)" + assert TransferType.DISK2H not in self._worker_map, \ + "DISK2H worker should not exist in layerwise mode (fused into layerwise worker)" + assert TransferType.LAYERWISE in self._worker_map, \ + "LAYERWISE worker must exist when layerwise transfer is enabled" # Start scheduler thread self._running = True self._scheduler_thread = threading.Thread(target=self._scheduler_loop) @@ -366,6 +692,10 @@ def _scheduler_loop(self) -> None: sel.register(self.task_queue._reader, selectors.EVENT_READ, data="new_graph") sel.register(self.finished_ops_queue._reader, selectors.EVENT_READ, data="finished_op") + # Register indexer finished_ops_queue when indexer is enabled + if self._has_indexer: + sel.register(self._indexer_finished_ops_queue._reader, selectors.EVENT_READ, data="indexer_finished_op") + # Register shutdown pipe for zero-latency shutdown sel.register(self.shutdown_read_fd, selectors.EVENT_READ, data="shutdown") @@ -406,21 +736,60 @@ def _scheduler_loop(self) -> None: nvtx.end_range(nvtx_r1) elif key.data == "finished_op": - # Collect finished ops (batch get all available) + # Collect finished ops from main KV worker (batch get all available) nvtx_r2 = nvtx.start_range(message="transfer scheduler. collect finished ops", color="orange") # Get all available ops in one go to reduce system calls while True: try: op_id = self.finished_ops_queue.get_nowait() - op = self.op_id_to_op[op_id] - free_op_from_buffer(op, self.pin_buffer) - self.completed_queue.put(CompletedOp(graph_id=op.graph_id, op_id=op.op_id)) - finished_ops.append(op) - del self.op_id_to_op[op_id] + # Check if this is a PP-replica op (same-node PP fan-out) + if op_id in self._pp_replica_to_parent_op: + # PP-replica op: decrement parent's pending_count + replica_op = self._pp_replica_op_map.pop(op_id) + parent_op_id = self._pp_replica_to_parent_op.pop(op_id) + parent_op = self.op_id_to_op[parent_op_id] + parent_op.pending_count -= 1 + # Clean up replica from op_id_to_op and NVTX + del self.op_id_to_op[op_id] + if op_id in self.op_id_to_nvtx_range: + nvtx.end_range(self.op_id_to_nvtx_range[op_id]) + self.op_id_to_nvtx_range.pop(op_id) + if parent_op.pending_count == 0: + self._finalize_op(parent_op, finished_ops) + flexkv_logger.debug( + f"[TransferEngine] PP replica op {op_id} completed, " + f"parent op {parent_op_id} pending_count={parent_op.pending_count}") + else: + op = self.op_id_to_op[op_id] + op.pending_count -= 1 + if op.pending_count == 0: + self._finalize_op(op, finished_ops) except queue.Empty: break nvtx.end_range(nvtx_r2) + elif key.data == "indexer_finished_op": + # Collect finished ops from indexer worker (batch get all available) + nvtx_r2i = nvtx.start_range(message="transfer scheduler. collect indexer finished ops", color="blue") + while True: + try: + op_id = self._indexer_finished_ops_queue.get_nowait() + assert op_id in self._indexer_op_to_parent_op, ( + f"[TransferEngine] Indexer op {op_id} not found in " + f"_indexer_op_to_parent_op. All indexer ops must be " + f"registered with a parent op." + ) + indexer_op = self._indexer_op_map.pop(op_id) + free_op_from_buffer(indexer_op, self.pin_buffer) + parent_op_id = self._indexer_op_to_parent_op.pop(op_id) + parent_op = self.op_id_to_op[parent_op_id] + parent_op.pending_count -= 1 + if parent_op.pending_count == 0: + self._finalize_op(parent_op, finished_ops) + except queue.Empty: + break + nvtx.end_range(nvtx_r2i) + # Exit loop if shutdown requested if should_shutdown: break @@ -456,6 +825,106 @@ def _scheduler_loop(self) -> None: sel.close() flexkv_logger.info("TransferEngine scheduler loop stopped") + def _finalize_op(self, op: TransferOp, finished_ops: List[TransferOp]) -> None: + """Finalize a completed op: release pin buffer, notify upper layer, and clean up. + + Called only when op.pending_count reaches 0, i.e., all workers (main KV + indexer) + have completed this op. This ensures atomic eviction semantics. + """ + free_op_from_buffer(op, self.pin_buffer) + # Compute transfer metrics for this completed op + num_blocks = len(op.src_block_ids) if op.src_block_ids is not None else 0 + num_bytes = num_blocks * self.cache_config.tokens_per_block * self.model_config.token_size_in_bytes_per_pp_stage + transfer_type_str = op.transfer_type.value if op.transfer_type != TransferType.VIRTUAL else None + self.completed_queue.put(CompletedOp( + graph_id=op.graph_id, + op_id=op.op_id, + transfer_type=transfer_type_str, + num_blocks=num_blocks, + num_bytes=num_bytes, + )) + finished_ops.append(op) + del self.op_id_to_op[op.op_id] + + def _assign_layerwise_op_to_workers(self, op: TransferOp) -> None: + """Fan-out a LAYERWISE op to all PP-stage layerwise workers on the same dp_rank. + + In cross-node PP, the remote TransferManagerOnRemote handles this by + rebinding WorkerKey. In same-node PP there is no remote TM, so we + replicate the op here for every PP-stage worker under the same dp_rank. + + Replicas are tracked via ``_pp_replica_to_parent_op`` so that their + completion decrements the parent op's ``pending_count`` (identical to + how indexer ops are tracked). + """ + from flexkv.common.transfer import LayerwiseTransferOp + assert isinstance(op, LayerwiseTransferOp) + + worker_map = self._worker_map[TransferType.LAYERWISE] + assert isinstance(worker_map, dict), \ + "LAYERWISE worker map must be a Dict[WorkerKey, WorkerHandle]" + + # Find all layerwise workers sharing the same dp_rank + sibling_keys = [wk for wk in worker_map if wk.dp_rank == op.dp_rank] + + if not sibling_keys: + raise ValueError( + f"No layerwise worker found for dp_rank={op.dp_rank}, pp_rank={op.pp_rank}") + + # Submit to the original pp_rank's worker + primary_key = WorkerKey(dp_rank=op.dp_rank, pp_rank=op.pp_rank) + if primary_key in worker_map: + worker_map[primary_key].submit_transfer(op) + else: + # Original worker not found — this shouldn't happen, but handle gracefully + raise ValueError( + f"No layerwise worker found for primary key {primary_key}") + + # If there's only one worker for this dp_rank, no fan-out needed + if len(sibling_keys) <= 1: + return + + # Create replicas for every other pp_rank under the same dp_rank + for wk in sibling_keys: + if wk == primary_key: + continue + + # Create a replica LayerwiseTransferOp with the target pp_rank + replica = LayerwiseTransferOp( + graph_id=op.graph_id, + src_block_ids_h2d=op.src_block_ids_h2d.copy(), + dst_block_ids_h2d=op.dst_block_ids_h2d.copy(), + src_block_ids_disk2h=op.src_block_ids_disk2h.copy(), + dst_block_ids_disk2h=op.dst_block_ids_disk2h.copy(), + layer_id=op.layer_id, + layer_granularity=op.layer_granularity, + dp_rank=op.dp_rank, + pp_rank=wk.pp_rank, + counter_id=op.counter_id, + indexer_src_block_ids=op.indexer_src_block_ids.copy(), + indexer_dst_block_ids=op.indexer_dst_block_ids.copy(), + ) + + # Track replica → parent so that completion decrements parent's pending_count + self._pp_replica_to_parent_op[replica.op_id] = op.op_id + self._pp_replica_op_map[replica.op_id] = replica + op.pending_count += 1 + + # Register in op_id_to_op so scheduler can find it on completion + self.op_id_to_op[replica.op_id] = replica + self.op_id_to_nvtx_range[replica.op_id] = nvtx.start_range( + f"schedule LAYERWISE_REPLICA op_id: {replica.op_id}, " + f"graph_id: {replica.graph_id}, pp_rank={wk.pp_rank}", + color=get_nvtx_range_color(replica.graph_id)) + + worker_map[wk].submit_transfer(replica) + + flexkv_logger.debug( + f"[TransferEngine] === Layerwise PP Replica Dispatched ===" + f"\n parent_op_id={op.op_id}, replica_op_id={replica.op_id}" + f"\n dp_rank={op.dp_rank}, pp_rank={wk.pp_rank}" + f"\n pending_count={op.pending_count}") + def _assign_op_to_worker(self, op: TransferOp) -> None: self.op_id_to_nvtx_range[op.op_id] = nvtx.start_range(f"schedule {op.transfer_type.name} " f"op_id: {op.op_id}, " @@ -468,9 +937,57 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: if op.transfer_type not in self._worker_map: raise ValueError(f"Unsupported transfer type: {op.transfer_type}") + # --- Same-node PP fan-out for LAYERWISE ops --- + # In cross-node PP, TransferManagerOnRemote rebinds WorkerKey so that + # PP1+ workers receive the op. In same-node PP, only one local + # TransferManager exists, so we fan-out here: for every layerwise + # worker that shares the same dp_rank but has a different pp_rank, we + # create a replica op with the correct pp_rank. + if op.transfer_type == TransferType.LAYERWISE: + self._assign_layerwise_op_to_workers(op) + return + + worker_key = WorkerKey(dp_rank=op.dp_rank, pp_rank=op.pp_rank) + + if self._has_indexer and op.transfer_type in self._indexer_worker_map: + # Indexer maps 1:1 with main KV blocks, use block_ids directly. + src_page_ids = op.src_block_ids + dst_page_ids = op.dst_block_ids + num_pages = src_page_ids.size + + if num_pages > 0: + # Always create a separate indexer_op to avoid sharing the same op + # object between indexer worker and main KV worker. + indexer_op = TransferOp( + graph_id=op.graph_id, + transfer_type=op.transfer_type, + src_block_ids=src_page_ids, + dst_block_ids=dst_page_ids, + layer_id=op.layer_id, + layer_granularity=op.layer_granularity, + dp_rank=op.dp_rank, + pp_rank=op.pp_rank, + ) + register_op_to_buffer(indexer_op, self.pin_buffer) + self._indexer_op_to_parent_op[indexer_op.op_id] = op.op_id + self._indexer_op_map[indexer_op.op_id] = indexer_op + op.pending_count += 1 + + flexkv_logger.debug( + f"[TransferEngine] === Indexer Op Dispatched (non-layerwise) ===" + f"\n parent_op_id={op.op_id}, indexer_op_id={indexer_op.op_id}" + f"\n type={op.transfer_type.name}, dp_rank={op.dp_rank}, pp_rank={op.pp_rank}" + f"\n num_pages={num_pages}, pending_count={op.pending_count}") + + indexer_worker = self._indexer_worker_map[op.transfer_type] + if isinstance(indexer_worker, dict): + indexer_worker[worker_key].submit_transfer(indexer_op) + else: + indexer_worker.submit_transfer(indexer_op) + worker = self._worker_map[op.transfer_type] - if isinstance(worker, List): - worker[op.dp_id].submit_transfer(op) + if isinstance(worker, dict): + worker[worker_key].submit_transfer(op) else: worker.submit_transfer(op) @@ -536,10 +1053,18 @@ def shutdown(self) -> None: else: flexkv_logger.debug(f"Shutdown pipes already closed: {e}") - # shutdown all workers + # shutdown indexer workers first + if self._has_indexer: + for worker in self._indexer_worker_map.values(): + if isinstance(worker, dict): + for w in worker.values(): + w.shutdown() + else: + worker.shutdown() + # shutdown main KV workers for worker in self._worker_map.values(): - if isinstance(worker, List): - for w in worker: + if isinstance(worker, dict): + for w in worker.values(): w.shutdown() else: worker.shutdown() diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index a853c2c2a8..bae3b7bcc9 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -10,7 +10,6 @@ from threading import Thread from typing import List, Any, Dict, Union, Optional, Tuple -import ctypes import numpy as np import nvtx import torch @@ -32,8 +31,15 @@ from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType -from flexkv.common.transfer import get_nvtx_range_color +from flexkv.common.transfer import get_nvtx_range_color, LayerwiseTransferOp from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV, MooncakeTransferEngineConfig +from flexkv.storage.allocator import HugePageTensorHandle, materialize_worker_tensor +from flexkv.transfer.host_buffer import ( + allocate_host_buffer, + cudaHostRegister, +) +from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp + from flexkv.mooncakeEngineWrapper import MoonCakeTransferEngineWrapper from flexkv.transfer.zmqHelper import NotifyMsg, NotifyStatus, SSDZMQServer, SSDZMQClient from flexkv.cache.redis_meta import RedisMeta @@ -47,58 +53,6 @@ transfer_kv_blocks_remote = None shared_transfer_kv_blocks_remote_read = None - -cudart = ctypes.CDLL('libcudart.so') - -def cudaHostRegister(tensor: torch.Tensor) -> None: - """Register a CPU tensor with CUDA for pinned memory access""" - ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() - ret = cudart.cudaHostRegister(ctypes.c_void_p(ptr), ctypes.c_size_t(size), 1) # 1 means cudaHostRegisterPortable - if ret != 0: - raise RuntimeError(f"cudaHostRegister failed with error code {ret}") - -def cudaHostUnregister(tensor: torch.Tensor) -> None: - """Unregister a CPU tensor from CUDA for pinned memory access""" - ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() - ret = cudart.cudaHostUnregister(ctypes.c_void_p(ptr)) - -@dataclass -class WorkerTransferOp: - transfer_op_id: int - transfer_graph_id: int - transfer_type: TransferType - layer_id: int - layer_granularity: int - src_slot_id: int - dst_slot_id: int - valid_block_num: int - src_block_ids: np.ndarray - dst_block_ids: np.ndarray - src_block_node_ids: Optional[np.ndarray] - # successors: List[int] - - def __init__(self, transfer_op: TransferOp): - self.transfer_op_id = transfer_op.op_id - self.transfer_graph_id = transfer_op.graph_id - self.transfer_type = transfer_op.transfer_type - self.layer_id = transfer_op.layer_id - self.layer_granularity = transfer_op.layer_granularity - self.src_slot_id = transfer_op.src_slot_id - self.dst_slot_id = transfer_op.dst_slot_id - self.valid_block_num = transfer_op.valid_block_num - # Always preserve optional src_block_node_ids from TransferOp - self.src_block_node_ids = transfer_op.src_block_node_ids - - if self.src_slot_id == -1 or self.dst_slot_id == -1: - self.src_block_ids = transfer_op.src_block_ids - self.dst_block_ids = transfer_op.dst_block_ids - else: - self.src_block_ids = np.empty(0) - self.dst_block_ids = np.empty(0) - # self.successors = list(transfer_op.successors) # for nvtx - class TransferWorkerBase(ABC): _worker_id_counter = 0 _worker_id_lock = threading.Lock() @@ -250,8 +204,7 @@ def run(self) -> None: transfer_status = False try: nvtx.push_range(f"launch {op.transfer_type.name} op_id: {op.transfer_op_id}, " - f"graph_id: {op.transfer_graph_id}, " - f"num_blocks: {op.valid_block_num}", + f"graph_id: {op.transfer_graph_id}", color=get_nvtx_range_color(op.transfer_graph_id)) transfer_status = self.launch_transfer(op) nvtx.pop_range() @@ -278,8 +231,12 @@ def __init__(self, worker_id: int, transfer_conn: Connection, process: mp.Proces self.process = process self.ready_event = ready_event - def submit_transfer(self, op: TransferOp) -> None: - self.transfer_conn.send(WorkerTransferOp(op)) + def submit_transfer(self, op: Union[TransferOp, LayerwiseTransferOp]) -> None: + if isinstance(op, LayerwiseTransferOp): + worker_op = WorkerLayerwiseTransferOp(op) + else: + worker_op = WorkerTransferOp(op) + self.transfer_conn.send(worker_op) def shutdown(self) -> None: try: @@ -305,7 +262,7 @@ def __init__(self, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, gpu_blocks: List[TensorSharedHandle], - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], gpu_kv_layout: KVCacheLayout, cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, @@ -317,6 +274,7 @@ def __init__(self, # initialize worker in a new process super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) # Register CPU tensors with CUDA + cpu_blocks = materialize_worker_tensor(cpu_blocks) flexkv_logger.info(f"Pinning CPU Memory: {cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] @@ -456,12 +414,11 @@ def __init__(self, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], gpu_kv_layouts: List[KVCacheLayout], cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, - dp_group_id: int, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, transfer_num_cta_h2d: int = 4, @@ -469,6 +426,7 @@ def __init__(self, super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size + cpu_blocks = materialize_worker_tensor(cpu_blocks) # Handle tensor import for multi-process case imported_gpu_blocks = [] for handles_in_one_gpu in gpu_blocks: @@ -482,7 +440,6 @@ def __init__(self, self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size - self.dp_group_id = dp_group_id flexkv_logger.info(f"Pinning CPU Memory: {cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) @@ -527,12 +484,13 @@ def __init__(self, gpu_device_ids = [self.gpu_blocks[i][0].device.index for i in range(self.num_gpus)] num_tensors_per_gpu = len(self.gpu_blocks[0]) + flexkv_logger.info(f"num_tensors_per_gpu: {num_tensors_per_gpu}") + self.tp_transfer_thread_group = TPTransferThreadGroup( self.num_gpus, gpu_block_ptrs_flat, num_tensors_per_gpu, cpu_blocks_ptr, - dp_group_id, self.num_layers, self.gpu_kv_strides_in_bytes, self.gpu_block_strides_in_bytes, @@ -626,7 +584,7 @@ def __init__(self, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], ssd_files: Dict[int, List[str]], # ssd_device_id -> file_paths cpu_kv_layout: KVCacheLayout, ssd_kv_layout: KVCacheLayout, @@ -634,6 +592,7 @@ def __init__(self, num_blocks_per_file: int, cache_config: CacheConfig): super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + cpu_blocks = materialize_worker_tensor(cpu_blocks) self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) @@ -750,7 +709,7 @@ def __init__(self, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, - cpu_blocks: List[torch.Tensor], + cpu_blocks: Union[List[torch.Tensor], torch.Tensor, HugePageTensorHandle], remote_file: List[str], cpu_kv_layout: KVCacheLayout, remote_kv_layout: KVCacheLayout, @@ -761,6 +720,8 @@ def __init__(self, raise RuntimeError("transfer_kv_blocks_remote not available, please build with FLEXKV_ENABLE_CFS=1") super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + cpu_blocks = materialize_worker_tensor(cpu_blocks) + self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.remote_files = remote_file self.num_remote_files = len(remote_file) @@ -1176,7 +1137,6 @@ def __init__( ssd_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, - dp_group_id: int, ) -> None: """ Initialize TP GDS Transfer Worker @@ -1192,7 +1152,6 @@ def __init__( ssd_kv_layout: Layout of SSD KV cache dtype: Data type tp_group_size: Size of tensor parallel group - dp_group_id: Data parallel group ID """ # Initialize base class first super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) @@ -1213,7 +1172,6 @@ def __init__( self.is_mla = gpu_kv_layouts[0].is_mla self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size - self.dp_group_id = dp_group_id # Layout information self.num_layers = gpu_kv_layouts[0].num_layer @@ -1236,7 +1194,7 @@ def __init__( # SSD layout calculations self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize - self.ssd_tp_stride_in_bytes = self.ssd_block_stride_in_bytes // self.tp_group_size if not self.is_mla else self.ssd_block_stride_in_bytes + self.ssd_tp_stride_in_bytes = self.ssd_block_stride_in_bytes // self.tp_size_per_node if not self.is_mla else self.ssd_block_stride_in_bytes # Resolve pointers in Python (where storage is valid); pass them to C++ so we avoid # "Tensor that doesn't have storage" when C++ calls .data_ptr() on tensors passed @@ -1255,7 +1213,6 @@ def __init__( gpu_block_ptrs_flat, num_tensors_per_gpu, ssd_files, - dp_group_id, self.num_layers, self.gpu_kv_strides_in_bytes, self.gpu_block_strides_in_bytes, @@ -1350,7 +1307,7 @@ def __init__(self, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], cpu_kv_layout: KVCacheLayout, remote_kv_layout: KVCacheLayout, dtype: torch.dtype, @@ -1358,8 +1315,10 @@ def __init__(self, ssd_kv_layout: KVCacheLayout = None, ssd_files: Dict[int, List[str]] = None, # ssd_device_id -> file_paths num_blocks_per_file: int = 0, + mooncake_config_path: str = None, ): super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + cpu_blocks = materialize_worker_tensor(cpu_blocks) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.num_layers = cpu_kv_layout.num_layer self.num_cpu_blocks = cpu_kv_layout.num_block @@ -1390,6 +1349,7 @@ def __init__(self, self.cache_config.redis_port, self.cache_config.redis_password, self.cache_config.local_ip, + node_ttl_seconds=self.cache_config.node_ttl_seconds, ) self.redis_meta_client.set_node_id(self.cache_config.distributed_node_id) @@ -1401,8 +1361,18 @@ def __init__(self, # step2: initialize mooncake transfer engine for the whole flexkv - # NOTE:now we read the config file by env paras - mooncake_config_path = os.environ["MOONCAKE_CONFIG_PATH"] + # NOTE: prefer explicit parameter > cache_config > env variable + # (spawn subprocesses may lose env vars, but cache_config is pickle-serialized) + if mooncake_config_path is None: + mooncake_config_path = getattr(self.cache_config, 'mooncake_config_path', None) + if mooncake_config_path is None: + mooncake_config_path = os.environ.get("MOONCAKE_CONFIG_PATH") + if mooncake_config_path is None: + raise RuntimeError( + "MOONCAKE_CONFIG_PATH is not set. Please either pass mooncake_config_path " + "parameter, set cache_config.mooncake_config_path, or set the " + "MOONCAKE_CONFIG_PATH environment variable." + ) self.mooncake_config = MooncakeTransferEngineConfig.from_file( mooncake_config_path ) @@ -1445,14 +1415,25 @@ def __init__(self, self.cpu_kv_layout.num_head, self.cpu_kv_layout.head_size, self.cpu_kv_layout.is_mla, - self.cpu_kv_layout._kv_shape, + self.cpu_kv_layout.kv_shape, ) - self.tmp_cpu_buffer = torch.empty( - self.tmp_cpu_buffer_layout.get_total_elements(), + # Allocate the temporary SSD->CPU staging buffer. + # + # Two backends are supported: + # (a) HugePage-backed mmap (when ``cache_config.use_hugepage_tmp_buffer`` + # is True and the kernel has huge pages reserved). We still need + # to pin it for CUDA via ``cudaHostRegister`` because the region + # is not allocated through PyTorch's pinned-memory allocator. + # (b) Pinned ``torch.empty`` (the original behavior, default). + tmp_num_elements = self.tmp_cpu_buffer_layout.get_total_elements() + self._tmp_cpu_buffer_handle = allocate_host_buffer( + num_elements=tmp_num_elements, dtype=self.dtype, - device="cpu", - pin_memory=True, + use_hugepage=self.cache_config.use_hugepage_tmp_buffer, + hugepage_size_bytes=self.cache_config.hugepage_size_bytes, ) + self.tmp_cpu_buffer = self._tmp_cpu_buffer_handle.tensor + self.mooncake_transfer_engine.regist_buffer( self.tmp_cpu_buffer.data_ptr(), self.tmp_cpu_buffer.numel() * self.tmp_cpu_buffer.element_size(), @@ -1525,6 +1506,9 @@ def shutdown(self): self.mooncake_transfer_engine.unregist_buffer(self.cpu_blocks.data_ptr()) if self.cache_config.enable_p2p_ssd: self.mooncake_transfer_engine.unregist_buffer(self.tmp_cpu_buffer.data_ptr()) + # Release CUDA pinning & HugePage mapping, if any. + if hasattr(self, "_tmp_cpu_buffer_handle"): + self._tmp_cpu_buffer_handle.release() # unregist node info from redis server self.unregist_node_meta() @@ -2143,8 +2127,27 @@ def unregist_node_meta(self, node_id: int = None) -> None: flexkv_logger.info(f"Unregistered node {self.redis_meta_client.get_node_id()} from Redis.") def get_node_meta(self, node_id: int) -> Optional[NodeMetaInfo]: - # TODO: how to remove the invalid node meta info in node_metas - """Get the node meta info by node id.""" + """Get the node meta info by node id. + + Before returning cached or freshly-fetched meta, we verify that the + node is still active (its node: key exists in Redis and has not + expired). This prevents RDMA transfers to stale addresses after a + remote node has crashed. + """ + # ===== Active-node validation (Scheme 4) ===== + if not self.redis_meta_client.is_node_active(node_id): + # Node is no longer active – purge cached meta if any + if node_id in self.node_metas: + del self.node_metas[node_id] + flexkv_logger.warning( + f"Node {node_id} is no longer active, removed cached meta." + ) + else: + flexkv_logger.warning( + f"Node {node_id} is not active, skipping meta fetch." + ) + return None + if node_id not in self.node_metas: ## fetch from redis node_redis_data = self.redis_meta_client.get_node_meta(node_id) diff --git a/flexkv/transfer/worker_op.py b/flexkv/transfer/worker_op.py new file mode 100644 index 0000000000..ecc7b29f9c --- /dev/null +++ b/flexkv/transfer/worker_op.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +from flexkv.common.transfer import TransferOp, TransferType, LayerwiseTransferOp + + +@dataclass +class WorkerTransferOp: + transfer_op_id: int + transfer_graph_id: int + transfer_type: TransferType + layer_id: int + layer_granularity: int + src_slot_id: int + dst_slot_id: int + valid_block_num: int + src_block_ids: np.ndarray + dst_block_ids: np.ndarray + src_block_node_ids: Optional[np.ndarray] + + def __init__(self, transfer_op: TransferOp): + self.transfer_op_id = transfer_op.op_id + self.transfer_graph_id = transfer_op.graph_id + self.transfer_type = transfer_op.transfer_type + self.layer_id = transfer_op.layer_id + self.layer_granularity = transfer_op.layer_granularity + self.src_slot_id = transfer_op.src_slot_id + self.dst_slot_id = transfer_op.dst_slot_id + self.valid_block_num = transfer_op.valid_block_num + # Always preserve optional src_block_node_ids from TransferOp + self.src_block_node_ids = transfer_op.src_block_node_ids + + if self.src_slot_id == -1 or self.dst_slot_id == -1: + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids + else: + self.src_block_ids = np.empty(0) + self.dst_block_ids = np.empty(0) + + +@dataclass +class WorkerLayerwiseTransferOp: + transfer_op_id: int + transfer_graph_id: int + transfer_type: TransferType + layer_id: int + layer_granularity: int + src_block_ids_h2d: np.ndarray + dst_block_ids_h2d: np.ndarray + src_block_ids_disk2h: np.ndarray + dst_block_ids_disk2h: np.ndarray + counter_id: int # Counter set index for triple buffering eventfd notification + # Indexer block_ids for fused indexer transfer + indexer_src_block_ids: np.ndarray + indexer_dst_block_ids: np.ndarray + + def __init__(self, transfer_op: LayerwiseTransferOp): + self.transfer_op_id = transfer_op.op_id + self.transfer_graph_id = transfer_op.graph_id + assert transfer_op.transfer_type == TransferType.LAYERWISE + self.transfer_type = transfer_op.transfer_type + self.layer_id = transfer_op.layer_id + self.layer_granularity = transfer_op.layer_granularity + self.src_block_ids_h2d = transfer_op.src_block_ids_h2d + self.dst_block_ids_h2d = transfer_op.dst_block_ids_h2d + self.src_block_ids_disk2h = transfer_op.src_block_ids_disk2h + self.dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h + self.counter_id = transfer_op.counter_id + self.indexer_src_block_ids = transfer_op.indexer_src_block_ids + self.indexer_dst_block_ids = transfer_op.indexer_dst_block_ids diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 0633447360..d8afeabd09 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -1,5 +1,6 @@ import os import multiprocessing as mp +import signal import time import queue import selectors @@ -18,7 +19,7 @@ import pickle import sys -from flexkv.common.transfer import TransferOpGraph, CompletedOp +from flexkv.common.transfer import TransferOpGraph, CompletedOp, WorkerKey from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle @@ -42,12 +43,16 @@ def __init__(self, # Multi-instance support: get instance_num from environment self.instance_num = GLOBAL_CONFIG_FROM_ENV.instance_num - # Calculate total expected GPUs across all instances - self.expected_gpus = self.instance_num * model_config.tp_size * model_config.dp_size + # Calculate total expected GPUs on this node across all instances + self.expected_gpus = self.instance_num * model_config.gpus_per_node self.all_gpu_layouts: Dict[int, KVCacheLayout] = {} self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks - self.gpu_client_mapping: Dict[int, int] = {} # device_id -> dp_client_id + self.gpu_worker_key_mapping: Dict[int, WorkerKey] = {} # device_id -> WorkerKey(dp_rank, pp_rank) + + # Indexer GPU registration data + self.all_indexer_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> indexer_gpu_blocks + self.all_indexer_gpu_layouts: Dict[int, KVCacheLayout] = {} self.context = zmq.Context(2) self.recv_from_client = get_zmq_socket( @@ -67,7 +72,14 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: try: self.all_gpu_blocks[device_id] = req.handles self.all_gpu_layouts[device_id] = req.gpu_layout - self.gpu_client_mapping[device_id] = req.dp_client_id + self.gpu_worker_key_mapping[device_id] = WorkerKey(dp_rank=req.dp_rank, pp_rank=req.pp_rank) + # Store indexer GPU data if present + if req.indexer_handles is not None: + self.all_indexer_gpu_blocks[device_id] = req.indexer_handles + self.all_indexer_gpu_layouts[device_id] = req.indexer_gpu_layout + flexkv_logger.info( + f"GPU {device_id}: registered indexer handles " + f"({len(req.indexer_handles)} layers)") except Exception as e: flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") @@ -75,18 +87,31 @@ def _register_gpu_blocks_via_socket(self) -> None: try: flexkv_logger.info(f"GPU tensor registration server started on port {self.gpu_register_port}, " f"expected {self.expected_gpus} GPUs to register " - f"(instance_num={self.instance_num}, tp={self.model_config.tp_size}, " - f"dp={self.model_config.dp_size})") + f"(instance_num={self.instance_num}, gpus_per_node={self.model_config.gpus_per_node}, " + f"total_gpus={self.model_config.total_gpus}, pp_rank={self.model_config.pp_rank}, " + f"node_rank={self.model_config.node_rank}, nnodes={self.model_config.nnodes})") + last_log_time = time.time() while len(self.all_gpu_blocks) < self.expected_gpus: try: # Recv from: flexkv.server.client.KVTPClient.register_to_server req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) except zmq.Again: + # Periodically log waiting status for debugging + now = time.time() + if now - last_log_time >= 5.0: + registered_ids = sorted(self.all_gpu_blocks.keys()) + flexkv_logger.info( + f"Still waiting for GPU registrations: " + f"{len(self.all_gpu_blocks)}/{self.expected_gpus} registered " + f"(registered_device_ids={registered_ids}, " + f"port={self.gpu_register_port})") + last_log_time = now time.sleep(0.001) continue if isinstance(req, RegisterTPClientRequest): - flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}") + flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}, " + f"device_id={req.device_id}, dp_rank={req.dp_rank}") self._handle_gpu_blocks_registration(req) flexkv_logger.info(f"GPU {req.device_id} registered successfully, " f"waiting for {self.expected_gpus - len(self.all_gpu_blocks)} GPUs to register") @@ -115,18 +140,28 @@ def initialize_transfer_engine(self) -> None: # Register GPU blocks with their global device IDs for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): - self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, - self.all_gpu_layouts[device_id], - device_id, - dtype=self.model_config.dtype) + # Get indexer data for this device if available + indexer_gpu_blocks = self.all_indexer_gpu_blocks.get(device_id) + indexer_gpu_layout = self.all_indexer_gpu_layouts.get(device_id) + indexer_dtype = (self.cache_config.indexer.dtype + if self.cache_config.indexer is not None else None) + self.storage_engine.register_gpu_blocks( + gpu_blocks_wrapper, + self.all_gpu_layouts[device_id], + device_id, + dtype=self.model_config.dtype, + indexer_gpu_blocks=indexer_gpu_blocks, + indexer_gpu_layout=indexer_gpu_layout, + indexer_dtype=indexer_dtype, + ) - # Group GPU handles by dp_client_id - grouped_gpu_handles: Dict[int, List] = {} + # Group GPU handles by WorkerKey + grouped_gpu_handles: Dict[WorkerKey, List] = {} for device_id in sorted(self.all_gpu_blocks.keys()): - dp_client_id = self.gpu_client_mapping[device_id] - if dp_client_id not in grouped_gpu_handles: - grouped_gpu_handles[dp_client_id] = [] - grouped_gpu_handles[dp_client_id].append( + worker_key = self.gpu_worker_key_mapping[device_id] + if worker_key not in grouped_gpu_handles: + grouped_gpu_handles[worker_key] = [] + grouped_gpu_handles[worker_key].append( self.storage_engine.get_storage_handle(DeviceType.GPU, device_id)) cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) \ @@ -138,13 +173,64 @@ def initialize_transfer_engine(self) -> None: if self.cache_config.enable_remote \ else None ) - self.transfer_engine = TransferEngine(gpu_handles=grouped_gpu_handles, - model_config=self.model_config, - cache_config=self.cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle, - remote_handle=remote_handle) - flexkv_logger.info("Initialized TransferEngine successfully") + + indexer_gpu_handles: Optional[Dict[WorkerKey, List]] = None + if self.storage_engine.has_storage_handle(DeviceType.CPU, is_indexer=True): + indexer_gpu_handles = {} + for device_id in sorted(self.all_gpu_blocks.keys()): + if self.storage_engine.has_storage_handle(DeviceType.GPU, device_id, is_indexer=True): + worker_key = self.gpu_worker_key_mapping[device_id] + if worker_key not in indexer_gpu_handles: + indexer_gpu_handles[worker_key] = [] + indexer_gpu_handles[worker_key].append( + self.storage_engine.get_storage_handle(DeviceType.GPU, device_id, is_indexer=True)) + indexer_cpu_handle = ( + self.storage_engine.get_storage_handle(DeviceType.CPU, is_indexer=True) + if self.storage_engine.has_storage_handle(DeviceType.CPU, is_indexer=True) + else None + ) + indexer_ssd_handle = ( + self.storage_engine.get_storage_handle(DeviceType.SSD, is_indexer=True) + if self.storage_engine.has_storage_handle(DeviceType.SSD, is_indexer=True) + else None + ) + indexer_remote_handle = ( + self.storage_engine.get_storage_handle(DeviceType.REMOTE, is_indexer=True) + if self.storage_engine.has_storage_handle(DeviceType.REMOTE, is_indexer=True) + else None + ) + + self.transfer_engine = TransferEngine( + gpu_handles=grouped_gpu_handles, + model_config=self.model_config, + cache_config=self.cache_config, + cpu_handle=cpu_handle, + ssd_handle=ssd_handle, + remote_handle=remote_handle, + indexer_gpu_handles=indexer_gpu_handles, + indexer_cpu_handle=indexer_cpu_handle, + indexer_ssd_handle=indexer_ssd_handle, + indexer_remote_handle=indexer_remote_handle, + ) + + # Derive local pp_rank from GPU registrations rather than model_config. + # In cross-node PP, TransferManagerOnRemote receives model_config from + # the PP0 master (pp_rank=0), but local GPUs register with their true + # pp_rank (e.g. pp_rank=1). All local workers share the same pp_rank + # because they belong to the same PP stage on this node. + worker_keys = set(self.gpu_worker_key_mapping.values()) + self._local_dp_rank = self.model_config.dp_rank + self._local_pp_rank = self.model_config.pp_rank + if len(worker_keys) >= 1: + pp_ranks = set(wk.pp_rank for wk in worker_keys) + assert len(pp_ranks) == 1, \ + f"Expected all local workers to share the same pp_rank, got {pp_ranks}" + self._local_pp_rank = pp_ranks.pop() + + flexkv_logger.info(f"Initialized TransferEngine successfully, " + f"grouped_gpu_handles keys={list(grouped_gpu_handles.keys())}, " + f"num_gpu_groups={len(grouped_gpu_handles)}, " + f"local_dp_rank={self._local_dp_rank}, local_pp_rank={self._local_pp_rank}") def submit(self, transfer_graph: TransferOpGraph) -> None: self.transfer_engine.submit_transfer_graph(transfer_graph) @@ -162,10 +248,30 @@ def shutdown(self) -> None: if hasattr(self, 'transfer_engine'): self.transfer_engine.shutdown() -def get_master_host_and_ports_from_env() -> Tuple[str, Tuple[str, str, str]]: - master_host = os.getenv("FLEXKV_MASTER_HOST", "localhost") +def resolve_master_host_and_ports( + master_host: Optional[str] = None, +) -> Tuple[str, Tuple[str, str, str]]: + """Resolve the (master_host, master_ports) tuple for multi-node transfer. + + ``master_host`` resolution order: + 1. explicit ``master_host`` argument (when provided by the caller, + e.g. via sglang ``--dist-init-addr``); + 2. ``FLEXKV_MASTER_HOST`` env var (used by framework-agnostic + launchers such as TRT-LLM's ``multi_node_launch.sh``); + 3. ``"localhost"`` default. + + ``master_ports`` always comes from ``FLEXKV_MASTER_PORTS`` (or default), + because changing ports rarely warrants a host-aware plumbing change. + """ + if master_host is None: + master_host = os.getenv("FLEXKV_MASTER_HOST", "localhost") master_ports = os.getenv("FLEXKV_MASTER_PORTS", "5556,5557,5558") master_ports = tuple(master_ports.split(",")) + flexkv_logger.info( + f"[TransferManager] resolved master endpoint: " + f"host={master_host!r} (source={'arg' if master_host is not None else 'env/default'}), " + f"ports={master_ports}" + ) return "tcp://" + master_host, master_ports def get_trtllm_subprocess_host_and_ports_from_env() -> Tuple[str, Tuple[str, str, str]]: @@ -178,9 +284,11 @@ class TransferManagerOnRemote(TransferManager): """ TransferManager for remote mode, used for multi-node tensor parallelism. """ - def __init__(self, mode: str = "Default"): + def __init__(self, mode: str = "Default", master_host: Optional[str] = None): if mode == "Default": - self.master_host, self.master_ports = get_master_host_and_ports_from_env() + self.master_host, self.master_ports = resolve_master_host_and_ports( + master_host=master_host + ) elif mode == "TrtllmSubprocess": self.master_host, self.master_ports = get_trtllm_subprocess_host_and_ports_from_env() else: @@ -201,6 +309,12 @@ def __init__(self, mode: str = "Default"): self._active_graphs: Dict[int, int] = {} self._active_graphs_lock = threading.Lock() + # Pending matching for cross-node PP: graph arrives before or after slot_mapping + # _pending_graphs stores (graph, task_end_op_id) tuples + self._pending_graphs: Dict[int, Tuple[TransferOpGraph, int]] = {} + self._pending_slot_mappings: Dict[int, np.ndarray] = {} + self._pending_lock = threading.Lock() + self._worker_thread: threading.Thread | None = None self._connect_to_master_transfer_manager() @@ -264,21 +378,21 @@ def _polling_worker(self) -> None: task_end_op_id = message.get('task_end_op_id', -1) if graph is not None: - graph_id = graph.graph_id - - with self._active_graphs_lock: - self._active_graphs[graph_id] = task_end_op_id - - self.submit(graph) + self._handle_submit(graph, task_end_op_id) else: flexkv_logger.warning("Received submit message without graph") elif msg_type == 'submit_batch': graphs = message.get('graphs', []) for graph in graphs: + self._rebind_graph_to_local_worker(graph) graph_id = graph.graph_id with self._active_graphs_lock: self._active_graphs[graph_id] = -1 self.submit(graph) + elif msg_type == 'set_slot_mapping': + task_id = message.get('task_id') + slot_mapping = message.get('slot_mapping') + self._handle_set_slot_mapping(task_id, slot_mapping) else: flexkv_logger.warning(f"Unexpected command message: {message}") else: @@ -327,11 +441,98 @@ def _polling_worker(self) -> None: poller.unregister(self.command_socket) poller.unregister(self.query_socket) + def _handle_set_slot_mapping(self, task_id: int, slot_mapping: np.ndarray) -> None: + """Handle set_slot_mapping message from FlexKVConnector. + + When the graph (with cleared GPU blocks) arrived earlier, we can immediately + set_gpu_blocks and submit. Otherwise, store the slot_mapping and wait + for the graph to arrive later. + """ + graph = None + task_end_op_id = -1 + with self._pending_lock: + if task_id in self._pending_graphs: + # Graph already arrived, set GPU blocks and prepare for submit + graph, task_end_op_id = self._pending_graphs.pop(task_id) + graph.set_gpu_blocks(slot_mapping) + self._rebind_graph_to_local_worker(graph) + flexkv_logger.debug( + f"[TransferManagerOnRemote] set_slot_mapping: " + f"graph for task_id={task_id} submitted (graph arrived first)" + ) + else: + # Graph not yet arrived, store slot_mapping for later matching + self._pending_slot_mappings[task_id] = slot_mapping + flexkv_logger.debug( + f"[TransferManagerOnRemote] set_slot_mapping: " + f"slot_mapping stored for task_id={task_id}, waiting for graph" + ) + return + + # Submit graph to transfer engine + with self._active_graphs_lock: + self._active_graphs[graph.graph_id] = task_end_op_id + self.submit(graph) + + def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> None: + """Handle submit message with pending matching support. + + If slot_mapping already arrived, set_gpu_blocks and submit immediately. + Otherwise, store graph in pending_graphs for later matching. + """ + task_id = graph.graph_id # Use graph_id as task_id for matching + with self._pending_lock: + if task_id in self._pending_slot_mappings: + # slot_mapping already arrived, set GPU blocks and submit + slot_mapping = self._pending_slot_mappings.pop(task_id) + graph.set_gpu_blocks(slot_mapping) + self._rebind_graph_to_local_worker(graph) + flexkv_logger.debug( + f"[TransferManagerOnRemote] submit: " + f"graph for task_id={task_id} submitted (slot_mapping arrived first)" + ) + else: + # slot_mapping not yet arrived, store graph and task_end_op_id for later matching + self._pending_graphs[task_id] = (graph, task_end_op_id) + flexkv_logger.debug( + f"[TransferManagerOnRemote] submit: " + f"graph stored for task_id={task_id}, waiting for slot_mapping" + ) + return # Don't submit yet, wait for slot_mapping + + # Submit graph to transfer engine + with self._active_graphs_lock: + self._active_graphs[graph.graph_id] = task_end_op_id + self.submit(graph) + + def _rebind_graph_to_local_worker(self, graph: TransferOpGraph) -> None: + """Rebind transfer graph ops to the local pp_rank. + + In cross-node PP setups, the master (PP0) creates transfer graphs with + its own pp_rank=0. When these graphs are sent to a remote node (e.g. PP1), + the ops' pp_rank must be updated to the local pp_rank so the + TransferEngine can find the correct workers. + + Each op's dp_rank is preserved — in multi-DP scenarios, different ops + may belong to different dp_ranks and should remain bound to their + original DP group. + """ + if self.model_config.pp_rank == self._local_pp_rank: + return # No rebinding needed + + for op in graph._op_map.values(): + op.pp_rank = self._local_pp_rank + + flexkv_logger.debug( + f"[TransferManagerOnRemote] Rebound graph {graph.graph_id} " + f"pp_rank from {self.model_config.pp_rank} to {self._local_pp_rank}" + ) + def start(self) -> None: self.initialize_transfer_engine() super().start() - self._is_ready = true + self._is_ready = True self._worker_thread = threading.Thread( target=self._polling_worker, daemon=True @@ -367,9 +568,6 @@ def __del__(self) -> None: @classmethod def create_process(cls, **kwargs: Any) -> Process: - import tempfile - import os - # Serialize the class and kwargs cls_data = pickle.dumps(cls) kwargs_data = pickle.dumps(kwargs) @@ -450,7 +648,6 @@ def cleanup_files(): except Exception: pass - import threading cleanup_thread = threading.Thread(target=cleanup_files, daemon=True) cleanup_thread.start() @@ -559,6 +756,11 @@ def _start_process(self) -> None: if self.process is not None and self.process.is_alive(): return + flexkv_logger.debug( + f"Spawning TransferManager subprocess: " + f"pp_rank={self.model_config.pp_rank}, node_rank={self.model_config.node_rank}, " + f"tp_size={self.model_config.tp_size}, dp_size={self.model_config.dp_size}, " + f"gpu_register_port={self.gpu_register_port}") self.process = self.mp_ctx.Process( target=self._process_worker, args=(self.model_config, @@ -571,6 +773,7 @@ def _start_process(self) -> None: daemon=False ) self.process.start() + flexkv_logger.debug(f"TransferManager subprocess spawned, pid={self.process.pid}") def _process_worker(self, model_config: ModelConfig, @@ -580,12 +783,29 @@ def _process_worker(self, gpu_register_port: str, ready_event, start_event) -> None: + # Automatically reap child processes (daemon transfer workers) to + # prevent zombie accumulation. Use a handler that calls waitpid() + # with WNOHANG so that multiprocessing.Process.join() still works + # correctly (SIG_IGN would cause join() to raise ChildProcessError). + def _reap_children(signum, frame): + while True: + try: + pid, _ = os.waitpid(-1, os.WNOHANG) + if pid == 0: + break + except ChildProcessError: + break + signal.signal(signal.SIGCHLD, _reap_children) try: + flexkv_logger.debug(f"_process_worker started, pid={os.getpid()}, " + f"gpu_register_port={gpu_register_port}, " + f"pp_rank={model_config.pp_rank}, node_rank={model_config.node_rank}") start_event.set() os.environ['MPI4PY_RC_INITIALIZE'] = 'false' transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) transfer_manager.initialize_transfer_engine() transfer_manager.start() + flexkv_logger.debug("TransferEngine started successfully, setting ready_event") ready_event.set() # Setup selector for event-driven processing (complete zero polling!) @@ -656,6 +876,13 @@ def _process_worker(self, except Exception as e: flexkv_logger.error(f"Error closing selector: {e}") + # Gracefully shut down transfer engine and its worker subprocesses + if 'transfer_manager' in locals(): + try: + transfer_manager.shutdown() + except Exception as e: + flexkv_logger.error(f"Error shutting down transfer manager: {e}") + command_conn.close() result_conn.close() @@ -717,7 +944,7 @@ def __del__(self): self.shutdown() -class TranserManagerMultiNodeHandle(TransferManagerHandleBase): +class TransferManagerMultiNodeHandle(TransferManagerHandleBase): def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, @@ -756,20 +983,20 @@ def _bind_master_ports(self) -> None: try: command_addr = f"{self.master_host}:{self.master_ports[0]}" self.command_socket.bind(command_addr) - flexkv_logger.debug(f"Master bound command port at {command_addr}") + flexkv_logger.info(f"Master bound command port at {command_addr}") result_addr = f"{self.master_host}:{self.master_ports[1]}" self.result_socket.bind(result_addr) - flexkv_logger.debug(f"Master bound result port at {result_addr}") + flexkv_logger.info(f"Master bound result port at {result_addr}") query_addr = f"{self.master_host}:{self.master_ports[2]}" self.query_socket.bind(query_addr) - flexkv_logger.debug(f"Master bound query port at {query_addr}") + flexkv_logger.info(f"Master bound query port at {query_addr}") self.result_socket.setsockopt(zmq.RCVTIMEO, 0) self._connected = True - flexkv_logger.debug("Master transfer manager ready for remote connections") + flexkv_logger.info("Master transfer manager ready for remote connections") except Exception as e: flexkv_logger.error(f"Master failed to bind ports: {e}") @@ -915,6 +1142,10 @@ def __init__(self, gpu_register_port: Optional[str] = None, mode: str = "process", **kwargs): # process or thread or remote + flexkv_logger.debug( + f"Creating TransferManagerHandle: mode={mode}, " + f"pp_rank={model_config.pp_rank}, node_rank={model_config.node_rank}, " + f"gpu_register_port={gpu_register_port}") if gpu_register_port is None: gpu_register_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" if mode == "process": @@ -928,7 +1159,7 @@ def __init__(self, elif mode == "remote": master_host = kwargs["master_host"] master_ports = kwargs["master_ports"] - self._handle: TransferManagerHandleBase = TranserManagerMultiNodeHandle( + self._handle: TransferManagerHandleBase = TransferManagerMultiNodeHandle( model_config, cache_config, gpu_register_port, master_host, master_ports ) else: diff --git a/install.sh b/install.sh new file mode 100755 index 0000000000..4f5d126b03 --- /dev/null +++ b/install.sh @@ -0,0 +1,549 @@ +#!/bin/bash +# ============================================================================= +# FlexKV One-Click Install Script +# ============================================================================= +# Usage: +# bash install.sh [OPTIONS] +# +# Options: +# --venv PATH Specify virtual environment path (default: ./venv) +# --no-venv Skip virtual environment creation, install directly +# --release Build in release mode (with Cython compilation) +# --debug Build in debug mode (default, no Cython) +# --enable-metrics Enable Prometheus monitoring support +# --enable-p2p Enable distributed P2P/Redis support (default: enabled) +# --disable-p2p Disable distributed P2P/Redis support +# --mooncake-version VER Mooncake release tag to build from source (default: latest main branch) +# --enable-gds Enable GDS support +# --enable-cfs Enable CFS support +# --skip-deps Skip system dependency installation +# --clean Clean all build artifacts and exit +# -h, --help Show this help message +# ============================================================================= +set -e + +# ======================== Default Configuration ======================== +VENV_PATH="./venv" +USE_VENV=1 +BUILD_TYPE="debug" +ENABLE_METRICS=0 +ENABLE_P2P=1 +ENABLE_GDS=0 +ENABLE_CFS=0 +SKIP_DEPS=0 +CLEAN_ONLY=0 +MOONCAKE_VERSION="" + +# Use sudo only if not running as root +if [ "$(id -u)" -eq 0 ]; then + SUDO="" +else + SUDO="sudo" +fi + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# ======================== Helper Functions ======================== +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +success() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; exit 1; } + +usage() { + head -n 17 "$0" | tail -n 14 | sed 's/^# \?//' + exit 0 +} + +# ======================== Parse Arguments ======================== +while [[ $# -gt 0 ]]; do + case "$1" in + --venv) + VENV_PATH="$2" + USE_VENV=1 + shift 2 + ;; + --no-venv) + USE_VENV=0 + shift + ;; + --release) + BUILD_TYPE="release" + shift + ;; + --debug) + BUILD_TYPE="debug" + shift + ;; + --enable-metrics) + ENABLE_METRICS=1 + shift + ;; + --enable-p2p) + ENABLE_P2P=1 + shift + ;; + --disable-p2p) + ENABLE_P2P=0 + shift + ;; + --mooncake-version) + MOONCAKE_VERSION="$2" + shift 2 + ;; + --enable-gds) + ENABLE_GDS=1 + shift + ;; + --enable-cfs) + ENABLE_CFS=1 + shift + ;; + --skip-deps) + SKIP_DEPS=1 + shift + ;; + --clean) + CLEAN_ONLY=1 + shift + ;; + -h|--help) + usage + ;; + *) + warn "Unknown option: $1" + shift + ;; + esac +done + +# ======================== Project Root ======================== +PROJECT_ROOT="$(cd "$(dirname "$0")" && pwd)" +cd "$PROJECT_ROOT" +info "Project root: $PROJECT_ROOT" + +# ======================== Clean Mode ======================== +if [ "$CLEAN_ONLY" -eq 1 ]; then + info "Cleaning all build artifacts..." + bash build.sh --clean + if [ -d "$VENV_PATH" ]; then + rm -rf "$VENV_PATH" + info "Removed virtual environment: $VENV_PATH" + fi + success "Clean completed." + exit 0 +fi + +# ======================== Step 1: Check System Dependencies ======================== +info "============================================" +info "Step 1: Checking system dependencies" +info "============================================" + +check_command() { + if command -v "$1" &>/dev/null; then + success "$1 found: $(command -v "$1")" + return 0 + else + warn "$1 not found" + return 1 + fi +} + +MISSING_CMDS=() +MISSING_PKGS=() + +# Check essential commands +check_command python3 || MISSING_CMDS+=("python3") +check_command cmake || { MISSING_CMDS+=("cmake"); MISSING_PKGS+=("cmake"); } +check_command git || { MISSING_CMDS+=("git"); MISSING_PKGS+=("git"); } +check_command gcc || { MISSING_CMDS+=("gcc"); MISSING_PKGS+=("build-essential"); } +check_command g++ || { MISSING_CMDS+=("g++"); MISSING_PKGS+=("build-essential"); } + +# Check python3-venv availability (test with a real temporary venv to catch missing ensurepip) +if [ "$USE_VENV" -eq 1 ]; then + _VENV_TEST_DIR=$(mktemp -d) + if ! python3 -m venv "$_VENV_TEST_DIR/test_venv" &>/dev/null 2>&1; then + rm -rf "$_VENV_TEST_DIR" + warn "python3-venv not available (ensurepip missing)" + PY_MINOR=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + MISSING_PKGS+=("python3.${PY_MINOR#3.}-venv" "python3-venv" "python3-full") + else + rm -rf "$_VENV_TEST_DIR" + fi +fi + +# Check for liburing-dev (required by setup.py: -luring) +if ! dpkg -s liburing-dev &>/dev/null 2>&1 && ! rpm -q liburing-devel &>/dev/null 2>&1; then + if [ -f /etc/debian_version ]; then + warn "liburing-dev not found" + MISSING_PKGS+=("liburing-dev") + elif [ -f /etc/redhat-release ]; then + warn "liburing-devel not found" + MISSING_PKGS+=("liburing-devel") + fi +fi + +# Check for hiredis if P2P enabled +if [ "$ENABLE_P2P" -eq 1 ]; then + if ! dpkg -s libhiredis-dev &>/dev/null 2>&1 && ! rpm -q hiredis-devel &>/dev/null 2>&1; then + if [ -f /etc/debian_version ]; then + MISSING_PKGS+=("libhiredis-dev") + elif [ -f /etc/redhat-release ]; then + MISSING_PKGS+=("hiredis-devel") + fi + fi +fi + +# Install missing packages +if [ ${#MISSING_PKGS[@]} -gt 0 ] && [ "$SKIP_DEPS" -eq 0 ]; then + # Deduplicate + UNIQUE_PKGS=($(echo "${MISSING_PKGS[@]}" | tr ' ' '\n' | sort -u | tr '\n' ' ')) + info "Installing missing packages: ${UNIQUE_PKGS[*]}" + + if command -v apt-get &>/dev/null; then + $SUDO apt-get update -qq + $SUDO apt-get install -y -qq "${UNIQUE_PKGS[@]}" + elif command -v yum &>/dev/null; then + $SUDO yum install -y "${UNIQUE_PKGS[@]}" + elif command -v dnf &>/dev/null; then + $SUDO dnf install -y "${UNIQUE_PKGS[@]}" + else + error "Cannot auto-install packages. Please manually install: ${UNIQUE_PKGS[*]}" + fi + success "System packages installed." +elif [ ${#MISSING_PKGS[@]} -gt 0 ] && [ "$SKIP_DEPS" -eq 1 ]; then + warn "Skipping dependency installation (--skip-deps). Missing: ${MISSING_PKGS[*]}" +fi + +# Final check for critical commands +for cmd in python3 cmake git gcc g++; do + command -v "$cmd" &>/dev/null || error "$cmd is still not available. Please install it manually." +done + +# Check NVIDIA CUDA toolkit +if ! command -v nvcc &>/dev/null; then + warn "nvcc not found. CUDA toolkit is required for building FlexKV." + warn "Please install CUDA toolkit from: https://developer.nvidia.com/cuda-downloads" + warn "Or load it via: module load cuda" +fi + +success "System dependencies check passed." + +# ======================== Step 2: Setup Python Virtual Environment ======================== +info "============================================" +info "Step 2: Setting up Python environment" +info "============================================" + +if [ "$USE_VENV" -eq 1 ]; then + if [ -d "$VENV_PATH" ] && [ -f "$VENV_PATH/bin/activate" ]; then + info "Using existing virtual environment: $VENV_PATH" + else + info "Creating virtual environment at: $VENV_PATH" + python3 -m venv "$VENV_PATH" + success "Virtual environment created." + fi + + # Activate virtual environment + source "$VENV_PATH/bin/activate" + success "Virtual environment activated: $(which python3)" + + # Upgrade pip + info "Upgrading pip..." + pip install --upgrade pip -q +else + warn "Skipping virtual environment (--no-venv). Installing to system Python." + warn "If you encounter 'externally-managed-environment' error, use --venv instead." +fi + +# Install Python build dependencies +info "Installing Python build dependencies..." +pip install -q setuptools wheel +if [ "$BUILD_TYPE" = "release" ]; then + pip install -q "Cython>=3.0.10" +fi + +# Check if torch is installed +if ! python3 -c "import torch" &>/dev/null 2>&1; then + warn "PyTorch not found. Installing PyTorch..." + warn "If you need a specific CUDA version, please install PyTorch manually first." + pip install torch +fi +success "Python environment ready." + +# ======================== Step 3: Initialize Git Submodules ======================== +info "============================================" +info "Step 3: Initializing git submodules" +info "============================================" + +if [ "$ENABLE_METRICS" -eq 1 ]; then + info "Metrics enabled: initializing all submodules (including prometheus-cpp)..." + git submodule update --init --recursive +else + info "Metrics disabled: initializing only xxHash submodule..." + git submodule update --init --recursive third_party/xxHash +fi +success "Git submodules initialized." + +# ======================== Step 4: Build C++ Libraries ======================== +info "============================================" +info "Step 4: Building C++ libraries (CMake)" +info "============================================" + +mkdir -p build +cd build + +CMAKE_ARGS="" +if [ "$ENABLE_METRICS" -eq 0 ]; then + CMAKE_ARGS="-DFLEXKV_ENABLE_MONITORING=OFF" +fi + +info "Running CMake configuration..." +cmake .. $CMAKE_ARGS + +info "Building C++ libraries..." +cmake --build . -j"$(nproc)" + +BUILD_LIB_PATH="$(pwd)/lib" +cd "$PROJECT_ROOT" + +# Set LD_LIBRARY_PATH +export LD_LIBRARY_PATH="$BUILD_LIB_PATH:$LD_LIBRARY_PATH" + +# Copy shared libraries to package directory +info "Copying shared libraries to package directory..." +PACKAGE_LIB_DIR="flexkv/lib" +mkdir -p "$PACKAGE_LIB_DIR" +if [ -d "$BUILD_LIB_PATH" ]; then + for lib_file in "$BUILD_LIB_PATH"/*.so*; do + if [ -f "$lib_file" ]; then + cp "$lib_file" "$PACKAGE_LIB_DIR/" + fi + done +fi +success "C++ libraries built successfully." + +# ======================== Step 4.5: Install Python Runtime Dependencies ======================== +info "============================================" +info "Step 4.5: Installing Python runtime dependencies" +info "============================================" + +# Core runtime dependencies (always needed) +RUNTIME_DEPS="numpy pyzmq psutil nvtx pyyaml expiring-dict" + +# Additional dependencies for P2P/distributed mode +if [ "$ENABLE_P2P" -eq 1 ]; then + RUNTIME_DEPS="$RUNTIME_DEPS redis" + info "mooncake-transfer-engine will be built from source in Step 4.6" +fi + +info "Installing runtime dependencies: $RUNTIME_DEPS" +pip install -q $RUNTIME_DEPS +success "Python runtime dependencies installed." + +# ======================== Step 4.6: Build Mooncake from Source ======================== +if [ "$ENABLE_P2P" -eq 1 ]; then + info "============================================" + info "Step 4.6: Building mooncake-transfer-engine from source" + info "============================================" + if [ -n "$MOONCAKE_VERSION" ]; then + info "Target version: $MOONCAKE_VERSION" + else + info "Target version: latest (main branch)" + fi + + MOONCAKE_BUILD_DIR="${PROJECT_ROOT}/.mooncake-build" + + # Clone or update mooncake source + if [ -d "$MOONCAKE_BUILD_DIR" ] && [ -d "$MOONCAKE_BUILD_DIR/.git" ]; then + info "Found existing mooncake source at $MOONCAKE_BUILD_DIR, updating..." + cd "$MOONCAKE_BUILD_DIR" + git fetch --tags + else + info "Cloning mooncake source to $MOONCAKE_BUILD_DIR..." + rm -rf "$MOONCAKE_BUILD_DIR" + git clone --recurse-submodules https://github.com/kvcache-ai/Mooncake.git "$MOONCAKE_BUILD_DIR" + cd "$MOONCAKE_BUILD_DIR" + fi + + # Checkout target version if specified + if [ -n "$MOONCAKE_VERSION" ]; then + info "Checking out $MOONCAKE_VERSION..." + git checkout "$MOONCAKE_VERSION" + else + info "Using latest main branch..." + git checkout main 2>/dev/null || git checkout master 2>/dev/null || true + git pull --ff-only 2>/dev/null || true + fi + git submodule sync --recursive + git submodule update --init --recursive + + # Install mooncake system dependencies + if [ "$SKIP_DEPS" -eq 0 ]; then + info "Installing mooncake system dependencies..." + $SUDO bash -x dependencies.sh -y + else + warn "Skipping mooncake dependency installation (--skip-deps)" + fi + + # Configure: only build transfer-engine with Redis support + info "Configuring mooncake-transfer-engine with Redis metadata backend support..." + mkdir -p build && cd build + + # Detect CUDA stubs path + CUDA_STUBS_PATH="" + if [ -d "/usr/local/cuda/lib64/stubs" ]; then + CUDA_STUBS_PATH="/usr/local/cuda/lib64/stubs" + elif [ -n "$CUDA_HOME" ] && [ -d "$CUDA_HOME/lib64/stubs" ]; then + CUDA_STUBS_PATH="$CUDA_HOME/lib64/stubs" + fi + + CMAKE_EXTRA_FLAGS="" + if [ -n "$CUDA_STUBS_PATH" ]; then + CMAKE_EXTRA_FLAGS="-DCMAKE_EXE_LINKER_FLAGS=-L${CUDA_STUBS_PATH}" + fi + + cmake -G Ninja .. \ + -DWITH_TE=ON \ + -DUSE_REDIS=ON \ + -DUSE_HTTP=ON \ + -DUSE_ETCD=OFF \ + -DUSE_CUDA=ON \ + -DWITH_STORE=OFF \ + -DWITH_P2P_STORE=OFF \ + -DWITH_EP=OFF \ + -DWITH_METRICS=OFF \ + -DBUILD_UNIT_TESTS=OFF \ + -DBUILD_EXAMPLES=ON \ + -DCMAKE_BUILD_TYPE=Release \ + $CMAKE_EXTRA_FLAGS + + # Build + info "Building mooncake-transfer-engine (this may take a while)..." + if [ -n "$CUDA_STUBS_PATH" ]; then + export LD_LIBRARY_PATH="${CUDA_STUBS_PATH}:$LD_LIBRARY_PATH" + export LIBRARY_PATH="${CUDA_STUBS_PATH}:$LIBRARY_PATH" + fi + cmake --build . -j"$(nproc)" + $SUDO cmake --install . + + # Build and install Python wheel + info "Building mooncake-transfer-engine Python wheel..." + cd "$MOONCAKE_BUILD_DIR" + + # Uninstall any existing mooncake pip package + pip uninstall -y mooncake-transfer-engine mooncake-transfer-engine-cuda13 2>/dev/null || true + + # Detect if CUDA 13 build + CUDA_MAJOR_VERSION="" + if command -v nvcc &>/dev/null; then + CUDA_MAJOR_VERSION=$(nvcc --version | grep -oP 'release \K[0-9]+') + elif [ -n "$CUDA_HOME" ] && [ -f "$CUDA_HOME/bin/nvcc" ]; then + CUDA_MAJOR_VERSION=$("$CUDA_HOME/bin/nvcc" --version | grep -oP 'release \K[0-9]+') + fi + + MOONCAKE_BUILD_ENV="" + if [ -n "$CUDA_MAJOR_VERSION" ] && [ "$CUDA_MAJOR_VERSION" -ge 13 ] 2>/dev/null; then + MOONCAKE_BUILD_ENV="CU13_BUILD=1" + fi + + eval $MOONCAKE_BUILD_ENV OUTPUT_DIR=dist ./scripts/build_wheel.sh + + # build_wheel.sh outputs wheel to mooncake-wheel/dist/ + MOONCAKE_WHEEL=$(ls mooncake-wheel/dist/*.whl 2>/dev/null | head -n 1) + if [ -z "$MOONCAKE_WHEEL" ]; then + error "mooncake-transfer-engine wheel not found in mooncake-wheel/dist/" + fi + pip install "$MOONCAKE_WHEEL" + + cd "$PROJECT_ROOT" + success "mooncake-transfer-engine built from source with Redis support!" + + # Verify Redis metadata backend support + info "Verifying mooncake Redis metadata backend support..." + python3 -c " +from mooncake import engine +e = engine.TransferEngine() +print('mooncake-transfer-engine loaded successfully (built from source with Redis support)') +" && success "mooncake verification passed!" || warn "mooncake verification had warnings, see above." +fi + +# ======================== Step 5: Install Python Package ======================== +info "============================================" +info "Step 5: Installing FlexKV Python package" +info "============================================" + +# Set environment variables for build +export FLEXKV_ENABLE_METRICS="$ENABLE_METRICS" +export FLEXKV_ENABLE_P2P="$ENABLE_P2P" +export FLEXKV_ENABLE_GDS="$ENABLE_GDS" +export FLEXKV_ENABLE_CFS="$ENABLE_CFS" + +if [ "$BUILD_TYPE" = "debug" ]; then + export FLEXKV_DEBUG=1 + info "Installing in debug mode (editable, no Cython)..." + pip install -v --no-build-isolation -e . +elif [ "$BUILD_TYPE" = "release" ]; then + export FLEXKV_DEBUG=0 + info "Building release wheel..." + python3 setup.py bdist_wheel -v + # Install the built wheel + WHEEL_FILE=$(ls dist/flexkv-*.whl 2>/dev/null | head -n 1) + if [ -n "$WHEEL_FILE" ]; then + pip install "$WHEEL_FILE" + else + error "Wheel file not found in dist/" + fi +fi +success "FlexKV Python package installed." + +# ======================== Step 6: Verify Installation ======================== +info "============================================" +info "Step 6: Verifying installation" +info "============================================" + +python3 -c " +import flexkv +print('FlexKV imported successfully') +try: + print(f'Version: {flexkv.__version__}') +except AttributeError: + pass +try: + from flexkv import c_ext + print('C extension loaded successfully') +except ImportError as e: + print(f'Warning: C extension not loaded: {e}') +" && success "FlexKV installation verified!" || warn "Verification had warnings, see above." + +# ======================== Summary ======================== +echo "" +info "============================================" +success "FlexKV installation completed!" +info "============================================" +echo "" +info "Build type: $BUILD_TYPE" +info "Metrics: $([ $ENABLE_METRICS -eq 1 ] && echo 'Enabled' || echo 'Disabled')" +info "P2P/Redis: $([ $ENABLE_P2P -eq 1 ] && echo 'Enabled' || echo 'Disabled')" +if [ "$ENABLE_P2P" -eq 1 ]; then + if [ -n "$MOONCAKE_VERSION" ]; then + info "Mooncake: Built from source ($MOONCAKE_VERSION) with Redis metadata backend" + else + info "Mooncake: Built from source (latest) with Redis metadata backend" + fi +fi +info "GDS: $([ $ENABLE_GDS -eq 1 ] && echo 'Enabled' || echo 'Disabled')" +info "CFS: $([ $ENABLE_CFS -eq 1 ] && echo 'Enabled' || echo 'Disabled')" + +if [ "$USE_VENV" -eq 1 ]; then + VENV_ABS_PATH="$(cd "$VENV_PATH" && pwd)" + echo "" + info "Virtual environment: $VENV_ABS_PATH" + info "To activate it in a new terminal, run:" + echo "" + echo " source $VENV_ABS_PATH/bin/activate" + echo "" +fi diff --git a/requirements.txt b/requirements.txt index 4c1ec7be69..69eb85601f 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,12 @@ setuptools>=40.0.0 torch>=1.10.0 -# nvtx==0.2.11 # Skip nvtx for now due to compatibility issues +numpy>=1.20.0 +pyzmq>=22.0.0 +psutil>=5.8.0 +nvtx>=0.2.8 +pyyaml>=5.4.0 Cython>=3.0.10 pytest>=6.0.0 pytest-benchmark>=3.0.0 expiring-dict==1.1.2 +redis>=4.0.0 diff --git a/setup.py b/setup.py index bbbbeb7972..2d3813d3cc 100755 --- a/setup.py +++ b/setup.py @@ -7,9 +7,49 @@ from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension + +def detect_cuda_arch(): + """Auto-detect GPU compute capability. Returns a semicolon-separated arch list. + Falls back to a safe default when no GPU is available.""" + try: + import torch + if torch.cuda.is_available(): + archs = set() + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + archs.add(f"{major}.{minor}") + if archs: + arch_list = ";".join(sorted(archs)) + print(f"Auto-detected GPU architectures: {arch_list}") + return arch_list + except Exception as e: + print(f"GPU architecture auto-detection failed: {e}") + # Fallback: common architectures (Ampere + Hopper) + fallback = "8.0;8.6;9.0" + print(f"No GPU detected, using fallback architectures: {fallback}") + return fallback + def get_version(): - with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: - return f.read().strip() + import subprocess + try: + # e.g. "v1.0.0-0-gabc1234" or "v1.0.0-3-gabc1234" + raw = subprocess.check_output( + ["git", "describe", "--tags", "--long", "--match", "v*"], + stderr=subprocess.PIPE, + cwd=os.path.dirname(os.path.abspath(__file__)), + ).decode().strip() + # parse: v1.0.0--g + parts = raw.rsplit("-", 2) + if len(parts) != 3: + raise ValueError(f"Unexpected git describe output format: {raw!r}") + tag, distance, git_hash = parts + tag = tag.lstrip("v") + if distance == "0": + return tag # clean release + else: + return f"{tag}+git{git_hash[1:]}" # dev build + except Exception: + return "0.0.0+unknown" build_dir = "build" os.makedirs(build_dir, exist_ok=True) @@ -34,6 +74,7 @@ def get_version(): "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", "csrc/radix_tree.cpp", + "csrc/layerwise.cpp", "csrc/monitoring/metrics_manager.cpp", # Monitoring support ] @@ -42,6 +83,7 @@ def get_version(): "csrc/tp_transfer_thread_group.h", "csrc/transfer_ssd.h", "csrc/radix_tree.h", + "csrc/layerwise.h", "csrc/monitoring/metrics_manager.h", # Monitoring support ] @@ -61,10 +103,10 @@ def get_version(): extra_link_args.extend(["-lprometheus-cpp-pull", "-lprometheus-cpp-core"]) else: print("FLEXKV_ENABLE_METRICS=0: building without Prometheus monitoring") -# If TORCH_CUDA_ARCH_LIST is not set, default to known supported archs -# to avoid auto-detection failure on newer GPUs (e.g. Blackwell sm_100) +# Auto-detect GPU architecture if TORCH_CUDA_ARCH_LIST is not explicitly set if not os.environ.get("TORCH_CUDA_ARCH_LIST"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" + os.environ["TORCH_CUDA_ARCH_LIST"] = detect_cuda_arch() +print(f"TORCH_CUDA_ARCH_LIST = {os.environ['TORCH_CUDA_ARCH_LIST']}") extra_compile_args = ["-std=c++17", "-O3"] if enable_metrics: diff --git a/tests/hugepage/conftest.py b/tests/hugepage/conftest.py new file mode 100644 index 0000000000..88bdd3f94f --- /dev/null +++ b/tests/hugepage/conftest.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from contextlib import suppress + +import pytest +import torch + +from flexkv.storage.allocator import alloc_hugepage_tensor +from flexkv.transfer import host_buffer + + +def alloc_hugepage_or_skip( + num_elements: int, + dtype: torch.dtype, + page_size_bytes: int, +) -> torch.Tensor: + try: + return alloc_hugepage_tensor( + num_elements=num_elements, + dtype=dtype, + page_size_bytes=page_size_bytes, + ) + except Exception as e: + pytest.skip(f"hugepage allocation failed: {e}") + + +def cuda_ops_or_skip() -> tuple: + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return host_buffer.cudaHostRegister, host_buffer.cudaHostUnregister + + +def unregister_suppress(tensor: torch.Tensor) -> None: + with suppress(Exception): + host_buffer.cudaHostUnregister(tensor) diff --git a/tests/hugepage/test_hugepage_transfer_e2e.py b/tests/hugepage/test_hugepage_transfer_e2e.py new file mode 100644 index 0000000000..faa80d722f --- /dev/null +++ b/tests/hugepage/test_hugepage_transfer_e2e.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import ctypes +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch + +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.storage.allocator import ( + DEFAULT_HUGE_PAGE_SIZE, + free_hugepage_tensor, +) +from tests.hugepage.conftest import ( + alloc_hugepage_or_skip, + cuda_ops_or_skip, + unregister_suppress, +) + +PAGE = DEFAULT_HUGE_PAGE_SIZE +_NUM_LAYERS = 1 +_NUM_BLOCKS = 4 +_TOKENS_PER_BLOCK = 16 +_NUM_HEADS = 8 +_HEAD_SIZE = 128 +_DTYPE = torch.bfloat16 +_ELEM_SIZE = _DTYPE.itemsize + +_CPU_LAYOUT = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=_NUM_LAYERS, + num_block=_NUM_BLOCKS, + tokens_per_block=_TOKENS_PER_BLOCK, + num_head=_NUM_HEADS, + head_size=_HEAD_SIZE, + is_mla=False, +) + +_CHUNK = _CPU_LAYOUT.get_chunk_size() +_BLOCK_STRIDE = _CPU_LAYOUT.get_block_stride() +_KV_STRIDE = _CPU_LAYOUT.get_kv_stride() +_LAYER_STRIDE = _CPU_LAYOUT.get_layer_stride() + + +def _ensure_c_ext(): + try: + from flexkv.c_ext import SSDIOCTX, transfer_kv_blocks_ssd + except ImportError: + pytest.skip("c_ext not built or SSD support disabled") + return SSDIOCTX, transfer_kv_blocks_ssd + + +def _ssd_layout_for(num_blocks_per_file: int) -> KVCacheLayout: + return KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=_NUM_LAYERS, + num_block=num_blocks_per_file, + tokens_per_block=_TOKENS_PER_BLOCK, + num_head=_NUM_HEADS, + head_size=_HEAD_SIZE, + is_mla=False, + ) + + +def _verify_ssd_read( + buf: np.ndarray, + pattern: np.ndarray, + kv_stride_bytes: int, + layer_stride_bytes: int, + chunk_bytes: int, + num_blocks: int, + kv_dim: int, +) -> None: + pattern_u8 = pattern.view(np.uint8) + lid = 0 + for bid in range(num_blocks): + for kv in range(kv_dim): + buf_start = lid * layer_stride_bytes + kv * kv_stride_bytes + bid * chunk_bytes + pat_start = buf_start + actual = buf[buf_start:buf_start + chunk_bytes] + expected = pattern_u8[pat_start:pat_start + chunk_bytes] + assert np.array_equal(actual, expected), ( + f"block {bid} {'K' if kv == 0 else 'V'} mismatch " + f"at offset={buf_start}, size={chunk_bytes}" + ) + + +def test_hugepage_ssd_to_gpu_roundtrip() -> None: + SSDIOCTX, transfer_kv_blocks_ssd = _ensure_c_ext() + cudaHostRegister, _ = cuda_ops_or_skip() + + num_blocks_per_file = _NUM_BLOCKS + kv_dim = 2 + _ssd_layout_for(num_blocks_per_file) + + chunk_bytes = _CHUNK * _ELEM_SIZE + block_stride_bytes = _BLOCK_STRIDE * _ELEM_SIZE + kv_stride_bytes = _KV_STRIDE * _ELEM_SIZE + layer_stride_bytes = _LAYER_STRIDE * _ELEM_SIZE + cpu_chunk_bytes = chunk_bytes + cpu_kv_stride_bytes = kv_stride_bytes + cpu_layer_stride_bytes = layer_stride_bytes + ssd_kv_stride_bytes = kv_stride_bytes + ssd_layer_stride_bytes = layer_stride_bytes + ssd_chunk_bytes = chunk_bytes + ssd_block_stride_bytes = block_stride_bytes + file_size = kv_dim * num_blocks_per_file * chunk_bytes + + pattern = np.arange(file_size // 2, dtype=np.int16) + pattern_bytes = pattern.view(np.uint8) + + tmpdir = Path(tempfile.mkdtemp(prefix="flexkv_e2e_")) + ssd_path = tmpdir / "ssd_0.bin" + pattern_bytes.tofile(ssd_path) + + hugepage_tensor = alloc_hugepage_or_skip( + _CPU_LAYOUT.get_total_elements(), + _DTYPE, + PAGE, + ) + ptr = hugepage_tensor.data_ptr() + needs_unpin = False + + try: + cudaHostRegister(hugepage_tensor) + needs_unpin = True + + ioctx = SSDIOCTX({0: [str(ssd_path)]}, 1, 0, 0) + layer_ids = torch.arange(0, _NUM_LAYERS, dtype=torch.int32) + ssd_block_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64) + cpu_block_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64) + + transfer_kv_blocks_ssd( + ioctx=ioctx, + cpu_layer_id_list=layer_ids, + cpu_tensor_ptr=hugepage_tensor.data_ptr(), + ssd_block_ids=ssd_block_ids, + cpu_block_ids=cpu_block_ids, + cpu_layer_stride_in_bytes=cpu_layer_stride_bytes, + cpu_kv_stride_in_bytes=cpu_kv_stride_bytes, + ssd_layer_stride_in_bytes=ssd_layer_stride_bytes, + ssd_kv_stride_in_bytes=ssd_kv_stride_bytes, + chunk_size_in_bytes=ssd_chunk_bytes, + block_stride_in_bytes=ssd_block_stride_bytes, + is_read=True, + num_blocks_per_file=num_blocks_per_file, + round_robin=1, + num_threads_per_device=1, + is_mla=False, + ) + + buf_np = np.frombuffer( + (ctypes.c_uint8 * (hugepage_tensor.numel() * _ELEM_SIZE)).from_address(ptr), + dtype=np.uint8, + ) + _verify_ssd_read( + buf_np, + pattern, + cpu_kv_stride_bytes, + cpu_layer_stride_bytes, + cpu_chunk_bytes, + num_blocks_per_file, + kv_dim, + ) + + gpu_tensor = torch.empty_like(hugepage_tensor, device="cuda") + gpu_tensor.copy_(hugepage_tensor, non_blocking=True) + torch.cuda.synchronize() + + roundtrip = torch.empty_like(hugepage_tensor) + roundtrip.copy_(gpu_tensor, non_blocking=True) + torch.cuda.synchronize() + + assert torch.equal( + hugepage_tensor.view(torch.int16), + roundtrip.view(torch.int16), + ) + finally: + if needs_unpin: + unregister_suppress(hugepage_tensor) + free_hugepage_tensor(hugepage_tensor) + shutil.rmtree(tmpdir) diff --git a/tests/hugepage/test_hugepage_unit.py b/tests/hugepage/test_hugepage_unit.py new file mode 100644 index 0000000000..48de07092f --- /dev/null +++ b/tests/hugepage/test_hugepage_unit.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import os + +import pytest +import torch + +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.transfer import DeviceType +from flexkv.storage.storage_engine import StorageEngine +from flexkv.storage.allocator import ( + DEFAULT_HUGE_PAGE_SIZE, + HugePageTensorHandle, + HugePageAllocator, + _live_hugepage_mappings, + alloc_hugepage_tensor, + free_hugepage_tensor, + get_worker_hugepage_handle, + materialize_worker_tensor, +) +from tests.hugepage.conftest import ( + alloc_hugepage_or_skip, + cuda_ops_or_skip, + unregister_suppress, +) + +PAGE = DEFAULT_HUGE_PAGE_SIZE + + +def test_basic_alloc_free() -> None: + n_bytes = 16 * 1024 * 1024 + n_elem = n_bytes // 2 + tensor = alloc_hugepage_or_skip(n_elem, torch.bfloat16, PAGE) + addr = tensor.data_ptr() + + try: + assert isinstance(tensor, torch.Tensor) + assert tensor.numel() == n_elem + assert tensor.dtype == torch.bfloat16 + assert tensor.device.type == "cpu" + assert addr != 0 + assert addr % PAGE == 0 + assert addr in _live_hugepage_mappings + + tensor.view(torch.int16).fill_(0x5A5A) + assert int(tensor.view(torch.int16)[0].item()) == 0x5A5A + finally: + free_hugepage_tensor(tensor) + + assert addr not in _live_hugepage_mappings + + +def test_invalid_args() -> None: + with pytest.raises(ValueError): + alloc_hugepage_tensor(0, torch.float32, page_size_bytes=PAGE) + + with pytest.raises(ValueError): + alloc_hugepage_tensor(1, torch.float32, page_size_bytes=PAGE + 1) + + +def test_non_hugetlbfs_fallback_is_rejected(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: + monkeypatch.setenv("FLEXKV_HUGETLBFS_DIR", str(tmp_path)) + + with pytest.raises(RuntimeError, match="is not a hugetlbfs mount"): + alloc_hugepage_tensor(1024, torch.float32, page_size_bytes=1024 * 1024 * 1024) + + +def test_worker_hugepage_handle_round_trip() -> None: + tensor = alloc_hugepage_or_skip(1024 * 1024, torch.bfloat16, PAGE) + + try: + handle = get_worker_hugepage_handle(tensor, tensor.numel(), tensor.dtype) + if handle is None: + pytest.skip("non-shareable hugepage allocation path on this host") + + rebuilt = materialize_worker_tensor(handle) + rebuilt.view(torch.int16)[0] = 0x1234 + + assert isinstance(handle, HugePageTensorHandle) + assert int(tensor.view(torch.int16)[0].item()) == 0x1234 + free_hugepage_tensor(rebuilt) + finally: + free_hugepage_tensor(tensor) + + +def test_hugepage_allocator_fallback() -> None: + layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=2, + num_block=64, + tokens_per_block=16, + num_head=8, + head_size=128, + is_mla=False, + ) + dtype = torch.bfloat16 + page_size = 1024 * 1024 * 1024 + old_dir = os.environ.get("FLEXKV_HUGETLBFS_DIR") + os.environ["FLEXKV_HUGETLBFS_DIR"] = "/nonexistent/flexkv_hugetlbfs" + try: + handle = HugePageAllocator.allocate( + layout=layout, + dtype=dtype, + page_size_bytes=page_size, + ) + finally: + if old_dir is None: + os.environ.pop("FLEXKV_HUGETLBFS_DIR", None) + else: + os.environ["FLEXKV_HUGETLBFS_DIR"] = old_dir + + tensor = handle.get_tensor() + assert isinstance(tensor, torch.Tensor) + assert tensor.numel() == layout.get_total_elements() + assert tensor.dtype == dtype + assert tensor.data_ptr() not in _live_hugepage_mappings + HugePageAllocator.free(handle) + + +def test_cuda_host_register() -> None: + cudaHostRegister, _ = cuda_ops_or_skip() + tensor = alloc_hugepage_or_skip(1024 * 1024, torch.bfloat16, PAGE) + + try: + cudaHostRegister(tensor) + + gpu_tensor = torch.empty_like(tensor, device="cuda") + tensor.fill_(1.25) + gpu_tensor.copy_(tensor, non_blocking=True) + torch.cuda.synchronize() + out = torch.empty_like(tensor) + out.copy_(gpu_tensor, non_blocking=True) + torch.cuda.synchronize() + assert torch.all(out == 1.25).item() + finally: + unregister_suppress(tensor) + free_hugepage_tensor(tensor) + + +def test_host_buffer_release_is_idempotent() -> None: + from flexkv.transfer.host_buffer import allocate_host_buffer + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available for pinned host buffer allocation") + + handle = allocate_host_buffer( + num_elements=1024, + dtype=torch.bfloat16, + use_hugepage=False, + hugepage_size_bytes=PAGE, + ) + + handle.release() + handle.release() + + assert not handle.is_hugepage + assert not handle.is_cuda_registered + + +def test_storage_engine_cpu_cache_uses_hugepage_when_enabled() -> None: + if not os.path.isdir(os.environ.get("FLEXKV_HUGETLBFS_DIR", "/mnt/hugepages")): + pytest.skip("hugetlbfs mount not available") + + model_config = ModelConfig( + num_layers=1, + num_kv_heads=1, + head_size=128, + use_mla=False, + dtype=torch.bfloat16, + ) + cache_config = CacheConfig( + enable_cpu=True, + enable_ssd=False, + num_cpu_blocks=8, + tokens_per_block=16, + use_hugepage_cpu_buffer=True, + hugepage_size_bytes=PAGE, + ) + + storage_engine = StorageEngine(model_config, cache_config) + cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) + cpu_tensor = cpu_handle.get_tensor() + + try: + if cpu_tensor.data_ptr() not in _live_hugepage_mappings: + pytest.skip("hugepage CPU cache allocation fell back on this host") + assert cpu_tensor.data_ptr() in _live_hugepage_mappings + worker_tensor = cpu_handle.get_worker_tensor() + assert isinstance(worker_tensor, HugePageTensorHandle) + finally: + HugePageAllocator.free(cpu_handle) + + +def test_storage_engine_cpu_cache_falls_back_when_hugepage_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: + model_config = ModelConfig( + num_layers=1, + num_kv_heads=1, + head_size=128, + use_mla=False, + dtype=torch.bfloat16, + ) + cache_config = CacheConfig( + enable_cpu=True, + enable_ssd=False, + num_cpu_blocks=8, + tokens_per_block=16, + use_hugepage_cpu_buffer=True, + hugepage_size_bytes=1024 * 1024 * 1024, + ) + monkeypatch.setenv("FLEXKV_HUGETLBFS_DIR", "/nonexistent/flexkv_hugetlbfs") + + storage_engine = StorageEngine(model_config, cache_config) + cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) + cpu_tensor = cpu_handle.get_tensor() + + assert cpu_tensor.data_ptr() not in _live_hugepage_mappings + HugePageAllocator.free(cpu_handle) diff --git a/tests/hugepage/test_hugepage_worker_integration.py b/tests/hugepage/test_hugepage_worker_integration.py new file mode 100644 index 0000000000..abab51821d --- /dev/null +++ b/tests/hugepage/test_hugepage_worker_integration.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import gc +from pathlib import Path +from unittest.mock import patch + +import pytest +import torch + +from flexkv.storage.allocator import ( + DEFAULT_HUGE_PAGE_SIZE, + alloc_hugepage_tensor, + free_hugepage_tensor, +) +from tests.hugepage.conftest import cuda_ops_or_skip, unregister_suppress + +PAGE = DEFAULT_HUGE_PAGE_SIZE +_NUM_PAGES = 8 +_NUM_BYTES = _NUM_PAGES * PAGE +_NUM_ELEMENTS = _NUM_BYTES // 2 + + +def _read_meminfo_hugepages() -> tuple[int, int, int]: + total = free = size_kb = 0 + with open("/proc/meminfo", encoding="utf-8") as f: + for line in f: + if line.startswith("HugePages_Total:"): + total = int(line.split()[1]) + elif line.startswith("HugePages_Free:"): + free = int(line.split()[1]) + elif line.startswith("Hugepagesize:"): + size_kb = int(line.split()[1]) + return total, free, size_kb * 1024 + + +def _require_hugepages(num_pages: int) -> tuple[int, int]: + total, free, _ = _read_meminfo_hugepages() + if total < num_pages: + pytest.skip(f"need at least {num_pages} huge pages") + return total, free + + +def _simulate_tmp_cpu_buffer_init(tmp_num_elements: int) -> tuple[torch.Tensor, bool]: + tmp_cpu_buffer = torch.empty( + tmp_num_elements, + dtype=torch.bfloat16, + device="cpu", + pin_memory=True, + ) + needs_unpin = False + hugepage_buf = None + cudaHostRegister, _ = cuda_ops_or_skip() + try: + hugepage_buf = alloc_hugepage_tensor( + num_elements=tmp_num_elements, + dtype=torch.bfloat16, + page_size_bytes=PAGE, + ) + cudaHostRegister(hugepage_buf) + except Exception: + if hugepage_buf is not None: + free_hugepage_tensor(hugepage_buf) + else: + tmp_cpu_buffer = hugepage_buf + needs_unpin = True + + return tmp_cpu_buffer, needs_unpin + + +class MockMooncakeEngine: + def __init__(self) -> None: + self.registered: set[int] = set() + + def regist_buffer(self, ptr: int, size: int) -> int: + assert ptr != 0 + assert size > 0 + self.registered.add(ptr) + return 0 + + def unregist_buffer(self, ptr: int) -> int: + assert ptr in self.registered + self.registered.discard(ptr) + return 0 + + +def test_full_lifecycle_hugepage() -> None: + _, free_before = _require_hugepages(_NUM_PAGES) + mooncake = MockMooncakeEngine() + tmp_cpu_buffer, needs_unpin = _simulate_tmp_cpu_buffer_init(_NUM_ELEMENTS) + + mooncake.regist_buffer( + tmp_cpu_buffer.data_ptr(), + tmp_cpu_buffer.numel() * tmp_cpu_buffer.element_size(), + ) + + _, free_after_alloc, _ = _read_meminfo_hugepages() + consumed = free_before - free_after_alloc + if needs_unpin: + assert consumed == _NUM_PAGES + assert len(mooncake.registered) == 1 + else: + assert consumed == 0 + + tmp_cpu_buffer.view(torch.int16).fill_(0x7B7B) + assert int(tmp_cpu_buffer.view(torch.int16)[0].item()) == 0x7B7B + + mooncake.unregist_buffer(tmp_cpu_buffer.data_ptr()) + if needs_unpin: + unregister_suppress(tmp_cpu_buffer) + free_hugepage_tensor(tmp_cpu_buffer) + + assert len(mooncake.registered) == 0 + + del tmp_cpu_buffer + gc.collect() + + _, free_after_free, _ = _read_meminfo_hugepages() + assert free_after_free == free_before + + +def test_fallback_when_cuda_host_register_fails() -> None: + _, free_before = _require_hugepages(_NUM_PAGES) + + with patch( + "flexkv.transfer.host_buffer.cudaHostRegister", + side_effect=RuntimeError("injected cudaHostRegister failure"), + ): + tmp_cpu_buffer, needs_unpin = _simulate_tmp_cpu_buffer_init(_NUM_ELEMENTS) + assert not needs_unpin + + mooncake = MockMooncakeEngine() + mooncake.regist_buffer( + tmp_cpu_buffer.data_ptr(), + tmp_cpu_buffer.numel() * tmp_cpu_buffer.element_size(), + ) + mooncake.unregist_buffer(tmp_cpu_buffer.data_ptr()) + + del tmp_cpu_buffer + gc.collect() + + _, free_after, _ = _read_meminfo_hugepages() + assert free_after == free_before diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index 2887d0723b..ba7f557820 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -147,7 +147,6 @@ def parse_config_event(self, event: Dict[str, Any]): num_remote_blocks=cache_config_data['num_remote_blocks'], ssd_cache_dir=cache_config_data['ssd_cache_dir'], gds_cache_dir=cache_config_data['gds_cache_dir'], - remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], remote_file_size=cache_config_data['remote_file_size'], remote_file_num=cache_config_data['remote_file_num'], remote_file_prefix=cache_config_data['remote_file_prefix'], @@ -234,7 +233,8 @@ def register_gpu_blocks_to_kvmanager(self, gpu_register_port: str): # Create registration request register_req = RegisterTPClientRequest( - dp_client_id=gpu_id // self.model_config.tp_size, # DP client ID + dp_rank=gpu_id // self.model_config.tp_size, # DP client ID + pp_rank=0, # single PP stage for replay device_id=gpu_id, handles=handles, gpu_layout=self.gpu_layout @@ -305,7 +305,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: slot_mapping = np.array(data['slot_mapping'], dtype=np.int64) token_mask = np.array(data['token_mask'], dtype=bool) if data['token_mask'] else None layer_granularity = data.get('layer_granularity', -1) - dp_id = data.get('dp_id', 0) + dp_rank = data.get('dp_id', 0) + pp_rank = data.get('pp_rank', 0) self.log(f"Replaying {request_type} request with {len(token_ids)} tokens") @@ -319,7 +320,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: slot_mapping=slot_mapping, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) elif request_type == "PUT": print(f"✅✅✅PUT token_ids: {token_ids[:128]}") @@ -330,7 +332,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: token_ids=token_ids, slot_mapping=slot_mapping, token_mask=token_mask, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) elif request_type == "GET_MATCH": print(f"🔍📝GET_MATCH token_ids: {token_ids[:128]}") @@ -341,7 +344,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: token_ids=token_ids, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) elif request_type == "PUT_MATCH": print(f"✅📝PUT_MATCH token_ids: {token_ids[:128]}") @@ -351,7 +355,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: task_id, return_mask = self.kvmanager.put_match( token_ids=token_ids, token_mask=token_mask, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) else: raise ValueError(f"Unknown request type: {request_type}") diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 3fe2365c69..012d91a1a4 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -176,8 +176,9 @@ def test_mempool(): with pytest.raises(ValueError): mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) - # recycle_blocks no longer raises ValueError for already free blocks - mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) + # Recycle already free blocks raises + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) assert mempool.num_free_blocks == DEFAULT_NUM_TOTAL_BLOCKS # Recycle wrong ndim raises @@ -756,3 +757,69 @@ def test_eviction_policy_reinsert_after_eviction(engine_cls): assert engine.match(seqs[2]).num_matched_blocks == 0, ( "C should be evicted to make room for re-inserted B" ) + + +# --------------------------------------------------------------------------- +# Tests – MatchResultAccel matched_node_id field +# --------------------------------------------------------------------------- +class TestMatchResultAccelNodeId: + """Verify the single-node matching constraint data structures.""" + + def test_matched_node_id_default_none(self): + """matched_node_id defaults to None when not set.""" + from flexkv.common.type import MatchResultAccel + result = MatchResultAccel( + num_ready_matched_blocks=0, + num_matched_blocks=0, + physical_blocks=np.array([], dtype=np.int64), + ) + assert result.matched_node_id is None + + def test_matched_node_id_set(self): + """matched_node_id can be set to a single integer.""" + from flexkv.common.type import MatchResultAccel + result = MatchResultAccel( + num_ready_matched_blocks=5, + num_matched_blocks=5, + physical_blocks=np.arange(5, dtype=np.int64), + matched_node_id=42, + ) + assert result.matched_node_id == 42 + assert isinstance(result.matched_node_id, int) + + def test_backward_compat_block_node_ids(self): + """block_node_ids (deprecated) still works alongside matched_node_id.""" + from flexkv.common.type import MatchResultAccel + bnids = np.array([42, 42, 42], dtype=np.uint32) + result = MatchResultAccel( + num_ready_matched_blocks=3, + num_matched_blocks=3, + physical_blocks=np.arange(3, dtype=np.int64), + matched_node_id=42, + block_node_ids=bnids, + ) + assert result.matched_node_id == 42 + assert np.all(result.block_node_ids == 42) + + +# --------------------------------------------------------------------------- +# Tests – CMatchResult matched_node_id field (C++ binding) +# --------------------------------------------------------------------------- +class TestCMatchResultNodeId: + """Verify the C++ CMatchResult exposes matched_node_id.""" + + def test_cmatch_result_default_node_id(self): + """CMatchResult.matched_node_id defaults to -1.""" + import torch + from flexkv.c_ext import CMatchResult + result = CMatchResult(0, 0, 0, None, None, torch.empty(0, dtype=torch.int64)) + assert result.matched_node_id == -1 + + def test_cmatch_result_with_node_id(self): + """CMatchResult.matched_node_id can be set via constructor.""" + import torch + from flexkv.c_ext import CMatchResult + blocks = torch.arange(3, dtype=torch.int64) + result = CMatchResult(3, 3, 0, None, None, blocks, 7) + assert result.matched_node_id == 7 + assert result.physical_blocks.shape[0] == 3 diff --git a/tests/test_config_hugepage.py b/tests/test_config_hugepage.py new file mode 100644 index 0000000000..1a14291ab8 --- /dev/null +++ b/tests/test_config_hugepage.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from flexkv.common.config import ( + CacheConfig, + ModelConfig, + UserConfig, + load_user_config_from_env, + update_default_config_from_user_config, +) + + +def test_load_user_config_from_env_reads_hugepage_flags(monkeypatch) -> None: + monkeypatch.setenv("FLEXKV_USE_HUGEPAGE_CPU_BUFFER", "1") + monkeypatch.setenv("FLEXKV_USE_HUGEPAGE_TMP_BUFFER", "1") + monkeypatch.setenv("FLEXKV_HUGEPAGE_SIZE_BYTES", str(1 << 30)) + + user_config = load_user_config_from_env() + + assert user_config.use_hugepage_cpu_buffer is True + assert user_config.use_hugepage_tmp_buffer is True + assert user_config.hugepage_size_bytes == 1 << 30 + + +def test_update_default_config_from_user_config_applies_hugepage_flags() -> None: + model_config = ModelConfig( + num_layers=1, + num_kv_heads=1, + head_size=128, + use_mla=False, + ) + cache_config = CacheConfig() + user_config = UserConfig( + cpu_cache_gb=16, + ssd_cache_gb=0, + use_hugepage_cpu_buffer=True, + use_hugepage_tmp_buffer=True, + hugepage_size_bytes=1 << 30, + ) + + update_default_config_from_user_config(model_config, cache_config, user_config) + + assert cache_config.use_hugepage_cpu_buffer is True + assert cache_config.use_hugepage_tmp_buffer is True + assert cache_config.hugepage_size_bytes == 1 << 30 diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index aef87bba8d..eff1c48e2c 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -8,13 +8,15 @@ import multiprocessing as mp from multiprocessing import Process, Pipe -from flexkv.common.config import ModelConfig, CacheConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.request import KVResponseStatus from flexkv.kvtask import KVTaskEngine from flexkv.kvmanager import KVManager from flexkv.common.memory_handle import TensorSharedHandle from flexkv.server.client import KVTPClient +import traceback + from flexkv.common.debug import flexkv_logger # Import utilities from common_utils @@ -38,7 +40,7 @@ def _fp8_cuda_ops_unavailable(): except NotImplementedError: return True -def run_tp_client(dp_client_id, +def run_tp_client(dp_rank, tp_rank, server_recv_port, model_config, @@ -48,8 +50,8 @@ def run_tp_client(dp_client_id, gpu_layout_type): """Run tp_client process""" try: - device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) + device_id = tp_rank + dp_rank * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_rank, pp_rank=0, device_id=device_id) gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) @@ -240,14 +242,14 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): initial_write_num = int(num_requests * initial_write_ratio) print("writing initial data...") put_ids = [] - for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + for token_ids, block_ids, dp_rank in request_pairs[:initial_write_num]: if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) kvmanager.wait([write_request], completely=True) @@ -259,7 +261,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): token_ids=torch.randint(0, 100, size=(8,), dtype=torch.int64), slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), token_mask=None, - dp_id=0, + dp_rank=0, namespace=namespace, ) kvmanager.wait([write_request], completely=True) @@ -270,7 +272,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): # token_ids=torch.randint(0, 100, size=(16,), dtype=torch.int64), # slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), # token_mask=my_mask, - # dp_id=0, + # dp_rank=0, #) #kvmanager.wait_for_graph_finished(write_request) @@ -287,13 +289,13 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): for i in range(initial_write_num, num_requests): print(f"performing mixed read/write {i} / {num_requests} ...") read_idx = i - initial_write_num - token_ids, block_ids, dp_id = request_pairs[read_idx] + token_ids, block_ids, dp_rank = request_pairs[read_idx] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) request_id, _ = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) kvmanager.launch(request_id, slot_mapping) @@ -301,14 +303,14 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): running_get_requests.append(request_id) req_id2block_ids[request_id] = block_ids req_id2token_ids[request_id] = token_ids - token_ids, block_ids, dp_id = request_pairs[i] + token_ids, block_ids, dp_rank = request_pairs[i] if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) request_id = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) req_id2block_ids[request_id] = block_ids @@ -366,31 +368,31 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): # =============== Test batched launched get =============== if not enable_gds: print("\n========== Testing batched launched get ==========") - + # Use the first few request_pairs that were written in initial phase batch_size = 6 - + batched_get_task_ids = [] batched_slot_mappings = [] batched_req_info = [] # Store (token_ids, block_ids) for verification - + # Create multiple get_match requests for i in range(batch_size): - token_ids, block_ids, dp_id = request_pairs[random.randint(0, num_requests - 1)] + token_ids, block_ids, dp_rank = request_pairs[random.randint(0, num_requests - 1)] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) - + request_id, return_mask = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) batched_get_task_ids.append(request_id) batched_slot_mappings.append(slot_mapping) batched_req_info.append((token_ids, block_ids, request_id)) print(f"Created get_match request {request_id} for request_pair[{i}]") - + # Launch all get requests as a batch print(f"Launching {len(batched_get_task_ids)} get requests as batch...") batch_id = kvmanager.launch( @@ -399,12 +401,12 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): as_batch=True )[0] print(f"Returned task_ids after batch launch: {batch_id}") - + # Wait for the batched get to complete # When as_batch=True, launch returns [batch_id], we need to wait on batch_id batch_results = kvmanager.wait(batch_id, completely=True) print(f"Batch wait returned {len(batch_results)} results") - + # Verify results batched_cache_hit = 0 batched_cache_miss = 0 @@ -415,7 +417,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): batched_cache_hit += return_mask.sum().item() batched_cache_miss += len(return_mask) - return_mask.sum().item() print(f"Task {batch_id}: cache_hit={batched_cache_hit}, cache_miss={batched_cache_miss}") - + # GPU KV cache verification for batched get if gpu_kv_verifier is not None: for idx, (token_ids, block_ids, req_id) in enumerate(batched_req_info): @@ -429,9 +431,9 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): token_ids[:valid_fetched_tokens], block_ids[:valid_fetched_tokens // tokens_per_block] ) - + print(f"Batched get test completed: hit={batched_cache_hit}, miss={batched_cache_miss}") - + # Since we read data that was written before, cache hit should be high if enable_cpu and num_cpu_blocks >= num_gpu_blocks: assert batched_cache_miss == 0, \ @@ -455,3 +457,662 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): return elif total_cache_miss > 0: print(f"verify skipped, because of total_cache_miss={total_cache_miss} > 0") + + +class GPUIndexerCacheVerifier: + def __init__(self, + shared_indexer_blocks, + indexer_kv_layout: KVCacheLayout, + tp_size: int, + dtype: torch.dtype) -> None: + if not shared_indexer_blocks: + raise ValueError("shared_indexer_blocks must not be empty") + + if isinstance(shared_indexer_blocks[0][0], torch.Tensor): + self.gpu_blocks = shared_indexer_blocks + else: + imported_gpu_blocks = [] + for handles_in_one_gpu in shared_indexer_blocks: + imported_gpu_blocks.append([handle.get_tensor() for handle in handles_in_one_gpu]) + self.gpu_blocks = imported_gpu_blocks + + self.num_layers = indexer_kv_layout.num_layer + self.tokens_per_block = indexer_kv_layout.tokens_per_block + self.head_size = indexer_kv_layout.head_size + self.tp_size = tp_size + self.dtype = dtype + + def hash_all_values(self, layer_id, token_ids): + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + + token_hash = 0 + for i, token_id in enumerate(token_ids): + token_hash += int(token_id) * (i + 17) + return torch.tensor(((layer_id + 1) * 29 + token_hash) % 251 + 1, dtype=self.dtype).item() + + def fill_gpu_blocks(self, block_ids, main_kv_tokens_per_block, token_ids): + """Fill indexer GPU blocks with deterministic hash values. + + Indexer uses tokens_per_block=1 on CPU/SSD side. Each indexer block + corresponds to one main-KV block (1:1 page mapping). We hash the + *entire page* of token_ids from the main KV request to produce a + single deterministic value per (layer, block). + + Args: + block_ids: block IDs to fill (same as main KV block_ids). + main_kv_tokens_per_block: tokens_per_block of main KV (e.g. 16). + token_ids: full token_ids tensor from the request. + """ + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for tp_id in range(self.tp_size): + for layer_id in range(self.num_layers): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * main_kv_tokens_per_block + end_token_idx = start_token_idx + main_kv_tokens_per_block + hash_value = self.hash_all_values( + layer_id, + token_ids[start_token_idx:end_token_idx], + ) + # gpu_tensor shape: (num_blocks, tokens_per_block=1, head_size) + gpu_tensor[block_id, :, :] = hash_value + + def clear_gpu_blocks(self, block_ids): + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for tp_id in range(self.tp_size): + for layer_id in range(self.num_layers): + self.gpu_blocks[tp_id][layer_id][block_ids, :, :] = 0 + + def verify_gpu_blocks(self, block_ids, main_kv_tokens_per_block, token_ids) -> bool: + """Verify indexer GPU blocks after round-trip transfer. + + Args: + block_ids: block IDs to verify. + main_kv_tokens_per_block: tokens_per_block of main KV. + token_ids: full token_ids tensor from the request. + """ + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + verification_passed = True + errors = [] + + for tp_id in range(self.tp_size): + for layer_id in range(self.num_layers): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * main_kv_tokens_per_block + end_token_idx = start_token_idx + main_kv_tokens_per_block + expected_hash_value = self.hash_all_values( + layer_id, + token_ids[start_token_idx:end_token_idx], + ) + actual_values = gpu_tensor[block_id, :, :] + expected_tensor = torch.full_like(actual_values, expected_hash_value) + if not torch.equal(actual_values, expected_tensor): + verification_passed = False + max_abs_diff = ( + actual_values.to(torch.int32) - expected_tensor.to(torch.int32) + ).abs().max().item() + errors.append( + f"Mismatch at tp={tp_id}, layer={layer_id}, block={block_id}: " + f"expected={expected_hash_value}, max_abs_diff={max_abs_diff}" + ) + + if not verification_passed: + print(f"Indexer verification failed with {len(errors)} errors:") + for error in errors[:10]: + print(f" {error}") + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more errors") + else: + print("Indexer GPU blocks verification passed!") + assert verification_passed + return verification_passed + + +def run_tp_client_with_indexer(dp_rank, + tp_rank, + server_recv_port, + model_config, + cache_config, + num_gpu_blocks, + child_conn, + gpu_layout_type): + """Run tp_client process with indexer support (shadow transfer mode). + + Indexer configuration is read from cache_config.indexer (IndexerCacheConfig). + """ + try: + device_id = tp_rank + dp_rank * model_config.tp_size + + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) + + # Create main GPU blocks + gpu_blocks_for_tp = [] + if gpu_layout_type == 0: + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + elif gpu_layout_type == 2: + kv_dim = 1 if model_config.use_mla else 2 + for _ in range(model_config.num_layers * kv_dim): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[2:]), dtype=model_config.dtype).cuda(device_id) + ) + else: + raise ValueError(f"Invalid GPU layout type for indexer test: {gpu_layout_type}") + + # Derive indexer params from cache_config.indexer (IndexerCacheConfig). + # Indexer uses tokens_per_block=1 (one indexer entry per page/block), + # matching the CPU/SSD layout in StorageEngine. + indexer_cfg = cache_config.indexer + assert indexer_cfg is not None, "cache_config.indexer must be set for indexer shadow transfer tests" + indexer_tokens_per_block = 1 # indexer: 1 entry per page (not main KV tokens_per_block) + indexer_num_layers = model_config.num_layers + + # Create indexer GPU blocks (MLA-style: 3D tensors) + indexer_blocks = [] + for _ in range(indexer_num_layers): + indexer_blocks.append( + torch.empty( + num_gpu_blocks, + indexer_tokens_per_block, + indexer_cfg.head_size, + dtype=indexer_cfg.dtype, + ).cuda(device_id) + ) + + from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType + indexer_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=indexer_num_layers, + num_block=num_gpu_blocks, + tokens_per_block=indexer_tokens_per_block, + num_head=indexer_cfg.num_kv_heads, + head_size=indexer_cfg.head_size, + is_mla=True, + ) + + # Use KVTPClient directly with indexer buffers (shadow transfer mode) + tp_client = KVTPClient( + gpu_register_port=server_recv_port + "_gpu_register", + dp_rank=dp_rank, + pp_rank=0, + device_id=device_id, + ) + tp_client.register_to_server( + kv_caches=gpu_blocks_for_tp, + kv_layout=gpu_kv_layout, + indexer_buffers=indexer_blocks, + indexer_layout=indexer_layout, + ) + + # Send GPU blocks back to main process via pipe + if child_conn is not None: + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks_for_tp] + shared_indexer_blocks = [TensorSharedHandle(tensor) for tensor in indexer_blocks] + child_conn.send({ + "main": shared_gpu_blocks, + "indexer": shared_indexer_blocks, + }) + child_conn.close() + + # Keep the process running + while True: + time.sleep(1) + except Exception as e: + print(f"[TP Client {tp_rank}] Exception occurred: {type(e).__name__}: {str(e)}") + traceback.print_exc() + if child_conn is not None: + child_conn.send(None) + child_conn.close() + + +def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, test_label="indexer", layerwise=False): + """Core test logic for KVManager with indexer shadow transfer. + + Shared by test_kvmanager_with_indexer (non-layerwise) and + test_kvmanager_with_indexer_layerwise (layerwise mode). + """ + tp_size = model_config.tp_size + tokens_per_block = cache_config.tokens_per_block + num_gpu_blocks = test_config["num_gpu_blocks"] + block_per_request = test_config['requests_per_block'] + initial_write_ratio = test_config['initial_write_ratio'] + num_requests = num_gpu_blocks // block_per_request + + skip_if_insufficient_gpus(tp_size) + + from flexkv.common.config import IndexerCacheConfig + cache_config.indexer = IndexerCacheConfig( + head_size=64, + num_kv_heads=1, + dtype=torch.uint8, + ) + + kvmanager = KVManager( + model_config, + cache_config, + ) + kvmanager.start() + + mp_ctx = mp.get_context('spawn') + pipe_connections = [] + tp_client_processes = [] + + for tp_rank in range(tp_size): + parent_conn, child_conn = mp_ctx.Pipe() + pipe_connections.append(parent_conn) + + tp_client_process = mp_ctx.Process( + target=run_tp_client_with_indexer, + args=(0, tp_rank, kvmanager.server_recv_port, + model_config, cache_config, num_gpu_blocks, child_conn, + gpu_layout_type), + daemon=True + ) + tp_client_processes.append(tp_client_process) + tp_client_process.start() + + all_gpu_blocks = [] + all_indexer_blocks = [] + for tp_rank, parent_conn in enumerate(pipe_connections): + try: + shared_payload = parent_conn.recv() + if shared_payload is not None: + if isinstance(shared_payload, dict): + shared_gpu_blocks = shared_payload.get("main") + shared_indexer_blocks = shared_payload.get("indexer") + else: + shared_gpu_blocks = shared_payload + shared_indexer_blocks = None + if shared_gpu_blocks is not None: + all_gpu_blocks.append(shared_gpu_blocks) + print(f"[Main Process] Received GPU blocks from TP client {tp_rank}") + if shared_indexer_blocks is not None: + all_indexer_blocks.append(shared_indexer_blocks) + parent_conn.close() + except Exception as e: + print(f"[Main Process] Error receiving from TP client {tp_rank}: {e}") + + gpu_kv_verifier = None + if all_gpu_blocks and len(all_gpu_blocks) == tp_size: + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) + gpu_kv_verifier = GPUKVCacheVerifier( + shared_gpu_blocks=all_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + gpu_layout_type=gpu_layout_type, + ) + + indexer_kv_verifier = None + indexer_cfg = cache_config.indexer + if all_indexer_blocks and len(all_indexer_blocks) == tp_size and indexer_cfg is not None: + indexer_gpu_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=1, # indexer: 1 entry per page + num_head=indexer_cfg.num_kv_heads, + head_size=indexer_cfg.head_size, + is_mla=True, + ) + indexer_kv_verifier = GPUIndexerCacheVerifier( + shared_indexer_blocks=all_indexer_blocks, + indexer_kv_layout=indexer_gpu_layout, + tp_size=model_config.tp_size, + dtype=indexer_cfg.dtype, + ) + + while not kvmanager.is_ready(): + time.sleep(1) + flexkv_logger.info(f"waiting for flexkv ({test_label}) to be ready") + print(f"[Test] KVManager ({test_label}) is ready") + + request_pairs = [generate_request_pair(i, block_per_request, num_gpu_blocks, tokens_per_block, 1) + for i in range(num_requests)] + initial_write_num = int(num_requests * initial_write_ratio) + + print(f"[Test] Testing put flow ({test_label})...") + for token_ids, block_ids, dp_rank in request_pairs[:initial_write_num]: + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.fill_gpu_blocks(block_ids, tokens_per_block, token_ids) + write_request = kvmanager.put_async( + token_ids=token_ids, + slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, + dp_rank=dp_rank, + ) + put_results = kvmanager.wait([write_request], completely=True) + assert put_results[write_request].status == KVResponseStatus.SUCCESS + if gpu_kv_verifier is not None: + gpu_kv_verifier.clear_gpu_blocks(block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.clear_gpu_blocks(block_ids) + print(f"[Test] Initial {initial_write_num} put operations completed ({test_label})") + + print(f"[Test] Testing get flow ({test_label})...") + total_cache_hit = 0 + total_cache_miss = 0 + running_get_requests = [] + req_id2block_ids = {} + req_id2token_ids = {} + + batch_task_ids = [] + batch_slot_mappings = [] + + for i in range(min(initial_write_num, num_requests)): + token_ids, block_ids, dp_rank = request_pairs[i] + slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) + request_id, _ = kvmanager.get_match( + token_ids=token_ids, + layer_granularity=-1, + token_mask=None, + dp_rank=dp_rank, + ) + batch_task_ids.append(request_id) + batch_slot_mappings.append(slot_mapping) + req_id2block_ids[request_id] = block_ids + req_id2token_ids[request_id] = token_ids + + if layerwise: + # Layerwise mode: launch all GETs as a single batch so that + # merge_to_batch_graph produces a LAYERWISE op (fused DISK2H+H2D). + returned_ids = kvmanager.launch( + task_ids=batch_task_ids, + slot_mappings=batch_slot_mappings, + as_batch=True, + layerwise_transfer=True, + ) + batch_id = returned_ids[0] + batch_results = kvmanager.wait(batch_id, completely=True) + kvresponse = batch_results[batch_id] + assert kvresponse.status == KVResponseStatus.SUCCESS, \ + f"Layerwise batch GET failed: {kvresponse.status}" + for idx, orig_req_id in enumerate(batch_task_ids): + mask = kvresponse.return_mask[idx] + total_cache_hit += mask.sum().item() + total_cache_miss += len(mask) - mask.sum().item() + if gpu_kv_verifier is not None: + valid_fetched_tokens = mask.sum().item() // tokens_per_block * tokens_per_block + if valid_fetched_tokens > 0: + assert gpu_kv_verifier.verify_kv_blocks( + req_id2token_ids[orig_req_id][:valid_fetched_tokens], + req_id2block_ids[orig_req_id][:valid_fetched_tokens // tokens_per_block]) + if indexer_kv_verifier is not None: + valid_fetched_blocks = mask.sum().item() // tokens_per_block + if valid_fetched_blocks > 0: + assert indexer_kv_verifier.verify_gpu_blocks( + req_id2block_ids[orig_req_id][:valid_fetched_blocks], + tokens_per_block, + req_id2token_ids[orig_req_id][:valid_fetched_blocks * tokens_per_block]) + else: + # Non-layerwise: launch each GET individually + for req_id in batch_task_ids: + kvmanager.launch(req_id, batch_slot_mappings[batch_task_ids.index(req_id)]) + running_get_requests.append(req_id) + + if running_get_requests: + return_results = kvmanager.wait(running_get_requests, completely=True) + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + total_cache_hit += kvresponse.return_mask.sum().item() + total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() + if gpu_kv_verifier is not None: + valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + if valid_fetched_tokens > 0: + assert gpu_kv_verifier.verify_kv_blocks( + req_id2token_ids[req_id][:valid_fetched_tokens], + req_id2block_ids[req_id][:valid_fetched_tokens // tokens_per_block]) + if indexer_kv_verifier is not None: + valid_fetched_blocks = kvresponse.return_mask.sum().item() // tokens_per_block + if valid_fetched_blocks > 0: + assert indexer_kv_verifier.verify_gpu_blocks( + req_id2block_ids[req_id][:valid_fetched_blocks], + tokens_per_block, + req_id2token_ids[req_id][:valid_fetched_blocks * tokens_per_block]) + print(f"[Test] Get flow completed ({test_label}): hit={total_cache_hit}, miss={total_cache_miss}") + + print(f"[Test] Testing try_wait flow ({test_label})...") + if initial_write_num < num_requests: + token_ids, block_ids, dp_rank = request_pairs[initial_write_num] + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.fill_gpu_blocks(block_ids, tokens_per_block, token_ids) + write_request = kvmanager.put_async( + token_ids=token_ids, + slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, + dp_rank=dp_rank, + ) + finished = {} + for _ in range(200): + finished = kvmanager.try_wait([write_request]) + if write_request in finished: + break + time.sleep(0.1) + assert write_request in finished, "try_wait should eventually return the completed task" + assert finished[write_request].status == KVResponseStatus.SUCCESS + if gpu_kv_verifier is not None: + gpu_kv_verifier.clear_gpu_blocks(block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.clear_gpu_blocks(block_ids) + print(f"[Test] try_wait flow completed ({test_label})") + + # Cache miss assertion: when total capacity >= GPU blocks, expect 0 miss + enable_cpu = cache_config.enable_cpu + enable_ssd = cache_config.enable_ssd + num_cpu_blocks = cache_config.num_cpu_blocks + num_ssd_blocks = cache_config.num_ssd_blocks + if (enable_cpu and num_cpu_blocks >= num_gpu_blocks) or \ + (enable_ssd and num_ssd_blocks >= num_gpu_blocks): + assert total_cache_miss == 0, f"Expected 0 cache miss, got {total_cache_miss}" + + shutdown_tp_client(tp_client_processes) + kvmanager.shutdown() + print(f"[Test] {test_label} PASSED") + + +@pytest.mark.parametrize( + "model_config", + [ + {"tp_size": 1, "dp_size": 1}, + ], indirect=True, +) +@pytest.mark.parametrize("cache_config", [ + {'enable_cpu': True, 'enable_ssd': False, 'num_cpu_blocks': 1024}, + {'enable_cpu': True, 'enable_ssd': True, 'num_cpu_blocks': 256, 'num_ssd_blocks': 2048}, +], indirect=True) +@pytest.mark.parametrize("test_config", [ + {'num_gpu_blocks': 256, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, +], indirect=True) +@pytest.mark.parametrize("gpu_layout_type", [0]) +def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_layout_type): + """Test KVManager with indexer: GPU↔CPU (and optionally ↔SSD) data correctness.""" + ssd_label = "+ssd" if cache_config.enable_ssd else "" + _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, + test_label=f"indexer{ssd_label}") + + +import ctypes +import socket +import struct +import threading + +# ---- Mock SGLang eventfd client for layerwise unit tests ---- + +_libc = ctypes.CDLL("libc.so.6", use_errno=True) + + +def _sys_eventfd(initval: int = 0, flags: int = 0) -> int: + """Create an eventfd file descriptor via libc.""" + fd = _libc.eventfd(ctypes.c_uint(initval), ctypes.c_int(flags)) + if fd == -1: + err = ctypes.get_errno() + raise OSError(err, f"eventfd failed: {os.strerror(err)}") + return fd + + +_EFD_SEMAPHORE = 0x1 + + +def _send_fds_via_scm(sock: socket.socket, fds: list, extra_data: bytes = b"x"): + """Send fds via SCM_RIGHTS (mirrors SGLang's send_fds).""" + fds_packed = struct.pack(f"{len(fds)}i", *fds) + ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds_packed)] + sock.sendmsg([extra_data], ancdata) + + +def _mock_sglang_eventfd_client(socket_path: str, + tp_rank: int, + tp_size: int, + num_layers: int, + num_counters: int = 3, + max_retries: int = 120, + retry_interval: float = 0.5): + """Simulate SGLang sending eventfds to the LayerwiseTransferWorker. + + Runs in a background thread. Creates real eventfds so the C++ + LayerwiseTransferGroup receives valid file descriptors. The eventfds + are never read by anyone in the test, but that is fine: the C++ + ``enable_eventfd_`` flag will be ``true`` and ``eventfd_write`` will + simply increment the counter without blocking. + """ + created_fds = [] + try: + # Create real eventfds + for _ in range(num_counters * num_layers): + created_fds.append(_sys_eventfd(0, _EFD_SEMAPHORE)) + + # Retry connecting until the worker process binds the socket + sock = None + for attempt in range(max_retries): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(socket_path) + print(f"[MockEventfdClient] Connected to {socket_path} " + f"(attempt {attempt + 1})") + break + except (FileNotFoundError, ConnectionRefusedError): + sock.close() + sock = None + time.sleep(retry_interval) + + if sock is None: + print(f"[MockEventfdClient] FAILED to connect to {socket_path} " + f"after {max_retries} attempts") + return + + metadata = struct.pack("iiii", + tp_rank, tp_size, + num_layers, num_counters) + sock.sendall(metadata) + + # Send eventfds for each counter via SCM_RIGHTS + fd_idx = 0 + for counter_id in range(num_counters): + fds = created_fds[fd_idx:fd_idx + num_layers] + fd_idx += num_layers + _send_fds_via_scm(sock, fds, struct.pack("i", counter_id)) + + # Wait for ACK + sock.settimeout(30.0) + ack = sock.recv(1) + if ack and ack[0] == 1: + print(f"[MockEventfdClient] Eventfd handshake OK " + f"(counters={num_counters}, layers={num_layers})") + else: + print(f"[MockEventfdClient] Unexpected ACK: {ack!r}") + sock.close() + except Exception as e: + print(f"[MockEventfdClient] Error: {e}") + traceback.print_exc() + # Note: we intentionally do NOT close the eventfds here. + # They must remain valid for the lifetime of the LayerwiseTransferGroup + # in the worker subprocess. They will be cleaned up when the worker + # process exits and the OS reclaims the file descriptors. + + +@pytest.mark.parametrize( + "model_config", + [ + {"tp_size": 1, "dp_size": 1}, + ], indirect=True, +) +@pytest.mark.parametrize("cache_config", [ + {'enable_cpu': True, 'enable_ssd': False, 'num_cpu_blocks': 1024}, + {'enable_cpu': True, 'enable_ssd': True, 'num_cpu_blocks': 256, 'num_ssd_blocks': 2048}, +], indirect=True) +@pytest.mark.parametrize("test_config", [ + {'num_gpu_blocks': 256, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, +], indirect=True) +@pytest.mark.parametrize("gpu_layout_type", [0]) +def test_kvmanager_with_indexer_layerwise(model_config, cache_config, test_config, gpu_layout_type): + """Test KVManager with indexer in LAYERWISE mode. + + Validates the full round-trip: + PUT: D2H + H2DISK (non-layerwise, same as normal) + GET: LAYERWISE (fused DISK2H + H2D) + Data correctness is verified for both the main KV cache and the + indexer (DSA) KV cache after the round-trip. + + A background thread simulates the SGLang eventfd client so the + LayerwiseTransferWorker can complete its initialization handshake + without any source-code changes. + """ + from flexkv.common.config import GLOBAL_CONFIG_FROM_ENV + + # Save original values + orig_layerwise_env = os.environ.get('FLEXKV_ENABLE_LAYERWISE_TRANSFER') + orig_layerwise_flag = GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer + + # Determine the socket path that the worker will listen on. + # For tp_size=1, pp_size=1, dp_size=1, there is no suffix. + socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', + '/tmp/flexkv_layerwise_eventfd.sock') + + try: + # Enable layerwise transfer + os.environ['FLEXKV_ENABLE_LAYERWISE_TRANSFER'] = '1' + GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer = True + + # Start mock SGLang eventfd client thread BEFORE kvmanager.start() + # so it is ready to connect once the worker process binds the socket. + eventfd_thread = threading.Thread( + target=_mock_sglang_eventfd_client, + args=(socket_path, 0, 1, model_config.num_layers), + daemon=True, + ) + eventfd_thread.start() + + ssd_label = "+ssd" if cache_config.enable_ssd else "" + _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, + test_label=f"layerwise+indexer{ssd_label}", layerwise=True) + + eventfd_thread.join(timeout=10) + finally: + # Restore original environment and config + if orig_layerwise_env is None: + os.environ.pop('FLEXKV_ENABLE_LAYERWISE_TRANSFER', None) + else: + os.environ['FLEXKV_ENABLE_LAYERWISE_TRANSFER'] = orig_layerwise_env + GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer = orig_layerwise_flag + + + diff --git a/tests/test_memory_handle.py b/tests/test_memory_handle.py index 3df9c228c3..18ff724406 100644 --- a/tests/test_memory_handle.py +++ b/tests/test_memory_handle.py @@ -84,6 +84,35 @@ def _worker_test_tensor_from_tensor_direct_ipc(conn, device_id): raise +def _worker_test_fp8_tensor_from_bytes(conn, device_id): + """Test construction from bytes with fp8 dtype""" + try: + handle = conn.recv() + assert isinstance(handle, TensorSharedHandle) + assert handle.use_direct_ipc + assert handle.tensor_dtype == torch.float8_e4m3fn + assert handle.tensor_shape == (10, 20) + + tensor = handle.get_tensor() + assert isinstance(tensor, torch.Tensor) + assert tensor.is_cuda + assert tensor.device.index == device_id + assert tensor.shape == (10, 20) + assert tensor.dtype == torch.float8_e4m3fn + + expected = ( + torch.arange(200, dtype=torch.float32) + .reshape(10, 20) + .cuda(device_id) + .to(torch.float8_e4m3fn) + ) + max_diff = (tensor.to(torch.float32) - expected.to(torch.float32)).abs().max().item() + conn.send(max_diff) + except Exception as e: + conn.send(f"Error: {e}") + raise + + def _worker_test_tensor_from_bytes(conn, device_id): """Test construction from bytes (IPC handle)""" try: diff --git a/tests/test_transfer_engine_atomic_eviction.py b/tests/test_transfer_engine_atomic_eviction.py new file mode 100644 index 0000000000..83b8ca4831 --- /dev/null +++ b/tests/test_transfer_engine_atomic_eviction.py @@ -0,0 +1,425 @@ +""" +Unit tests for atomic indexer eviction in TransferEngine. + +These tests verify that: +1. TransferOp.pending_count defaults to 1. +2. _finalize_op is called only when pending_count reaches 0. +3. With indexer enabled: CompletedOp is NOT emitted until both main KV and indexer + workers complete (pending_count == 0). +4. With indexer disabled: behavior is identical to the original (pending_count starts + at 1, _finalize_op is called immediately after main KV completes). +""" +import queue +import unittest +from typing import List +from unittest.mock import MagicMock, patch, call + +import numpy as np + +from flexkv.common.transfer import TransferOp, TransferType, CompletedOp, LayerwiseTransferOp, WorkerKey + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_op(transfer_type: TransferType = TransferType.D2H) -> TransferOp: + """Create a minimal TransferOp for testing.""" + if transfer_type == TransferType.LAYERWISE: + return _make_layerwise_op() + return TransferOp( + graph_id=0, + transfer_type=transfer_type, + src_block_ids=np.array([0, 1], dtype=np.int64), + dst_block_ids=np.array([2, 3], dtype=np.int64), + ) + + +def _make_layerwise_op(**kwargs) -> LayerwiseTransferOp: + """Create a minimal LayerwiseTransferOp for testing.""" + defaults = dict( + graph_id=0, + src_block_ids_h2d=np.array([0, 1], dtype=np.int64), + dst_block_ids_h2d=np.array([2, 3], dtype=np.int64), + src_block_ids_disk2h=np.array([], dtype=np.int64), + dst_block_ids_disk2h=np.array([], dtype=np.int64), + ) + defaults.update(kwargs) + return LayerwiseTransferOp(**defaults) + + +# --------------------------------------------------------------------------- +# Tests – TransferOp.pending_count field +# --------------------------------------------------------------------------- + +class TestTransferOpPendingCount(unittest.TestCase): + """Requirement 5: TransferOp supports pending_count field.""" + + def test_default_pending_count_is_one(self): + """pending_count SHALL default to 1 (req 5.1).""" + op = _make_op() + self.assertEqual(op.pending_count, 1) + + def test_pending_count_is_mutable(self): + """pending_count SHALL be mutable (dataclass, not frozen).""" + op = _make_op() + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + op.pending_count -= 1 + self.assertEqual(op.pending_count, 1) + op.pending_count -= 1 + self.assertEqual(op.pending_count, 0) + + +# --------------------------------------------------------------------------- +# Tests – _finalize_op logic (unit-level, no real workers) +# --------------------------------------------------------------------------- + +class TestFinalizeOpLogic(unittest.TestCase): + """ + Requirement 1, 3, 4: _finalize_op is called only when pending_count == 0. + We test the logic directly by simulating what _scheduler_loop does. + """ + + def _simulate_worker_done(self, op: TransferOp, finished_ops: List[TransferOp], + finalize_fn) -> None: + """Simulate what _scheduler_loop does when a worker completes an op.""" + op.pending_count -= 1 + if op.pending_count == 0: + finalize_fn(op, finished_ops) + + def test_no_indexer_finalize_called_immediately(self): + """Without indexer: pending_count starts at 1, finalize called after main KV done (req 6.1).""" + op = _make_op() + self.assertEqual(op.pending_count, 1) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Main KV worker completes + self._simulate_worker_done(op, finished_ops, finalize_mock) + + # pending_count should be 0 and finalize should have been called once + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + def test_with_indexer_finalize_not_called_after_main_kv_only(self): + """With indexer: finalize NOT called when only main KV completes (req 3.1, 4.1).""" + op = _make_op() + # Simulate _assign_op_to_worker incrementing pending_count before submitting to indexer + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Main KV worker completes first + self._simulate_worker_done(op, finished_ops, finalize_mock) + + # pending_count should be 1, finalize should NOT have been called + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + self.assertEqual(len(finished_ops), 0) + + def test_with_indexer_finalize_called_after_both_complete(self): + """With indexer: finalize called exactly once when both workers complete (req 3.2, 4.2).""" + op = _make_op() + # Simulate _assign_op_to_worker incrementing pending_count before submitting to indexer + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Main KV worker completes first + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + + # Indexer worker completes + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + def test_with_indexer_finalize_called_once_regardless_of_order(self): + """Finalize called exactly once even if indexer completes before main KV (req 3.2, 4.2).""" + op = _make_op() + op.pending_count += 1 # indexer registered + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Indexer worker completes first + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + + # Main KV worker completes + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + +# --------------------------------------------------------------------------- +# Tests – _finalize_op method behavior +# --------------------------------------------------------------------------- + +class TestFinalizeOpMethod(unittest.TestCase): + """ + Test that _finalize_op correctly calls free_op_from_buffer, puts CompletedOp, + appends to finished_ops, and deletes from op_id_to_op. + """ + + def _make_engine_stub(self): + """Create a minimal stub of TransferEngine with the real _finalize_op method.""" + from flexkv.transfer.transfer_engine import TransferEngine, free_op_from_buffer + + engine = object.__new__(TransferEngine) + engine.op_id_to_op = {} + engine.completed_queue = MagicMock() + engine.pin_buffer = MagicMock() + engine.cache_config = MagicMock() + engine.cache_config.tokens_per_block = 16 + engine.model_config = MagicMock() + engine.model_config.token_size_in_bytes = 2 + return engine + + def test_finalize_op_releases_buffer_and_notifies(self): + """_finalize_op SHALL call free_op_from_buffer and put CompletedOp (req 3.2, 4.2).""" + from flexkv.transfer.transfer_engine import TransferEngine, free_op_from_buffer + + engine = self._make_engine_stub() + op = _make_op() + engine.op_id_to_op[op.op_id] = op + + finished_ops: List[TransferOp] = [] + + with patch('flexkv.transfer.transfer_engine.free_op_from_buffer') as mock_free: + engine._finalize_op(op, finished_ops) + + # free_op_from_buffer called once + mock_free.assert_called_once_with(op, engine.pin_buffer) + # CompletedOp put to completed_queue once + engine.completed_queue.put.assert_called_once() + completed_op_arg = engine.completed_queue.put.call_args[0][0] + self.assertIsInstance(completed_op_arg, CompletedOp) + self.assertEqual(completed_op_arg.graph_id, op.graph_id) + self.assertEqual(completed_op_arg.op_id, op.op_id) + # op appended to finished_ops + self.assertIn(op, finished_ops) + # op removed from op_id_to_op + self.assertNotIn(op.op_id, engine.op_id_to_op) + + def test_finalize_op_removes_op_from_tracking_dict(self): + """_finalize_op SHALL delete op from op_id_to_op (req 3.2 - no double free).""" + engine = self._make_engine_stub() + op = _make_op() + engine.op_id_to_op[op.op_id] = op + + finished_ops: List[TransferOp] = [] + + with patch('flexkv.transfer.transfer_engine.free_op_from_buffer'): + engine._finalize_op(op, finished_ops) + + self.assertNotIn(op.op_id, engine.op_id_to_op) + + def test_finalize_op_not_called_twice(self): + """op_id_to_op deletion prevents double finalization (req 3.2 - exactly once).""" + engine = self._make_engine_stub() + op = _make_op() + engine.op_id_to_op[op.op_id] = op + + finished_ops: List[TransferOp] = [] + + with patch('flexkv.transfer.transfer_engine.free_op_from_buffer'): + engine._finalize_op(op, finished_ops) + # Second call should raise KeyError since op was already removed + with self.assertRaises(KeyError): + engine._finalize_op(op, finished_ops) + + +# --------------------------------------------------------------------------- +# Tests – Indexer Layerwise Worker initialization and op dispatch +# --------------------------------------------------------------------------- + +class TestIndexerLayerwiseWorkerInit(unittest.TestCase): + """ + Tests for indexer LayerwiseTransferWorker initialization and LAYERWISE op dispatch. + Verifies requirements 1.1, 1.3, 2.1, 2.2, 5.1, 5.3. + """ + + def _make_engine_stub_with_indexer(self, enable_layerwise: bool = True): + """ + Create a minimal TransferEngine stub with _has_indexer=True and + a pre-populated _indexer_worker_map (simulating post-_init_workers state). + """ + from flexkv.transfer.transfer_engine import TransferEngine + + engine = object.__new__(TransferEngine) + engine._has_indexer = True + engine._worker_map = {} + engine._indexer_worker_map = {} + engine._indexer_op_to_parent_op = {} + engine._indexer_op_map = {} + engine.op_id_to_op = {} + engine.op_id_to_nvtx_range = {} + engine.completed_queue = MagicMock() + engine.pin_buffer = MagicMock() + engine.cache_config = MagicMock() + engine.cache_config.tokens_per_block = 16 + engine.model_config = MagicMock() + engine.model_config.token_size_in_bytes = 2 + + # Create mock workers for main KV + main_layerwise_worker = MagicMock() + engine._worker_map[TransferType.H2D] = [MagicMock()] + engine._worker_map[TransferType.D2H] = [MagicMock()] + if enable_layerwise: + # LAYERWISE worker map must be Dict[WorkerKey, WorkerHandle] + wk0 = WorkerKey(dp_rank=0, pp_rank=0) + engine._worker_map[TransferType.LAYERWISE] = {wk0: main_layerwise_worker} + + # Create mock workers for indexer + indexer_h2d_worker = MagicMock() + indexer_layerwise_worker = MagicMock() + engine._indexer_worker_map[TransferType.H2D] = [indexer_h2d_worker] + engine._indexer_worker_map[TransferType.D2H] = [MagicMock()] + if enable_layerwise: + engine._indexer_worker_map[TransferType.LAYERWISE] = [indexer_layerwise_worker] + + # PP replica tracking (needed by _assign_layerwise_op_to_workers) + engine._pp_replica_to_parent_op = {} + engine._pp_replica_op_map = {} + + return engine, main_layerwise_worker, indexer_layerwise_worker + + def test_indexer_worker_map_contains_layerwise_when_enabled(self): + """ + WHEN enable_layerwise_transfer=True AND indexer handles exist + THEN _indexer_worker_map SHALL contain TransferType.LAYERWISE (req 1.1). + """ + engine, _, _ = self._make_engine_stub_with_indexer(enable_layerwise=True) + self.assertIn(TransferType.LAYERWISE, engine._indexer_worker_map) + + def test_indexer_worker_map_no_layerwise_when_disabled(self): + """ + IF enable_layerwise_transfer=False + THEN _indexer_worker_map SHALL NOT contain TransferType.LAYERWISE (req 5.1). + """ + engine, _, _ = self._make_engine_stub_with_indexer(enable_layerwise=False) + self.assertNotIn(TransferType.LAYERWISE, engine._indexer_worker_map) + + def test_layerwise_op_pending_count_not_incremented_for_single_pp_stage(self): + """ + WHEN _assign_op_to_worker processes a LAYERWISE op with only one PP stage + THEN op.pending_count SHALL remain 1 (no fan-out needed). + LAYERWISE indexer is fused inside the worker, not dispatched separately. + """ + from flexkv.transfer.transfer_engine import register_op_to_buffer + import nvtx + + engine, main_worker, indexer_worker = self._make_engine_stub_with_indexer(enable_layerwise=True) + + op = _make_op(TransferType.LAYERWISE) + op.dp_rank = 0 + op.pp_rank = 0 + engine.op_id_to_op[op.op_id] = op + + initial_pending_count = op.pending_count # should be 1 + + with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ + patch('nvtx.start_range', return_value=MagicMock()): + engine._assign_op_to_worker(op) + + # With a single PP stage, no fan-out → pending_count stays at 1 + self.assertEqual(op.pending_count, initial_pending_count) + + def test_layerwise_op_submitted_to_main_worker(self): + """ + WHEN _assign_op_to_worker processes a LAYERWISE op + THEN op SHALL be submitted to the main KV layerwise worker. + LAYERWISE indexer is fused inside the worker, not dispatched separately. + """ + from flexkv.transfer.transfer_engine import register_op_to_buffer + + engine, main_worker, indexer_worker = self._make_engine_stub_with_indexer(enable_layerwise=True) + + op = _make_op(TransferType.LAYERWISE) + op.dp_rank = 0 + op.pp_rank = 0 + engine.op_id_to_op[op.op_id] = op + + with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ + patch('nvtx.start_range', return_value=MagicMock()): + engine._assign_op_to_worker(op) + + # Main KV layerwise worker should have received the op + main_worker.submit_transfer.assert_called_once_with(op) + + def test_layerwise_op_no_indexer_pending_count_stays_one(self): + """ + WHEN no indexer exists and LAYERWISE op is dispatched + THEN pending_count SHALL remain 1 (req 5.3). + """ + from flexkv.transfer.transfer_engine import TransferEngine + + engine = object.__new__(TransferEngine) + engine._has_indexer = False + engine._worker_map = {} + engine._indexer_worker_map = {} + engine.op_id_to_op = {} + engine.op_id_to_nvtx_range = {} + engine._pp_replica_to_parent_op = {} + engine._pp_replica_op_map = {} + + main_layerwise_worker = MagicMock() + wk0 = WorkerKey(dp_rank=0, pp_rank=0) + engine._worker_map[TransferType.LAYERWISE] = {wk0: main_layerwise_worker} + + op = _make_op(TransferType.LAYERWISE) + op.dp_rank = 0 + op.pp_rank = 0 + engine.op_id_to_op[op.op_id] = op + + initial_pending_count = op.pending_count # should be 1 + + with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ + patch('nvtx.start_range', return_value=MagicMock()): + engine._assign_op_to_worker(op) + + # pending_count should remain 1 (no indexer to increment for) + self.assertEqual(op.pending_count, initial_pending_count) + main_layerwise_worker.submit_transfer.assert_called_once_with(op) + + def test_finalize_called_after_both_layerwise_workers_complete(self): + """ + WHEN both main KV and indexer layerwise workers complete + THEN _finalize_op SHALL be called exactly once (req 3.2, 4.2). + """ + op = _make_op(TransferType.LAYERWISE) + # Simulate _assign_op_to_worker incrementing pending_count for indexer + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + def simulate_done(o, fo, fn): + o.pending_count -= 1 + if o.pending_count == 0: + fn(o, fo) + + # Main KV layerwise worker completes + simulate_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + + # Indexer layerwise worker completes + simulate_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + +if __name__ == "__main__": + unittest.main()