From 4d083b47188d0d82d994772a3d953736cb594b40 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 17 Dec 2025 09:44:55 +0000 Subject: [PATCH 01/59] initialize layerwise worker --- flexkv/kvtask.py | 9 +- flexkv/transfer/layerwise.py | 171 +++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 flexkv/transfer/layerwise.py diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index fa77bd0c1f..1930e2543b 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -826,7 +826,8 @@ def launch_tasks(self, task_ids: List[int], slot_mappings: List[np.ndarray], as_batch: bool = False, - batch_id: int = -1) -> List[int]: + batch_id: int = -1, + layerwise_transfer: bool = False) -> List[int]: assert isinstance(slot_mappings[0], np.ndarray) # trace launch tasks self.tracer.trace_launch_tasks(task_ids, slot_mappings, as_batch) @@ -841,7 +842,11 @@ def launch_tasks(self, if batch_id == -1: batch_id = self._gen_task_id() batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT - transfer_graphs = [self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type)] + batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type) + if layerwise_transfer: + # TODO: merge all ops into one layerwise transfer op + pass + transfer_graphs = [batch_task_graph] self.tasks[batch_id].status = TaskStatus.RUNNING task_ids = [batch_id] else: diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py new file mode 100644 index 0000000000..bda77cbd9e --- /dev/null +++ b/flexkv/transfer/layerwise.py @@ -0,0 +1,171 @@ +import copy +import torch.multiprocessing as mp +import threading +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe +from multiprocessing.connection import Connection +from threading import Thread +from typing import List, Any, Dict, Union, Optional + +import ctypes +import numpy as np +import nvtx +import torch + +from flexkv import c_ext + +from flexkv.c_ext import transfer_kv_blocks, transfer_kv_blocks_ssd, \ + transfer_kv_blocks_gds, TPTransferThreadGroup, TPGDSTransferThreadGroup +from flexkv.common.debug import flexkv_logger +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType +from flexkv.common.transfer import get_nvtx_range_color +from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV + +try: + from flexkv.c_ext import transfer_kv_blocks_remote +except ImportError: + transfer_kv_blocks_remote = None + +from flexkv.transfer.worker import WorkerTransferOp +from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister + +class LayerwiseTransferWorker(TransferWorkerBase): + def __init__(self, + worker_id: int, + transfer_conn: Connection, + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, + gpu_blocks: List[List[TensorSharedHandle]], + cpu_blocks: torch.Tensor, + ssd_files: Dict[int, List[str]], + gpu_kv_layouts: List[KVCacheLayout], + cpu_kv_layout: KVCacheLayout, + ssd_kv_layout: KVCacheLayout, + dtype: torch.dtype, + tp_group_size: int, + dp_group_id: int, + gpu_device_id: int, + num_blocks_per_file: int, + use_ce_transfer_h2d: bool = False, + use_ce_transfer_d2h: bool = False, + transfer_sms_h2d: int = 8, + transfer_sms_d2h: int = 8) -> None: + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + assert len(gpu_blocks) == tp_group_size + imported_gpu_blocks = [] + for handles_in_one_gpu in gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_gpu_blocks.append(blocks_in_one_gpu) + self.gpu_blocks = imported_gpu_blocks + self.dtype = dtype # note this should be quantized data type + self.is_mla = gpu_kv_layouts[0].is_mla + + self.num_gpus = len(self.gpu_blocks) + self.tp_group_size = tp_group_size + self.dp_group_id = dp_group_id + + # initialize GPU storage + self.num_layers = gpu_kv_layouts[0].num_layer + # here the chunk size doesn't include the layer info + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_block_strides_in_bytes = [gpu_kv_layout.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_layer_strides_in_bytes = [gpu_kv_layout.get_layer_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + + self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] + + # 确定 GPU block 类型 + if len(self.gpu_blocks) == 1: + self.gpu_block_type_ = 1 # TRTLLM + elif len(self.gpu_blocks) == self.num_layers: + self.gpu_block_type_ = 0 # VLLM + elif len(self.gpu_blocks) == self.num_layers * 2: + self.gpu_block_type_ = 2 # SGLANG + else: + raise ValueError(f"Invalid GPU block type: {len(self.gpu_blocks)}") + + # initialize CPU storage + flexkv_logger.info(f"[LayerwiseWorker-{worker_id}] Pinning CPU Memory: " + f"{cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") + cudaHostRegister(cpu_blocks) + self.cpu_blocks = cpu_blocks + + self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize + self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize + self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize + self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize + + self.use_ce_transfer_h2d = use_ce_transfer_h2d + self.use_ce_transfer_d2h = use_ce_transfer_d2h + self.transfer_sms_h2d = transfer_sms_h2d + self.transfer_sms_d2h = transfer_sms_d2h + + # initialize SSD storage + self.ssd_files = ssd_files + self.num_blocks_per_file = num_blocks_per_file + self.num_files = sum(len(file_list) for file_list in ssd_files.values()) + self.round_robin = 1 + + ssd_kv_layout_per_file = ssd_kv_layout.div_block(self.num_files, padding=True) + self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize + self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize + self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize + + # assert self.ssd_block_stride_in_bytes == self.cpu_block_stride_in_bytes + + try: + self.ioctx = c_ext.SSDIOCTX( + ssd_files, + len(ssd_files), + GLOBAL_CONFIG_FROM_ENV.iouring_entries, + GLOBAL_CONFIG_FROM_ENV.iouring_flags + ) + except Exception as e: + flexkv_logger.error(f"Error setting ssd ioctx: {e}\n") + raise RuntimeError("SSD Worker init failed") from e + + gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) + gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) + gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) + gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) + self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, + self.num_layers, gpu_kv_strides_tensor, + gpu_block_strides_tensor, gpu_layer_strides_tensor, + gpu_chunk_sizes_tensor) + + def _transfer_impl(self, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, + transfer_type: TransferType, + layer_id: int, + layer_granularity: int, + **kwargs: Any) -> None: + pass + + def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: + layer_id = transfer_op.layer_id + layer_granularity = transfer_op.layer_granularity + if layer_id == -1: + layer_id = 0 + if layer_granularity == -1: + layer_granularity = self.num_layers + + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + + self._transfer_impl( + src_block_ids, + dst_block_ids, + transfer_op.transfer_type, + layer_id, + layer_granularity, + ) From 9085d521baefb2f99519c44370a7e18311681cf1 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 17 Dec 2025 10:16:35 +0000 Subject: [PATCH 02/59] add layerwise transfer op --- flexkv/common/transfer.py | 54 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 83f3c38518..7a3b226929 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -5,6 +5,8 @@ import numpy as np +from flexkv.common.debug import flexkv_logger + @dataclass(frozen=True) class CompletedOp: @@ -53,6 +55,7 @@ class TransferType(Enum): # so that the op 3 will not be executed actually, but can indicate the completion of # a group of transfer ops VIRTUAL = "Virtual" + LAYERWISE = "LAYERWISE" # class DistType(Enum): # DISTH = "DISTH" @@ -105,8 +108,30 @@ def __post_init__(self) -> None: TransferOp._next_op_id += 1 assert self.src_block_ids.dtype == np.int64 assert self.dst_block_ids.dtype == np.int64 - self.valid_block_num = self.src_block_ids.size +@dataclass +class LayerwiseTransferOp(TransferOp): + + src_block_ids_h2d: np.ndarray + dst_block_ids_h2d: np.ndarray + src_block_ids_disk2h: np.ndarray + dst_block_ids_disk2h: np.ndarray + + def __post_init__(self) -> None: + self.transfer_type = TransferType.LAYERWISE + if self.layer_granularity == -1: + flexkv_logger.warning("layer_granularity is not set, using default value 1") + self.layer_granularity = 1 + assert self.src_block_ids_h2d.size == self.dst_block_ids_h2d.size + assert self.src_block_ids_disk2h.size == self.dst_block_ids_disk2h.size + with LayerwiseTransferOp._lock: + self.op_id = LayerwiseTransferOp._next_op_id + LayerwiseTransferOp._next_op_id += 1 + + assert self.src_block_ids_h2d.dtype == np.int64 + assert self.dst_block_ids_h2d.dtype == np.int64 + assert self.src_block_ids_disk2h.dtype == np.int64 + assert self.dst_block_ids_disk2h.dtype == np.int64 class TransferOpGraph: _next_graph_id = 0 @@ -336,8 +361,11 @@ def _merge_ops(ops: List[TransferOp], transfer_type: TransferType, return merged_op -def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], task_end_op_ids: List[int], - op_callback_dict: Dict[int, Callable]) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: +def merge_to_batch_graph(batch_id: int, + transfer_graphs: List[TransferOpGraph], + task_end_op_ids: List[int], + op_callback_dict: Dict[int, Callable], + layerwise_transfer: bool = False) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: """ Merge multiple TransferOpGraphs into a single batch graph. @@ -351,6 +379,7 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], transfer_graphs: List of graphs to merge task_end_op_ids: List of end op IDs for each task (one per graph) op_callback_dict: Dict mapping old op_id -> callback + layerwise_transfer: Whether to merge the graphs into a layerwise transfer op Returns: (merged_graph, batch_end_op_id, new_op_callback_dict) @@ -394,6 +423,24 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], merged_graph, callbacks_by_type[TransferType.H2D], new_op_callback_dict) if merged_disk2h_op is not None and merged_h2d_op is not None: merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) + if layerwise_transfer: # FIXME: rebase issue + if merged_h2d_op is not None: + layerwise_transfer_op = LayerwiseTransferOp( + graph_id=merged_graph.graph_id, + src_block_ids_h2d=merged_h2d_op.src_block_ids, + dst_block_ids_h2d=merged_h2d_op.dst_block_ids, + src_block_ids_disk2h=merged_disk2h_op.src_block_ids \ + if merged_disk2h_op is not None \ + else np.array([], dtype=np.int64), + dst_block_ids_disk2h=merged_disk2h_op.dst_block_ids \ + if merged_disk2h_op is not None \ + else np.array([], dtype=np.int64), + layer_id=0, + layer_granularity=1, + dp_id=h2d_ops[0].dp_id, + ) + merged_graph.add_transfer_op(layerwise_transfer_op) + batch_end_op_id = -1 # PUT path: D2H -> H2DISK merged_d2h_op = _merge_ops(ops_by_type[TransferType.D2H], TransferType.D2H, @@ -415,6 +462,7 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], else: batch_end_op_id = -1 + return merged_graph, batch_end_op_id, new_op_callback_dict From 645a22c9e380e4c8a077b039efa149ae0a7bcc20 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 17 Dec 2025 22:15:00 -0800 Subject: [PATCH 03/59] clear op callback if layerwise --- flexkv/common/transfer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 7a3b226929..723491bb23 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -441,6 +441,8 @@ def merge_to_batch_graph(batch_id: int, ) merged_graph.add_transfer_op(layerwise_transfer_op) batch_end_op_id = -1 + # layerwise transfer op does not need callbacks + new_op_callback_dict.clear() # PUT path: D2H -> H2DISK merged_d2h_op = _merge_ops(ops_by_type[TransferType.D2H], TransferType.D2H, From be7cc027fa9d5c53e41811770220bae5f0c43037 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 00:03:22 -0800 Subject: [PATCH 04/59] layerwise worker naive impl --- flexkv/common/config.py | 2 + flexkv/kvtask.py | 17 +++--- flexkv/transfer/layerwise.py | 87 ++++++++++++++++++++++-------- flexkv/transfer/transfer_engine.py | 30 ++++++++++- flexkv/transfer/worker.py | 46 ++++------------ flexkv/transfer/worker_op.py | 65 ++++++++++++++++++++++ 6 files changed, 180 insertions(+), 67 deletions(-) create mode 100644 flexkv/transfer/worker_op.py diff --git a/flexkv/common/config.py b/flexkv/common/config.py index df792f5f8b..7f75a7d560 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -101,6 +101,8 @@ def __post_init__(self): remote_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_REMOTE_LAYOUT', 'BLOCKFIRST').upper()), gds_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_GDS_LAYOUT', 'BLOCKFIRST').upper()), + enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 0))), + use_ce_transfer_h2d=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_H2D', 0))), use_ce_transfer_d2h=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_D2H', 0))), transfer_num_cta_h2d=int(os.getenv('FLEXKV_TRANSFER_NUM_CTA_H2D', 4)), diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 1930e2543b..18c1312167 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -12,7 +12,7 @@ import torch import numpy as np -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger from flexkv.common.block import hash_token from flexkv.common.transfer import TransferOpGraph, merge_to_batch_graph, get_nvtx_default_color, CompletedOp @@ -777,9 +777,12 @@ def prefetch_async(self, return task_id def merge_to_batch_kvtask(self, + batch_id: int, + task_ids: List[int], - batch_task_type: TaskType) -> TransferOpGraph: + batch_task_type: TaskType, + layerwise_transfer: bool = False) -> TransferOpGraph: op_callback_dict = {} task_end_op_ids = [] callbacks = [] @@ -800,7 +803,8 @@ def merge_to_batch_kvtask(self, batch_task_graph, task_end_op_id, op_callback_dict = merge_to_batch_graph(batch_id, transfer_graphs, task_end_op_ids, - op_callback_dict) + op_callback_dict, + layerwise_transfer) self.tasks[batch_id] = KVTask( task_id=batch_id, token_ids=np.concatenate([self.tasks[task_id].token_ids for task_id in task_ids]), @@ -841,11 +845,10 @@ def launch_tasks(self, if len(task_ids) > 1 and as_batch and (all_get or all_put): if batch_id == -1: batch_id = self._gen_task_id() + if not GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + layerwise_transfer = False batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT - batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type) - if layerwise_transfer: - # TODO: merge all ops into one layerwise transfer op - pass + batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type, layerwise_transfer) transfer_graphs = [batch_task_graph] self.tasks[batch_id].status = TaskStatus.RUNNING task_ids = [batch_id] diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index bda77cbd9e..b150680a1f 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -30,7 +30,7 @@ except ImportError: transfer_kv_blocks_remote = None -from flexkv.transfer.worker import WorkerTransferOp +from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister class LayerwiseTransferWorker(TransferWorkerBase): @@ -48,7 +48,6 @@ def __init__(self, dtype: torch.dtype, tp_group_size: int, dp_group_id: int, - gpu_device_id: int, num_blocks_per_file: int, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, @@ -82,17 +81,16 @@ def __init__(self, self.gpu_layer_strides_in_bytes = [gpu_kv_layout.get_layer_stride() * self.dtype.itemsize \ for gpu_kv_layout in gpu_kv_layouts] - self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] - - # 确定 GPU block 类型 - if len(self.gpu_blocks) == 1: + # 确定 GPU block 类型 (使用第一个 GPU 的 block 数量来判断) + num_blocks_first_gpu = len(imported_gpu_blocks[0]) if imported_gpu_blocks else 0 + if num_blocks_first_gpu == 1: self.gpu_block_type_ = 1 # TRTLLM - elif len(self.gpu_blocks) == self.num_layers: + elif num_blocks_first_gpu == self.num_layers: self.gpu_block_type_ = 0 # VLLM - elif len(self.gpu_blocks) == self.num_layers * 2: + elif num_blocks_first_gpu == self.num_layers * 2: self.gpu_block_type_ = 2 # SGLANG else: - raise ValueError(f"Invalid GPU block type: {len(self.gpu_blocks)}") + raise ValueError(f"Invalid GPU block type: {num_blocks_first_gpu}") # initialize CPU storage flexkv_logger.info(f"[LayerwiseWorker-{worker_id}] Pinning CPU Memory: " @@ -138,21 +136,62 @@ def __init__(self, gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) - self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, - self.num_layers, gpu_kv_strides_tensor, - gpu_block_strides_tensor, gpu_layer_strides_tensor, - gpu_chunk_sizes_tensor) + self.tp_transfer_thread_group = TPTransferThreadGroup( + self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, + self.num_layers, gpu_kv_strides_tensor, + gpu_block_strides_tensor, gpu_layer_strides_tensor, + gpu_chunk_sizes_tensor) def _transfer_impl(self, - src_block_ids: torch.Tensor, - dst_block_ids: torch.Tensor, - transfer_type: TransferType, + src_block_ids_h2d: torch.Tensor, + dst_block_ids_h2d: torch.Tensor, + src_block_ids_disk2h: torch.Tensor, + dst_block_ids_disk2h: torch.Tensor, layer_id: int, layer_granularity: int, **kwargs: Any) -> None: - pass + assert src_block_ids_h2d.dtype == torch.int64 + assert dst_block_ids_h2d.dtype == torch.int64 + assert src_block_ids_disk2h.dtype == torch.int64 + assert dst_block_ids_disk2h.dtype == torch.int64 + assert len(src_block_ids_h2d) == len(dst_block_ids_h2d) + assert len(src_block_ids_disk2h) == len(dst_block_ids_disk2h) + layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) + if len(src_block_ids_disk2h) > 0: + transfer_kv_blocks_ssd( + ioctx=self.ioctx, + cpu_layer_id_list=layer_id_list, + cpu_tensor_ptr=self.cpu_layer_ptrs[0].item(), + ssd_block_ids=src_block_ids_disk2h, + cpu_block_ids=dst_block_ids_disk2h, + cpu_layer_stride_in_bytes=self.cpu_layer_stride_in_bytes, + cpu_kv_stride_in_bytes=self.cpu_kv_stride_in_bytes, + ssd_layer_stride_in_bytes=self.ssd_layer_stride_in_bytes, + ssd_kv_stride_in_bytes=self.ssd_kv_stride_in_bytes, + chunk_size_in_bytes=self.chunk_size_in_bytes, + block_stride_in_bytes=self.block_stride_in_bytes, + is_read=True, + num_blocks_per_file=self.num_blocks_per_file, + round_robin=self.round_robin, + num_threads_per_device=32, + is_mla=self.is_mla, + ) + self.tp_transfer_thread_group.tp_group_transfer( + dst_block_ids_h2d, + src_block_ids_h2d, + self.cpu_kv_stride_in_bytes, + self.cpu_layer_stride_in_bytes, + self.cpu_block_stride_in_bytes, + self.cpu_chunk_size_in_bytes, + self.transfer_sms_h2d, + True, # is H2D + self.use_ce_transfer_h2d, + layer_id, + layer_granularity, + self.is_mla, + ) - def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: + def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: layer_id = transfer_op.layer_id layer_granularity = transfer_op.layer_granularity if layer_id == -1: @@ -160,12 +199,16 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers - src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + src_block_ids_h2d = transfer_op.src_block_ids_h2d + dst_block_ids_h2d = transfer_op.dst_block_ids_h2d + src_block_ids_disk2h = transfer_op.src_block_ids_disk2h + dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h self._transfer_impl( - src_block_ids, - dst_block_ids, - transfer_op.transfer_type, + src_block_ids_h2d, + dst_block_ids_h2d, + src_block_ids_disk2h, + dst_block_ids_disk2h, layer_id, layer_granularity, ) diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index c3b9539f41..a7bc7b08c4 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -41,6 +41,7 @@ tpGDSTransferWorker, PEER2CPUTransferWorker, ) +from flexkv.transfer.layerwise import LayerwiseTransferWorker from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.ring_buffer import SharedOpPool @@ -52,6 +53,8 @@ def register_op_to_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: Device type prefixes prevent hash collisions when different device types use the same block ID values (e.g., CPU block 0 vs SSD block 0). """ + if op.transfer_type == TransferType.LAYERWISE: + return # Map TransferType to (src_device_type, dst_device_type) for hash prefix # This prevents hash collisions when different devices use the same block IDs transfer_type_to_devices = { @@ -306,7 +309,32 @@ def _init_workers(self) -> None: ] self._worker_map[TransferType.DISK2D] = self.gds_workers self._worker_map[TransferType.D2DISK] = self.gds_workers - + if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + self.layerwise_workers = [ + LayerwiseTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ + for j in range(i * self.tp_size, (i + 1) * self.tp_size)], + cpu_blocks=self._cpu_handle.get_tensor(), + ssd_files=self._ssd_handle.get_file_list(), + gpu_kv_layouts=[self.gpu_handles[i].kv_layout \ + for i in range(i * self.tp_size, (i + 1) * self.tp_size)], + cpu_kv_layout=self._cpu_handle.kv_layout, + ssd_kv_layout=self._ssd_handle.kv_layout, + dtype=self.gpu_handles[i].dtype, + tp_group_size=self.tp_size, + dp_group_id=i, + num_blocks_per_file=self._ssd_handle.num_blocks_per_file, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_sms_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_sms_h2d, + transfer_sms_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_sms_d2h, + ) + for i in range(self.dp_size) + ] + if self.cache_config.enable_kv_sharing and self._cpu_handle is not None and (self.cache_config.enable_p2p_cpu \ or (self._ssd_handle and self.cache_config.enable_p2p_ssd)): ## NOTE:if we have the cpu handle and enable p2p cpu transfer we need this worker diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index a853c2c2a8..2b90def54f 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -32,8 +32,10 @@ from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType -from flexkv.common.transfer import get_nvtx_range_color +from flexkv.common.transfer import get_nvtx_range_color, LayerwiseTransferOp from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV, MooncakeTransferEngineConfig +from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp + from flexkv.mooncakeEngineWrapper import MoonCakeTransferEngineWrapper from flexkv.transfer.zmqHelper import NotifyMsg, NotifyStatus, SSDZMQServer, SSDZMQClient from flexkv.cache.redis_meta import RedisMeta @@ -64,40 +66,6 @@ def cudaHostUnregister(tensor: torch.Tensor) -> None: size = tensor.numel() * tensor.element_size() ret = cudart.cudaHostUnregister(ctypes.c_void_p(ptr)) -@dataclass -class WorkerTransferOp: - transfer_op_id: int - transfer_graph_id: int - transfer_type: TransferType - layer_id: int - layer_granularity: int - src_slot_id: int - dst_slot_id: int - valid_block_num: int - src_block_ids: np.ndarray - dst_block_ids: np.ndarray - src_block_node_ids: Optional[np.ndarray] - # successors: List[int] - - def __init__(self, transfer_op: TransferOp): - self.transfer_op_id = transfer_op.op_id - self.transfer_graph_id = transfer_op.graph_id - self.transfer_type = transfer_op.transfer_type - self.layer_id = transfer_op.layer_id - self.layer_granularity = transfer_op.layer_granularity - self.src_slot_id = transfer_op.src_slot_id - self.dst_slot_id = transfer_op.dst_slot_id - self.valid_block_num = transfer_op.valid_block_num - # Always preserve optional src_block_node_ids from TransferOp - self.src_block_node_ids = transfer_op.src_block_node_ids - - if self.src_slot_id == -1 or self.dst_slot_id == -1: - self.src_block_ids = transfer_op.src_block_ids - self.dst_block_ids = transfer_op.dst_block_ids - else: - self.src_block_ids = np.empty(0) - self.dst_block_ids = np.empty(0) - # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): _worker_id_counter = 0 @@ -278,8 +246,12 @@ def __init__(self, worker_id: int, transfer_conn: Connection, process: mp.Proces self.process = process self.ready_event = ready_event - def submit_transfer(self, op: TransferOp) -> None: - self.transfer_conn.send(WorkerTransferOp(op)) + def submit_transfer(self, op: Union[TransferOp, LayerwiseTransferOp]) -> None: + if isinstance(op, LayerwiseTransferOp): + worker_op = WorkerLayerwiseTransferOp(op) + else: + worker_op = WorkerTransferOp(op) + self.transfer_conn.send(worker_op) def shutdown(self) -> None: try: diff --git a/flexkv/transfer/worker_op.py b/flexkv/transfer/worker_op.py new file mode 100644 index 0000000000..3fb76cc8f8 --- /dev/null +++ b/flexkv/transfer/worker_op.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +from flexkv.common.transfer import TransferOp, TransferType, LayerwiseTransferOp + + +@dataclass +class WorkerTransferOp: + transfer_op_id: int + transfer_graph_id: int + transfer_type: TransferType + layer_id: int + layer_granularity: int + src_slot_id: int + dst_slot_id: int + valid_block_num: int + src_block_ids: np.ndarray + dst_block_ids: np.ndarray + src_block_node_ids: Optional[np.ndarray] + + def __init__(self, transfer_op: TransferOp): + self.transfer_op_id = transfer_op.op_id + self.transfer_graph_id = transfer_op.graph_id + self.transfer_type = transfer_op.transfer_type + self.layer_id = transfer_op.layer_id + self.layer_granularity = transfer_op.layer_granularity + self.src_slot_id = transfer_op.src_slot_id + self.dst_slot_id = transfer_op.dst_slot_id + self.valid_block_num = transfer_op.valid_block_num + # Always preserve optional src_block_node_ids from TransferOp + self.src_block_node_ids = transfer_op.src_block_node_ids + + if self.src_slot_id == -1: + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids + else: + self.src_block_ids = np.empty(0) + self.dst_block_ids = np.empty(0) + + +@dataclass +class WorkerLayerwiseTransferOp: + transfer_op_id: int + transfer_graph_id: int + transfer_type: TransferType + layer_id: int + layer_granularity: int + src_block_ids_h2d: np.ndarray + dst_block_ids_h2d: np.ndarray + src_block_ids_disk2h: np.ndarray + dst_block_ids_disk2h: np.ndarray + + def __init__(self, transfer_op: LayerwiseTransferOp): + self.transfer_op_id = transfer_op.op_id + self.transfer_graph_id = transfer_op.graph_id + assert transfer_op.transfer_type == TransferType.LAYERWISE + self.transfer_type = transfer_op.transfer_type + self.layer_id = transfer_op.layer_id + self.layer_granularity = transfer_op.layer_granularity + self.src_block_ids_h2d = transfer_op.src_block_ids_h2d + self.dst_block_ids_h2d = transfer_op.dst_block_ids_h2d + self.src_block_ids_disk2h = transfer_op.src_block_ids_disk2h + self.dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h From 52a19f3714a816754abbc871a5218fcfe1b04afa Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 00:07:26 -0800 Subject: [PATCH 05/59] check layerwise condition --- flexkv/kvtask.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 18c1312167..bc3bcd86fd 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -845,8 +845,16 @@ def launch_tasks(self, if len(task_ids) > 1 and as_batch and (all_get or all_put): if batch_id == -1: batch_id = self._gen_task_id() - if not GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: - layerwise_transfer = False + if layerwise_transfer: + if not GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + flexkv_logger.warning("layerwise transfer is not enabled") + layerwise_transfer = False + else: + for task_id in task_ids: + if self.tasks[task_id].task_type != TaskType.GET: + flexkv_logger.warning("only support layerwise get") + layerwise_transfer = False + break batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type, layerwise_transfer) transfer_graphs = [batch_task_graph] From fb6ce1262ace604cdec66a28038e4f68d32fb586 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 00:36:57 -0800 Subject: [PATCH 06/59] add default value --- flexkv/common/transfer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 723491bb23..1b1f42ed6d 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -112,10 +112,10 @@ def __post_init__(self) -> None: @dataclass class LayerwiseTransferOp(TransferOp): - src_block_ids_h2d: np.ndarray - dst_block_ids_h2d: np.ndarray - src_block_ids_disk2h: np.ndarray - dst_block_ids_disk2h: np.ndarray + src_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + dst_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + src_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + dst_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) def __post_init__(self) -> None: self.transfer_type = TransferType.LAYERWISE From 4dc9fb4dd1cb82de1edd8b32360419b91b2b206d Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 00:58:56 -0800 Subject: [PATCH 07/59] fix bug and benchmark --- benchmarks/benchmark_single_batch.py | 13 +++++++------ flexkv/common/transfer.py | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index c629b545ea..3b9e77c4a5 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -133,19 +133,20 @@ def benchmark_flexkv(model_config: ModelConfig, all_tokens = 0 start_time = time.time() batch_get_ids = [] + return_masks = [] + cached_tokens = 0 for i in range(batch_size): all_tokens += len(batch_sequence_tensor[i]) - task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], + task_id, return_mask = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) batch_get_ids.append(task_id) + cached_tokens += return_mask.sum().item() get_match_time = time.time() - start_time - kvmanager.launch(batch_get_ids, batch_slot_mapping) - get_result = kvmanager.wait(batch_get_ids) + batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True) + get_result = kvmanager.wait(batch_id_list) elapsed_time_get = time.time() - start_time - cached_tokens = 0 for _, response in get_result.items(): - if response.status == KVResponseStatus.SUCCESS: - cached_tokens += response.return_mask.sum().item() + assert response.status == KVResponseStatus.SUCCESS transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 1b1f42ed6d..f689fa39c5 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -108,6 +108,7 @@ def __post_init__(self) -> None: TransferOp._next_op_id += 1 assert self.src_block_ids.dtype == np.int64 assert self.dst_block_ids.dtype == np.int64 + self.valid_block_num = self.src_block_ids.size @dataclass class LayerwiseTransferOp(TransferOp): From 80c2077e687a4975b16ff435856c58ac2ae5cb71 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 01:05:28 -0800 Subject: [PATCH 08/59] add layerwise param --- benchmarks/benchmark_single_batch.py | 2 +- flexkv/common/config.py | 2 +- flexkv/kvmanager.py | 7 ++++--- flexkv/server/client.py | 3 ++- flexkv/server/request.py | 1 + flexkv/server/server.py | 6 +++++- flexkv/transfer/layerwise.py | 3 +-- 7 files changed, 15 insertions(+), 9 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index 3b9e77c4a5..69af2f46c7 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -142,7 +142,7 @@ def benchmark_flexkv(model_config: ModelConfig, batch_get_ids.append(task_id) cached_tokens += return_mask.sum().item() get_match_time = time.time() - start_time - batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True) + batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=True) get_result = kvmanager.wait(batch_id_list) elapsed_time_get = time.time() - start_time for _, response in get_result.items(): diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 7f75a7d560..c2b4e38574 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -101,7 +101,7 @@ def __post_init__(self): remote_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_REMOTE_LAYOUT', 'BLOCKFIRST').upper()), gds_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_GDS_LAYOUT', 'BLOCKFIRST').upper()), - enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 0))), + enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 1))), use_ce_transfer_h2d=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_H2D', 0))), use_ce_transfer_d2h=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_D2H', 0))), diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 8e61a15136..aaff75361c 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -239,7 +239,8 @@ def prefetch_async(self, def launch(self, task_ids: Union[int, List[int]], slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]], - as_batch: bool = False) -> List[int]: + as_batch: bool = False, + layerwise_transfer: bool = False) -> List[int]: if isinstance(task_ids, int): task_ids = [task_ids] if not isinstance(slot_mappings, List): @@ -247,9 +248,9 @@ def launch(self, if isinstance(slot_mappings[0], torch.Tensor): slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] if self.server_client_mode: - return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch) + return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer) else: - return self.kv_task_engine.launch_tasks(task_ids, slot_mappings, as_batch) + return self.kv_task_engine.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer) def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 9947fa2cf7..4a8bc241de 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -177,11 +177,12 @@ def launch_tasks( task_ids: List[int], slot_mappings: List[np.ndarray], as_batch: bool = False, + layerwise_transfer: bool = False, ) -> List[int]: batch_id = -1 if as_batch: batch_id = self._get_task_id() - req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id) + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id, layerwise_transfer) self.send_to_server.send_pyobj(req) return [batch_id] if as_batch else task_ids diff --git a/flexkv/server/request.py b/flexkv/server/request.py index e540f495cb..dde39c7b42 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -78,6 +78,7 @@ class LaunchTaskRequest: slot_mappings: List[np.ndarray] as_batch: bool = False batch_id: int = -1 + layerwise_transfer: bool = False @dataclass class CancelTaskRequest: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index dbefb260e2..94f574d711 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -383,7 +383,11 @@ def _handle_prefetch_request(self, req: PrefetchRequest) -> None: def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: """Handle LaunchTask request""" - self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings, req.as_batch, req.batch_id) + self.kv_task_engine.launch_tasks(req.task_ids, + req.slot_mappings, + req.as_batch, + req.batch_id, + req.layerwise_transfer) def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index b150680a1f..75e6a202d0 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -81,7 +81,6 @@ def __init__(self, self.gpu_layer_strides_in_bytes = [gpu_kv_layout.get_layer_stride() * self.dtype.itemsize \ for gpu_kv_layout in gpu_kv_layouts] - # 确定 GPU block 类型 (使用第一个 GPU 的 block 数量来判断) num_blocks_first_gpu = len(imported_gpu_blocks[0]) if imported_gpu_blocks else 0 if num_blocks_first_gpu == 1: self.gpu_block_type_ = 1 # TRTLLM @@ -93,7 +92,7 @@ def __init__(self, raise ValueError(f"Invalid GPU block type: {num_blocks_first_gpu}") # initialize CPU storage - flexkv_logger.info(f"[LayerwiseWorker-{worker_id}] Pinning CPU Memory: " + flexkv_logger.info(f"Pinning CPU Memory: " f"{cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) self.cpu_blocks = cpu_blocks From f64b9c0044b4c2d98bc535ad2685c7c89d57c632 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 02:33:53 -0800 Subject: [PATCH 09/59] fix bugs --- flexkv/common/transfer.py | 30 ++++++++++++++++++++++++++---- flexkv/kvmanager.py | 7 ++++++- flexkv/kvtask.py | 1 - flexkv/transfer/layerwise.py | 20 ++++++++++---------- flexkv/transfer/transfer_engine.py | 1 + flexkv/transfer/worker.py | 3 +-- 6 files changed, 44 insertions(+), 18 deletions(-) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index f689fa39c5..b96ba2e820 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -118,16 +118,38 @@ class LayerwiseTransferOp(TransferOp): src_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) dst_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + def __init__(self, + graph_id: int, + src_block_ids_h2d: np.ndarray, + dst_block_ids_h2d: np.ndarray, + src_block_ids_disk2h: np.ndarray, + dst_block_ids_disk2h: np.ndarray, + layer_id: int = 0, + layer_granularity: int = 1, + dp_id: int = 0) -> None: + self.src_block_ids_h2d = src_block_ids_h2d + self.dst_block_ids_h2d = dst_block_ids_h2d + self.src_block_ids_disk2h = src_block_ids_disk2h + self.dst_block_ids_disk2h = dst_block_ids_disk2h + + super().__init__( + graph_id=graph_id, + transfer_type=TransferType.LAYERWISE, + src_block_ids=np.array([], dtype=np.int64), + dst_block_ids=np.array([], dtype=np.int64), + layer_id=layer_id, + layer_granularity=layer_granularity, + dp_id=dp_id, + ) + def __post_init__(self) -> None: - self.transfer_type = TransferType.LAYERWISE + super().__post_init__() + if self.layer_granularity == -1: flexkv_logger.warning("layer_granularity is not set, using default value 1") self.layer_granularity = 1 assert self.src_block_ids_h2d.size == self.dst_block_ids_h2d.size assert self.src_block_ids_disk2h.size == self.dst_block_ids_disk2h.size - with LayerwiseTransferOp._lock: - self.op_id = LayerwiseTransferOp._next_op_id - LayerwiseTransferOp._next_op_id += 1 assert self.src_block_ids_h2d.dtype == np.int64 assert self.dst_block_ids_h2d.dtype == np.int64 diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index aaff75361c..bf484fb24b 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -250,7 +250,12 @@ def launch(self, if self.server_client_mode: return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer) else: - return self.kv_task_engine.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer) + return self.kv_task_engine.launch_tasks( + task_ids, + slot_mappings, + as_batch=as_batch, + layerwise_transfer=layerwise_transfer + ) def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index bc3bcd86fd..462a6e3093 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -799,7 +799,6 @@ def merge_to_batch_kvtask(self, task_end_op_ids.append(self.tasks[task_id].task_end_op_id) callbacks.append(self.tasks[task_id].callback) return_masks.append(self.tasks[task_id].return_mask) - batch_task_graph, task_end_op_id, op_callback_dict = merge_to_batch_graph(batch_id, transfer_graphs, task_end_op_ids, diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 75e6a202d0..a34834c9ec 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -54,7 +54,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8) -> None: super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) - assert len(gpu_blocks) == tp_group_size + assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" imported_gpu_blocks = [] for handles_in_one_gpu in gpu_blocks: blocks_in_one_gpu = [] @@ -118,8 +118,6 @@ def __init__(self, self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize - # assert self.ssd_block_stride_in_bytes == self.cpu_block_stride_in_bytes - try: self.ioctx = c_ext.SSDIOCTX( ssd_files, @@ -141,6 +139,8 @@ def __init__(self, gpu_block_strides_tensor, gpu_layer_strides_tensor, gpu_chunk_sizes_tensor) + self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) + def _transfer_impl(self, src_block_ids_h2d: torch.Tensor, dst_block_ids_h2d: torch.Tensor, @@ -167,8 +167,8 @@ def _transfer_impl(self, cpu_kv_stride_in_bytes=self.cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes=self.ssd_layer_stride_in_bytes, ssd_kv_stride_in_bytes=self.ssd_kv_stride_in_bytes, - chunk_size_in_bytes=self.chunk_size_in_bytes, - block_stride_in_bytes=self.block_stride_in_bytes, + chunk_size_in_bytes=self.cpu_chunk_size_in_bytes, + block_stride_in_bytes=self.cpu_block_stride_in_bytes, is_read=True, num_blocks_per_file=self.num_blocks_per_file, round_robin=self.round_robin, @@ -198,11 +198,11 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers - src_block_ids_h2d = transfer_op.src_block_ids_h2d - dst_block_ids_h2d = transfer_op.dst_block_ids_h2d - src_block_ids_disk2h = transfer_op.src_block_ids_disk2h - dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h - + src_block_ids_h2d = torch.from_numpy(transfer_op.src_block_ids_h2d).to(dtype=torch.int64) + dst_block_ids_h2d = torch.from_numpy(transfer_op.dst_block_ids_h2d).to(dtype=torch.int64) + src_block_ids_disk2h = torch.from_numpy(transfer_op.src_block_ids_disk2h).to(dtype=torch.int64) + dst_block_ids_disk2h = torch.from_numpy(transfer_op.dst_block_ids_disk2h).to(dtype=torch.int64) + layer_granularity = self.num_layers # TODO: remove this self._transfer_impl( src_block_ids_h2d, dst_block_ids_h2d, diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index a7bc7b08c4..84aab15028 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -334,6 +334,7 @@ def _init_workers(self) -> None: ) for i in range(self.dp_size) ] + self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers if self.cache_config.enable_kv_sharing and self._cpu_handle is not None and (self.cache_config.enable_p2p_cpu \ or (self._ssd_handle and self.cache_config.enable_p2p_ssd)): diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 2b90def54f..13980932cf 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -218,8 +218,7 @@ def run(self) -> None: transfer_status = False try: nvtx.push_range(f"launch {op.transfer_type.name} op_id: {op.transfer_op_id}, " - f"graph_id: {op.transfer_graph_id}, " - f"num_blocks: {op.valid_block_num}", + f"graph_id: {op.transfer_graph_id}", color=get_nvtx_range_color(op.transfer_graph_id)) transfer_status = self.launch_transfer(op) nvtx.pop_range() From d86335b6d6e1c6b83c7ceeae0c0e2bc350008237 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 21:48:03 -0800 Subject: [PATCH 10/59] disable layerwise in benchmark --- benchmarks/benchmark_single_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index 69af2f46c7..4bf078a2ac 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -142,7 +142,7 @@ def benchmark_flexkv(model_config: ModelConfig, batch_get_ids.append(task_id) cached_tokens += return_mask.sum().item() get_match_time = time.time() - start_time - batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=True) + batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) get_result = kvmanager.wait(batch_id_list) elapsed_time_get = time.time() - start_time for _, response in get_result.items(): From 7a746655c1a42919bce543adb5a52535795460cc Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 18 Dec 2025 23:08:54 -0800 Subject: [PATCH 11/59] pin memory of block ids --- benchmarks/benchmark_single_batch.py | 2 +- flexkv/transfer/layerwise.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index 4bf078a2ac..69af2f46c7 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -142,7 +142,7 @@ def benchmark_flexkv(model_config: ModelConfig, batch_get_ids.append(task_id) cached_tokens += return_mask.sum().item() get_match_time = time.time() - start_time - batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=True) get_result = kvmanager.wait(batch_id_list) elapsed_time_get = time.time() - start_time for _, response in get_result.items(): diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index a34834c9ec..4487da5db0 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -198,8 +198,8 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers - src_block_ids_h2d = torch.from_numpy(transfer_op.src_block_ids_h2d).to(dtype=torch.int64) - dst_block_ids_h2d = torch.from_numpy(transfer_op.dst_block_ids_h2d).to(dtype=torch.int64) + src_block_ids_h2d = torch.from_numpy(transfer_op.src_block_ids_h2d).to(dtype=torch.int64).pin_memory() + dst_block_ids_h2d = torch.from_numpy(transfer_op.dst_block_ids_h2d).to(dtype=torch.int64).pin_memory() src_block_ids_disk2h = torch.from_numpy(transfer_op.src_block_ids_disk2h).to(dtype=torch.int64) dst_block_ids_disk2h = torch.from_numpy(transfer_op.dst_block_ids_disk2h).to(dtype=torch.int64) layer_granularity = self.num_layers # TODO: remove this From 7641dccdeab4bb70aa7dcfaec0c915433065e499 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Fri, 9 Jan 2026 14:30:04 +0800 Subject: [PATCH 12/59] make ssd optional --- flexkv/transfer/layerwise.py | 66 +++++++++++++++++------------- flexkv/transfer/transfer_engine.py | 9 ++-- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 4487da5db0..260cdbc26a 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -108,26 +108,28 @@ def __init__(self, self.transfer_sms_d2h = transfer_sms_d2h # initialize SSD storage - self.ssd_files = ssd_files - self.num_blocks_per_file = num_blocks_per_file - self.num_files = sum(len(file_list) for file_list in ssd_files.values()) - self.round_robin = 1 - - ssd_kv_layout_per_file = ssd_kv_layout.div_block(self.num_files, padding=True) - self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize - self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize - self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize - - try: - self.ioctx = c_ext.SSDIOCTX( - ssd_files, - len(ssd_files), - GLOBAL_CONFIG_FROM_ENV.iouring_entries, - GLOBAL_CONFIG_FROM_ENV.iouring_flags - ) - except Exception as e: - flexkv_logger.error(f"Error setting ssd ioctx: {e}\n") - raise RuntimeError("SSD Worker init failed") from e + self.enable_ssd = len(ssd_files) > 0 + if self.enable_ssd: + self.ssd_files = ssd_files + self.num_blocks_per_file = num_blocks_per_file + self.num_files = sum(len(file_list) for file_list in ssd_files.values()) + self.round_robin = 1 + + ssd_kv_layout_per_file = ssd_kv_layout.div_block(self.num_files, padding=True) + self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize + self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize + self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize + + try: + self.ioctx = c_ext.SSDIOCTX( + ssd_files, + len(ssd_files), + GLOBAL_CONFIG_FROM_ENV.iouring_entries, + GLOBAL_CONFIG_FROM_ENV.iouring_flags + ) + except Exception as e: + flexkv_logger.error(f"Error setting ssd ioctx: {e}\n") + raise RuntimeError("SSD Worker init failed") from e gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) @@ -144,19 +146,22 @@ def __init__(self, def _transfer_impl(self, src_block_ids_h2d: torch.Tensor, dst_block_ids_h2d: torch.Tensor, - src_block_ids_disk2h: torch.Tensor, - dst_block_ids_disk2h: torch.Tensor, + src_block_ids_disk2h: Optional[torch.Tensor], + dst_block_ids_disk2h: Optional[torch.Tensor], layer_id: int, layer_granularity: int, **kwargs: Any) -> None: assert src_block_ids_h2d.dtype == torch.int64 assert dst_block_ids_h2d.dtype == torch.int64 - assert src_block_ids_disk2h.dtype == torch.int64 - assert dst_block_ids_disk2h.dtype == torch.int64 assert len(src_block_ids_h2d) == len(dst_block_ids_h2d) - assert len(src_block_ids_disk2h) == len(dst_block_ids_disk2h) + if src_block_ids_disk2h is not None: + assert src_block_ids_disk2h.dtype == torch.int64 + assert dst_block_ids_disk2h.dtype == torch.int64 + assert len(src_block_ids_disk2h) == len(dst_block_ids_disk2h) layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) - if len(src_block_ids_disk2h) > 0: + if src_block_ids_disk2h is not None and \ + len(src_block_ids_disk2h) > 0: + assert self.enable_ssd, "SSD is not enabled for LayerwiseTransferWorker" transfer_kv_blocks_ssd( ioctx=self.ioctx, cpu_layer_id_list=layer_id_list, @@ -200,8 +205,13 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: src_block_ids_h2d = torch.from_numpy(transfer_op.src_block_ids_h2d).to(dtype=torch.int64).pin_memory() dst_block_ids_h2d = torch.from_numpy(transfer_op.dst_block_ids_h2d).to(dtype=torch.int64).pin_memory() - src_block_ids_disk2h = torch.from_numpy(transfer_op.src_block_ids_disk2h).to(dtype=torch.int64) - dst_block_ids_disk2h = torch.from_numpy(transfer_op.dst_block_ids_disk2h).to(dtype=torch.int64) + + if transfer_op.src_block_ids_disk2h.size > 0: + src_block_ids_disk2h = torch.from_numpy(transfer_op.src_block_ids_disk2h).to(dtype=torch.int64) + dst_block_ids_disk2h = torch.from_numpy(transfer_op.dst_block_ids_disk2h).to(dtype=torch.int64) + else: + src_block_ids_disk2h = None + dst_block_ids_disk2h = None layer_granularity = self.num_layers # TODO: remove this self._transfer_impl( src_block_ids_h2d, diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 84aab15028..29f757770f 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -310,6 +310,9 @@ def _init_workers(self) -> None: self._worker_map[TransferType.DISK2D] = self.gds_workers self._worker_map[TransferType.D2DISK] = self.gds_workers if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + ssd_files = {} if self._ssd_handle is None else self._ssd_handle.get_file_list() + ssd_kv_layout = None if self._ssd_handle is None else self._ssd_handle.kv_layout + num_blocks_per_file = 0 if self._ssd_handle is None else self._ssd_handle.num_blocks_per_file self.layerwise_workers = [ LayerwiseTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -318,15 +321,15 @@ def _init_workers(self) -> None: gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_blocks=self._cpu_handle.get_tensor(), - ssd_files=self._ssd_handle.get_file_list(), + ssd_files=ssd_files, gpu_kv_layouts=[self.gpu_handles[i].kv_layout \ for i in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_kv_layout=self._cpu_handle.kv_layout, - ssd_kv_layout=self._ssd_handle.kv_layout, + ssd_kv_layout=ssd_kv_layout, dtype=self.gpu_handles[i].dtype, tp_group_size=self.tp_size, dp_group_id=i, - num_blocks_per_file=self._ssd_handle.num_blocks_per_file, + num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_sms_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_sms_h2d, From 08355554e343f31531215e9f9576b369af9d9d73 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 8 Jan 2026 22:59:45 -0800 Subject: [PATCH 13/59] initial layerwise cpp impl --- csrc/bindings.cpp | 28 ++++ csrc/layerwise.cpp | 242 +++++++++++++++++++++++++++++++++++ csrc/layerwise.h | 88 +++++++++++++ flexkv/transfer/layerwise.py | 74 +++++------ setup.py | 2 + 5 files changed, 391 insertions(+), 43 deletions(-) create mode 100644 csrc/layerwise.cpp create mode 100644 csrc/layerwise.h diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 79d68f5bbe..c1123a881c 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -17,6 +17,9 @@ #include #include "cache_utils.h" +#include "layerwise.h" +#include "pcfs/pcfs.h" +#include "tp_transfer_thread_group.h" #include "gds/gds_manager.h" #include "gds/tp_gds_transfer_thread_group.h" #include "pcfs/pcfs.h" @@ -414,6 +417,31 @@ PYBIND11_MODULE(c_ext, m) { py::arg("is_read"), py::arg("num_blocks_per_file"), py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false); + py::class_(m, "LayerwiseTransferGroup") + .def(py::init> &, + torch::Tensor &, std::map> &, + int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, + torch::Tensor &, int, int>(), + py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), + py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("gpu_kv_strides_tensor"), + py::arg("gpu_block_strides_tensor"), + py::arg("gpu_layer_strides_tensor"), + py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), + py::arg("iouring_flags")) + .def("layerwise_transfer", + &flexkv::LayerwiseTransferGroup::layerwise_transfer, + py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), + py::arg("ssd_layer_stride_in_bytes"), + py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), + py::arg("round_robin"), py::arg("num_threads_per_device"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), + py::arg("cpu_kv_stride_in_bytes"), + py::arg("cpu_layer_stride_in_bytes"), + py::arg("cpu_block_stride_in_bytes"), + py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), + py::arg("use_ce_transfer"), py::arg("layer_id"), + py::arg("layer_granularity"), py::arg("is_mla")); #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp new file mode 100644 index 0000000000..12f53369fe --- /dev/null +++ b/csrc/layerwise.cpp @@ -0,0 +1,242 @@ +#include "layerwise.h" +#include +#include + +namespace flexkv { + +LayerwiseTransferGroup::LayerwiseTransferGroup( + int num_gpus, const std::vector> &gpu_blocks, + torch::Tensor &cpu_blocks, + std::map> &ssd_files, int dp_group_id, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, + int iouring_flags) { + + num_gpus_ = num_gpus; + + gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_layer_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t *kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); + int64_t *block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t *layer_strides_ptr = gpu_layer_strides_tensor.data_ptr(); + int64_t *chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; + gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; + gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + gpu_layer_strides_in_bytes_[i] = layer_strides_ptr[i]; + } + + queues_.resize(num_gpus_); + mtxs_ = std::vector(num_gpus_); + cvs_ = std::vector(num_gpus_); + + num_tensors_per_gpu_ = gpu_blocks[0].size(); + cudaMallocHost((void **)&gpu_blocks_, + num_gpus_ * num_tensors_per_gpu_ * sizeof(void *)); + for (int i = 0; i < num_gpus_; ++i) { + for (int j = 0; j < num_tensors_per_gpu_; ++j) { + gpu_blocks_[i * num_tensors_per_gpu_ + j] = gpu_blocks[i][j].data_ptr(); + } + } + + if (num_tensors_per_gpu_ == 1) { + backend_type_ = BackendType::TRTLLM; + } else if (num_tensors_per_gpu_ == num_layers) { + backend_type_ = BackendType::VLLM; + } else if (num_tensors_per_gpu_ == num_layers * 2) { + backend_type_ = BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported GPU block type: " + + std::to_string(num_tensors_per_gpu_)); + } + + gpu_tensor_handlers_.reserve(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { + int64_t **gpu_blocks_ptr = + reinterpret_cast(gpu_blocks_ + i * num_tensors_per_gpu_); + gpu_tensor_handlers_.emplace_back( + backend_type_, gpu_blocks_ptr, num_layers, gpu_kv_strides_in_bytes_[i], + gpu_block_strides_in_bytes_[i], gpu_layer_strides_in_bytes_[i]); + } + + cpu_blocks_ = cpu_blocks.data_ptr(); + + dp_group_id_ = dp_group_id; + streams_.resize(num_gpus_); + for (int i = 0; i < num_gpus_; i += 1) { + cudaSetDevice(dp_group_id * num_gpus_ + i); + cudaStreamCreate(&streams_[i]); + } + + // Initialize SSD IO context if ssd_files is not empty + enable_ssd_ = !ssd_files.empty(); + if (enable_ssd_) { + ioctx_ = std::make_unique(ssd_files, ssd_files.size(), + iouring_entries, iouring_flags); + } + + // Create the thread pool + stop_pool_ = false; + for (int i = 0; i < num_gpus_; ++i) { + threads_.emplace_back([this, i]() { + int device_id = dp_group_id_ * num_gpus_ + i; + cudaSetDevice(device_id); + + while (true) { + Task task; + { + std::unique_lock lk(mtxs_[i]); + cvs_[i].wait(lk, [&] { return stop_pool_ || !queues_[i].empty(); }); + if (stop_pool_ && queues_[i].empty()) + return; + + task = std::move(queues_[i].front()); + queues_[i].pop(); + } + task(); + } + }); + } +} + +LayerwiseTransferGroup::~LayerwiseTransferGroup() { + stop_pool_ = true; + for (auto &cv : cvs_) + cv.notify_all(); + for (auto &t : threads_) + if (t.joinable()) + t.join(); + + cudaFreeHost(gpu_blocks_); + + gpu_tensor_handlers_.clear(); + delete[] gpu_kv_strides_in_bytes_; + delete[] gpu_block_strides_in_bytes_; + delete[] gpu_layer_strides_in_bytes_; + delete[] gpu_chunk_sizes_in_bytes_; +} + +std::future LayerwiseTransferGroup::enqueue_for_gpu(int gpu_idx, + Task task) { + auto pkg = std::make_shared>(std::move(task)); + auto fut = pkg->get_future(); + { + std::lock_guard lk(mtxs_[gpu_idx]); + queues_[gpu_idx].emplace([pkg] { (*pkg)(); }); + } + cvs_[gpu_idx].notify_one(); + return fut; +} + +void LayerwiseTransferGroup::layerwise_transfer( + const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids_d2h, + const int64_t ssd_layer_stride_in_bytes, + const int64_t ssd_kv_stride_in_bytes, const int num_blocks_per_file, + const int round_robin, const int num_threads_per_device, + const torch::Tensor &gpu_block_id_tensor, + const torch::Tensor &cpu_block_id_tensor, + const int64_t cpu_kv_stride_in_bytes, + const int64_t cpu_layer_stride_in_bytes, + const int64_t cpu_block_stride_in_bytes, + const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, + const bool use_ce_transfer, const int layer_id, const int layer_granularity, + const bool is_mla) { + + // Step 1: SSD -> CPU transfer (if ssd_block_ids is not empty) + if (enable_ssd_ && ssd_block_ids.numel() > 0) { + torch::Tensor layer_id_list = + torch::arange(layer_id, layer_id + layer_granularity, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *ioctx_, layer_id_list, reinterpret_cast(cpu_blocks_), + ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, + cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, + ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, + cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + num_blocks_per_file, round_robin, num_threads_per_device, is_mla); + } + + // Step 2: CPU -> GPU transfer + std::atomic failed{false}; + std::string error_msg; + std::vector> futures; + futures.reserve(num_gpus_); + + for (int i = 0; i < num_gpus_; ++i) { + futures.emplace_back(enqueue_for_gpu(i, [&, i]() { + try { + int num_blocks = gpu_block_id_tensor.numel(); + + int64_t *gpu_block_ids = + static_cast(gpu_block_id_tensor.data_ptr()); + int64_t *cpu_block_ids = + static_cast(cpu_block_id_tensor.data_ptr()); + void *cpu_ptr = cpu_blocks_; + int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; + if (is_mla) { + cpu_startoff_inside_chunks = 0; + } + int64_t gpu_startoff_inside_chunks = 0; + int64_t chunk_size = gpu_chunk_sizes_in_bytes_[i]; + + // Dispatch to the appropriate template based on backend type + switch (backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, + true, use_ce_transfer, is_mla); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, + true, use_ce_transfer, is_mla); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, + true, use_ce_transfer, is_mla); + break; + } + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + failed = true; + error_msg = cudaGetErrorString(err); + } + } catch (const std::exception &e) { + failed = true; + error_msg = e.what(); + } + })); + } + + for (auto &f : futures) { + f.get(); + } + + if (failed) { + throw std::runtime_error("layerwise_transfer failed: " + error_msg); + } +} + +} // namespace flexkv diff --git a/csrc/layerwise.h b/csrc/layerwise.h new file mode 100644 index 0000000000..83f894f875 --- /dev/null +++ b/csrc/layerwise.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtensor_handler.cuh" +#include "transfer.cuh" +#include "transfer_ssd.h" + +namespace flexkv { + +class LayerwiseTransferGroup { +public: + LayerwiseTransferGroup( + int num_gpus, const std::vector> &gpu_blocks, + torch::Tensor &cpu_blocks, + std::map> &ssd_files, int dp_group_id, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, + int iouring_flags); + + ~LayerwiseTransferGroup(); + + // Layerwise transfer: SSD->CPU + CPU->GPU in one call + void layerwise_transfer( + const torch::Tensor + &ssd_block_ids, // SSD source block ids (for disk2host) + const torch::Tensor + &cpu_block_ids_d2h, // CPU dest block ids (for disk2host) + const int64_t ssd_layer_stride_in_bytes, + const int64_t ssd_kv_stride_in_bytes, const int num_blocks_per_file, + const int round_robin, const int num_threads_per_device, + const torch::Tensor + &gpu_block_id_tensor, // GPU dest block ids (for host2device) + const torch::Tensor + &cpu_block_id_tensor, // CPU source block ids (for host2device) + const int64_t cpu_kv_stride_in_bytes, + const int64_t cpu_layer_stride_in_bytes, + const int64_t cpu_block_stride_in_bytes, + const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, + const bool use_ce_transfer, const int layer_id, + const int layer_granularity, const bool is_mla); + +private: + using Task = std::function; + std::future enqueue_for_gpu(int gpu_idx, Task task); + + int num_gpus_; + int dp_group_id_; + void **gpu_blocks_; + void *cpu_blocks_; + int num_tensors_per_gpu_; + int64_t *gpu_kv_strides_in_bytes_; + int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_layer_strides_in_bytes_; + int64_t *gpu_chunk_sizes_in_bytes_; + + BackendType backend_type_; + std::vector gpu_tensor_handlers_; + + std::vector threads_; + std::vector streams_; + + std::vector> queues_; + std::vector mtxs_; + std::vector cvs_; + std::atomic stop_pool_; + + // SSD IO context + bool enable_ssd_; + std::unique_ptr ioctx_; +}; + +} // namespace flexkv diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 260cdbc26a..f99087157a 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -17,7 +17,8 @@ from flexkv import c_ext from flexkv.c_ext import transfer_kv_blocks, transfer_kv_blocks_ssd, \ - transfer_kv_blocks_gds, TPTransferThreadGroup, TPGDSTransferThreadGroup + transfer_kv_blocks_gds, TPTransferThreadGroup, TPGDSTransferThreadGroup, \ + LayerwiseTransferGroup from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType @@ -109,8 +110,8 @@ def __init__(self, # initialize SSD storage self.enable_ssd = len(ssd_files) > 0 + self.ssd_files = ssd_files if self.enable_ssd: - self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) self.round_robin = 1 @@ -119,29 +120,26 @@ def __init__(self, self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize - - try: - self.ioctx = c_ext.SSDIOCTX( - ssd_files, - len(ssd_files), - GLOBAL_CONFIG_FROM_ENV.iouring_entries, - GLOBAL_CONFIG_FROM_ENV.iouring_flags - ) - except Exception as e: - flexkv_logger.error(f"Error setting ssd ioctx: {e}\n") - raise RuntimeError("SSD Worker init failed") from e + else: + self.num_blocks_per_file = 0 + self.round_robin = 1 + self.ssd_kv_stride_in_bytes = 0 + self.ssd_layer_stride_in_bytes = 0 + self.ssd_block_stride_in_bytes = 0 gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) - self.tp_transfer_thread_group = TPTransferThreadGroup( - self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, - self.num_layers, gpu_kv_strides_tensor, - gpu_block_strides_tensor, gpu_layer_strides_tensor, - gpu_chunk_sizes_tensor) - self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) + # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers + self.layerwise_transfer_group = LayerwiseTransferGroup( + self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, + dp_group_id, self.num_layers, + gpu_kv_strides_tensor, gpu_block_strides_tensor, + gpu_layer_strides_tensor, gpu_chunk_sizes_tensor, + GLOBAL_CONFIG_FROM_ENV.iouring_entries, + GLOBAL_CONFIG_FROM_ENV.iouring_flags) def _transfer_impl(self, src_block_ids_h2d: torch.Tensor, @@ -158,29 +156,20 @@ def _transfer_impl(self, assert src_block_ids_disk2h.dtype == torch.int64 assert dst_block_ids_disk2h.dtype == torch.int64 assert len(src_block_ids_disk2h) == len(dst_block_ids_disk2h) - layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) - if src_block_ids_disk2h is not None and \ - len(src_block_ids_disk2h) > 0: - assert self.enable_ssd, "SSD is not enabled for LayerwiseTransferWorker" - transfer_kv_blocks_ssd( - ioctx=self.ioctx, - cpu_layer_id_list=layer_id_list, - cpu_tensor_ptr=self.cpu_layer_ptrs[0].item(), - ssd_block_ids=src_block_ids_disk2h, - cpu_block_ids=dst_block_ids_disk2h, - cpu_layer_stride_in_bytes=self.cpu_layer_stride_in_bytes, - cpu_kv_stride_in_bytes=self.cpu_kv_stride_in_bytes, - ssd_layer_stride_in_bytes=self.ssd_layer_stride_in_bytes, - ssd_kv_stride_in_bytes=self.ssd_kv_stride_in_bytes, - chunk_size_in_bytes=self.cpu_chunk_size_in_bytes, - block_stride_in_bytes=self.cpu_block_stride_in_bytes, - is_read=True, - num_blocks_per_file=self.num_blocks_per_file, - round_robin=self.round_robin, - num_threads_per_device=32, - is_mla=self.is_mla, - ) - self.tp_transfer_thread_group.tp_group_transfer( + + # Use unified layerwise transfer C++ interface + ssd_block_ids = src_block_ids_disk2h if src_block_ids_disk2h is not None else torch.empty(0, dtype=torch.int64) + cpu_block_ids_d2h = dst_block_ids_disk2h if dst_block_ids_disk2h is not None \ + else torch.empty(0, dtype=torch.int64) + + self.layerwise_transfer_group.layerwise_transfer( + ssd_block_ids, + cpu_block_ids_d2h, + self.ssd_layer_stride_in_bytes, + self.ssd_kv_stride_in_bytes, + self.num_blocks_per_file, + self.round_robin, + 32, # num_threads_per_device dst_block_ids_h2d, src_block_ids_h2d, self.cpu_kv_stride_in_bytes, @@ -188,7 +177,6 @@ def _transfer_impl(self, self.cpu_block_stride_in_bytes, self.cpu_chunk_size_in_bytes, self.transfer_sms_h2d, - True, # is H2D self.use_ce_transfer_h2d, layer_id, layer_granularity, diff --git a/setup.py b/setup.py index bbbbeb7972..641296232c 100755 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ def get_version(): "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", "csrc/radix_tree.cpp", + "csrc/layerwise.cpp" "csrc/monitoring/metrics_manager.cpp", # Monitoring support ] @@ -42,6 +43,7 @@ def get_version(): "csrc/tp_transfer_thread_group.h", "csrc/transfer_ssd.h", "csrc/radix_tree.h", + "csrc/layerwise.h", "csrc/monitoring/metrics_manager.h", # Monitoring support ] From 68814138eb22bb6f1c5ed5774c62ebe9d4f17e2c Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 9 Jan 2026 01:57:28 -0800 Subject: [PATCH 14/59] add callback && fix some bugs --- csrc/bindings.cpp | 37 +++++- csrc/layerwise.cpp | 239 ++++++++++++++++------------------- csrc/layerwise.h | 23 +--- csrc/transfer.cu | 16 ++- csrc/transfer.cuh | 3 +- flexkv/transfer/layerwise.py | 9 +- 6 files changed, 161 insertions(+), 166 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index c1123a881c..a1d5627d00 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -48,7 +48,8 @@ void transfer_kv_blocks_binding( int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes, int start_layer_id, int num_layers, int transfer_num_cta = 4, bool is_host_to_device = true, - bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0) { + bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0, + bool sync = true) { int num_blocks = gpu_block_id_tensor.numel(); int64_t *gpu_block_ids = @@ -88,7 +89,7 @@ void transfer_kv_blocks_binding( cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_num_cta, is_host_to_device, - use_ce_transfer, is_mla); + use_ce_transfer, is_mla, sync); break; case flexkv::BackendType::TRTLLM: flexkv::transfer_kv_blocks( @@ -96,7 +97,7 @@ void transfer_kv_blocks_binding( cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_num_cta, is_host_to_device, - use_ce_transfer, is_mla); + use_ce_transfer, is_mla, sync); break; case flexkv::BackendType::SGLANG: flexkv::transfer_kv_blocks( @@ -104,7 +105,7 @@ void transfer_kv_blocks_binding( cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_num_cta, is_host_to_device, - use_ce_transfer, is_mla); + use_ce_transfer, is_mla, sync); break; } @@ -406,7 +407,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("start_layer_id"), py::arg("num_layers"), py::arg("transfer_num_cta") = 4, py::arg("is_host_to_device") = true, py::arg("use_ce_transfer") = false, py::arg("is_mla") = false, - py::arg("gpu_block_type") = 0); + py::arg("gpu_block_type") = 0, py::arg("sync") = true); m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding, "Transfer KV blocks between SSD and CPU memory", py::arg("ioctx"), py::arg("cpu_layer_id_list"), py::arg("cpu_tensor_ptr"), @@ -531,6 +532,32 @@ PYBIND11_MODULE(c_ext, m) { py::arg("layer_granularity"), py::arg("is_mla")); #endif + py::class_(m, "LayerwiseTransferGroup") + .def(py::init> &, + torch::Tensor &, std::map> &, + int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, + torch::Tensor &, int, int>(), + py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), + py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("gpu_kv_strides_tensor"), + py::arg("gpu_block_strides_tensor"), + py::arg("gpu_layer_strides_tensor"), + py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), + py::arg("iouring_flags")) + .def("layerwise_transfer", + &flexkv::LayerwiseTransferGroup::layerwise_transfer, + py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), + py::arg("ssd_layer_stride_in_bytes"), + py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), + py::arg("round_robin"), py::arg("num_threads_per_device"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), + py::arg("cpu_kv_stride_in_bytes"), + py::arg("cpu_layer_stride_in_bytes"), + py::arg("cpu_block_stride_in_bytes"), + py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), + py::arg("use_ce_transfer"), py::arg("num_layers"), + py::arg("layer_granularity"), py::arg("is_mla")); + // Add Hasher class binding py::class_(m, "Hasher") .def(py::init<>()) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 12f53369fe..904f466f4e 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -1,9 +1,31 @@ #include "layerwise.h" +#include +#include #include #include namespace flexkv { +struct LayerCallbackData { + int start_layer; + int layers_this_batch; + int num_gpus; + std::atomic *counter; +}; + +static void CUDART_CB layer_done_host_callback(void *userData) { + LayerCallbackData *data = static_cast(userData); + int completed = data->counter->fetch_add(1) + 1; + if (completed == data->num_gpus) { + printf( + "[LayerwiseTransfer] All %d GPUs: Layers [%d, %d) transfer completed\n", + data->num_gpus, data->start_layer, + data->start_layer + data->layers_this_batch); + delete data->counter; + } + delete data; +} + LayerwiseTransferGroup::LayerwiseTransferGroup( int num_gpus, const std::vector> &gpu_blocks, torch::Tensor &cpu_blocks, @@ -33,10 +55,6 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( gpu_layer_strides_in_bytes_[i] = layer_strides_ptr[i]; } - queues_.resize(num_gpus_); - mtxs_ = std::vector(num_gpus_); - cvs_ = std::vector(num_gpus_); - num_tensors_per_gpu_ = gpu_blocks[0].size(); cudaMallocHost((void **)&gpu_blocks_, num_gpus_ * num_tensors_per_gpu_ * sizeof(void *)); @@ -69,10 +87,14 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( cpu_blocks_ = cpu_blocks.data_ptr(); dp_group_id_ = dp_group_id; + + // Create CUDA streams for each GPU streams_.resize(num_gpus_); - for (int i = 0; i < num_gpus_; i += 1) { + events_.resize(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { cudaSetDevice(dp_group_id * num_gpus_ + i); cudaStreamCreate(&streams_[i]); + cudaEventCreate(&events_[i]); } // Initialize SSD IO context if ssd_files is not empty @@ -81,38 +103,14 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( ioctx_ = std::make_unique(ssd_files, ssd_files.size(), iouring_entries, iouring_flags); } - - // Create the thread pool - stop_pool_ = false; - for (int i = 0; i < num_gpus_; ++i) { - threads_.emplace_back([this, i]() { - int device_id = dp_group_id_ * num_gpus_ + i; - cudaSetDevice(device_id); - - while (true) { - Task task; - { - std::unique_lock lk(mtxs_[i]); - cvs_[i].wait(lk, [&] { return stop_pool_ || !queues_[i].empty(); }); - if (stop_pool_ && queues_[i].empty()) - return; - - task = std::move(queues_[i].front()); - queues_[i].pop(); - } - task(); - } - }); - } } LayerwiseTransferGroup::~LayerwiseTransferGroup() { - stop_pool_ = true; - for (auto &cv : cvs_) - cv.notify_all(); - for (auto &t : threads_) - if (t.joinable()) - t.join(); + for (int i = 0; i < num_gpus_; i++) { + cudaSetDevice(dp_group_id_ * num_gpus_ + i); + cudaStreamDestroy(streams_[i]); + cudaEventDestroy(events_[i]); + } cudaFreeHost(gpu_blocks_); @@ -123,16 +121,14 @@ LayerwiseTransferGroup::~LayerwiseTransferGroup() { delete[] gpu_chunk_sizes_in_bytes_; } -std::future LayerwiseTransferGroup::enqueue_for_gpu(int gpu_idx, - Task task) { - auto pkg = std::make_shared>(std::move(task)); - auto fut = pkg->get_future(); - { - std::lock_guard lk(mtxs_[gpu_idx]); - queues_[gpu_idx].emplace([pkg] { (*pkg)(); }); +void LayerwiseTransferGroup::layer_done_callback(int start_layer, + int layers_this_batch) { + std::atomic *counter = new std::atomic(0); + for (int i = 0; i < num_gpus_; ++i) { + LayerCallbackData *data = new LayerCallbackData{ + start_layer, layers_this_batch, num_gpus_, counter}; + cudaLaunchHostFunc(streams_[i], layer_done_host_callback, data); } - cvs_[gpu_idx].notify_one(); - return fut; } void LayerwiseTransferGroup::layerwise_transfer( @@ -146,96 +142,83 @@ void LayerwiseTransferGroup::layerwise_transfer( const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, - const bool use_ce_transfer, const int layer_id, const int layer_granularity, - const bool is_mla) { - - // Step 1: SSD -> CPU transfer (if ssd_block_ids is not empty) - if (enable_ssd_ && ssd_block_ids.numel() > 0) { - torch::Tensor layer_id_list = - torch::arange(layer_id, layer_id + layer_granularity, - torch::TensorOptions().dtype(torch::kInt32)); - transfer_kv_blocks_ssd( - *ioctx_, layer_id_list, reinterpret_cast(cpu_blocks_), - ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, - cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, - ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, - cpu_block_stride_in_bytes, - true, // is_read: SSD -> CPU - num_blocks_per_file, round_robin, num_threads_per_device, is_mla); - } + const bool use_ce_transfer, const int num_layers, + const int layer_granularity, const bool is_mla) { + + int num_blocks = gpu_block_id_tensor.numel(); + int64_t *gpu_block_ids = + static_cast(gpu_block_id_tensor.data_ptr()); + int64_t *cpu_block_ids = + static_cast(cpu_block_id_tensor.data_ptr()); + void *cpu_ptr = cpu_blocks_; + + for (int start_layer = 0; start_layer < num_layers; + start_layer += layer_granularity) { + int layers_this_batch = + std::min(layer_granularity, num_layers - start_layer); + // Step 1: SSD -> CPU transfer + if (enable_ssd_ && ssd_block_ids.numel() > 0) { + torch::Tensor layer_id_list = + torch::arange(start_layer, start_layer + layers_this_batch, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *ioctx_, layer_id_list, reinterpret_cast(cpu_blocks_), + ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, + cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, + ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, + cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + num_blocks_per_file, round_robin, num_threads_per_device, is_mla); + } - // Step 2: CPU -> GPU transfer - std::atomic failed{false}; - std::string error_msg; - std::vector> futures; - futures.reserve(num_gpus_); + // Step 2: CPU -> GPU transfer + for (int i = 0; i < num_gpus_; ++i) { + cudaSetDevice(dp_group_id_ * num_gpus_ + i); - for (int i = 0; i < num_gpus_; ++i) { - futures.emplace_back(enqueue_for_gpu(i, [&, i]() { - try { - int num_blocks = gpu_block_id_tensor.numel(); - - int64_t *gpu_block_ids = - static_cast(gpu_block_id_tensor.data_ptr()); - int64_t *cpu_block_ids = - static_cast(cpu_block_id_tensor.data_ptr()); - void *cpu_ptr = cpu_blocks_; - int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; - if (is_mla) { - cpu_startoff_inside_chunks = 0; - } - int64_t gpu_startoff_inside_chunks = 0; - int64_t chunk_size = gpu_chunk_sizes_in_bytes_[i]; - - // Dispatch to the appropriate template based on backend type - switch (backend_type_) { - case BackendType::VLLM: - flexkv::transfer_kv_blocks( - num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, - cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, - cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, - true, use_ce_transfer, is_mla); - break; - case BackendType::TRTLLM: - flexkv::transfer_kv_blocks( - num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, - cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, - cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, - true, use_ce_transfer, is_mla); - break; - case BackendType::SGLANG: - flexkv::transfer_kv_blocks( - num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, - cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, - cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_sms, - true, use_ce_transfer, is_mla); - break; - } - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - failed = true; - error_msg = cudaGetErrorString(err); - } - } catch (const std::exception &e) { - failed = true; - error_msg = e.what(); + int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; + if (is_mla) { + cpu_startoff_inside_chunks = 0; } - })); - } + int64_t gpu_startoff_inside_chunks = 0; + int64_t chunk_size = gpu_chunk_sizes_in_bytes_[i]; + + switch (backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + break; + } + } - for (auto &f : futures) { - f.get(); + layer_done_callback(start_layer, layers_this_batch); } - - if (failed) { - throw std::runtime_error("layerwise_transfer failed: " + error_msg); + for (int i = 0; i < num_gpus_; ++i) { + cudaError_t err = cudaStreamSynchronize(streams_[i]); + if (err != cudaSuccess) { + throw std::runtime_error("layerwise_transfer failed on GPU " + + std::to_string(i) + ": " + + cudaGetErrorString(err)); + } } } diff --git a/csrc/layerwise.h b/csrc/layerwise.h index 83f894f875..94487de565 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -1,17 +1,10 @@ #pragma once -#include -#include #include #include -#include -#include #include #include -#include -#include #include -#include #include #include @@ -35,7 +28,7 @@ class LayerwiseTransferGroup { ~LayerwiseTransferGroup(); - // Layerwise transfer: SSD->CPU + CPU->GPU in one call + // Layerwise transfer: SSD->CPU + CPU->GPU void layerwise_transfer( const torch::Tensor &ssd_block_ids, // SSD source block ids (for disk2host) @@ -52,13 +45,10 @@ class LayerwiseTransferGroup { const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, - const bool use_ce_transfer, const int layer_id, + const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla); private: - using Task = std::function; - std::future enqueue_for_gpu(int gpu_idx, Task task); - int num_gpus_; int dp_group_id_; void **gpu_blocks_; @@ -72,17 +62,14 @@ class LayerwiseTransferGroup { BackendType backend_type_; std::vector gpu_tensor_handlers_; - std::vector threads_; std::vector streams_; - - std::vector> queues_; - std::vector mtxs_; - std::vector cvs_; - std::atomic stop_pool_; + std::vector events_; // SSD IO context bool enable_ssd_; std::unique_ptr ioctx_; + + void layer_done_callback(int start_layer, int layers_this_batch); }; } // namespace flexkv diff --git a/csrc/transfer.cu b/csrc/transfer.cu index 60ac276857..f46412e406 100644 --- a/csrc/transfer.cu +++ b/csrc/transfer.cu @@ -87,7 +87,7 @@ void transfer_kv_blocks( int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_num_cta, bool is_host_to_device, - bool use_ce_transfer, bool is_mla) { + bool use_ce_transfer, bool is_mla, bool sync) { int block_size = 1024; @@ -120,8 +120,8 @@ void transfer_kv_blocks( j * cpu_kv_stride_int64 + cpu_block_idx * cpu_block_stride_int64 + cpu_startoff_inside_chunks_int64; - int64_t *gpu_ptr = - ptr_at(gpu_tensor_handler, i, j, gpu_block_idx); + int64_t *gpu_ptr = ptr_at(gpu_tensor_handler, + i + start_layer_id, j, gpu_block_idx); int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks_int64; @@ -167,7 +167,9 @@ void transfer_kv_blocks( actual_chunk_bytes * static_cast(num_layers) * static_cast(kv_dim) * static_cast(num_blocks)); } - cudaStreamSynchronize(stream); + if (sync) { + cudaStreamSynchronize(stream); + } } // Explicit template instantiations @@ -176,16 +178,16 @@ template void transfer_kv_blocks(int, int, int, int64_t *, int64_t *, void *, int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, - bool, bool, bool); + bool, bool, bool, bool); template void transfer_kv_blocks( int, int, int, int64_t *, GTensorHandler, int64_t, int64_t *, void *, int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, - bool); + bool, bool); template void transfer_kv_blocks( int, int, int, int64_t *, GTensorHandler, int64_t, int64_t *, void *, int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, - bool); + bool, bool); } // namespace flexkv diff --git a/csrc/transfer.cuh b/csrc/transfer.cuh index 7436b2e887..5aab0af1d8 100644 --- a/csrc/transfer.cuh +++ b/csrc/transfer.cuh @@ -30,6 +30,7 @@ void transfer_kv_blocks( int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_num_cta, - bool is_host_to_device, bool use_ce_transfer, bool is_mla); + bool is_host_to_device, bool use_ce_transfer, bool is_mla, + bool sync = true); } // namespace flexkv diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index f99087157a..8eba9adb1d 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -146,7 +146,6 @@ def _transfer_impl(self, dst_block_ids_h2d: torch.Tensor, src_block_ids_disk2h: Optional[torch.Tensor], dst_block_ids_disk2h: Optional[torch.Tensor], - layer_id: int, layer_granularity: int, **kwargs: Any) -> None: assert src_block_ids_h2d.dtype == torch.int64 @@ -178,16 +177,13 @@ def _transfer_impl(self, self.cpu_chunk_size_in_bytes, self.transfer_sms_h2d, self.use_ce_transfer_h2d, - layer_id, + self.num_layers, layer_granularity, self.is_mla, ) def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: - layer_id = transfer_op.layer_id layer_granularity = transfer_op.layer_granularity - if layer_id == -1: - layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers @@ -200,12 +196,11 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: else: src_block_ids_disk2h = None dst_block_ids_disk2h = None - layer_granularity = self.num_layers # TODO: remove this + self._transfer_impl( src_block_ids_h2d, dst_block_ids_h2d, src_block_ids_disk2h, dst_block_ids_disk2h, - layer_id, layer_granularity, ) From 7c7ab2763f29c07f976c21453c25bc2f57150f85 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 9 Jan 2026 02:40:30 -0800 Subject: [PATCH 15/59] fix --- csrc/bindings.cpp | 261 +++++++++++++++++++++++----------------------- 1 file changed, 131 insertions(+), 130 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index a1d5627d00..3ac1a2089a 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -1,8 +1,8 @@ #include #include -#include #include #include +#include #include "transfer.cuh" #include @@ -17,7 +17,6 @@ #include #include "cache_utils.h" -#include "layerwise.h" #include "pcfs/pcfs.h" #include "tp_transfer_thread_group.h" #include "gds/gds_manager.h" @@ -61,7 +60,7 @@ void transfer_kv_blocks_binding( void *cpu_ptr = static_cast(cpu_tensor.data_ptr()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + // Determine backend type from gpu_block_type parameter flexkv::BackendType backend_type; if (gpu_block_type == 0) { @@ -71,16 +70,19 @@ void transfer_kv_blocks_binding( } else if (gpu_block_type == 2) { backend_type = flexkv::BackendType::SGLANG; } else { - throw std::runtime_error("Unsupported gpu_block_type: " + - std::to_string(gpu_block_type)); + throw std::runtime_error("Unsupported gpu_block_type: " + std::to_string(gpu_block_type)); } - + // Create GTensorHandler flexkv::GTensorHandler handler( - backend_type, reinterpret_cast(gpu_tensor_ptrs), num_layers, - gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, - gpu_layer_stride_in_bytes); - + backend_type, + reinterpret_cast(gpu_tensor_ptrs), + num_layers, + gpu_kv_stride_in_bytes, + gpu_block_stride_in_bytes, + gpu_layer_stride_in_bytes + ); + // Dispatch to appropriate template instantiation switch (backend_type) { case flexkv::BackendType::VLLM: @@ -108,7 +110,7 @@ void transfer_kv_blocks_binding( use_ce_transfer, is_mla, sync); break; } - + cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { throw std::runtime_error(cudaGetErrorString(err)); @@ -116,21 +118,22 @@ void transfer_kv_blocks_binding( } void transfer_kv_blocks_ssd_binding( - flexkv::SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, - int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids, - const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes, - int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, - int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes, - int64_t block_stride_in_bytes, bool is_read, int num_blocks_per_file, - int round_robin = 1, int num_threads_per_device = 8, bool is_mla = false) { + flexkv::SSDIOCTX &ioctx, + const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, + const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids, + int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, + int64_t ssd_layer_stride_in_bytes, int64_t ssd_kv_stride_in_bytes, + int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read, + int num_blocks_per_file, int round_robin = 1, + int num_threads_per_device = 8, bool is_mla = false) { TORCH_CHECK(ssd_block_ids.dtype() == torch::kInt64, "ssd_block_ids must be int64"); TORCH_CHECK(cpu_block_ids.dtype() == torch::kInt64, "cpu_block_ids must be int64"); flexkv::transfer_kv_blocks_ssd( - ioctx, cpu_layer_id_list, cpu_tensor_ptr, ssd_block_ids, cpu_block_ids, - cpu_layer_stride_in_bytes, cpu_kv_stride_in_bytes, + ioctx, cpu_layer_id_list, cpu_tensor_ptr, ssd_block_ids, + cpu_block_ids, cpu_layer_stride_in_bytes, cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, ssd_kv_stride_in_bytes, chunk_size_in_bytes, block_stride_in_bytes, is_read, num_blocks_per_file, round_robin, num_threads_per_device, is_mla); @@ -279,111 +282,112 @@ void transfer_kv_blocks_gds_binding( } // GDS Manager Python bindings -py::list gds_batch_write_binding(GDSManager &manager, +py::list gds_batch_write_binding(GDSManager& manager, py::list operations_list) { - size_t batch_size = operations_list.size(); - std::vector operations(batch_size); - std::vector results(batch_size); - - for (size_t i = 0; i < batch_size; ++i) { - py::dict op_dict = operations_list[i].cast(); - operations[i].filename = op_dict["filename"].cast().c_str(); - operations[i].gpu_data = - op_dict["gpu_data"].cast().data_ptr(); - operations[i].size = op_dict["size"].cast(); - operations[i].file_offset = op_dict["file_offset"].cast(); - operations[i].result = &results[i]; - } - - int batch_id = manager.batch_write(operations.data(), batch_size); - - py::list result_list; - result_list.append(batch_id); - for (size_t i = 0; i < batch_size; ++i) { - result_list.append(results[i]); - } - - return result_list; + size_t batch_size = operations_list.size(); + std::vector operations(batch_size); + std::vector results(batch_size); + + for (size_t i = 0; i < batch_size; ++i) { + py::dict op_dict = operations_list[i].cast(); + operations[i].filename = op_dict["filename"].cast().c_str(); + operations[i].gpu_data = op_dict["gpu_data"].cast().data_ptr(); + operations[i].size = op_dict["size"].cast(); + operations[i].file_offset = op_dict["file_offset"].cast(); + operations[i].result = &results[i]; + } + + int batch_id = manager.batch_write(operations.data(), batch_size); + + py::list result_list; + result_list.append(batch_id); + for (size_t i = 0; i < batch_size; ++i) { + result_list.append(results[i]); + } + + return result_list; } -py::list gds_batch_read_binding(GDSManager &manager, py::list operations_list) { - size_t batch_size = operations_list.size(); - std::vector operations(batch_size); - std::vector results(batch_size); - - for (size_t i = 0; i < batch_size; ++i) { - py::dict op_dict = operations_list[i].cast(); - operations[i].filename = op_dict["filename"].cast().c_str(); - operations[i].gpu_buffer = - op_dict["gpu_buffer"].cast().data_ptr(); - operations[i].size = op_dict["size"].cast(); - operations[i].file_offset = op_dict["file_offset"].cast(); - operations[i].result = &results[i]; - } - - int batch_id = manager.batch_read(operations.data(), batch_size); - - py::list result_list; - result_list.append(batch_id); - for (size_t i = 0; i < batch_size; ++i) { - result_list.append(results[i]); - } - - return result_list; +py::list gds_batch_read_binding(GDSManager& manager, + py::list operations_list) { + size_t batch_size = operations_list.size(); + std::vector operations(batch_size); + std::vector results(batch_size); + + for (size_t i = 0; i < batch_size; ++i) { + py::dict op_dict = operations_list[i].cast(); + operations[i].filename = op_dict["filename"].cast().c_str(); + operations[i].gpu_buffer = op_dict["gpu_buffer"].cast().data_ptr(); + operations[i].size = op_dict["size"].cast(); + operations[i].file_offset = op_dict["file_offset"].cast(); + operations[i].result = &results[i]; + } + + int batch_id = manager.batch_read(operations.data(), batch_size); + + py::list result_list; + result_list.append(batch_id); + for (size_t i = 0; i < batch_size; ++i) { + result_list.append(results[i]); + } + + return result_list; } -ssize_t gds_write_binding(GDSManager &manager, const std::string &filename, - torch::Tensor gpu_data, size_t file_offset = 0) { - return manager.write(filename.c_str(), gpu_data.data_ptr(), - gpu_data.numel() * gpu_data.element_size(), file_offset); +ssize_t gds_write_binding(GDSManager& manager, + const std::string& filename, + torch::Tensor gpu_data, + size_t file_offset = 0) { + return manager.write(filename.c_str(), gpu_data.data_ptr(), + gpu_data.numel() * gpu_data.element_size(), file_offset); } -ssize_t gds_read_binding(GDSManager &manager, const std::string &filename, - torch::Tensor gpu_buffer, size_t file_offset = 0) { - return manager.read(filename.c_str(), gpu_buffer.data_ptr(), - gpu_buffer.numel() * gpu_buffer.element_size(), - file_offset); +ssize_t gds_read_binding(GDSManager& manager, + const std::string& filename, + torch::Tensor gpu_buffer, + size_t file_offset = 0) { + return manager.read(filename.c_str(), gpu_buffer.data_ptr(), + gpu_buffer.numel() * gpu_buffer.element_size(), file_offset); } -ssize_t gds_write_async_binding(GDSManager &manager, - const std::string &filename, - torch::Tensor gpu_data, - size_t file_offset = 0) { - return manager.write_async(filename.c_str(), gpu_data.data_ptr(), - gpu_data.numel() * gpu_data.element_size(), - file_offset); +ssize_t gds_write_async_binding(GDSManager& manager, + const std::string& filename, + torch::Tensor gpu_data, + size_t file_offset = 0) { + return manager.write_async(filename.c_str(), gpu_data.data_ptr(), + gpu_data.numel() * gpu_data.element_size(), file_offset); } -ssize_t gds_read_async_binding(GDSManager &manager, const std::string &filename, - torch::Tensor gpu_buffer, - size_t file_offset = 0) { - return manager.read_async(filename.c_str(), gpu_buffer.data_ptr(), - gpu_buffer.numel() * gpu_buffer.element_size(), - file_offset); +ssize_t gds_read_async_binding(GDSManager& manager, + const std::string& filename, + torch::Tensor gpu_buffer, + size_t file_offset = 0) { + return manager.read_async(filename.c_str(), gpu_buffer.data_ptr(), + gpu_buffer.numel() * gpu_buffer.element_size(), file_offset); } // Helper function to create and initialize a GDS file with specified size -bool create_gds_file_binding(GDSManager &manager, const std::string &filename, +bool create_gds_file_binding(GDSManager& manager, + const std::string& filename, size_t file_size) { - // First create/truncate the file to the desired size - int fd = open(filename.c_str(), O_CREAT | O_RDWR | O_TRUNC, 0644); - if (fd < 0) { - return false; - } - - // Pre-allocate the file to the specified size - if (ftruncate(fd, file_size) != 0) { + // First create/truncate the file to the desired size + int fd = open(filename.c_str(), O_CREAT | O_RDWR | O_TRUNC, 0644); + if (fd < 0) { + return false; + } + + // Pre-allocate the file to the specified size + if (ftruncate(fd, file_size) != 0) { + close(fd); + return false; + } + + // Ensure data is written to disk + fsync(fd); close(fd); - return false; - } - - // Ensure data is written to disk - fsync(fd); - close(fd); - - // Now add the file to GDS manager (this will open it with O_DIRECT and - // register with cuFile) - return manager.add_file(filename.c_str()); + + // Now add the file to GDS manager (this will open it with O_DIRECT and register with cuFile) + return manager.add_file(filename.c_str()); } #endif @@ -441,7 +445,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), - py::arg("use_ce_transfer"), py::arg("layer_id"), + py::arg("use_ce_transfer"), py::arg("num_layers"), py::arg("layer_granularity"), py::arg("is_mla")); #ifdef FLEXKV_ENABLE_CFS @@ -481,8 +485,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("block_hashes")); py::class_(m, "SSDIOCTX") - .def( - py::init> &, int, int, int>()); + .def(py::init> &, int, int, int>()); py::class_(m, "TPTransferThreadGroup") .def(py::init &, int, int64_t, int, int, @@ -642,8 +645,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("evicted_blocks"), py::arg("evicted_block_hashes"), py::arg("num_evicted"), py::call_guard()) .def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks) - .def("total_unready_blocks", - &flexkv::CRadixTreeIndex::total_unready_blocks) + .def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks) .def("total_ready_blocks", &flexkv::CRadixTreeIndex::total_ready_blocks) .def("match_prefix", &flexkv::CRadixTreeIndex::match_prefix, py::arg("block_hashes"), py::arg("num_blocks"), @@ -675,35 +677,33 @@ PYBIND11_MODULE(c_ext, m) { #ifdef FLEXKV_ENABLE_GDS // Add GDS Manager class binding py::class_(m, "GDSManager") - .def(py::init> &, int, int>(), + .def(py::init>&, int, int>(), "Initialize GDS Manager with device-organized files", - py::arg("ssd_files"), py::arg("num_devices"), - py::arg("round_robin") = 1) + py::arg("ssd_files"), py::arg("num_devices"), py::arg("round_robin") = 1) .def("is_ready", &GDSManager::is_ready, "Check if GDS manager is ready for operations") .def("get_last_error", &GDSManager::get_last_error, "Get the last error message") .def("add_file", &GDSManager::add_file, - "Add and register a file with GDS (creates with O_DIRECT)", - py::arg("filename")) + "Add and register a file with GDS (creates with O_DIRECT)", py::arg("filename")) .def("remove_file", &GDSManager::remove_file, "Remove and unregister a file from GDS", py::arg("filename")) - .def("write", &gds_write_binding, "Write data from GPU memory to file", + .def("write", &gds_write_binding, + "Write data from GPU memory to file", py::arg("filename"), py::arg("gpu_data"), py::arg("file_offset") = 0) - .def("read", &gds_read_binding, "Read data from file to GPU memory", - py::arg("filename"), py::arg("gpu_buffer"), - py::arg("file_offset") = 0) + .def("read", &gds_read_binding, + "Read data from file to GPU memory", + py::arg("filename"), py::arg("gpu_buffer"), py::arg("file_offset") = 0) .def("write_async", &gds_write_async_binding, "Write data from GPU memory to file asynchronously", py::arg("filename"), py::arg("gpu_data"), py::arg("file_offset") = 0) .def("read_async", &gds_read_async_binding, "Read data from file to GPU memory asynchronously", - py::arg("filename"), py::arg("gpu_buffer"), - py::arg("file_offset") = 0) - .def("batch_write", &gds_batch_write_binding, "Batch write operations", - py::arg("operations")) - .def("batch_read", &gds_batch_read_binding, "Batch read operations", - py::arg("operations")) + py::arg("filename"), py::arg("gpu_buffer"), py::arg("file_offset") = 0) + .def("batch_write", &gds_batch_write_binding, + "Batch write operations", py::arg("operations")) + .def("batch_read", &gds_batch_read_binding, + "Batch read operations", py::arg("operations")) .def("batch_synchronize", &GDSManager::batch_synchronize, "Wait for batch operations to complete", py::arg("batch_id")) .def("synchronize", &GDSManager::synchronize, @@ -717,7 +717,8 @@ PYBIND11_MODULE(c_ext, m) { .def("get_round_robin", &GDSManager::get_round_robin, "Get round-robin granularity") .def("get_file_paths", &GDSManager::get_file_paths, - "Get file paths for a specific device", py::arg("device_id")) + "Get file paths for a specific device", + py::arg("device_id")) .def("create_gds_file", &create_gds_file_binding, "Create and register a GDS file with specified size", py::arg("filename"), py::arg("file_size")); From ed4187da1e412ef2c70d1adc789eac98cb402bf7 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 12 Jan 2026 06:23:18 +0000 Subject: [PATCH 16/59] some fix --- csrc/layerwise.cpp | 2 ++ flexkv/transfer/layerwise.py | 9 +-------- tests/test_kvmanager.py | 20 ++++++++++---------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 904f466f4e..508939f70c 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -17,6 +17,8 @@ static void CUDART_CB layer_done_host_callback(void *userData) { LayerCallbackData *data = static_cast(userData); int completed = data->counter->fetch_add(1) + 1; if (completed == data->num_gpus) { + // TODO: use eventfd to notify the consumer that [start_layer, start_layer + + // layers_this_batch) transfer completed printf( "[LayerwiseTransfer] All %d GPUs: Layers [%d, %d) transfer completed\n", data->num_gpus, data->start_layer, diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 8eba9adb1d..c7b4d00bbe 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -16,9 +16,7 @@ from flexkv import c_ext -from flexkv.c_ext import transfer_kv_blocks, transfer_kv_blocks_ssd, \ - transfer_kv_blocks_gds, TPTransferThreadGroup, TPGDSTransferThreadGroup, \ - LayerwiseTransferGroup +from flexkv.c_ext import LayerwiseTransferGroup from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType @@ -26,11 +24,6 @@ from flexkv.common.transfer import get_nvtx_range_color from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV -try: - from flexkv.c_ext import transfer_kv_blocks_remote -except ImportError: - transfer_kv_blocks_remote = None - from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index aef87bba8d..6f8393023d 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -366,19 +366,19 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): # =============== Test batched launched get =============== if not enable_gds: print("\n========== Testing batched launched get ==========") - + # Use the first few request_pairs that were written in initial phase batch_size = 6 - + batched_get_task_ids = [] batched_slot_mappings = [] batched_req_info = [] # Store (token_ids, block_ids) for verification - + # Create multiple get_match requests for i in range(batch_size): token_ids, block_ids, dp_id = request_pairs[random.randint(0, num_requests - 1)] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) - + request_id, return_mask = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, @@ -390,7 +390,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): batched_slot_mappings.append(slot_mapping) batched_req_info.append((token_ids, block_ids, request_id)) print(f"Created get_match request {request_id} for request_pair[{i}]") - + # Launch all get requests as a batch print(f"Launching {len(batched_get_task_ids)} get requests as batch...") batch_id = kvmanager.launch( @@ -399,12 +399,12 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): as_batch=True )[0] print(f"Returned task_ids after batch launch: {batch_id}") - + # Wait for the batched get to complete # When as_batch=True, launch returns [batch_id], we need to wait on batch_id batch_results = kvmanager.wait(batch_id, completely=True) print(f"Batch wait returned {len(batch_results)} results") - + # Verify results batched_cache_hit = 0 batched_cache_miss = 0 @@ -415,7 +415,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): batched_cache_hit += return_mask.sum().item() batched_cache_miss += len(return_mask) - return_mask.sum().item() print(f"Task {batch_id}: cache_hit={batched_cache_hit}, cache_miss={batched_cache_miss}") - + # GPU KV cache verification for batched get if gpu_kv_verifier is not None: for idx, (token_ids, block_ids, req_id) in enumerate(batched_req_info): @@ -429,9 +429,9 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): token_ids[:valid_fetched_tokens], block_ids[:valid_fetched_tokens // tokens_per_block] ) - + print(f"Batched get test completed: hit={batched_cache_hit}, miss={batched_cache_miss}") - + # Since we read data that was written before, cache hit should be high if enable_cpu and num_cpu_blocks >= num_gpu_blocks: assert batched_cache_miss == 0, \ From 79dc6cebdffe26ff8c872a9b5e7ae0447e78d841 Mon Sep 17 00:00:00 2001 From: jianyingzhu Date: Thu, 15 Jan 2026 08:09:24 +0000 Subject: [PATCH 17/59] add sglang support using eventfd --- csrc/bindings.cpp | 9 +-- csrc/layerwise.cpp | 89 +++++++++++++++++++++++++--- csrc/layerwise.h | 16 ++++- flexkv/common/transfer.py | 9 ++- flexkv/kvmanager.py | 8 ++- flexkv/kvtask.py | 11 ++-- flexkv/server/client.py | 3 +- flexkv/server/request.py | 1 + flexkv/server/server.py | 3 +- flexkv/transfer/layerwise.py | 110 ++++++++++++++++++++++++++++++++++- flexkv/transfer/worker_op.py | 2 + 11 files changed, 233 insertions(+), 28 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 3ac1a2089a..7e53608ef3 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -426,14 +426,15 @@ PYBIND11_MODULE(c_ext, m) { .def(py::init> &, torch::Tensor &, std::map> &, int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, - torch::Tensor &, int, int>(), + torch::Tensor &, int, int, torch::Tensor &, int>(), py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), py::arg("gpu_kv_strides_tensor"), py::arg("gpu_block_strides_tensor"), py::arg("gpu_layer_strides_tensor"), py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), - py::arg("iouring_flags")) + py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"), + py::arg("tp_size")) .def("layerwise_transfer", &flexkv::LayerwiseTransferGroup::layerwise_transfer, py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), @@ -446,8 +447,8 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_block_stride_in_bytes"), py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), py::arg("use_ce_transfer"), py::arg("num_layers"), - py::arg("layer_granularity"), py::arg("is_mla")); - + py::arg("layer_granularity"), py::arg("is_mla"), + py::arg("counter_id") = 0); #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, "Transfer KV blocks between remote and CPU memory", diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 508939f70c..1ca7cef77d 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace flexkv { @@ -11,18 +13,51 @@ struct LayerCallbackData { int layers_this_batch; int num_gpus; std::atomic *counter; + // Eventfd info for notification + bool enable_eventfd; + int tp_size; + int num_layers; + int *layer_eventfds; // Pointer to eventfds array for current counter set }; static void CUDART_CB layer_done_host_callback(void *userData) { LayerCallbackData *data = static_cast(userData); int completed = data->counter->fetch_add(1) + 1; if (completed == data->num_gpus) { - // TODO: use eventfd to notify the consumer that [start_layer, start_layer + - // layers_this_batch) transfer completed - printf( - "[LayerwiseTransfer] All %d GPUs: Layers [%d, %d) transfer completed\n", - data->num_gpus, data->start_layer, - data->start_layer + data->layers_this_batch); + // Notify via eventfd when all GPUs complete this layer batch + if (data->enable_eventfd && data->layer_eventfds != nullptr) { + // Signal each tp_rank's eventfd for completed layers + for (int layer = data->start_layer; + layer < data->start_layer + data->layers_this_batch; ++layer) { + for (int tp_rank = 0; tp_rank < data->tp_size; ++tp_rank) { + int fd = data->layer_eventfds[tp_rank * data->num_layers + layer]; + if (fd >= 0) { + // Write 2 to support both get_key_buffer and get_value_buffer waits + uint64_t val = 2; + ssize_t ret = write(fd, &val, sizeof(val)); + // if (ret == sizeof(val)) { + // fprintf(stderr, "[LayerwiseTransfer] eventfd_write SUCCESS: tp_rank=%d, layer=%d, fd=%d, val=%lu\n", + // tp_rank, layer, fd, val); + // } else { + // fprintf(stderr, + // "[LayerwiseTransfer] Warning: eventfd_write failed for " + // "tp_rank %d, layer %d, fd %d, errno=%d\n", tp_rank, layer, fd, errno); + // } + // fflush(stderr); + } + } + } + } + // else { + // fprintf(stderr, "[LayerwiseTransfer] WARNING: eventfd disabled or null! enable=%d, ptr=%p\n", + // data->enable_eventfd, (void*)data->layer_eventfds); + // fflush(stderr); + // } + + // fprintf(stderr, + // "[LayerwiseTransfer] All %d GPUs: Layers [%d, %d) transfer completed\n", + // data->num_gpus, data->start_layer, + // data->start_layer + data->layers_this_batch); delete data->counter; } delete data; @@ -36,9 +71,31 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( torch::Tensor &gpu_block_strides_tensor, torch::Tensor &gpu_layer_strides_tensor, torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, - int iouring_flags) { + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size) { num_gpus_ = num_gpus; + num_layers_ = num_layers; + tp_size_ = tp_size; + current_counter_id_ = 0; + + // Initialize eventfds + enable_eventfd_ = (layer_eventfds_tensor.numel() > 0); + if (enable_eventfd_) { + // layer_eventfds_tensor layout: [num_counters, tp_size, num_layers] + // Index formula: counter_id * tp_size * num_layers + tp_rank * num_layers + layer + int total_fds = layer_eventfds_tensor.numel(); + num_counters_ = total_fds / (tp_size * num_layers); + + int32_t *fds_ptr = layer_eventfds_tensor.data_ptr(); + layer_eventfds_.assign(fds_ptr, fds_ptr + total_fds); + + printf("[LayerwiseTransferGroup] Initialized with eventfds: " + "tp_size=%d, num_counters=%d, num_layers=%d, total_fds=%d\n", + tp_size_, num_counters_, num_layers_, total_fds); + } else { + num_counters_ = 0; + printf("[LayerwiseTransferGroup] Initialized without eventfds\n"); + } gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; @@ -126,9 +183,19 @@ LayerwiseTransferGroup::~LayerwiseTransferGroup() { void LayerwiseTransferGroup::layer_done_callback(int start_layer, int layers_this_batch) { std::atomic *counter = new std::atomic(0); + + // Get eventfd pointer for current counter set + int *eventfds_ptr = nullptr; + if (enable_eventfd_ && num_counters_ > 0) { + // Offset into layer_eventfds_ for current counter set + int offset = current_counter_id_ * tp_size_ * num_layers_; + eventfds_ptr = layer_eventfds_.data() + offset; + } + for (int i = 0; i < num_gpus_; ++i) { LayerCallbackData *data = new LayerCallbackData{ - start_layer, layers_this_batch, num_gpus_, counter}; + start_layer, layers_this_batch, num_gpus_, counter, + enable_eventfd_, tp_size_, num_layers_, eventfds_ptr}; cudaLaunchHostFunc(streams_[i], layer_done_host_callback, data); } } @@ -145,7 +212,11 @@ void LayerwiseTransferGroup::layerwise_transfer( const int64_t cpu_block_stride_in_bytes, const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, const bool use_ce_transfer, const int num_layers, - const int layer_granularity, const bool is_mla) { + const int layer_granularity, const bool is_mla, + const int counter_id) { + + // Set current counter ID for eventfd notification + current_counter_id_ = counter_id; int num_blocks = gpu_block_id_tensor.numel(); int64_t *gpu_block_ids = diff --git a/csrc/layerwise.h b/csrc/layerwise.h index 94487de565..7d6e9bae96 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "gtensor_handler.cuh" #include "transfer.cuh" @@ -24,7 +26,7 @@ class LayerwiseTransferGroup { torch::Tensor &gpu_block_strides_tensor, torch::Tensor &gpu_layer_strides_tensor, torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, - int iouring_flags); + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size); ~LayerwiseTransferGroup(); @@ -46,7 +48,8 @@ class LayerwiseTransferGroup { const int64_t cpu_block_stride_in_bytes, const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, const bool use_ce_transfer, const int num_layers, - const int layer_granularity, const bool is_mla); + const int layer_granularity, const bool is_mla, + const int counter_id = 0); // Counter set index for triple buffering private: int num_gpus_; @@ -69,6 +72,15 @@ class LayerwiseTransferGroup { bool enable_ssd_; std::unique_ptr ioctx_; + // Layer eventfds for notification + // Shape: [tp_size, num_counters, num_layers] + bool enable_eventfd_; + int tp_size_; + int num_counters_; + int num_layers_; + std::vector layer_eventfds_; // Flat array + int current_counter_id_; // Current counter set index for this transfer + void layer_done_callback(int start_layer, int layers_this_batch); }; diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index b96ba2e820..0c3529c901 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -117,6 +117,7 @@ class LayerwiseTransferOp(TransferOp): dst_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) src_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) dst_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + counter_id: int = 0 # Counter set index for triple buffering eventfd notification def __init__(self, graph_id: int, @@ -126,11 +127,13 @@ def __init__(self, dst_block_ids_disk2h: np.ndarray, layer_id: int = 0, layer_granularity: int = 1, - dp_id: int = 0) -> None: + dp_id: int = 0, + counter_id: int = 0) -> None: self.src_block_ids_h2d = src_block_ids_h2d self.dst_block_ids_h2d = dst_block_ids_h2d self.src_block_ids_disk2h = src_block_ids_disk2h self.dst_block_ids_disk2h = dst_block_ids_disk2h + self.counter_id = counter_id super().__init__( graph_id=graph_id, @@ -388,7 +391,8 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], task_end_op_ids: List[int], op_callback_dict: Dict[int, Callable], - layerwise_transfer: bool = False) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: + layerwise_transfer: bool = False, + counter_id: int = 0) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: """ Merge multiple TransferOpGraphs into a single batch graph. @@ -461,6 +465,7 @@ def merge_to_batch_graph(batch_id: int, layer_id=0, layer_granularity=1, dp_id=h2d_ops[0].dp_id, + counter_id=counter_id, ) merged_graph.add_transfer_op(layerwise_transfer_op) batch_end_op_id = -1 diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index bf484fb24b..280242f0aa 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -240,7 +240,8 @@ def launch(self, task_ids: Union[int, List[int]], slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]], as_batch: bool = False, - layerwise_transfer: bool = False) -> List[int]: + layerwise_transfer: bool = False, + counter_id: int = 0) -> List[int]: if isinstance(task_ids, int): task_ids = [task_ids] if not isinstance(slot_mappings, List): @@ -248,13 +249,14 @@ def launch(self, if isinstance(slot_mappings[0], torch.Tensor): slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] if self.server_client_mode: - return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer) + return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer, counter_id) else: return self.kv_task_engine.launch_tasks( task_ids, slot_mappings, as_batch=as_batch, - layerwise_transfer=layerwise_transfer + layerwise_transfer=layerwise_transfer, + counter_id=counter_id ) def cancel(self, task_ids: Union[int, List[int]]) -> None: diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 462a6e3093..e549e7ffc7 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -782,7 +782,8 @@ def merge_to_batch_kvtask(self, task_ids: List[int], batch_task_type: TaskType, - layerwise_transfer: bool = False) -> TransferOpGraph: + layerwise_transfer: bool = False, + counter_id: int = 0) -> TransferOpGraph: op_callback_dict = {} task_end_op_ids = [] callbacks = [] @@ -803,7 +804,8 @@ def merge_to_batch_kvtask(self, transfer_graphs, task_end_op_ids, op_callback_dict, - layerwise_transfer) + layerwise_transfer, + counter_id) self.tasks[batch_id] = KVTask( task_id=batch_id, token_ids=np.concatenate([self.tasks[task_id].token_ids for task_id in task_ids]), @@ -830,7 +832,8 @@ def launch_tasks(self, slot_mappings: List[np.ndarray], as_batch: bool = False, batch_id: int = -1, - layerwise_transfer: bool = False) -> List[int]: + layerwise_transfer: bool = False, + counter_id: int = 0) -> List[int]: assert isinstance(slot_mappings[0], np.ndarray) # trace launch tasks self.tracer.trace_launch_tasks(task_ids, slot_mappings, as_batch) @@ -855,7 +858,7 @@ def launch_tasks(self, layerwise_transfer = False break batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT - batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type, layerwise_transfer) + batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type, layerwise_transfer, counter_id) transfer_graphs = [batch_task_graph] self.tasks[batch_id].status = TaskStatus.RUNNING task_ids = [batch_id] diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 4a8bc241de..96fa3596a2 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -178,11 +178,12 @@ def launch_tasks( slot_mappings: List[np.ndarray], as_batch: bool = False, layerwise_transfer: bool = False, + counter_id: int = 0, ) -> List[int]: batch_id = -1 if as_batch: batch_id = self._get_task_id() - req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id, layerwise_transfer) + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id, layerwise_transfer, counter_id) self.send_to_server.send_pyobj(req) return [batch_id] if as_batch else task_ids diff --git a/flexkv/server/request.py b/flexkv/server/request.py index dde39c7b42..3935bbbb58 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -79,6 +79,7 @@ class LaunchTaskRequest: as_batch: bool = False batch_id: int = -1 layerwise_transfer: bool = False + counter_id: int = 0 # Counter set index for triple buffering eventfd notification @dataclass class CancelTaskRequest: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 94f574d711..2991ce4e3d 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -387,7 +387,8 @@ def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: req.slot_mappings, req.as_batch, req.batch_id, - req.layerwise_transfer) + req.layerwise_transfer, + req.counter_id) def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index c7b4d00bbe..12c63224d5 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -2,12 +2,15 @@ import torch.multiprocessing as mp import threading import time +import os +import socket +import struct from abc import ABC, abstractmethod from dataclasses import dataclass from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe from multiprocessing.connection import Connection from threading import Thread -from typing import List, Any, Dict, Union, Optional +from typing import List, Any, Dict, Union, Optional, Tuple import ctypes import numpy as np @@ -27,6 +30,25 @@ from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister + +def _recv_fds(sock: socket.socket, num_fds: int) -> Tuple[List[int], bytes]: + """Receive multiple fds + extra_data via Unix domain socket (SCM_RIGHTS).""" + data_buf = bytearray(256) + anc_buf_size = socket.CMSG_SPACE(num_fds * struct.calcsize("i")) + + nbytes, ancdata, flags, addr = sock.recvmsg_into([data_buf], anc_buf_size, anc_buf_size) + data = bytes(data_buf[:nbytes]) + + fds = [] + for level, ctype, cdata in ancdata: + if level == socket.SOL_SOCKET and ctype == socket.SCM_RIGHTS: + num_received = len(cdata) // struct.calcsize("i") + fds = list(struct.unpack(f"{num_received}i", cdata[:num_received * struct.calcsize("i")])) + break + if not fds: + raise RuntimeError("did not receive fds via SCM_RIGHTS") + return fds, data + class LayerwiseTransferWorker(TransferWorkerBase): def __init__(self, worker_id: int, @@ -125,6 +147,8 @@ def __init__(self, gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) + layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) + # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers self.layerwise_transfer_group = LayerwiseTransferGroup( self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, @@ -132,7 +156,86 @@ def __init__(self, gpu_kv_strides_tensor, gpu_block_strides_tensor, gpu_layer_strides_tensor, gpu_chunk_sizes_tensor, GLOBAL_CONFIG_FROM_ENV.iouring_entries, - GLOBAL_CONFIG_FROM_ENV.iouring_flags) + GLOBAL_CONFIG_FROM_ENV.iouring_flags, + layer_eventfds_tensor, tp_group_size) + + def _receive_eventfds_from_sglang(self, tp_group_size: int, + max_retries: int = 180, + retry_interval: float = 1.0) -> torch.Tensor: + """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" + socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock') + + def cleanup_socket(): + try: + if os.path.exists(socket_path): + os.unlink(socket_path) + except OSError: + pass + + cleanup_socket() + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + try: + server_sock.bind(socket_path) + server_sock.listen(tp_group_size) + os.chmod(socket_path, 0o777) + flexkv_logger.info(f"[LayerwiseWorker] Listening on {socket_path}, waiting for {tp_group_size} connections") + except Exception as e: + flexkv_logger.error(f"[LayerwiseWorker] Failed to bind/listen: {e}") + server_sock.close() + return torch.empty(0, dtype=torch.int32) + + server_sock.settimeout(max_retries * retry_interval) + all_rank_eventfds: Dict[int, Dict[int, List[int]]] = {} + num_layers, num_counters = self.num_layers, 3 + + try: + for conn_idx in range(tp_group_size): + try: + conn, _ = server_sock.accept() + except socket.timeout: + flexkv_logger.warning(f"[LayerwiseWorker] Timeout, received {conn_idx}/{tp_group_size}") + break + + with conn: + metadata = conn.recv(16) + if len(metadata) < 16: + flexkv_logger.error(f"[LayerwiseWorker] Incomplete metadata: {len(metadata)} bytes") + continue + + tp_rank, _, recv_num_layers, recv_num_counters = struct.unpack("iiii", metadata) + if conn_idx == 0: + num_layers, num_counters = recv_num_layers, recv_num_counters + + rank_eventfds = {} + for _ in range(recv_num_counters): + fds, extra_data = _recv_fds(conn, recv_num_layers) + counter_id = struct.unpack("i", extra_data[:4])[0] + rank_eventfds[counter_id] = fds + + all_rank_eventfds[tp_rank] = rank_eventfds + flexkv_logger.info(f"[LayerwiseWorker] Received eventfds from tp_rank={tp_rank}") + except Exception as e: + flexkv_logger.error(f"[LayerwiseWorker] Error in accept loop: {e}") + finally: + server_sock.close() + cleanup_socket() + + if not all_rank_eventfds: + flexkv_logger.warning("[LayerwiseWorker] No connections received") + return torch.empty(0, dtype=torch.int32) + + # Build tensor: [num_counters, tp_size, num_layers] + eventfds_list = [] + for counter_id in range(num_counters): + for tp_rank in range(tp_group_size): + fds = all_rank_eventfds.get(tp_rank, {}).get(counter_id, [-1] * num_layers) + eventfds_list.extend(fds) + + tensor = torch.tensor(eventfds_list, dtype=torch.int32) + flexkv_logger.info(f"[LayerwiseWorker] Eventfds tensor: {tensor.shape}, counters={num_counters}, tp={tp_group_size}, layers={num_layers}") + return tensor def _transfer_impl(self, src_block_ids_h2d: torch.Tensor, @@ -140,6 +243,7 @@ def _transfer_impl(self, src_block_ids_disk2h: Optional[torch.Tensor], dst_block_ids_disk2h: Optional[torch.Tensor], layer_granularity: int, + counter_id: int = 0, **kwargs: Any) -> None: assert src_block_ids_h2d.dtype == torch.int64 assert dst_block_ids_h2d.dtype == torch.int64 @@ -173,6 +277,7 @@ def _transfer_impl(self, self.num_layers, layer_granularity, self.is_mla, + counter_id, ) def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: @@ -196,4 +301,5 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: src_block_ids_disk2h, dst_block_ids_disk2h, layer_granularity, + transfer_op.counter_id, ) diff --git a/flexkv/transfer/worker_op.py b/flexkv/transfer/worker_op.py index 3fb76cc8f8..a906af8b12 100644 --- a/flexkv/transfer/worker_op.py +++ b/flexkv/transfer/worker_op.py @@ -51,6 +51,7 @@ class WorkerLayerwiseTransferOp: dst_block_ids_h2d: np.ndarray src_block_ids_disk2h: np.ndarray dst_block_ids_disk2h: np.ndarray + counter_id: int # Counter set index for triple buffering eventfd notification def __init__(self, transfer_op: LayerwiseTransferOp): self.transfer_op_id = transfer_op.op_id @@ -63,3 +64,4 @@ def __init__(self, transfer_op: LayerwiseTransferOp): self.dst_block_ids_h2d = transfer_op.dst_block_ids_h2d self.src_block_ids_disk2h = transfer_op.src_block_ids_disk2h self.dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h + self.counter_id = transfer_op.counter_id From b7c86bfee4c866f6cb763daedefccb1f61aa0301 Mon Sep 17 00:00:00 2001 From: jianyingzhu Date: Wed, 21 Jan 2026 09:58:29 +0000 Subject: [PATCH 18/59] print bandwidth for layerwise transfer --- csrc/layerwise.cpp | 68 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 1ca7cef77d..db35c4cd86 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -225,10 +225,29 @@ void LayerwiseTransferGroup::layerwise_transfer( static_cast(cpu_block_id_tensor.data_ptr()); void *cpu_ptr = cpu_blocks_; + // Create CUDA events for timing each layer batch (on GPU 0) + int num_batches = (num_layers + layer_granularity - 1) / layer_granularity; + std::vector timing_events(num_batches + 1); // +1 for start event + std::vector batch_start_layers(num_batches); + std::vector batch_layers_count(num_batches); + + cudaSetDevice(dp_group_id_ * num_gpus_); + for (int i = 0; i <= num_batches; ++i) { + cudaEventCreate(&timing_events[i]); + } + + // Record start event + cudaEventRecord(timing_events[0], streams_[0]); + + int batch_idx = 0; for (int start_layer = 0; start_layer < num_layers; start_layer += layer_granularity) { int layers_this_batch = std::min(layer_granularity, num_layers - start_layer); + + batch_start_layers[batch_idx] = start_layer; + batch_layers_count[batch_idx] = layers_this_batch; + // Step 1: SSD -> CPU transfer if (enable_ssd_ && ssd_block_ids.numel() > 0) { torch::Tensor layer_id_list = @@ -245,8 +264,9 @@ void LayerwiseTransferGroup::layerwise_transfer( } // Step 2: CPU -> GPU transfer + cudaSetDevice(dp_group_id_ * num_gpus_ + 0); for (int i = 0; i < num_gpus_; ++i) { - cudaSetDevice(dp_group_id_ * num_gpus_ + i); + int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; if (is_mla) { @@ -283,7 +303,12 @@ void LayerwiseTransferGroup::layerwise_transfer( } } + // Record event after this batch on GPU 0 + cudaSetDevice(dp_group_id_ * num_gpus_); + cudaEventRecord(timing_events[batch_idx + 1], streams_[0]); + layer_done_callback(start_layer, layers_this_batch); + batch_idx++; } for (int i = 0; i < num_gpus_; ++i) { cudaError_t err = cudaStreamSynchronize(streams_[i]); @@ -293,6 +318,47 @@ void LayerwiseTransferGroup::layerwise_transfer( cudaGetErrorString(err)); } } + + // Calculate and print timing for each layer batch + // chunk_size per GPU * num_gpus * 2 (K+V) * layers_this_batch * num_blocks + fprintf(stderr, "\n[LayerwiseTransfer] CPU->GPU Transfer Timing (num_blocks=%d):\n", num_blocks); + float total_time_ms = 0.0f; + int64_t total_bytes = 0; + + for (int i = 0; i < num_batches; ++i) { + float elapsed_ms = 0.0f; + cudaEventElapsedTime(&elapsed_ms, timing_events[i], timing_events[i + 1]); + + // Calculate bytes transferred for this batch + // For each GPU: chunk_size * 2 (K+V) * layers * num_blocks + int64_t bytes_this_batch = 0; + for (int g = 0; g < num_gpus_; ++g) { + bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * batch_layers_count[i] * num_blocks; + } + + double bandwidth_gbps = (bytes_this_batch / (1024.0 * 1024.0 * 1024.0)) / (elapsed_ms / 1000.0); + + fprintf(stderr, " Layers [%d, %d): time=%.3f ms, size=%.2f MB, bandwidth=%.2f GB/s\n", + batch_start_layers[i], + batch_start_layers[i] + batch_layers_count[i], + elapsed_ms, + bytes_this_batch / (1024.0 * 1024.0), + bandwidth_gbps); + + total_time_ms += elapsed_ms; + total_bytes += bytes_this_batch; + } + + double total_bandwidth_gbps = (total_bytes / (1024.0 * 1024.0 * 1024.0)) / (total_time_ms / 1000.0); + fprintf(stderr, " Total: time=%.3f ms, size=%.2f MB, avg_bandwidth=%.2f GB/s\n\n", + total_time_ms, total_bytes / (1024.0 * 1024.0), total_bandwidth_gbps); + fflush(stderr); + + // Cleanup timing events + cudaSetDevice(dp_group_id_ * num_gpus_); + for (int i = 0; i <= num_batches; ++i) { + cudaEventDestroy(timing_events[i]); + } } } // namespace flexkv From d635935da91cabc36d419a7eeb13c3cb86584eaf Mon Sep 17 00:00:00 2001 From: jianyingzhu Date: Wed, 28 Jan 2026 10:37:53 +0000 Subject: [PATCH 19/59] add nvtx for layerwise --- csrc/layerwise.cpp | 101 +++++++++++++++++++++++++++++--------- csrc/layerwise.h | 7 ++- flexkv/common/transfer.py | 2 +- 3 files changed, 84 insertions(+), 26 deletions(-) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index db35c4cd86..997226df80 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace flexkv { @@ -18,6 +19,11 @@ struct LayerCallbackData { int tp_size; int num_layers; int *layer_eventfds; // Pointer to eventfds array for current counter set + // NVTX range id for CPU->GPU transfer + nvtxRangeId_t *current_range_id_ptr; // Pointer to current layer's range ID + bool is_last_batch; // Whether this is the last batch + char next_range_name[64]; // Name for next layer's range (if not last batch) + nvtxRangeId_t *next_range_id_ptr; // Pointer to next layer's range ID storage }; static void CUDART_CB layer_done_host_callback(void *userData) { @@ -35,29 +41,18 @@ static void CUDART_CB layer_done_host_callback(void *userData) { // Write 2 to support both get_key_buffer and get_value_buffer waits uint64_t val = 2; ssize_t ret = write(fd, &val, sizeof(val)); - // if (ret == sizeof(val)) { - // fprintf(stderr, "[LayerwiseTransfer] eventfd_write SUCCESS: tp_rank=%d, layer=%d, fd=%d, val=%lu\n", - // tp_rank, layer, fd, val); - // } else { - // fprintf(stderr, - // "[LayerwiseTransfer] Warning: eventfd_write failed for " - // "tp_rank %d, layer %d, fd %d, errno=%d\n", tp_rank, layer, fd, errno); - // } - // fflush(stderr); } } } } - // else { - // fprintf(stderr, "[LayerwiseTransfer] WARNING: eventfd disabled or null! enable=%d, ptr=%p\n", - // data->enable_eventfd, (void*)data->layer_eventfds); - // fflush(stderr); - // } - - // fprintf(stderr, - // "[LayerwiseTransfer] All %d GPUs: Layers [%d, %d) transfer completed\n", - // data->num_gpus, data->start_layer, - // data->start_layer + data->layers_this_batch); + // End current NVTX range when all GPUs complete + if (data->current_range_id_ptr != nullptr && *data->current_range_id_ptr != 0) { + nvtxRangeEnd(*data->current_range_id_ptr); + } + // Start next layer's NVTX range (so it begins right after current layer ends) + if (!data->is_last_batch && data->next_range_id_ptr != nullptr) { + *data->next_range_id_ptr = nvtxRangeStartA(data->next_range_name); + } delete data->counter; } delete data; @@ -150,9 +145,14 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( // Create CUDA streams for each GPU streams_.resize(num_gpus_); events_.resize(num_gpus_); + + // Get highest priority (lowest value) + int leastPriority, greatestPriority; + cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority); + for (int i = 0; i < num_gpus_; i++) { cudaSetDevice(dp_group_id * num_gpus_ + i); - cudaStreamCreate(&streams_[i]); + cudaStreamCreateWithPriority(&streams_[i], cudaStreamNonBlocking, greatestPriority); cudaEventCreate(&events_[i]); } @@ -181,7 +181,11 @@ LayerwiseTransferGroup::~LayerwiseTransferGroup() { } void LayerwiseTransferGroup::layer_done_callback(int start_layer, - int layers_this_batch) { + int layers_this_batch, + nvtxRangeId_t *current_range_id_ptr, + bool is_last_batch, + const char *next_range_name, + nvtxRangeId_t *next_range_id_ptr) { std::atomic *counter = new std::atomic(0); // Get eventfd pointer for current counter set @@ -195,7 +199,12 @@ void LayerwiseTransferGroup::layer_done_callback(int start_layer, for (int i = 0; i < num_gpus_; ++i) { LayerCallbackData *data = new LayerCallbackData{ start_layer, layers_this_batch, num_gpus_, counter, - enable_eventfd_, tp_size_, num_layers_, eventfds_ptr}; + enable_eventfd_, tp_size_, num_layers_, eventfds_ptr, + current_range_id_ptr, is_last_batch, {0}, next_range_id_ptr}; + // Copy next range name + if (next_range_name != nullptr) { + snprintf(data->next_range_name, sizeof(data->next_range_name), "%s", next_range_name); + } cudaLaunchHostFunc(streams_[i], layer_done_host_callback, data); } } @@ -239,6 +248,29 @@ void LayerwiseTransferGroup::layerwise_transfer( // Record start event cudaEventRecord(timing_events[0], streams_[0]); + // Allocate storage for NVTX range IDs (one per batch) + std::vector h2d_range_ids(num_batches, 0); + // Pre-generate all range names with data size info + std::vector h2d_range_names(num_batches); + for (int b = 0; b < num_batches; ++b) { + int sl = b * layer_granularity; + int ltb = std::min(layer_granularity, num_layers - sl); + // Calculate data size for this batch: chunk_size * 2 (K+V) * layers * num_blocks + int64_t bytes_this_batch = 0; + for (int g = 0; g < num_gpus_; ++g) { + bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * ltb * num_blocks; + } + double mb_this_batch = bytes_this_batch / (1024.0 * 1024.0); + char name[128]; + snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) %.2fMB", sl, sl + ltb, mb_this_batch); + h2d_range_names[b] = name; + } + + // Start the first batch's NVTX range in main thread + if (num_batches > 0) { + h2d_range_ids[0] = nvtxRangeStartA(h2d_range_names[0].c_str()); + } + int batch_idx = 0; for (int start_layer = 0; start_layer < num_layers; start_layer += layer_granularity) { @@ -250,6 +282,15 @@ void LayerwiseTransferGroup::layerwise_transfer( // Step 1: SSD -> CPU transfer if (enable_ssd_ && ssd_block_ids.numel() > 0) { + // Calculate SSD->CPU data size: cpu_chunk_size * 2 (K+V) * layers * num_ssd_blocks + int num_ssd_blocks = ssd_block_ids.numel(); + int64_t ssd_bytes = cpu_chunk_size_in_bytes * 2 * layers_this_batch * num_ssd_blocks; + double ssd_mb = ssd_bytes / (1024.0 * 1024.0); + char ssd_range_name[128]; + snprintf(ssd_range_name, sizeof(ssd_range_name), + "SSD->CPU Layer[%d,%d) %.2fMB", start_layer, start_layer + layers_this_batch, ssd_mb); + nvtxRangePushA(ssd_range_name); + torch::Tensor layer_id_list = torch::arange(start_layer, start_layer + layers_this_batch, torch::TensorOptions().dtype(torch::kInt32)); @@ -261,9 +302,14 @@ void LayerwiseTransferGroup::layerwise_transfer( cpu_block_stride_in_bytes, true, // is_read: SSD -> CPU num_blocks_per_file, round_robin, num_threads_per_device, is_mla); + + nvtxRangePop(); } // Step 2: CPU -> GPU transfer + // NVTX range for this batch was already started (by main thread for first batch, + // or by previous batch's callback for subsequent batches) + cudaSetDevice(dp_group_id_ * num_gpus_ + 0); for (int i = 0; i < num_gpus_; ++i) { @@ -307,7 +353,14 @@ void LayerwiseTransferGroup::layerwise_transfer( cudaSetDevice(dp_group_id_ * num_gpus_); cudaEventRecord(timing_events[batch_idx + 1], streams_[0]); - layer_done_callback(start_layer, layers_this_batch); + // NVTX: current range ends in callback, next range starts in callback + bool is_last_batch = (batch_idx == num_batches - 1); + const char *next_name = is_last_batch ? nullptr : h2d_range_names[batch_idx + 1].c_str(); + nvtxRangeId_t *next_id_ptr = is_last_batch ? nullptr : &h2d_range_ids[batch_idx + 1]; + + layer_done_callback(start_layer, layers_this_batch, + &h2d_range_ids[batch_idx], is_last_batch, + next_name, next_id_ptr); batch_idx++; } for (int i = 0; i < num_gpus_; ++i) { @@ -328,7 +381,7 @@ void LayerwiseTransferGroup::layerwise_transfer( for (int i = 0; i < num_batches; ++i) { float elapsed_ms = 0.0f; cudaEventElapsedTime(&elapsed_ms, timing_events[i], timing_events[i + 1]); - + // Calculate bytes transferred for this batch // For each GPU: chunk_size * 2 (K+V) * layers * num_blocks int64_t bytes_this_batch = 0; diff --git a/csrc/layerwise.h b/csrc/layerwise.h index 7d6e9bae96..e1941f589d 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "gtensor_handler.cuh" #include "transfer.cuh" @@ -81,7 +82,11 @@ class LayerwiseTransferGroup { std::vector layer_eventfds_; // Flat array int current_counter_id_; // Current counter set index for this transfer - void layer_done_callback(int start_layer, int layers_this_batch); + void layer_done_callback(int start_layer, int layers_this_batch, + nvtxRangeId_t *current_range_id_ptr, + bool is_last_batch, + const char *next_range_name, + nvtxRangeId_t *next_range_id_ptr); }; } // namespace flexkv diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 0c3529c901..d29fbdb0b0 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -463,7 +463,7 @@ def merge_to_batch_graph(batch_id: int, if merged_disk2h_op is not None \ else np.array([], dtype=np.int64), layer_id=0, - layer_granularity=1, + layer_granularity=16, dp_id=h2d_ops[0].dp_id, counter_id=counter_id, ) From af4c1d51421321a1cf491330feae8bd008d6bac3 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 29 Jan 2026 08:17:51 +0000 Subject: [PATCH 20/59] update kernel --- flexkv/common/transfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index d29fbdb0b0..0c3529c901 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -463,7 +463,7 @@ def merge_to_batch_graph(batch_id: int, if merged_disk2h_op is not None \ else np.array([], dtype=np.int64), layer_id=0, - layer_granularity=16, + layer_granularity=1, dp_id=h2d_ops[0].dp_id, counter_id=counter_id, ) From a159da93ebc65887d902cd83b9013a8fb841adfa Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 30 Jan 2026 04:54:55 +0000 Subject: [PATCH 21/59] remove print --- csrc/layerwise.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 997226df80..3aa757b1f0 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -374,7 +374,7 @@ void LayerwiseTransferGroup::layerwise_transfer( // Calculate and print timing for each layer batch // chunk_size per GPU * num_gpus * 2 (K+V) * layers_this_batch * num_blocks - fprintf(stderr, "\n[LayerwiseTransfer] CPU->GPU Transfer Timing (num_blocks=%d):\n", num_blocks); + // fprintf(stderr, "\n[LayerwiseTransfer] CPU->GPU Transfer Timing (num_blocks=%d):\n", num_blocks); float total_time_ms = 0.0f; int64_t total_bytes = 0; @@ -391,21 +391,21 @@ void LayerwiseTransferGroup::layerwise_transfer( double bandwidth_gbps = (bytes_this_batch / (1024.0 * 1024.0 * 1024.0)) / (elapsed_ms / 1000.0); - fprintf(stderr, " Layers [%d, %d): time=%.3f ms, size=%.2f MB, bandwidth=%.2f GB/s\n", - batch_start_layers[i], - batch_start_layers[i] + batch_layers_count[i], - elapsed_ms, - bytes_this_batch / (1024.0 * 1024.0), - bandwidth_gbps); + // fprintf(stderr, " Layers [%d, %d): time=%.3f ms, size=%.2f MB, bandwidth=%.2f GB/s\n", + // batch_start_layers[i], + // batch_start_layers[i] + batch_layers_count[i], + // elapsed_ms, + // bytes_this_batch / (1024.0 * 1024.0), + // bandwidth_gbps); total_time_ms += elapsed_ms; total_bytes += bytes_this_batch; } double total_bandwidth_gbps = (total_bytes / (1024.0 * 1024.0 * 1024.0)) / (total_time_ms / 1000.0); - fprintf(stderr, " Total: time=%.3f ms, size=%.2f MB, avg_bandwidth=%.2f GB/s\n\n", - total_time_ms, total_bytes / (1024.0 * 1024.0), total_bandwidth_gbps); - fflush(stderr); + // fprintf(stderr, " Total: time=%.3f ms, size=%.2f MB, avg_bandwidth=%.2f GB/s\n\n", + // total_time_ms, total_bytes / (1024.0 * 1024.0), total_bandwidth_gbps); + // fflush(stderr); // Cleanup timing events cudaSetDevice(dp_group_id_ * num_gpus_); From 417b59be88f7455ab3c96bafd2443dc8ff6fac75 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 30 Jan 2026 06:53:25 +0000 Subject: [PATCH 22/59] fix cuda device set --- csrc/layerwise.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 3aa757b1f0..728c53a6da 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -309,11 +309,10 @@ void LayerwiseTransferGroup::layerwise_transfer( // Step 2: CPU -> GPU transfer // NVTX range for this batch was already started (by main thread for first batch, // or by previous batch's callback for subsequent batches) - - cudaSetDevice(dp_group_id_ * num_gpus_ + 0); + for (int i = 0; i < num_gpus_; ++i) { - - + // TODO: support multi-instance + cudaSetDevice(dp_group_id_ * num_gpus_ + i); int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; if (is_mla) { cpu_startoff_inside_chunks = 0; From ffbcecda5307016c5963d36e3824cc3a1b1b9c74 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 2 Feb 2026 07:54:03 +0000 Subject: [PATCH 23/59] fix --- csrc/bindings.cpp | 1 + csrc/layerwise.cpp | 19 ++++++++++++------- csrc/layerwise.h | 1 + flexkv/cache/redis_meta.py | 1 + flexkv/common/config.py | 2 +- flexkv/transfer/transfer_engine.py | 26 ++++++++++++-------------- setup.py | 2 +- 7 files changed, 29 insertions(+), 23 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 7e53608ef3..3ef097c7d3 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -30,6 +30,7 @@ #include "dist/block_meta.h" #include "dist/distributed_radix_tree.h" #include "dist/lease_meta_mempool.h" +#include "layerwise.h" #include "dist/local_radix_tree.h" #include "dist/lock_free_q.h" #include "dist/redis_meta_channel.h" diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 728c53a6da..0f84c44998 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -142,6 +142,12 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( dp_group_id_ = dp_group_id; + // Get GPU device IDs from tensors (like tp_transfer_thread_group.cpp) + gpu_device_ids_.resize(num_gpus_); + for (int i = 0; i < num_gpus_; ++i) { + gpu_device_ids_[i] = gpu_blocks[i][0].device().index(); + } + // Create CUDA streams for each GPU streams_.resize(num_gpus_); events_.resize(num_gpus_); @@ -151,7 +157,7 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority); for (int i = 0; i < num_gpus_; i++) { - cudaSetDevice(dp_group_id * num_gpus_ + i); + cudaSetDevice(gpu_device_ids_[i]); cudaStreamCreateWithPriority(&streams_[i], cudaStreamNonBlocking, greatestPriority); cudaEventCreate(&events_[i]); } @@ -166,7 +172,7 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( LayerwiseTransferGroup::~LayerwiseTransferGroup() { for (int i = 0; i < num_gpus_; i++) { - cudaSetDevice(dp_group_id_ * num_gpus_ + i); + cudaSetDevice(gpu_device_ids_[i]); cudaStreamDestroy(streams_[i]); cudaEventDestroy(events_[i]); } @@ -240,7 +246,7 @@ void LayerwiseTransferGroup::layerwise_transfer( std::vector batch_start_layers(num_batches); std::vector batch_layers_count(num_batches); - cudaSetDevice(dp_group_id_ * num_gpus_); + cudaSetDevice(gpu_device_ids_[0]); for (int i = 0; i <= num_batches; ++i) { cudaEventCreate(&timing_events[i]); } @@ -311,8 +317,7 @@ void LayerwiseTransferGroup::layerwise_transfer( // or by previous batch's callback for subsequent batches) for (int i = 0; i < num_gpus_; ++i) { - // TODO: support multi-instance - cudaSetDevice(dp_group_id_ * num_gpus_ + i); + cudaSetDevice(gpu_device_ids_[i]); int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; if (is_mla) { cpu_startoff_inside_chunks = 0; @@ -349,7 +354,7 @@ void LayerwiseTransferGroup::layerwise_transfer( } // Record event after this batch on GPU 0 - cudaSetDevice(dp_group_id_ * num_gpus_); + cudaSetDevice(gpu_device_ids_[0]); cudaEventRecord(timing_events[batch_idx + 1], streams_[0]); // NVTX: current range ends in callback, next range starts in callback @@ -407,7 +412,7 @@ void LayerwiseTransferGroup::layerwise_transfer( // fflush(stderr); // Cleanup timing events - cudaSetDevice(dp_group_id_ * num_gpus_); + cudaSetDevice(gpu_device_ids_[0]); for (int i = 0; i <= num_batches; ++i) { cudaEventDestroy(timing_events[i]); } diff --git a/csrc/layerwise.h b/csrc/layerwise.h index e1941f589d..d186316836 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -66,6 +66,7 @@ class LayerwiseTransferGroup { BackendType backend_type_; std::vector gpu_tensor_handlers_; + std::vector gpu_device_ids_; std::vector streams_; std::vector events_; diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index b2888bf838..cca48a11e6 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Iterable, List, Tuple, Optional, Union, Dict from dataclasses import dataclass from enum import IntEnum diff --git a/flexkv/common/config.py b/flexkv/common/config.py index c2b4e38574..7f75a7d560 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -101,7 +101,7 @@ def __post_init__(self): remote_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_REMOTE_LAYOUT', 'BLOCKFIRST').upper()), gds_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_GDS_LAYOUT', 'BLOCKFIRST').upper()), - enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 1))), + enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 0))), use_ce_transfer_h2d=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_H2D', 0))), use_ce_transfer_d2h=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_D2H', 0))), diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 29f757770f..56e85267b4 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -117,7 +117,7 @@ def __init__(self, # Create shutdown pipe for zero-latency selector self.shutdown_read_fd, self.shutdown_write_fd = os.pipe() - self.gpu_handles = gpu_handles + self.gpu_handle_groups = gpu_handles # dp_client_id -> list of GPU handles for that TP group self._cpu_handle = cpu_handle self._ssd_handle = ssd_handle self._remote_handle = remote_handle @@ -128,9 +128,9 @@ def __init__(self, self.op_id_to_nvtx_range: Dict[int, str] = {} - self.dp_size = model_config.dp_size + # self.dp_size = model_config.dp_size self.tp_size = model_config.tp_size - self.num_gpu_groups = len(self.gpu_handles) + self.num_gpu_groups = len(self.gpu_handle_groups) self._running = False def _init_workers(self) -> None: @@ -176,7 +176,7 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, gpu_handles in self.gpu_handles.items() + for _, gpu_handles in self.gpu_handle_groups.items() ] else: self.h2d_workers = [ @@ -196,7 +196,7 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, gpu_handles in self.gpu_handles.items() + for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] self.d2h_workers = [ tpGPUCPUTransferWorker.create_worker( @@ -288,7 +288,7 @@ def _init_workers(self) -> None: dtype=self._ssd_handle.dtype, gpu_device_id=gpu_handles[0].gpu_device_id, ) - for _, gpu_handles in self.gpu_handles.items() + for _, gpu_handles in self.gpu_handle_groups.items() ] else: self.gds_workers = [ @@ -305,7 +305,7 @@ def _init_workers(self) -> None: tp_group_size=self.tp_size, dp_group_id=dp_client_id, ) - for dp_client_id, gpu_handles in self.gpu_handles.items() + for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] self._worker_map[TransferType.DISK2D] = self.gds_workers self._worker_map[TransferType.D2DISK] = self.gds_workers @@ -318,24 +318,22 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ - for j in range(i * self.tp_size, (i + 1) * self.tp_size)], + gpu_blocks=[handle.get_tensor_handle_list() for handle in gpu_handles], cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=ssd_files, - gpu_kv_layouts=[self.gpu_handles[i].kv_layout \ - for i in range(i * self.tp_size, (i + 1) * self.tp_size)], + gpu_kv_layouts=[handle.kv_layout for handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, ssd_kv_layout=ssd_kv_layout, - dtype=self.gpu_handles[i].dtype, + dtype=gpu_handles[0].dtype, tp_group_size=self.tp_size, - dp_group_id=i, + dp_group_id=dp_client_id, num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_sms_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_sms_h2d, transfer_sms_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_sms_d2h, ) - for i in range(self.dp_size) + for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers diff --git a/setup.py b/setup.py index 641296232c..02bb0b5d01 100755 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def get_version(): "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", "csrc/radix_tree.cpp", - "csrc/layerwise.cpp" + "csrc/layerwise.cpp", "csrc/monitoring/metrics_manager.cpp", # Monitoring support ] From fd8ce04c6728c5ff4ff16e0fc204ce318260ffab Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Tue, 3 Feb 2026 03:48:16 +0000 Subject: [PATCH 24/59] fix mempool --- flexkv/cache/mempool.py | 10 ++-------- flexkv/common/memory_handle.py | 8 ++++---- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/flexkv/cache/mempool.py b/flexkv/cache/mempool.py index 4decae1b8a..7d3df3797d 100644 --- a/flexkv/cache/mempool.py +++ b/flexkv/cache/mempool.py @@ -43,17 +43,11 @@ def recycle_blocks(self, block_ids: np.ndarray) -> None: if block_ids.ndim != 1 or block_ids.dtype != np.int64: raise ValueError("block_ids must be a 1D tensor of int64") - # Remove duplicates first (same block ID appearing multiple times) block_ids = np.unique(block_ids) - - # Filter out already-free blocks to avoid double-free errors - # This can happen due to race conditions or eviction edge cases + already_free = self._free_mask[block_ids] if already_free.any(): - # Only recycle blocks that are actually in use - block_ids = block_ids[~already_free] - if len(block_ids) == 0: - return # Nothing to recycle + raise ValueError(f"block_ids {block_ids[already_free]} are already free") self._free_mask[block_ids] = True self._num_free += len(block_ids) diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index 9838c1ba6c..e5fff581fc 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -170,10 +170,10 @@ def _init_from_ipc_handle( self.offset = offset - flexkv_logger.info( - f"TensorSharedHandle constructed from external IPC handle {self.ipc_handle.hex()} on device {self.device} \ - with shape {self.tensor_shape} and dtype {self.tensor_dtype}, ptr offset={offset}" - ) + # flexkv_logger.info( + # f"TensorSharedHandle constructed from external IPC handle {self.ipc_handle.hex()} on device {self.device} \ + # with shape {self.tensor_shape} and dtype {self.tensor_dtype}, ptr offset={offset}" + # ) @staticmethod def _ensure_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: From 2463aa4e3be9078a6fdafe4d8e07cfb884c9d5ab Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Tue, 3 Feb 2026 04:01:43 +0000 Subject: [PATCH 25/59] refactor transfer config, set num of cta instead of sm --- csrc/bindings.cpp | 2 +- csrc/layerwise.cpp | 8 ++++---- csrc/layerwise.h | 2 +- flexkv/transfer/layerwise.py | 10 +++++----- flexkv/transfer/transfer_engine.py | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 3ef097c7d3..219f7abc84 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -446,7 +446,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_kv_stride_in_bytes"), py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), - py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), + py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_cta_num"), py::arg("use_ce_transfer"), py::arg("num_layers"), py::arg("layer_granularity"), py::arg("is_mla"), py::arg("counter_id") = 0); diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 0f84c44998..f3a38eec42 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -225,7 +225,7 @@ void LayerwiseTransferGroup::layerwise_transfer( const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, - const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, + const int64_t cpu_chunk_size_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, const int counter_id) { @@ -332,7 +332,7 @@ void LayerwiseTransferGroup::layerwise_transfer( gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, - streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; case BackendType::TRTLLM: flexkv::transfer_kv_blocks( @@ -340,7 +340,7 @@ void LayerwiseTransferGroup::layerwise_transfer( gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, - streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; case BackendType::SGLANG: flexkv::transfer_kv_blocks( @@ -348,7 +348,7 @@ void LayerwiseTransferGroup::layerwise_transfer( gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, - streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; } } diff --git a/csrc/layerwise.h b/csrc/layerwise.h index d186316836..eea6561cf9 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -47,7 +47,7 @@ class LayerwiseTransferGroup { const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, - const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, + const int64_t cpu_chunk_size_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, const int counter_id = 0); // Counter set index for triple buffering diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 12c63224d5..59e336803a 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -67,8 +67,8 @@ def __init__(self, num_blocks_per_file: int, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, - transfer_sms_h2d: int = 8, - transfer_sms_d2h: int = 8) -> None: + h2d_cta_num: int = 4, + d2h_cta_num: int = 4) -> None: super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" imported_gpu_blocks = [] @@ -120,8 +120,8 @@ def __init__(self, self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h - self.transfer_sms_h2d = transfer_sms_h2d - self.transfer_sms_d2h = transfer_sms_d2h + self.h2d_cta_num = h2d_cta_num + self.d2h_cta_num = d2h_cta_num # initialize SSD storage self.enable_ssd = len(ssd_files) > 0 @@ -272,7 +272,7 @@ def _transfer_impl(self, self.cpu_layer_stride_in_bytes, self.cpu_block_stride_in_bytes, self.cpu_chunk_size_in_bytes, - self.transfer_sms_h2d, + self.h2d_cta_num, self.use_ce_transfer_h2d, self.num_layers, layer_granularity, diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 56e85267b4..7eeb76b784 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -330,8 +330,8 @@ def _init_workers(self) -> None: num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_sms_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_sms_h2d, - transfer_sms_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_sms_d2h, + h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.h2d_cta_num, + d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.d2h_cta_num, ) for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] From 586dbc6caf444bc768311eadb1c749a56f77f533 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 23 Mar 2026 05:37:32 +0000 Subject: [PATCH 26/59] fix --- benchmarks/benchmark_single_batch.py | 13 ++++++----- csrc/bindings.cpp | 32 ++-------------------------- flexkv/cache/redis_meta.py | 6 +++--- flexkv/transfer/transfer_engine.py | 4 ++-- flexkv/transfer/worker_op.py | 2 +- tests/test_cache_engine.py | 5 +++-- 6 files changed, 17 insertions(+), 45 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index 69af2f46c7..f05a63d97e 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -133,20 +133,19 @@ def benchmark_flexkv(model_config: ModelConfig, all_tokens = 0 start_time = time.time() batch_get_ids = [] - return_masks = [] - cached_tokens = 0 for i in range(batch_size): all_tokens += len(batch_sequence_tensor[i]) - task_id, return_mask = kvmanager.get_match(batch_sequence_tensor[i], + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) batch_get_ids.append(task_id) - cached_tokens += return_mask.sum().item() get_match_time = time.time() - start_time - batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=True) - get_result = kvmanager.wait(batch_id_list) + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) elapsed_time_get = time.time() - start_time + cached_tokens = 0 for _, response in get_result.items(): - assert response.status == KVResponseStatus.SUCCESS + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 219f7abc84..6fa33ddbf0 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -17,8 +17,6 @@ #include #include "cache_utils.h" -#include "pcfs/pcfs.h" -#include "tp_transfer_thread_group.h" #include "gds/gds_manager.h" #include "gds/tp_gds_transfer_thread_group.h" #include "pcfs/pcfs.h" @@ -30,11 +28,11 @@ #include "dist/block_meta.h" #include "dist/distributed_radix_tree.h" #include "dist/lease_meta_mempool.h" -#include "layerwise.h" #include "dist/local_radix_tree.h" #include "dist/lock_free_q.h" #include "dist/redis_meta_channel.h" #endif +#include "layerwise.h" #include "monitoring/metrics_manager.h" #include @@ -423,7 +421,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("is_read"), py::arg("num_blocks_per_file"), py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false); - py::class_(m, "LayerwiseTransferGroup") + py::class_(m, "LayerwiseTransferGroup") .def(py::init> &, torch::Tensor &, std::map> &, int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, @@ -537,32 +535,6 @@ PYBIND11_MODULE(c_ext, m) { py::arg("layer_granularity"), py::arg("is_mla")); #endif - py::class_(m, "LayerwiseTransferGroup") - .def(py::init> &, - torch::Tensor &, std::map> &, - int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, - torch::Tensor &, int, int>(), - py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), - py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), - py::arg("gpu_kv_strides_tensor"), - py::arg("gpu_block_strides_tensor"), - py::arg("gpu_layer_strides_tensor"), - py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), - py::arg("iouring_flags")) - .def("layerwise_transfer", - &flexkv::LayerwiseTransferGroup::layerwise_transfer, - py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), - py::arg("ssd_layer_stride_in_bytes"), - py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), - py::arg("round_robin"), py::arg("num_threads_per_device"), - py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), - py::arg("cpu_kv_stride_in_bytes"), - py::arg("cpu_layer_stride_in_bytes"), - py::arg("cpu_block_stride_in_bytes"), - py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), - py::arg("use_ce_transfer"), py::arg("num_layers"), - py::arg("layer_granularity"), py::arg("is_mla")); - // Add Hasher class binding py::class_(m, "Hasher") .def(py::init<>()) diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index cca48a11e6..dd00f3b9bd 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -150,8 +150,8 @@ def __init__(self, host: str, port: int, local_ip: str, password: str = "") -> N self._running = False self._listener_thread: Optional[threading.Thread] = None self.current_node_id_set: set = set() - self._client: Optional[_redis.Redis] = None - self._sub_client: Optional[_redis.Redis] = None + self._client: Optional["_redis.Redis"] = None + self._sub_client: Optional["_redis.Redis"] = None self._cleanup_done = False # register cleanup function on exit @@ -167,7 +167,7 @@ def __del__(self) -> None: # ignore exceptions in destructor, avoid affecting program exit pass - def _get_client(self) -> _redis.Redis: + def _get_client(self) -> "_redis.Redis": """Get Redis client with connection settings""" return _redis.Redis( host=self.host, diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 7eeb76b784..1f52ddca4f 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -158,7 +158,7 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, gpu_handles in self.gpu_handles.items() + for _, gpu_handles in self.gpu_handle_groups.items() ] self.d2h_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( @@ -215,7 +215,7 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, gpu_handles in self.gpu_handles.items() + for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] self._worker_map[TransferType.H2D] = self.h2d_workers self._worker_map[TransferType.D2H] = self.d2h_workers diff --git a/flexkv/transfer/worker_op.py b/flexkv/transfer/worker_op.py index a906af8b12..a271435275 100644 --- a/flexkv/transfer/worker_op.py +++ b/flexkv/transfer/worker_op.py @@ -32,7 +32,7 @@ def __init__(self, transfer_op: TransferOp): # Always preserve optional src_block_node_ids from TransferOp self.src_block_node_ids = transfer_op.src_block_node_ids - if self.src_slot_id == -1: + if self.src_slot_id == -1 or self.dst_slot_id == -1: self.src_block_ids = transfer_op.src_block_ids self.dst_block_ids = transfer_op.dst_block_ids else: diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 3fe2365c69..6b1cc5779d 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -176,8 +176,9 @@ def test_mempool(): with pytest.raises(ValueError): mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) - # recycle_blocks no longer raises ValueError for already free blocks - mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) + # Recycle already free blocks raises + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) assert mempool.num_free_blocks == DEFAULT_NUM_TOTAL_BLOCKS # Recycle wrong ndim raises From bc1a18c6db2db3aba47a87a6d7fbcb4cf0e9b48e Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 23 Mar 2026 07:18:52 +0000 Subject: [PATCH 27/59] fix unit test --- tests/test_memory_handle.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_memory_handle.py b/tests/test_memory_handle.py index 3df9c228c3..18ff724406 100644 --- a/tests/test_memory_handle.py +++ b/tests/test_memory_handle.py @@ -84,6 +84,35 @@ def _worker_test_tensor_from_tensor_direct_ipc(conn, device_id): raise +def _worker_test_fp8_tensor_from_bytes(conn, device_id): + """Test construction from bytes with fp8 dtype""" + try: + handle = conn.recv() + assert isinstance(handle, TensorSharedHandle) + assert handle.use_direct_ipc + assert handle.tensor_dtype == torch.float8_e4m3fn + assert handle.tensor_shape == (10, 20) + + tensor = handle.get_tensor() + assert isinstance(tensor, torch.Tensor) + assert tensor.is_cuda + assert tensor.device.index == device_id + assert tensor.shape == (10, 20) + assert tensor.dtype == torch.float8_e4m3fn + + expected = ( + torch.arange(200, dtype=torch.float32) + .reshape(10, 20) + .cuda(device_id) + .to(torch.float8_e4m3fn) + ) + max_diff = (tensor.to(torch.float32) - expected.to(torch.float32)).abs().max().item() + conn.send(max_diff) + except Exception as e: + conn.send(f"Error: {e}") + raise + + def _worker_test_tensor_from_bytes(conn, device_id): """Test construction from bytes (IPC handle)""" try: From d3cc1d9bffe99606a0f6fb9a4d5fdadd4a59fb80 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 23 Mar 2026 07:19:32 +0000 Subject: [PATCH 28/59] update --- flexkv/kvtask.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index e549e7ffc7..24d4c4f8b4 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -851,12 +851,9 @@ def launch_tasks(self, if not GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: flexkv_logger.warning("layerwise transfer is not enabled") layerwise_transfer = False - else: - for task_id in task_ids: - if self.tasks[task_id].task_type != TaskType.GET: - flexkv_logger.warning("only support layerwise get") - layerwise_transfer = False - break + elif not all_get: + flexkv_logger.warning("only support layerwise get") + layerwise_transfer = False batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type, layerwise_transfer, counter_id) transfer_graphs = [batch_task_graph] From 12be2cc2f96a955936b95814f99ef18d1d133260 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 23 Mar 2026 07:51:08 +0000 Subject: [PATCH 29/59] merge h2d and disk2h to layerwiseop --- flexkv/common/transfer.py | 59 ++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 0c3529c901..3c8a3a3738 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -378,7 +378,6 @@ def _merge_ops(ops: List[TransferOp], transfer_type: TransferType, layer_granularity=ops[0].layer_granularity, dp_id=ops[0].dp_id, ) - graph.add_transfer_op(merged_op) if callbacks: if len(callbacks) == 1: op_callback_dict[merged_op.op_id] = callbacks[0] @@ -448,9 +447,8 @@ def merge_to_batch_graph(batch_id: int, merged_graph, callbacks_by_type[TransferType.DISK2H], new_op_callback_dict) merged_h2d_op = _merge_ops(ops_by_type[TransferType.H2D], TransferType.H2D, merged_graph, callbacks_by_type[TransferType.H2D], new_op_callback_dict) - if merged_disk2h_op is not None and merged_h2d_op is not None: - merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) - if layerwise_transfer: # FIXME: rebase issue + + if layerwise_transfer: if merged_h2d_op is not None: layerwise_transfer_op = LayerwiseTransferOp( graph_id=merged_graph.graph_id, @@ -464,34 +462,43 @@ def merge_to_batch_graph(batch_id: int, else np.array([], dtype=np.int64), layer_id=0, layer_granularity=1, - dp_id=h2d_ops[0].dp_id, + dp_id=ops_by_type[TransferType.H2D][0].dp_id, counter_id=counter_id, ) merged_graph.add_transfer_op(layerwise_transfer_op) batch_end_op_id = -1 - # layerwise transfer op does not need callbacks new_op_callback_dict.clear() - - # PUT path: D2H -> H2DISK - merged_d2h_op = _merge_ops(ops_by_type[TransferType.D2H], TransferType.D2H, - merged_graph, callbacks_by_type[TransferType.D2H], new_op_callback_dict) - merged_h2disk_op = _merge_ops(ops_by_type[TransferType.H2DISK], TransferType.H2DISK, - merged_graph, callbacks_by_type[TransferType.H2DISK], new_op_callback_dict) - if merged_d2h_op is not None and merged_h2disk_op is not None: - merged_graph.add_dependency(merged_h2disk_op.op_id, merged_d2h_op.op_id) - - # batch_end_op_id: GET: H2D > DISK2H; PUT: H2DISK > D2H - if merged_h2d_op is not None: - batch_end_op_id = merged_h2d_op.op_id - elif merged_disk2h_op is not None: - batch_end_op_id = merged_disk2h_op.op_id - elif merged_h2disk_op is not None: - batch_end_op_id = merged_h2disk_op.op_id - elif merged_d2h_op is not None: - batch_end_op_id = merged_d2h_op.op_id else: - batch_end_op_id = -1 - + if merged_disk2h_op is not None: + merged_graph.add_transfer_op(merged_disk2h_op) + if merged_h2d_op is not None: + merged_graph.add_transfer_op(merged_h2d_op) + if merged_disk2h_op is not None and merged_h2d_op is not None: + merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) + + # PUT path: D2H -> H2DISK + merged_d2h_op = _merge_ops(ops_by_type[TransferType.D2H], TransferType.D2H, + merged_graph, callbacks_by_type[TransferType.D2H], new_op_callback_dict) + merged_h2disk_op = _merge_ops(ops_by_type[TransferType.H2DISK], TransferType.H2DISK, + merged_graph, callbacks_by_type[TransferType.H2DISK], new_op_callback_dict) + if merged_d2h_op is not None: + merged_graph.add_transfer_op(merged_d2h_op) + if merged_h2disk_op is not None: + merged_graph.add_transfer_op(merged_h2disk_op) + if merged_d2h_op is not None and merged_h2disk_op is not None: + merged_graph.add_dependency(merged_h2disk_op.op_id, merged_d2h_op.op_id) + + # batch_end_op_id: GET: H2D > DISK2H; PUT: H2DISK > D2H + if merged_h2d_op is not None: + batch_end_op_id = merged_h2d_op.op_id + elif merged_disk2h_op is not None: + batch_end_op_id = merged_disk2h_op.op_id + elif merged_h2disk_op is not None: + batch_end_op_id = merged_h2disk_op.op_id + elif merged_d2h_op is not None: + batch_end_op_id = merged_d2h_op.op_id + else: + batch_end_op_id = -1 return merged_graph, batch_end_op_id, new_op_callback_dict From ee4c5c00f2ed90a5865150492dbab9a138663160 Mon Sep 17 00:00:00 2001 From: Jianying <53503712+jianyingzhu@users.noreply.github.com> Date: Wed, 1 Apr 2026 17:22:09 +0800 Subject: [PATCH 30/59] Bug Fix: illegal memory (#133) * fix illegal memory * add log_transfer_performance to layerwise worker --------- Co-authored-by: jianyingzhu --- flexkv/kvtask.py | 10 +++--- flexkv/transfer/layerwise.py | 53 ++++++++++++++++++++---------- flexkv/transfer/transfer_engine.py | 11 ++++--- 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 24d4c4f8b4..88038df389 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -777,9 +777,9 @@ def prefetch_async(self, return task_id def merge_to_batch_kvtask(self, - + batch_id: int, - + task_ids: List[int], batch_task_type: TaskType, layerwise_transfer: bool = False, @@ -844,7 +844,7 @@ def launch_tasks(self, all_get = all(self.tasks[tid].task_type == TaskType.GET for tid in task_ids) all_put = all(self.tasks[tid].task_type == TaskType.PUT for tid in task_ids) - if len(task_ids) > 1 and as_batch and (all_get or all_put): + if (len(task_ids) > 1 or layerwise_transfer) and as_batch and (all_get or all_put): if batch_id == -1: batch_id = self._gen_task_id() if layerwise_transfer: @@ -855,7 +855,9 @@ def launch_tasks(self, flexkv_logger.warning("only support layerwise get") layerwise_transfer = False batch_task_type = TaskType.BATCH_GET if all_get else TaskType.BATCH_PUT - batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, batch_task_type, layerwise_transfer, counter_id) + batch_task_graph = self.merge_to_batch_kvtask( + batch_id, task_ids, batch_task_type, layerwise_transfer, counter_id + ) transfer_graphs = [batch_task_graph] self.tasks[batch_id].status = TaskStatus.RUNNING task_ids = [batch_id] diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 59e336803a..cdee47d5e0 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -35,7 +35,7 @@ def _recv_fds(sock: socket.socket, num_fds: int) -> Tuple[List[int], bytes]: """Receive multiple fds + extra_data via Unix domain socket (SCM_RIGHTS).""" data_buf = bytearray(256) anc_buf_size = socket.CMSG_SPACE(num_fds * struct.calcsize("i")) - + nbytes, ancdata, flags, addr = sock.recvmsg_into([data_buf], anc_buf_size, anc_buf_size) data = bytes(data_buf[:nbytes]) @@ -159,23 +159,23 @@ def __init__(self, GLOBAL_CONFIG_FROM_ENV.iouring_flags, layer_eventfds_tensor, tp_group_size) - def _receive_eventfds_from_sglang(self, tp_group_size: int, - max_retries: int = 180, + def _receive_eventfds_from_sglang(self, tp_group_size: int, + max_retries: int = 180, retry_interval: float = 1.0) -> torch.Tensor: """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock') - + def cleanup_socket(): try: if os.path.exists(socket_path): os.unlink(socket_path) except OSError: pass - + cleanup_socket() server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - + try: server_sock.bind(socket_path) server_sock.listen(tp_group_size) @@ -185,11 +185,11 @@ def cleanup_socket(): flexkv_logger.error(f"[LayerwiseWorker] Failed to bind/listen: {e}") server_sock.close() return torch.empty(0, dtype=torch.int32) - + server_sock.settimeout(max_retries * retry_interval) all_rank_eventfds: Dict[int, Dict[int, List[int]]] = {} num_layers, num_counters = self.num_layers, 3 - + try: for conn_idx in range(tp_group_size): try: @@ -197,23 +197,23 @@ def cleanup_socket(): except socket.timeout: flexkv_logger.warning(f"[LayerwiseWorker] Timeout, received {conn_idx}/{tp_group_size}") break - + with conn: metadata = conn.recv(16) if len(metadata) < 16: flexkv_logger.error(f"[LayerwiseWorker] Incomplete metadata: {len(metadata)} bytes") continue - + tp_rank, _, recv_num_layers, recv_num_counters = struct.unpack("iiii", metadata) if conn_idx == 0: num_layers, num_counters = recv_num_layers, recv_num_counters - + rank_eventfds = {} for _ in range(recv_num_counters): fds, extra_data = _recv_fds(conn, recv_num_layers) counter_id = struct.unpack("i", extra_data[:4])[0] rank_eventfds[counter_id] = fds - + all_rank_eventfds[tp_rank] = rank_eventfds flexkv_logger.info(f"[LayerwiseWorker] Received eventfds from tp_rank={tp_rank}") except Exception as e: @@ -221,20 +221,23 @@ def cleanup_socket(): finally: server_sock.close() cleanup_socket() - + if not all_rank_eventfds: flexkv_logger.warning("[LayerwiseWorker] No connections received") return torch.empty(0, dtype=torch.int32) - + # Build tensor: [num_counters, tp_size, num_layers] eventfds_list = [] for counter_id in range(num_counters): for tp_rank in range(tp_group_size): fds = all_rank_eventfds.get(tp_rank, {}).get(counter_id, [-1] * num_layers) eventfds_list.extend(fds) - + tensor = torch.tensor(eventfds_list, dtype=torch.int32) - flexkv_logger.info(f"[LayerwiseWorker] Eventfds tensor: {tensor.shape}, counters={num_counters}, tp={tp_group_size}, layers={num_layers}") + flexkv_logger.info( + f"[LayerwiseWorker] Eventfds tensor: {tensor.shape}, " + f"counters={num_counters}, tp={tp_group_size}, layers={num_layers}" + ) return tensor def _transfer_impl(self, @@ -280,7 +283,7 @@ def _transfer_impl(self, counter_id, ) - def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: + def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: layer_granularity = transfer_op.layer_granularity if layer_granularity == -1: layer_granularity = self.num_layers @@ -295,6 +298,9 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: src_block_ids_disk2h = None dst_block_ids_disk2h = None + num_h2d_blocks = len(src_block_ids_h2d) + + start_time = time.time() self._transfer_impl( src_block_ids_h2d, dst_block_ids_h2d, @@ -303,3 +309,16 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: layer_granularity, transfer_op.counter_id, ) + end_time = time.time() + + kv_dim = 2 if not self.is_mla else 1 + transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * num_h2d_blocks * kv_dim + + self._log_transfer_performance( + transfer_op, + transfer_size, + start_time, + end_time, + ) + + return True diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 1f52ddca4f..f432da16f9 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -122,7 +122,10 @@ def __init__(self, self._ssd_handle = ssd_handle self._remote_handle = remote_handle self._cache_config = cache_config - self._enable_pcfs_sharing = GLOBAL_CONFIG_FROM_ENV.index_accel and cache_config.enable_kv_sharing # TODO: is this correct? + # TODO: is this correct? + self._enable_pcfs_sharing = ( + GLOBAL_CONFIG_FROM_ENV.index_accel and cache_config.enable_kv_sharing + ) self.pin_buffer = SharedOpPool(2048, self.cache_config.num_cpu_blocks) @@ -330,13 +333,13 @@ def _init_workers(self) -> None: num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.h2d_cta_num, - d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.d2h_cta_num, + h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers - + if self.cache_config.enable_kv_sharing and self._cpu_handle is not None and (self.cache_config.enable_p2p_cpu \ or (self._ssd_handle and self.cache_config.enable_p2p_ssd)): ## NOTE:if we have the cpu handle and enable p2p cpu transfer we need this worker From 72c51874fec857efbafb9c4f7fabe89fbcc2b32b Mon Sep 17 00:00:00 2001 From: zittozhang Date: Sun, 5 Apr 2026 17:08:24 +0800 Subject: [PATCH 31/59] feat: add DSA cache and PP support --- CMakeLists.txt | 39 +- VERSION | 1 - build.sh | 42 ++ flexkv/cache/cache_engine.py | 6 +- flexkv/cache/hie_cache_engine.py | 34 +- flexkv/cache/redis_meta.py | 56 ++- flexkv/common/config.py | 16 + flexkv/common/transfer.py | 7 + flexkv/integration/config.py | 67 ++- flexkv/integration/vllm/vllm_v1_adapter.py | 42 +- flexkv/server/client.py | 13 +- flexkv/server/request.py | 3 + flexkv/storage/storage_engine.py | 196 ++++++-- flexkv/transfer/layerwise.py | 8 +- flexkv/transfer/transfer_engine.py | 405 ++++++++++++++++- flexkv/transfer_manager.py | 74 ++- setup.py | 22 +- tests/test_kvmanager.py | 424 +++++++++++++++++- tests/test_transfer_engine_atomic_eviction.py | 404 +++++++++++++++++ 19 files changed, 1769 insertions(+), 90 deletions(-) delete mode 100644 VERSION create mode 100644 tests/test_transfer_engine_atomic_eviction.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 749b201b8d..34c777d811 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,42 @@ cmake_minimum_required(VERSION 3.10) -project(MainProject VERSION 1.0) + +find_package(Git QUIET) + +if(GIT_FOUND) + execute_process( + COMMAND ${GIT_EXECUTABLE} describe --tags --long --match "v*" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GIT_DESCRIBE_OUTPUT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(GIT_DESCRIBE_OUTPUT MATCHES "^v([0-9]+\\.[0-9]+\\.[0-9]+)-([0-9]+)-g([0-9a-f]+)$") + set(GIT_VERSION "${CMAKE_MATCH_1}") + set(GIT_DISTANCE "${CMAKE_MATCH_2}") + set(GIT_HASH "${CMAKE_MATCH_3}") + if(GIT_DISTANCE STREQUAL "0") + set(DETECTED_VERSION "${GIT_VERSION}") + else() + set(DETECTED_VERSION "${GIT_VERSION}+git${GIT_HASH}") + endif() + message(STATUS "Version from git tag: ${DETECTED_VERSION}") + endif() +endif() + +if(NOT DEFINED DETECTED_VERSION OR DETECTED_VERSION STREQUAL "") + set(DETECTED_VERSION "0.0.0") + message(WARNING "Could not detect version from git tag, using fallback: ${DETECTED_VERSION}") +endif() + +# Strip +gitXXXXXXX suffix for CMake project VERSION (must be numeric X.Y.Z) +if(DETECTED_VERSION MATCHES "^([0-9]+\\.[0-9]+\\.[0-9]+)") + set(NUMERIC_VERSION "${CMAKE_MATCH_1}") +else() + set(NUMERIC_VERSION "0.0.0") +endif() + +project(MainProject VERSION ${NUMERIC_VERSION}) +message(STATUS "Project version: ${PROJECT_VERSION} (full: ${DETECTED_VERSION})") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) diff --git a/VERSION b/VERSION deleted file mode 100644 index 3eefcb9dd5..0000000000 --- a/VERSION +++ /dev/null @@ -1 +0,0 @@ -1.0.0 diff --git a/build.sh b/build.sh index 3e021a940c..12f4c018c5 100755 --- a/build.sh +++ b/build.sh @@ -15,12 +15,54 @@ for arg in "$@"; do BUILD_TYPE="release" shift ;; + --clean) + BUILD_TYPE="clean" + shift + ;; *) # Unknown option ;; esac done +# Handle clean +if [ "$BUILD_TYPE" = "clean" ]; then + echo "=== Cleaning all build artifacts ===" + + # Remove CMake build directory + if [ -d "build" ]; then + rm -rf build + echo "Removed build/" + fi + + # Remove compiled .so files in package directory + find flexkv -name "*.so" -type f -delete -print | sed 's/^/Removed /' + + # Remove copied libs directory + if [ -d "flexkv/lib" ]; then + rm -rf flexkv/lib + echo "Removed flexkv/lib/" + fi + + # Remove Python build artifacts + find . -maxdepth 2 -name "*.egg-info" -type d | while read d; do + rm -rf "$d" + echo "Removed $d" + done + # Only remove top-level dist/ (Python build output), not csrc/dist/ source directory + if [ -d "dist" ]; then + rm -rf dist + echo "Removed dist/" + fi + find . -name "__pycache__" -type d | while read d; do + rm -rf "$d" + echo "Removed $d" + done + + echo "=== Clean completed ===" + exit 0 +fi + echo "=== Building in ${BUILD_TYPE} mode ===" # Install submodules diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index e4a198ccb9..a0571aa59e 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -396,7 +396,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m if cache_config.enable_cpu: if cache_config.enable_p2p_cpu: - self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta) #TODO + self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) #TODO elif self.index_accel: self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, @@ -420,7 +420,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: if cache_config.enable_p2p_ssd: - self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta) #TODO + self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) #TODO elif self.index_accel: self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, cache_config.num_ssd_blocks, @@ -445,7 +445,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m if cache_config.enable_remote: if cache_config.enable_kv_sharing: # Build PCFSCacheEngine from CacheConfig directly (replacing RemotePCFSCacheEngine) TODO - self.remote_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.REMOTE, meta=self.redis_meta) + self.remote_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.REMOTE, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) elif self.index_accel: self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, cache_config.num_remote_blocks, diff --git a/flexkv/cache/hie_cache_engine.py b/flexkv/cache/hie_cache_engine.py index 52c97b3a2b..838663fdc8 100644 --- a/flexkv/cache/hie_cache_engine.py +++ b/flexkv/cache/hie_cache_engine.py @@ -37,7 +37,9 @@ def __init__(self, evict_start_threshold: float = 1.0, hit_reward_seconds: int = 0, eviction_policy: str = "lru", - meta: Optional[RedisMeta] = None) -> None: + meta: Optional[RedisMeta] = None, + pp_rank: int = 0, + pp_size: int = 1) -> None: if num_total_blocks <= 0: raise ValueError(f"Invalid num_total_blocks: {num_total_blocks}") if tokens_per_block <= 0 or (tokens_per_block & (tokens_per_block - 1)) != 0: @@ -90,6 +92,8 @@ def __init__(self, self.num_total_blocks = num_total_blocks self.evict_ratio = evict_ratio self.evict_start_threshold = evict_start_threshold + self.pp_rank = pp_rank + self.pp_size = pp_size # cumulative statistics: for analyzing distributed KV reuse benefits self._stats_total_queried_tokens = 0 # total tokens queried @@ -102,17 +106,22 @@ def start(self) -> None: if self._meta is None: raise ValueError("RedisMeta is not provided; ensure from_cache_config stores it or pass it to start().") #TODO can we use like this to distinguish the different tree pairs? + # Determine base block key prefix by device type if self.device_type == DeviceType.REMOTE: - local_ch_block_key = "PCFSB" - remote_ch_block_key = "PCFSB" + base_key = "PCFSB" elif self.device_type == DeviceType.CPU: - local_ch_block_key = "CPUB" - remote_ch_block_key = "CPUB" + base_key = "CPUB" elif self.device_type == DeviceType.SSD: - local_ch_block_key = "SSDB" - remote_ch_block_key = "SSDB" + base_key = "SSDB" else: raise ValueError(f"Invalid device type: {self.device_type}") + + if self.pp_size > 1: + local_ch_block_key = f"{base_key}:pp{self.pp_rank}" + remote_ch_block_key = f"{base_key}:pp{self.pp_rank}" + else: + local_ch_block_key = base_key + remote_ch_block_key = base_key self.remote_ch = self._meta.get_redis_meta_channel(remote_ch_block_key) self.local_ch = self._meta.get_redis_meta_channel(local_ch_block_key) # Load and store mapping of node_id -> file_nodeids from Redis @@ -443,7 +452,7 @@ def recycle(self, physical_blocks: np.ndarray) -> None: #TODO pfcs may not work now @classmethod - def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Optional[RedisMeta] = None) -> "HierarchyLRCacheEngine": + def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Optional[RedisMeta] = None, pp_rank: int = 0, pp_size: int = 1) -> "HierarchyLRCacheEngine": """Create a PCFSCacheEngine from CacheConfig. This replaces RemotePCFSCacheEngine. It wires both local and remote @@ -522,14 +531,16 @@ def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, me local_safety_ttl_ms=int(GLOBAL_CONFIG_FROM_ENV.safety_ttl_ms), eviction_policy=GLOBAL_CONFIG_FROM_ENV.eviction_policy, meta=meta, + pp_rank=pp_rank, + pp_size=pp_size, ) #TODO is this enough for peercpu and peerssd? @classmethod - def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_type: DeviceType, meta: Optional[RedisMeta] = None) -> "HierarchyLRCacheEngine": + def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_type: DeviceType, meta: Optional[RedisMeta] = None, pp_rank: int = 0, pp_size: int = 1) -> "HierarchyLRCacheEngine": if device_type == DeviceType.REMOTE: - return cls.pcfs_ce_from_cache_config(cache_config, node_id, meta) + return cls.pcfs_ce_from_cache_config(cache_config, node_id, meta, pp_rank=pp_rank, pp_size=pp_size) else: # select correct blocks configuration based on device_type if device_type == DeviceType.CPU: @@ -563,6 +574,7 @@ def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_typ hit_reward_seconds=int(GLOBAL_CONFIG_FROM_ENV.hit_reward_seconds), eviction_policy=GLOBAL_CONFIG_FROM_ENV.eviction_policy, meta=meta, + pp_rank=pp_rank, + pp_size=pp_size, ) raise ValueError("Invalid device type: {cache_config.device_type}") - diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index dd00f3b9bd..8a2c9ace6b 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -258,7 +258,9 @@ def register_node(self) -> Optional[int]: "local_ip": self.local_ip, # Keep for backward compatibility "uuid": self.uuid, "status": "active", - "timestamp": str(int(time.time())) + "timestamp": str(int(time.time())), + "pp_rank": str(getattr(self, 'pp_rank', 0)), + "pp_size": str(getattr(self, 'pp_size', 1)), }) # Publish node update event @@ -504,13 +506,14 @@ def add_node_ids(self, node_ids: Iterable[Union[int, str]]) -> int: # rpush returns the new length of the list return int(r.rpush(f"pcfs:{nid}", *values)) - def regist_buffer(self, mrs: Iterable[object]) -> int: + def regist_buffer(self, mrs: Iterable[object], pp_rank: int = 0, pp_size: int = 1) -> int: """Register RDMA memory regions in Redis. Each element in mrs can be one of: - dict with keys {"buffer_ptr": ..., "buffer_size": ...} - tuple/list (buffer_ptr, buffer_size) - Stored as hash: key = buffer::, field "buffer_size" = . + Stored as hash: key = buffer:[:pp]:, field "buffer_size" = . + When pp_size > 1, pp_rank is included in the key for isolation. Returns the number of regions processed. """ nid = self.get_node_id() @@ -527,21 +530,27 @@ def regist_buffer(self, mrs: Iterable[object]) -> int: continue if ptr is None or size is None: continue - key = f"buffer:{nid}:{int(ptr)}" + if pp_size > 1: + key = f"buffer:{nid}:pp{pp_rank}:{int(ptr)}" + else: + key = f"buffer:{nid}:{int(ptr)}" pipe.hset(key, mapping={"buffer_size": int(size)}) processed += 1 if processed: pipe.execute() return processed - def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool: + def unregist_buffer(self, buffer_ptr: Union[int, str], pp_rank: int = 0, pp_size: int = 1) -> bool: """Unregister a previously registered RDMA memory region by buffer_ptr. - Looks up key buffer:: and deletes it if present. + Looks up key buffer:[:pp]: and deletes it if present. Returns True if the key existed and was deleted, otherwise False. """ nid = self.get_node_id() - key = f"buffer:{nid}:{int(buffer_ptr)}" + if pp_size > 1: + key = f"buffer:{nid}:pp{pp_rank}:{int(buffer_ptr)}" + else: + key = f"buffer:{nid}:{int(buffer_ptr)}" r = self._client() exists = bool(r.exists(key)) if exists: @@ -549,31 +558,40 @@ def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool: return True return False - def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int) -> None: + def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int, pp_rank: int = 0, pp_size: int = 1) -> None: """Register node meta information as a Redis hash. - Key: meta: + Key: meta:[:pp] + When pp_size > 1, pp_rank is included in the key for PP rank isolation. Fields: node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int) """ r = self._client() - key = f"meta:{int(node_id)}" + if pp_size > 1: + key = f"meta:{int(node_id)}:pp{pp_rank}" + else: + key = f"meta:{int(node_id)}" r.hset(key, mapping={ "node_id": int(node_id), "addr": str(addr), "zmq_addr": str(zmq_addr), "cpu_buffer_ptr": int(cpu_buffer_ptr), "ssd_buffer_ptr": int(ssd_buffer_ptr), + "pp_rank": int(pp_rank), + "pp_size": int(pp_size), }) - def get_node_meta(self, node_id: int) -> dict: + def get_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> dict: """Get node meta information from Redis. - Reads key meta: and returns a dict with fields: + Reads key meta:[:pp] and returns a dict with fields: node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int). Returns empty dict if the key does not exist. """ r = self._client() - key = f"meta:{int(node_id)}" + if pp_size > 1: + key = f"meta:{int(node_id)}:pp{pp_rank}" + else: + key = f"meta:{int(node_id)}" data = r.hgetall(key) if not data: return {} @@ -588,10 +606,16 @@ def get_node_meta(self, node_id: int) -> dict: out["ssd_buffer_ptr"] = int(sb) if sb is not None and sb != "" else 0 return out - def unregist_node_meta(self, node_id: int) -> bool: - """Unregister node meta by node_id. Returns True if deleted.""" + def unregist_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> bool: + """Unregister node meta by node_id. Returns True if deleted. + + When pp_size > 1, only deletes the key for the specified pp_rank. + """ r = self._client() - key = f"meta:{int(node_id)}" + if pp_size > 1: + key = f"meta:{int(node_id)}:pp{pp_rank}" + else: + key = f"meta:{int(node_id)}" return bool(r.delete(key)) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 7f75a7d560..46b044a1ab 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -11,6 +11,17 @@ from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.debug import flexkv_logger + +@dataclass +class IndexerCacheConfig: + """Indexer-specific cache configuration, embedded inside CacheConfig.""" + # Indexer head layout + head_size: int = 0 # qk_rope_head_dim for DSA/NSA models + num_kv_heads: int = 1 # typically 1 for MLA-style indexer + dtype: torch.dtype = torch.uint8 # indexer storage dtype (fp8 quantized) + page_size: int = 1 + + @dataclass class ModelConfig: num_layers: int = 1 @@ -22,6 +33,8 @@ class ModelConfig: # parallel configs tp_size: int = 1 dp_size: int = 1 + pp_size: int = 1 + pp_rank: int = 0 @property def token_size_in_bytes(self) -> int: @@ -46,6 +59,9 @@ class CacheConfig: num_tmp_cpu_blocks: int = 500 # only used when distributed ssd p2p, it controls the number blocks of temp cpu buffer which used for copy data from ssd to cpu + # Indexer configuration + indexer: Optional[IndexerCacheConfig] = None + # mempool capacity configs num_cpu_blocks: int = 1000000 num_ssd_blocks: int = 10000000 diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 3c8a3a3738..3b38a03b5a 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -12,6 +12,9 @@ class CompletedOp: graph_id: int op_id: int + transfer_type: Optional[str] = None + num_blocks: int = 0 + num_bytes: int = 0 def is_graph_completed(self) -> bool: return self.op_id == -1 @@ -96,6 +99,10 @@ class TransferOp: remote_node_ids: Optional[np.ndarray] = None # used for distributed cpu and ssd src_block_node_ids: Optional[np.ndarray] = None + # pending_count tracks how many workers (main KV + indexer) have not yet completed this op. + # Initialized to 1; incremented before submitting to indexer worker. + # _scheduler_loop decrements it on each worker completion; finalization happens only when it reaches 0. + pending_count: int = 1 def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 82bf15634c..5bbdb04685 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -40,6 +40,28 @@ def __post_init__(self): if self.gpu_register_port == "": self.gpu_register_port = self.server_recv_port + "_gpu_register" + def _detect_indexer_config_from_hf(self, hf_config, source: str = "", page_size: int = 1) -> None: + if hf_config is None: + return + + try: + qk_rope_head_dim = getattr(hf_config, 'qk_rope_head_dim', None) + if qk_rope_head_dim is None or qk_rope_head_dim <= 0: + return + + self.cache_config.indexer = IndexerCacheConfig( + head_size=qk_rope_head_dim, + num_kv_heads=1, + dtype=torch.uint8, + page_size=page_size, + ) + source_label = f" ({source})" if source else "" + logger.info( + f"Detected sparse attention indexer config{source_label}: " + f"head_size={qk_rope_head_dim}, dtype=uint8, page_size={page_size}") + except Exception as e: + logger.debug(f"Could not detect indexer config ({source}): {e}") + @classmethod def from_env(cls) -> 'FlexKVConfig': enable_flexkv = bool(int(os.getenv('ENABLE_FLEXKV', 1))) @@ -68,6 +90,8 @@ def post_init_from_vllm_config( self.model_config.use_mla = vllm_config.model_config.is_deepseek_mla self.model_config.tp_size = vllm_config.parallel_config.tensor_parallel_size self.model_config.dp_size = vllm_config.parallel_config.data_parallel_size + self.model_config.pp_size = vllm_config.parallel_config.pipeline_parallel_size + self.model_config.pp_rank = getattr(vllm_config.parallel_config, 'pipeline_parallel_rank', 0) if self.model_config.use_mla: self.model_config.num_kv_heads = 1 else: @@ -76,12 +100,17 @@ def post_init_from_vllm_config( self.server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port self.gpu_register_port = self.server_recv_port + "_gpu_register" + hf_config = getattr(vllm_config.model_config, 'hf_config', None) + self._detect_indexer_config_from_hf(hf_config, source="vllm") def post_init_from_sglang_config( self, sglang_config, tp_size: int, page_size: int, + num_local_layers: int = 0, + pp_size: int = 1, + pp_rank: int = 0, ): """ Initialize FlexKVConfig fields from sglang config. @@ -89,15 +118,25 @@ def post_init_from_sglang_config( sglang_config: sglang.srt.configs.model_config.ModelConfig-like object tp_size: tensor parallel size used by sglang page_size: KV block size (tokens per block) used by sglang + num_local_layers: number of layers on this PP rank (0 means no PP, use total layers) + pp_size: pipeline parallel size (default 1, no PP) + pp_rank: pipeline parallel rank (default 0) """ # cache config - self.cache_config.tokens_per_block = int(page_size) + self.cache_config.tokens_per_block = 1 - self.model_config.num_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) + total_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) + self.model_config.num_layers = int(num_local_layers) if num_local_layers > 0 else total_layers - if hasattr(sglang_config, "get_num_kv_heads"): + if hasattr(sglang_config, "get_total_num_kv_heads"): + try: + self.model_config.num_kv_heads = int(sglang_config.get_total_num_kv_heads()) + except Exception: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + elif hasattr(sglang_config, "get_num_kv_heads"): try: - self.model_config.num_kv_heads = int(sglang_config.get_num_kv_heads(tp_size)) + per_rank = int(sglang_config.get_num_kv_heads(tp_size)) + self.model_config.num_kv_heads = per_rank * tp_size except Exception: self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) else: @@ -116,7 +155,22 @@ def post_init_from_sglang_config( self.model_config.tp_size = int(tp_size) self.model_config.dp_size = int(getattr(sglang_config, "dp_size", 1)) + self.model_config.pp_size = int(pp_size) + self.model_config.pp_rank = int(pp_rank) update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) + + hf_config = getattr(sglang_config, 'hf_config', None) + self._detect_indexer_config_from_hf(hf_config, source="sglang", page_size=page_size) + + if self.cache_config.indexer is not None: + logger.info( + f"[FlexKV] Complete indexer config (sglang): " + f"page_size={self.cache_config.indexer.page_size}, " + f"head_size={self.cache_config.indexer.head_size}, " + f"dtype={self.cache_config.indexer.dtype}, " + f"num_layers={self.model_config.num_layers}, " + f"tokens_per_block={self.cache_config.tokens_per_block}" + ) def post_init_from_trt_config( self, @@ -172,6 +226,8 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: else: self.model_config.tp_size = config.mapping.tp_size self.model_config.dp_size = 1 + self.model_config.pp_size = getattr(config.mapping, 'pp_size', 1) + self.model_config.pp_rank = getattr(config.mapping, 'pp_rank', 0) # self.model_config (model configs part) try: @@ -197,7 +253,8 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: else: self.model_config.head_size = hf_config.hidden_size // hf_config.num_attention_heads self.model_config.num_kv_heads = hf_config.num_attention_heads - + + self._detect_indexer_config_from_hf(hf_config, source="TRT-LLM") except Exception as e: flexkv_logger.error(f"Failed to load config from {model_path}: {e}") # Update cache config with user config after model config is initialized diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index b88e17b55b..015cbb026a 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -709,8 +709,19 @@ def __init__( def register_to_server(self, kv_caches: dict[str, torch.Tensor]): logger.info("Start register kv_caches") - gpu_blocks = list(kv_caches.values()) - num_layer = len(kv_caches) + + # Separate main KV caches from indexer caches by layer name. + main_kv_caches: dict[str, torch.Tensor] = {} + indexer_kv_caches: dict[str, torch.Tensor] = {} + for layer_name, tensor in kv_caches.items(): + if ".k_cache" in layer_name: + indexer_kv_caches[layer_name] = tensor + else: + main_kv_caches[layer_name] = tensor + + # Build main KV cache layout + gpu_blocks = list(main_kv_caches.values()) + num_layer = len(main_kv_caches) if self.flexkv_config.model_config.use_mla: assert gpu_blocks[0].ndim == 3, ( f"expect kv cached tensor has 3 dim but get shape={gpu_blocks[0].shape}.") @@ -734,7 +745,32 @@ def register_to_server(self, kv_caches: dict[str, torch.Tensor]): head_size=head_size, is_mla=self.flexkv_config.model_config.use_mla, ) - self.tp_client.register_to_server(gpu_blocks, gpu_layout) + + # Build indexer layout if indexer caches are present + indexer_buffers = None + indexer_layout = None + if indexer_kv_caches: + indexer_buffers = list(indexer_kv_caches.values()) + first_indexer_buffer = indexer_buffers[0] + assert first_indexer_buffer.ndim == 3, ( + f"expect indexer cache tensor has 3 dim but get shape={first_indexer_buffer.shape}.") + indexer_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=len(indexer_buffers), + num_block=first_indexer_buffer.shape[0], + tokens_per_block=first_indexer_buffer.shape[1], + num_head=1, + head_size=first_indexer_buffer.shape[2], + is_mla=True, + ) + + self.tp_client.register_to_server( + kv_caches=gpu_blocks, + kv_layout=gpu_layout, + indexer_buffers=indexer_buffers, + indexer_layout=indexer_layout, + ) + logger.info("Finish register kv_caches") def __del__(self): diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 96fa3596a2..4563e601e3 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -256,6 +256,8 @@ def register_to_server( kv_caches: List[torch.Tensor], kv_layout: KVCacheLayout, override_device_id: Optional[int] = None, + indexer_buffers: Optional[List[torch.Tensor]] = None, + indexer_layout: Optional[KVCacheLayout] = None, ) -> None: if not kv_caches or not kv_caches[0].is_cuda: raise ValueError("GPU blocks must be CUDA tensors") @@ -268,11 +270,20 @@ def register_to_server( handle = TensorSharedHandle(tensor, device_id) handles.append(handle) + # Build optional indexer handles + indexer_handles = None + if indexer_buffers is not None and len(indexer_buffers) > 0: + indexer_handles = [] + for tensor in indexer_buffers: + indexer_handles.append(TensorSharedHandle(tensor, device_id)) + register_req = RegisterTPClientRequest( self.dp_client_id, device_id, handles, - kv_layout + kv_layout, + indexer_handles=indexer_handles, + indexer_gpu_layout=indexer_layout, ) self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 3935bbbb58..757b9e0056 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -22,6 +22,9 @@ class RegisterTPClientRequest: device_id: int handles: List[TensorSharedHandle] gpu_layout: KVCacheLayout + # --- Indexer shadow transfer fields --- + indexer_handles: Optional[List[TensorSharedHandle]] = None + indexer_gpu_layout: Optional[KVCacheLayout] = None @dataclass class IsReadyRequest: diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 50fe069cc5..640e100220 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -6,6 +6,7 @@ import hashlib from flexkv.common.config import ModelConfig, CacheConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import StorageHandle, KVCacheLayout, KVCacheLayoutType from flexkv.common.transfer import DeviceType @@ -18,8 +19,11 @@ def __init__(self, cache_config: CacheConfig): """Initialize storage engine""" self._storage_handles: Dict[Tuple[DeviceType, int], StorageHandle] = {} + self._indexer_storage_handles: Dict[Tuple[DeviceType, int], StorageHandle] = {} self._model_config = model_config self._cache_config = cache_config + self._indexer_config = cache_config.indexer + if self._cache_config.enable_cpu: self._cpu_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, @@ -35,6 +39,29 @@ def __init__(self, layout=self._cpu_layout, dtype=self._model_config.dtype, ) + if self._indexer_config is not None: + indexer_page_size = self._indexer_config.page_size + indexer_num_cpu_blocks = ( + self._cache_config.num_cpu_blocks // indexer_page_size + if indexer_page_size > 1 + else self._cache_config.num_cpu_blocks + ) + indexer_cpu_layout = KVCacheLayout( + type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, + num_layer=self._model_config.num_layers, + num_block=indexer_num_cpu_blocks, + tokens_per_block=self._indexer_config.page_size, + num_head=self._indexer_config.num_kv_heads, + head_size=self._indexer_config.head_size, + is_mla=True + ) + self.allocate( + device_type=DeviceType.CPU, + layout=indexer_cpu_layout, + dtype=self._indexer_config.dtype, + is_indexer=True, + ) + if self._cache_config.enable_ssd: if not GLOBAL_CONFIG_FROM_ENV.ssd_layout_type == self._cpu_layout.type: raise ValueError(f"SSD layout type must be the same as CPU layout type: {self._cpu_layout.type}") @@ -54,6 +81,31 @@ def __init__(self, cache_dir=self._cache_config.ssd_cache_dir, max_file_size_gb=GLOBAL_CONFIG_FROM_ENV.max_file_size_gb ) + if self._indexer_config is not None: + indexer_page_size = self._indexer_config.page_size + indexer_num_ssd_blocks = ( + self._cache_config.num_ssd_blocks // indexer_page_size + if indexer_page_size > 1 + else self._cache_config.num_ssd_blocks + ) + indexer_ssd_layout = KVCacheLayout( + type=GLOBAL_CONFIG_FROM_ENV.ssd_layout_type, + num_layer=self._model_config.num_layers, + num_block=indexer_num_ssd_blocks, + tokens_per_block=self._indexer_config.page_size, + num_head=self._indexer_config.num_kv_heads, + head_size=self._indexer_config.head_size, + is_mla=True + ) + self.allocate( + device_type=DeviceType.SSD, + layout=indexer_ssd_layout, + dtype=self._indexer_config.dtype, + cache_dir=self._cache_config.ssd_cache_dir, + max_file_size_gb=GLOBAL_CONFIG_FROM_ENV.max_file_size_gb, + is_indexer=True, + ) + if self._cache_config.enable_remote: if not GLOBAL_CONFIG_FROM_ENV.remote_layout_type == self._cpu_layout.type: raise ValueError(f"Remote layout type must be the same as CPU layout type: {self._cpu_layout.type}") @@ -73,12 +125,49 @@ def __init__(self, file_path=self._cache_config.remote_cache_path, remote_config_custom = self._cache_config.remote_config_custom ) + if self._indexer_config is not None: + indexer_page_size = self._indexer_config.page_size + indexer_num_remote_blocks = ( + self._cache_config.num_remote_blocks // indexer_page_size + if indexer_page_size > 1 + else self._cache_config.num_remote_blocks + ) + indexer_remote_layout = KVCacheLayout( + type=GLOBAL_CONFIG_FROM_ENV.remote_layout_type, + num_layer=self._model_config.num_layers, + num_block=indexer_num_remote_blocks, + tokens_per_block=self._indexer_config.page_size, + num_head=self._indexer_config.num_kv_heads, + head_size=self._indexer_config.head_size, + is_mla=True + ) + indexer_remote_path = self._cache_config.remote_cache_path + if isinstance(indexer_remote_path, str): + indexer_remote_path = indexer_remote_path + "_indexer" + elif isinstance(indexer_remote_path, list): + indexer_remote_path = [p + "_indexer" for p in indexer_remote_path] + self.allocate( + device_type=DeviceType.REMOTE, + layout=indexer_remote_layout, + dtype=self._indexer_config.dtype, + file_path=indexer_remote_path, + remote_config_custom=self._cache_config.remote_config_custom, + is_indexer=True, + ) + + @property + def _has_indexer(self) -> bool: + """True when indexer is configured and CPU buffer is allocated.""" + return (DeviceType.CPU, 0) in self._indexer_storage_handles def register_gpu_blocks(self, gpu_blocks: List[TensorSharedHandle], gpu_layout: KVCacheLayout, device_id: int = 0, - dtype: torch.dtype = torch.float16) -> None: + dtype: torch.dtype = torch.float16, + indexer_gpu_blocks: Optional[List[TensorSharedHandle]] = None, + indexer_gpu_layout: Optional[KVCacheLayout] = None, + indexer_dtype: Optional[torch.dtype] = None) -> None: self.allocate( device_type=DeviceType.GPU, layout=gpu_layout, @@ -86,6 +175,35 @@ def register_gpu_blocks(self, device_id=device_id, raw_data=gpu_blocks ) + if indexer_gpu_blocks is not None: + # Log indexer GPU registration parameters + indexer_page_size = self._indexer_config.page_size if self._indexer_config else 1 + flexkv_logger.info( + f"[StorageEngine] Registering indexer GPU buffer: " + f"num_block={indexer_gpu_layout.num_block}, " + f"page_size={indexer_page_size}, " + f"head_size={indexer_gpu_layout.head_size}, " + f"num_head={indexer_gpu_layout.num_head}, " + f"dtype={indexer_dtype}" + ) + # Validate indexer num_block vs main KV num_block + if indexer_page_size > 1: + expected_indexer_blocks = gpu_layout.num_block // indexer_page_size + if indexer_gpu_layout.num_block != expected_indexer_blocks: + flexkv_logger.warning( + f"[StorageEngine] Indexer GPU num_block mismatch: " + f"indexer_num_block={indexer_gpu_layout.num_block}, " + f"expected={expected_indexer_blocks} " + f"(main_kv_num_block={gpu_layout.num_block} // page_size={indexer_page_size})" + ) + self.allocate( + device_type=DeviceType.GPU, + layout=indexer_gpu_layout, + dtype=indexer_dtype if indexer_dtype is not None else dtype, + device_id=device_id, + raw_data=indexer_gpu_blocks, + is_indexer=True, + ) def allocate(self, device_type: DeviceType, @@ -93,24 +211,38 @@ def allocate(self, dtype: torch.dtype, device_id: int = 0, raw_data: Optional[Union[List[TensorSharedHandle], List[str], str]] = None, + is_indexer: bool = False, **kwargs: Any) -> bool: """ - Create and add an allocator for specified device + Create and add an allocator for specified device. Args: - device_type: Type of the device (CPU, GPU, etc.) - layout: Layout of kv cache - dtype: Data type of tensors - device_id: Device ID (default 0) - raw_data: Optional raw data to be used for initialization + device_type: Type of the device (CPU, GPU, SSD, REMOTE). + layout: Layout of kv cache. + dtype: Data type of tensors. + device_id: Device ID (default 0). + raw_data: Optional raw data to be used for initialization. + The expected type depends on ``device_type``: + + * ``DeviceType.CPU`` – ``torch.Tensor`` + * ``DeviceType.GPU`` – ``List[TensorSharedHandle]`` or + ``List[torch.Tensor]`` + * ``DeviceType.SSD`` – ``str`` or ``List[str]`` + (file path(s) to existing SSD cache files) + * ``DeviceType.REMOTE`` – ``str`` or ``List[str]`` + (remote file path(s)) + is_indexer: Whether this allocation is for indexer storage. + When True, SSD file_prefix uses 'indexer_' tag + (e.g. ``flexkv_indexer_ssdcache_``). **kwargs: Additional arguments for specific allocator types - (e.g., pin_memory for CPU, file_path for Disk) + (e.g., pin_memory for CPU, file_path for Disk). Returns: - bool: True if allocator created successfully, False otherwise + bool: True if allocator created successfully, False if already exists. """ + storage_handles = self._indexer_storage_handles if is_indexer else self._storage_handles key = (device_type, device_id) - if key in self._storage_handles: + if key in storage_handles: return False storage_handle: StorageHandle @@ -137,7 +269,7 @@ def allocate(self, assert isinstance(raw_data, list) and \ (all(isinstance(x, TensorSharedHandle) for x in raw_data) or \ all(isinstance(x, torch.Tensor) for x in raw_data)), \ - "raw_data for GPUAllocator must be List[TensorWrapper] or List[Tensor]" + "raw_data for GPUAllocator must be List[TensorSharedHandle] or List[Tensor]" storage_handle = GPUAllocator.from_raw_data( data=raw_data, # type: ignore layout=layout, @@ -169,7 +301,8 @@ def allocate(self, server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port hash_value = hashlib.md5(server_recv_port.encode()).hexdigest() rand_suffix = f"{hash_value[:6]}" - file_prefix = f"flexkv_ssdcache_{rand_suffix}" + ssd_prefix_tag = "indexer_" if is_indexer else "" + file_prefix = f"flexkv_{ssd_prefix_tag}ssdcache_{rand_suffix}" storage_handle = SSDAllocator.allocate( layout=layout, dtype=dtype, @@ -206,28 +339,41 @@ def allocate(self, ) else: raise ValueError(f"Unsupported device type: {device_type}") - self._storage_handles[key] = storage_handle + storage_handles[key] = storage_handle return True def get_storage_handle(self, device_type: DeviceType, - device_id: int = 0) -> StorageHandle: + device_id: int = 0, + is_indexer: bool = False) -> StorageHandle: """ - Get accessible handle for specified blocks + Get accessible handle for specified blocks. Args: - device_type: Type of the device to get handle from - device_id: Device ID + device_type: Type of the device to get handle from. + device_id: Device ID. + is_indexer: Whether to get indexer storage handle. """ + storage_handles = self._indexer_storage_handles if is_indexer else self._storage_handles key = (device_type, device_id) - if key not in self._storage_handles: - raise ValueError(f"Storage handle not found for device type: {device_type}, device id: {device_id}") - - storage_handle = self._storage_handles[key] - return storage_handle + if key not in storage_handles: + raise ValueError( + f"Storage handle not found for device type: {device_type}, " + f"device id: {device_id}, is_indexer: {is_indexer}" + ) + return storage_handles[key] def has_storage_handle(self, device_type: DeviceType, - device_id: int = 0) -> bool: - """Check if storage handle exists for given device type and id""" - return (device_type, device_id) in self._storage_handles + device_id: int = 0, + is_indexer: bool = False) -> bool: + """ + Check if storage handle exists for given device type and id. + + Args: + device_type: Type of the device. + device_id: Device ID. + is_indexer: Whether to check indexer storage handle. + """ + storage_handles = self._indexer_storage_handles if is_indexer else self._storage_handles + return (device_type, device_id) in storage_handles diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index cdee47d5e0..bef78366ce 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -68,7 +68,8 @@ def __init__(self, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, h2d_cta_num: int = 4, - d2h_cta_num: int = 4) -> None: + d2h_cta_num: int = 4, + enable_eventfd: bool = True) -> None: super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" imported_gpu_blocks = [] @@ -147,7 +148,10 @@ def __init__(self, gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) - layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) + if enable_eventfd: + layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) + else: + layer_eventfds_tensor = torch.empty(0, dtype=torch.int32) # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers self.layerwise_transfer_group = LayerwiseTransferGroup( diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index f432da16f9..f2c3f4cc04 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -23,7 +23,7 @@ import contextlib import nvtx -import torch +import numpy as np from flexkv.common.debug import flexkv_logger from flexkv.common.storage import StorageHandle @@ -90,7 +90,11 @@ def __init__(self, cache_config: CacheConfig, cpu_handle: Optional[StorageHandle] = None, ssd_handle: Optional[StorageHandle] = None, - remote_handle: Optional[StorageHandle] = None): + remote_handle: Optional[StorageHandle] = None, + indexer_gpu_handles: Optional[Dict[int, List[StorageHandle]]] = None, + indexer_cpu_handle: Optional[StorageHandle] = None, + indexer_ssd_handle: Optional[StorageHandle] = None, + indexer_remote_handle: Optional[StorageHandle] = None): """ Initialize transfer engine @@ -127,6 +131,11 @@ def __init__(self, GLOBAL_CONFIG_FROM_ENV.index_accel and cache_config.enable_kv_sharing ) + self._indexer_gpu_handles = indexer_gpu_handles + self._indexer_cpu_handle = indexer_cpu_handle + self._indexer_ssd_handle = indexer_ssd_handle + self._indexer_remote_handle = indexer_remote_handle + self.pin_buffer = SharedOpPool(2048, self.cache_config.num_cpu_blocks) self.op_id_to_nvtx_range: Dict[int, str] = {} @@ -135,6 +144,12 @@ def __init__(self, self.tp_size = model_config.tp_size self.num_gpu_groups = len(self.gpu_handle_groups) self._running = False + self._has_indexer = False + + self._indexer_page_size = 1 + if cache_config.indexer is not None: + self._indexer_page_size = cache_config.indexer.page_size + self._indexer_op_to_parent_op: Dict[int, int] = {} def _init_workers(self) -> None: if self._running: @@ -366,10 +381,240 @@ def _init_workers(self) -> None: if self.cache_config.enable_p2p_ssd: self._worker_map[TransferType.PEERSSD2H] = self.cpu_remote_cpu_worker + # Initialize indexer workers + if (self._indexer_gpu_handles is not None + and self._indexer_cpu_handle is not None): + self._indexer_finished_ops_queue = self.mp_ctx.Queue() + self._indexer_worker_map: Dict[TransferType, Union[WorkerHandle, List[WorkerHandle]]] = {} + if self.tp_size == 1: + self._indexer_h2d_workers = [ + GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + self._indexer_d2h_workers = [ + GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + else: + self._indexer_h2d_workers = [ + tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + tp_group_size=self.tp_size, + dp_group_id=dp_client_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + self._indexer_d2h_workers = [ + tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + tp_group_size=self.tp_size, + dp_group_id=dp_client_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + self._indexer_worker_map[TransferType.H2D] = self._indexer_h2d_workers + self._indexer_worker_map[TransferType.D2H] = self._indexer_d2h_workers + if self._indexer_ssd_handle is not None and self._indexer_cpu_handle is not None: + self._indexer_h2disk_worker = CPUSSDDiskTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, + ) + self._indexer_disk2h_worker = CPUSSDDiskTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, + ) + self._indexer_worker_map[TransferType.H2DISK] = self._indexer_h2disk_worker + self._indexer_worker_map[TransferType.DISK2H] = self._indexer_disk2h_worker + flexkv_logger.info("TransferEngine: indexer SSD workers initialized") + if self._indexer_remote_handle is not None and self._indexer_cpu_handle is not None: + self._indexer_h2remote_worker = CPURemoteTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + remote_file=self._indexer_remote_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + remote_kv_layout=self._indexer_remote_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + remote_config_custom=self._indexer_remote_handle.remote_config_custom, + enable_pcfs_sharing=self._enable_pcfs_sharing, + ) + self._indexer_remote2h_worker = CPURemoteTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + remote_file=self._indexer_remote_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + remote_kv_layout=self._indexer_remote_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + remote_config_custom=self._indexer_remote_handle.remote_config_custom, + ) + self._indexer_worker_map[TransferType.H2REMOTE] = self._indexer_h2remote_worker + self._indexer_worker_map[TransferType.REMOTE2H] = self._indexer_remote2h_worker + flexkv_logger.info("TransferEngine: indexer Remote workers initialized") + if self.cache_config.enable_gds and self._indexer_ssd_handle is not None: + if self.tp_size == 1: + self._indexer_gds_workers = [ + GDSTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_ssd_handle.dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + ) + for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + else: + self._indexer_gds_workers = [ + tpGDSTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + ssd_files=self._indexer_ssd_handle.get_file_list(), + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_ssd_handle.dtype, + tp_group_size=self.tp_size, + dp_group_id=dp_client_id, + ) + for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + self._indexer_worker_map[TransferType.DISK2D] = self._indexer_gds_workers + self._indexer_worker_map[TransferType.D2DISK] = self._indexer_gds_workers + flexkv_logger.info("TransferEngine: indexer GDS workers initialized") + if self.cache_config.enable_kv_sharing and self._indexer_cpu_handle is not None and ( + self.cache_config.enable_p2p_cpu + or (self._indexer_ssd_handle and self.cache_config.enable_p2p_ssd)): + flexkv_logger.info("[transfer_engine] initializing the indexer PEER2CPUTransferWorker!") + self._indexer_cpu_remote_cpu_worker: WorkerHandle = PEER2CPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + remote_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + cache_config=self._cache_config, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout if self._indexer_ssd_handle else None, + ssd_files=self._indexer_ssd_handle.get_file_list() if self._indexer_ssd_handle else None, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file if self._indexer_ssd_handle else None, + ) + if self.cache_config.enable_p2p_cpu: + self._indexer_worker_map[TransferType.PEERH2H] = self._indexer_cpu_remote_cpu_worker + if self.cache_config.enable_p2p_ssd: + self._indexer_worker_map[TransferType.PEERSSD2H] = self._indexer_cpu_remote_cpu_worker + flexkv_logger.info("TransferEngine: indexer P2P workers initialized") + if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + indexer_ssd_files = {} if self._indexer_ssd_handle is None else self._indexer_ssd_handle.get_file_list() + indexer_ssd_kv_layout = None if self._indexer_ssd_handle is None else self._indexer_ssd_handle.kv_layout + indexer_num_blocks_per_file = 0 if self._indexer_ssd_handle is None else self._indexer_ssd_handle.num_blocks_per_file + self._indexer_layerwise_workers = [ + LayerwiseTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + ssd_files=indexer_ssd_files, + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + ssd_kv_layout=indexer_ssd_kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + tp_group_size=self.tp_size, + dp_group_id=dp_client_id, + num_blocks_per_file=indexer_num_blocks_per_file, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + enable_eventfd=False, + ) + for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + self._indexer_worker_map[TransferType.LAYERWISE] = self._indexer_layerwise_workers + flexkv_logger.info("TransferEngine: indexer Layerwise workers initialized") + self._has_indexer = True + flexkv_logger.info( + f"TransferEngine: indexer inline workers initialized " + f"({len(self._indexer_h2d_workers)} H2D + {len(self._indexer_d2h_workers)} D2H)") if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") - # Wait for all workers to ready + # Wait for all main KV workers to ready for transfer_type, worker in self._worker_map.items(): if isinstance(worker, List): for w in worker: @@ -380,6 +625,18 @@ def _init_workers(self) -> None: flexkv_logger.info(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") worker.ready_event.wait() flexkv_logger.info(f"{transfer_type.name} worker {worker.worker_id} is ready") + # Wait for all indexer workers to ready + if self._has_indexer: + for transfer_type, worker in self._indexer_worker_map.items(): + if isinstance(worker, List): + for w in worker: + flexkv_logger.info(f"waiting for indexer {transfer_type.name} worker {w.worker_id} to ready") + w.ready_event.wait() + flexkv_logger.info(f"indexer {transfer_type.name} worker {w.worker_id} is ready") + else: + flexkv_logger.info(f"waiting for indexer {transfer_type.name} worker {worker.worker_id} to ready") + worker.ready_event.wait() + flexkv_logger.info(f"indexer {transfer_type.name} worker {worker.worker_id} is ready") # Start scheduler thread self._running = True self._scheduler_thread = threading.Thread(target=self._scheduler_loop) @@ -399,6 +656,10 @@ def _scheduler_loop(self) -> None: sel.register(self.task_queue._reader, selectors.EVENT_READ, data="new_graph") sel.register(self.finished_ops_queue._reader, selectors.EVENT_READ, data="finished_op") + # Register indexer finished_ops_queue when indexer is enabled + if self._has_indexer: + sel.register(self._indexer_finished_ops_queue._reader, selectors.EVENT_READ, data="indexer_finished_op") + # Register shutdown pipe for zero-latency shutdown sel.register(self.shutdown_read_fd, selectors.EVENT_READ, data="shutdown") @@ -439,21 +700,42 @@ def _scheduler_loop(self) -> None: nvtx.end_range(nvtx_r1) elif key.data == "finished_op": - # Collect finished ops (batch get all available) + # Collect finished ops from main KV worker (batch get all available) nvtx_r2 = nvtx.start_range(message="transfer scheduler. collect finished ops", color="orange") # Get all available ops in one go to reduce system calls while True: try: op_id = self.finished_ops_queue.get_nowait() op = self.op_id_to_op[op_id] - free_op_from_buffer(op, self.pin_buffer) - self.completed_queue.put(CompletedOp(graph_id=op.graph_id, op_id=op.op_id)) - finished_ops.append(op) - del self.op_id_to_op[op_id] + op.pending_count -= 1 + if op.pending_count == 0: + self._finalize_op(op, finished_ops) except queue.Empty: break nvtx.end_range(nvtx_r2) + elif key.data == "indexer_finished_op": + # Collect finished ops from indexer worker (batch get all available) + nvtx_r2i = nvtx.start_range(message="transfer scheduler. collect indexer finished ops", color="blue") + while True: + try: + op_id = self._indexer_finished_ops_queue.get_nowait() + assert op_id in self._indexer_op_to_parent_op, ( + f"[TransferEngine] Indexer op {op_id} not found in " + f"_indexer_op_to_parent_op. All indexer ops must be " + f"registered with a parent op." + ) + indexer_op = self.op_id_to_op.pop(op_id) + free_op_from_buffer(indexer_op, self.pin_buffer) + parent_op_id = self._indexer_op_to_parent_op.pop(op_id) + parent_op = self.op_id_to_op[parent_op_id] + parent_op.pending_count -= 1 + if parent_op.pending_count == 0: + self._finalize_op(parent_op, finished_ops) + except queue.Empty: + break + nvtx.end_range(nvtx_r2i) + # Exit loop if shutdown requested if should_shutdown: break @@ -489,6 +771,59 @@ def _scheduler_loop(self) -> None: sel.close() flexkv_logger.info("TransferEngine scheduler loop stopped") + def _finalize_op(self, op: TransferOp, finished_ops: List[TransferOp]) -> None: + """Finalize a completed op: release pin buffer, notify upper layer, and clean up. + + Called only when op.pending_count reaches 0, i.e., all workers (main KV + indexer) + have completed this op. This ensures atomic eviction semantics. + """ + free_op_from_buffer(op, self.pin_buffer) + # Compute transfer metrics for this completed op + num_blocks = len(op.src_block_ids) if op.src_block_ids is not None else 0 + num_bytes = num_blocks * self.cache_config.tokens_per_block * self.model_config.token_size_in_bytes + transfer_type_str = op.transfer_type.value if op.transfer_type != TransferType.VIRTUAL else None + self.completed_queue.put(CompletedOp( + graph_id=op.graph_id, + op_id=op.op_id, + transfer_type=transfer_type_str, + num_blocks=num_blocks, + num_bytes=num_bytes, + )) + finished_ops.append(op) + del self.op_id_to_op[op.op_id] + + def _convert_to_page_level_block_ids(self, block_ids: np.ndarray) -> np.ndarray: + """Convert block-level IDs to page-level IDs by dividing by page_size. + + Input block_ids are guaranteed to be page-aligned by the sglang caller + (e.g. flexkv_radix_cache), so we only do a defensive assert here. + """ + page_size = self._indexer_page_size + if block_ids.size == 0: + return block_ids.copy() + + assert block_ids.size % page_size == 0, ( + f"[TransferEngine] block_ids size {block_ids.size} is not a multiple " + f"of indexer page_size {page_size}. " + f"Caller must page-align block_ids before reaching transfer engine." + ) + + reshaped = block_ids.reshape(-1, page_size) + page_block_ids = reshaped[:, 0] // page_size + + # Validate: all block_ids in each group must map to the same indexer page. + # sglang's PagedTokenToKVPoolAllocator guarantees page-aligned allocation, + # so this should never fail in practice. + last_in_group = reshaped[:, -1] // page_size + assert np.array_equal(page_block_ids, last_in_group), ( + f"[TransferEngine] Indexer page group(s) have block_ids spanning multiple pages " + f"(page_size={page_size}). This indicates slot indices are not page-aligned. " + f"First mismatch: first_block={reshaped[page_block_ids != last_in_group][0, 0]}, " + f"last_block={reshaped[page_block_ids != last_in_group][0, -1]}" + ) + + return page_block_ids.astype(np.int64) + def _assign_op_to_worker(self, op: TransferOp) -> None: self.op_id_to_nvtx_range[op.op_id] = nvtx.start_range(f"schedule {op.transfer_type.name} " f"op_id: {op.op_id}, " @@ -501,6 +836,50 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: if op.transfer_type not in self._worker_map: raise ValueError(f"Unsupported transfer type: {op.transfer_type}") + if self._has_indexer and op.transfer_type in self._indexer_worker_map: + # Convert block_ids to indexer page-level IDs. + # _convert_to_page_level_block_ids handles all page_size values uniformly + src_page_ids = self._convert_to_page_level_block_ids(op.src_block_ids) + dst_page_ids = self._convert_to_page_level_block_ids(op.dst_block_ids) + + # Ensure both sides have the same number of pages after conversion + assert src_page_ids.size == dst_page_ids.size, ( + f"[TransferEngine] src_page_ids size {src_page_ids.size} != " + f"dst_page_ids size {dst_page_ids.size} for op {op.op_id}. " + f"Both sides must have the same number of pages." + ) + num_pages = src_page_ids.size + + if num_pages > 0: + # Always create a separate indexer_op to avoid sharing the same op + # object between indexer worker and main KV worker. + indexer_op = TransferOp( + graph_id=op.graph_id, + transfer_type=op.transfer_type, + src_block_ids=src_page_ids, + dst_block_ids=dst_page_ids, + layer_id=op.layer_id, + layer_granularity=op.layer_granularity, + dp_id=op.dp_id, + ) + register_op_to_buffer(indexer_op, self.pin_buffer) + self._indexer_op_to_parent_op[indexer_op.op_id] = op.op_id + self.op_id_to_op[indexer_op.op_id] = indexer_op + op.pending_count += 1 + + flexkv_logger.debug( + f"[TransferEngine] Created indexer op {indexer_op.op_id} " + f"for parent op {op.op_id}: {num_pages} pages, " + f"page_size={self._indexer_page_size}, " + f"type={op.transfer_type.name}" + ) + + indexer_worker = self._indexer_worker_map[op.transfer_type] + if isinstance(indexer_worker, List): + indexer_worker[op.dp_id].submit_transfer(indexer_op) + else: + indexer_worker.submit_transfer(indexer_op) + worker = self._worker_map[op.transfer_type] if isinstance(worker, List): worker[op.dp_id].submit_transfer(op) @@ -569,7 +948,15 @@ def shutdown(self) -> None: else: flexkv_logger.debug(f"Shutdown pipes already closed: {e}") - # shutdown all workers + # shutdown indexer workers first + if self._has_indexer: + for worker in self._indexer_worker_map.values(): + if isinstance(worker, List): + for w in worker: + w.shutdown() + else: + worker.shutdown() + # shutdown main KV workers for worker in self._worker_map.values(): if isinstance(worker, List): for w in worker: diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 0633447360..49166f3aea 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -49,6 +49,10 @@ def __init__(self, self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks self.gpu_client_mapping: Dict[int, int] = {} # device_id -> dp_client_id + # Indexer GPU registration data + self.all_indexer_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> indexer_gpu_blocks + self.all_indexer_gpu_layouts: Dict[int, KVCacheLayout] = {} + self.context = zmq.Context(2) self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, gpu_register_port, True) @@ -68,6 +72,13 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: self.all_gpu_blocks[device_id] = req.handles self.all_gpu_layouts[device_id] = req.gpu_layout self.gpu_client_mapping[device_id] = req.dp_client_id + # Store indexer GPU data if present + if req.indexer_handles is not None: + self.all_indexer_gpu_blocks[device_id] = req.indexer_handles + self.all_indexer_gpu_layouts[device_id] = req.indexer_gpu_layout + flexkv_logger.info( + f"GPU {device_id}: registered indexer handles " + f"({len(req.indexer_handles)} layers)") except Exception as e: flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") @@ -115,10 +126,20 @@ def initialize_transfer_engine(self) -> None: # Register GPU blocks with their global device IDs for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): - self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, - self.all_gpu_layouts[device_id], - device_id, - dtype=self.model_config.dtype) + # Get indexer data for this device if available + indexer_gpu_blocks = self.all_indexer_gpu_blocks.get(device_id) + indexer_gpu_layout = self.all_indexer_gpu_layouts.get(device_id) + indexer_dtype = (self.cache_config.indexer.dtype + if self.cache_config.indexer is not None else None) + self.storage_engine.register_gpu_blocks( + gpu_blocks_wrapper, + self.all_gpu_layouts[device_id], + device_id, + dtype=self.model_config.dtype, + indexer_gpu_blocks=indexer_gpu_blocks, + indexer_gpu_layout=indexer_gpu_layout, + indexer_dtype=indexer_dtype, + ) # Group GPU handles by dp_client_id grouped_gpu_handles: Dict[int, List] = {} @@ -138,12 +159,45 @@ def initialize_transfer_engine(self) -> None: if self.cache_config.enable_remote \ else None ) - self.transfer_engine = TransferEngine(gpu_handles=grouped_gpu_handles, - model_config=self.model_config, - cache_config=self.cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle, - remote_handle=remote_handle) + + indexer_gpu_handles: Optional[Dict[int, List]] = None + if self.storage_engine.has_storage_handle(DeviceType.CPU, is_indexer=True): + indexer_gpu_handles = {} + for device_id in sorted(self.all_gpu_blocks.keys()): + if self.storage_engine.has_storage_handle(DeviceType.GPU, device_id, is_indexer=True): + dp_client_id = self.gpu_client_mapping[device_id] + if dp_client_id not in indexer_gpu_handles: + indexer_gpu_handles[dp_client_id] = [] + indexer_gpu_handles[dp_client_id].append( + self.storage_engine.get_storage_handle(DeviceType.GPU, device_id, is_indexer=True)) + indexer_cpu_handle = ( + self.storage_engine.get_storage_handle(DeviceType.CPU, is_indexer=True) + if self.storage_engine.has_storage_handle(DeviceType.CPU, is_indexer=True) + else None + ) + indexer_ssd_handle = ( + self.storage_engine.get_storage_handle(DeviceType.SSD, is_indexer=True) + if self.storage_engine.has_storage_handle(DeviceType.SSD, is_indexer=True) + else None + ) + indexer_remote_handle = ( + self.storage_engine.get_storage_handle(DeviceType.REMOTE, is_indexer=True) + if self.storage_engine.has_storage_handle(DeviceType.REMOTE, is_indexer=True) + else None + ) + + self.transfer_engine = TransferEngine( + gpu_handles=grouped_gpu_handles, + model_config=self.model_config, + cache_config=self.cache_config, + cpu_handle=cpu_handle, + ssd_handle=ssd_handle, + remote_handle=remote_handle, + indexer_gpu_handles=indexer_gpu_handles, + indexer_cpu_handle=indexer_cpu_handle, + indexer_ssd_handle=indexer_ssd_handle, + indexer_remote_handle=indexer_remote_handle, + ) flexkv_logger.info("Initialized TransferEngine successfully") def submit(self, transfer_graph: TransferOpGraph) -> None: diff --git a/setup.py b/setup.py index 02bb0b5d01..e5b618f65a 100755 --- a/setup.py +++ b/setup.py @@ -8,8 +8,26 @@ from torch.utils import cpp_extension def get_version(): - with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: - return f.read().strip() + import subprocess + try: + # e.g. "v1.0.0-0-gabc1234" or "v1.0.0-3-gabc1234" + raw = subprocess.check_output( + ["git", "describe", "--tags", "--long", "--match", "v*"], + stderr=subprocess.PIPE, + cwd=os.path.dirname(os.path.abspath(__file__)), + ).decode().strip() + # parse: v1.0.0--g + parts = raw.rsplit("-", 2) + if len(parts) != 3: + raise ValueError(f"Unexpected git describe output format: {raw!r}") + tag, distance, git_hash = parts + tag = tag.lstrip("v") + if distance == "0": + return tag # clean release + else: + return f"{tag}+git{git_hash[1:]}" # dev build + except Exception: + return "0.0.0+unknown" build_dir = "build" os.makedirs(build_dir, exist_ok=True) diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index 6f8393023d..e0c5f21d54 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -8,13 +8,15 @@ import multiprocessing as mp from multiprocessing import Process, Pipe -from flexkv.common.config import ModelConfig, CacheConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.request import KVResponseStatus from flexkv.kvtask import KVTaskEngine from flexkv.kvmanager import KVManager from flexkv.common.memory_handle import TensorSharedHandle from flexkv.server.client import KVTPClient +import traceback + from flexkv.common.debug import flexkv_logger # Import utilities from common_utils @@ -455,3 +457,423 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): return elif total_cache_miss > 0: print(f"verify skipped, because of total_cache_miss={total_cache_miss} > 0") + + +class GPUIndexerCacheVerifier: + def __init__(self, + shared_indexer_blocks, + indexer_kv_layout: KVCacheLayout, + tp_size: int, + dtype: torch.dtype) -> None: + if not shared_indexer_blocks: + raise ValueError("shared_indexer_blocks must not be empty") + + if isinstance(shared_indexer_blocks[0][0], torch.Tensor): + self.gpu_blocks = shared_indexer_blocks + else: + imported_gpu_blocks = [] + for handles_in_one_gpu in shared_indexer_blocks: + imported_gpu_blocks.append([handle.get_tensor() for handle in handles_in_one_gpu]) + self.gpu_blocks = imported_gpu_blocks + + self.num_layers = indexer_kv_layout.num_layer + self.tokens_per_block = indexer_kv_layout.tokens_per_block + self.head_size = indexer_kv_layout.head_size + self.tp_size = tp_size + self.dtype = dtype + + def hash_all_values(self, layer_id, token_ids): + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + + token_hash = 0 + for i, token_id in enumerate(token_ids): + token_hash += int(token_id) * (i + 17) + return torch.tensor(((layer_id + 1) * 29 + token_hash) % 251 + 1, dtype=self.dtype).item() + + def fill_gpu_blocks(self, token_ids, block_ids): + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for tp_id in range(self.tp_size): + for layer_id in range(self.num_layers): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + hash_value = self.hash_all_values( + layer_id, + token_ids[start_token_idx:end_token_idx], + ) + gpu_tensor[block_id, :, :] = hash_value + + def clear_gpu_blocks(self, block_ids): + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for tp_id in range(self.tp_size): + for layer_id in range(self.num_layers): + self.gpu_blocks[tp_id][layer_id][block_ids, :, :] = 0 + + def verify_gpu_blocks(self, token_ids, block_ids) -> bool: + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + verification_passed = True + errors = [] + + for tp_id in range(self.tp_size): + for layer_id in range(self.num_layers): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + expected_hash_value = self.hash_all_values( + layer_id, + token_ids[start_token_idx:end_token_idx], + ) + actual_values = gpu_tensor[block_id, :, :] + expected_tensor = torch.full_like(actual_values, expected_hash_value) + if not torch.equal(actual_values, expected_tensor): + verification_passed = False + max_abs_diff = ( + actual_values.to(torch.int32) - expected_tensor.to(torch.int32) + ).abs().max().item() + errors.append( + f"Mismatch at tp={tp_id}, layer={layer_id}, block={block_id}: " + f"expected={expected_hash_value}, max_abs_diff={max_abs_diff}" + ) + + if not verification_passed: + print(f"Indexer verification failed with {len(errors)} errors:") + for error in errors[:10]: + print(f" {error}") + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more errors") + else: + print("Indexer GPU blocks verification passed!") + assert verification_passed + return verification_passed + + +def run_tp_client_with_indexer(dp_client_id, + tp_rank, + server_recv_port, + model_config, + cache_config, + num_gpu_blocks, + child_conn, + gpu_layout_type): + """Run tp_client process with indexer support (shadow transfer mode). + + Indexer configuration is read from cache_config.indexer (IndexerCacheConfig). + """ + try: + device_id = tp_rank + dp_client_id * model_config.tp_size + + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) + + # Create main GPU blocks + gpu_blocks_for_tp = [] + if gpu_layout_type == 0: + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + elif gpu_layout_type == 2: + kv_dim = 1 if model_config.use_mla else 2 + for _ in range(model_config.num_layers * kv_dim): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[2:]), dtype=model_config.dtype).cuda(device_id) + ) + else: + raise ValueError(f"Invalid GPU layout type for indexer test: {gpu_layout_type}") + + # Derive indexer params from cache_config.indexer (IndexerCacheConfig). + # Shared fields (num_layers, tokens_per_block) come from main model_config / cache_config. + indexer_cfg = cache_config.indexer + assert indexer_cfg is not None, "cache_config.indexer must be set for indexer shadow transfer tests" + indexer_tokens_per_block = cache_config.tokens_per_block # shared with main KV + indexer_num_layers = model_config.num_layers # shared with main KV + + # Create indexer GPU blocks (MLA-style: 3D tensors) + indexer_blocks = [] + for _ in range(indexer_num_layers): + indexer_blocks.append( + torch.empty( + num_gpu_blocks, + indexer_tokens_per_block, + indexer_cfg.head_size, + dtype=indexer_cfg.dtype, + ).cuda(device_id) + ) + + from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType + indexer_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=indexer_num_layers, + num_block=num_gpu_blocks, + tokens_per_block=indexer_tokens_per_block, + num_head=indexer_cfg.num_kv_heads, + head_size=indexer_cfg.head_size, + is_mla=indexer_cfg.use_mla, + ) + + # Use KVTPClient directly with indexer buffers (shadow transfer mode) + tp_client = KVTPClient( + gpu_register_port=server_recv_port + "_gpu_register", + dp_client_id=dp_client_id, + device_id=device_id, + ) + tp_client.register_to_server( + kv_caches=gpu_blocks_for_tp, + kv_layout=gpu_kv_layout, + indexer_buffers=indexer_blocks, + indexer_layout=indexer_layout, + ) + + # Send GPU blocks back to main process via pipe + if child_conn is not None: + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks_for_tp] + shared_indexer_blocks = [TensorSharedHandle(tensor) for tensor in indexer_blocks] + child_conn.send({ + "main": shared_gpu_blocks, + "indexer": shared_indexer_blocks, + }) + child_conn.close() + + # Keep the process running + while True: + time.sleep(1) + except Exception as e: + print(f"[TP Client {tp_rank}] Exception occurred: {type(e).__name__}: {str(e)}") + traceback.print_exc() + if child_conn is not None: + child_conn.send(None) + child_conn.close() + + +@pytest.mark.parametrize( + "model_config", + [ + {"tp_size": 1, "dp_size": 1}, + ], indirect=True, +) +@pytest.mark.parametrize("cache_config", [ + {'enable_cpu': True, 'enable_ssd': False, 'num_cpu_blocks': 1024}, +], indirect=True) +@pytest.mark.parametrize("test_config", [ + {'num_gpu_blocks': 256, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, +], indirect=True) +@pytest.mark.parametrize("gpu_layout_type", [0]) +def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_layout_type): + """Test KVManager shadow transfer mode with attached indexer buffers.""" + tp_size = model_config.tp_size + tokens_per_block = cache_config.tokens_per_block + num_gpu_blocks = test_config["num_gpu_blocks"] + block_per_request = test_config['requests_per_block'] + initial_write_ratio = test_config['initial_write_ratio'] + num_requests = num_gpu_blocks // block_per_request + + skip_if_insufficient_gpus(tp_size) + + # Set indexer config inside cache_config.indexer (IndexerCacheConfig). + # Only indexer-unique fields are stored here; shared fields (num_layers, + # tokens_per_block, num_cpu_blocks) are read from model_config / cache_config. + from flexkv.common.config import IndexerCacheConfig + cache_config.indexer = IndexerCacheConfig( + head_size=64, + num_kv_heads=1, + use_mla=True, + dtype=torch.uint8, + ) + + kvmanager = KVManager( + model_config, + cache_config, + ) + kvmanager.start() + + mp_ctx = mp.get_context('spawn') + pipe_connections = [] + tp_client_processes = [] + + for tp_rank in range(tp_size): + parent_conn, child_conn = mp_ctx.Pipe() + pipe_connections.append(parent_conn) + + tp_client_process = mp_ctx.Process( + target=run_tp_client_with_indexer, + args=(0, tp_rank, kvmanager.server_recv_port, + model_config, cache_config, num_gpu_blocks, child_conn, + gpu_layout_type), + daemon=True + ) + tp_client_processes.append(tp_client_process) + tp_client_process.start() + + all_gpu_blocks = [] + all_indexer_blocks = [] + for tp_rank, parent_conn in enumerate(pipe_connections): + try: + shared_payload = parent_conn.recv() + if shared_payload is not None: + if isinstance(shared_payload, dict): + shared_gpu_blocks = shared_payload.get("main") + shared_indexer_blocks = shared_payload.get("indexer") + else: + shared_gpu_blocks = shared_payload + shared_indexer_blocks = None + if shared_gpu_blocks is not None: + all_gpu_blocks.append(shared_gpu_blocks) + print(f"[Main Process] Received GPU blocks from TP client {tp_rank}") + if shared_indexer_blocks is not None: + all_indexer_blocks.append(shared_indexer_blocks) + parent_conn.close() + except Exception as e: + print(f"[Main Process] Error receiving from TP client {tp_rank}: {e}") + + gpu_kv_verifier = None + if all_gpu_blocks and len(all_gpu_blocks) == tp_size: + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) + gpu_kv_verifier = GPUKVCacheVerifier( + shared_gpu_blocks=all_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + gpu_layout_type=gpu_layout_type, + ) + + indexer_kv_verifier = None + indexer_cfg = cache_config.indexer + if all_indexer_blocks and len(all_indexer_blocks) == tp_size and indexer_cfg is not None: + indexer_gpu_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=indexer_cfg.num_kv_heads, + head_size=indexer_cfg.head_size, + is_mla=indexer_cfg.use_mla, + ) + indexer_kv_verifier = GPUIndexerCacheVerifier( + shared_indexer_blocks=all_indexer_blocks, + indexer_kv_layout=indexer_gpu_layout, + tp_size=model_config.tp_size, + dtype=indexer_cfg.dtype, + ) + + while not kvmanager.is_ready(): + time.sleep(1) + flexkv_logger.info("waiting for flexkv (with indexer shadow transfer) to be ready") + print("[Test] KVManager (with indexer shadow transfer) is ready") + + request_pairs = [generate_request_pair(i, block_per_request, num_gpu_blocks, tokens_per_block, 1) + for i in range(num_requests)] + initial_write_num = int(num_requests * initial_write_ratio) + + print("[Test] Testing put flow with indexer shadow transfer...") + for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + write_request = kvmanager.put_async( + token_ids=token_ids, + slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, + dp_id=dp_id, + ) + put_results = kvmanager.wait([write_request], completely=True) + assert put_results[write_request].status == KVResponseStatus.SUCCESS + if gpu_kv_verifier is not None: + gpu_kv_verifier.clear_gpu_blocks(block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.clear_gpu_blocks(block_ids) + print(f"[Test] Initial {initial_write_num} put operations completed with indexer shadow transfer") + + print("[Test] Testing get flow with indexer shadow transfer...") + total_cache_hit = 0 + total_cache_miss = 0 + running_get_requests = [] + req_id2block_ids = {} + req_id2token_ids = {} + + for i in range(min(initial_write_num, num_requests)): + token_ids, block_ids, dp_id = request_pairs[i] + slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) + request_id, _ = kvmanager.get_match( + token_ids=token_ids, + layer_granularity=-1, + token_mask=None, + dp_id=dp_id, + ) + kvmanager.launch(request_id, slot_mapping) + running_get_requests.append(request_id) + req_id2block_ids[request_id] = block_ids + req_id2token_ids[request_id] = token_ids + + if running_get_requests: + return_results = kvmanager.wait(running_get_requests, completely=True) + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + total_cache_hit += kvresponse.return_mask.sum().item() + total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() + if gpu_kv_verifier is not None: + valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + if valid_fetched_tokens > 0: + assert gpu_kv_verifier.verify_kv_blocks( + req_id2token_ids[req_id][:valid_fetched_tokens], + req_id2block_ids[req_id][:valid_fetched_tokens // tokens_per_block]) + if indexer_kv_verifier is not None: + valid_fetched_blocks = kvresponse.return_mask.sum().item() // tokens_per_block + if valid_fetched_blocks > 0: + assert indexer_kv_verifier.verify_gpu_blocks( + req_id2token_ids[req_id][:valid_fetched_blocks * tokens_per_block], + req_id2block_ids[req_id][:valid_fetched_blocks]) + print(f"[Test] Get flow completed: hit={total_cache_hit}, miss={total_cache_miss}") + + print("[Test] Testing try_wait flow with indexer shadow transfer...") + if initial_write_num < num_requests: + token_ids, block_ids, dp_id = request_pairs[initial_write_num] + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + write_request = kvmanager.put_async( + token_ids=token_ids, + slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, + dp_id=dp_id, + ) + finished = {} + for _ in range(200): + finished = kvmanager.try_wait([write_request]) + if write_request in finished: + break + time.sleep(0.1) + assert write_request in finished, "try_wait should eventually return the completed task" + assert finished[write_request].status == KVResponseStatus.SUCCESS + if gpu_kv_verifier is not None: + gpu_kv_verifier.clear_gpu_blocks(block_ids) + if indexer_kv_verifier is not None: + indexer_kv_verifier.clear_gpu_blocks(block_ids) + print("[Test] try_wait flow completed") + + print("[Test] Testing shutdown with indexer shadow transfer...") + if cache_config.enable_cpu and cache_config.num_cpu_blocks >= num_gpu_blocks: + assert total_cache_miss == 0, f"Expected 0 cache miss, got {total_cache_miss}" + + shutdown_tp_client(tp_client_processes) + kvmanager.shutdown() + print("[Test] Shutdown completed successfully") + print("[Test] test_kvmanager_with_indexer PASSED") diff --git a/tests/test_transfer_engine_atomic_eviction.py b/tests/test_transfer_engine_atomic_eviction.py new file mode 100644 index 0000000000..4310e1d1f2 --- /dev/null +++ b/tests/test_transfer_engine_atomic_eviction.py @@ -0,0 +1,404 @@ +""" +Unit tests for atomic indexer eviction in TransferEngine. + +These tests verify that: +1. TransferOp.pending_count defaults to 1. +2. _finalize_op is called only when pending_count reaches 0. +3. With indexer enabled: CompletedOp is NOT emitted until both main KV and indexer + workers complete (pending_count == 0). +4. With indexer disabled: behavior is identical to the original (pending_count starts + at 1, _finalize_op is called immediately after main KV completes). +""" +import queue +import unittest +from typing import List +from unittest.mock import MagicMock, patch, call + +import numpy as np + +from flexkv.common.transfer import TransferOp, TransferType, CompletedOp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_op(transfer_type: TransferType = TransferType.D2H) -> TransferOp: + """Create a minimal TransferOp for testing.""" + return TransferOp( + graph_id=0, + transfer_type=transfer_type, + src_block_ids=np.array([0, 1], dtype=np.int64), + dst_block_ids=np.array([2, 3], dtype=np.int64), + ) + + +# --------------------------------------------------------------------------- +# Tests – TransferOp.pending_count field +# --------------------------------------------------------------------------- + +class TestTransferOpPendingCount(unittest.TestCase): + """Requirement 5: TransferOp supports pending_count field.""" + + def test_default_pending_count_is_one(self): + """pending_count SHALL default to 1 (req 5.1).""" + op = _make_op() + self.assertEqual(op.pending_count, 1) + + def test_pending_count_is_mutable(self): + """pending_count SHALL be mutable (dataclass, not frozen).""" + op = _make_op() + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + op.pending_count -= 1 + self.assertEqual(op.pending_count, 1) + op.pending_count -= 1 + self.assertEqual(op.pending_count, 0) + + +# --------------------------------------------------------------------------- +# Tests – _finalize_op logic (unit-level, no real workers) +# --------------------------------------------------------------------------- + +class TestFinalizeOpLogic(unittest.TestCase): + """ + Requirement 1, 3, 4: _finalize_op is called only when pending_count == 0. + We test the logic directly by simulating what _scheduler_loop does. + """ + + def _simulate_worker_done(self, op: TransferOp, finished_ops: List[TransferOp], + finalize_fn) -> None: + """Simulate what _scheduler_loop does when a worker completes an op.""" + op.pending_count -= 1 + if op.pending_count == 0: + finalize_fn(op, finished_ops) + + def test_no_indexer_finalize_called_immediately(self): + """Without indexer: pending_count starts at 1, finalize called after main KV done (req 6.1).""" + op = _make_op() + self.assertEqual(op.pending_count, 1) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Main KV worker completes + self._simulate_worker_done(op, finished_ops, finalize_mock) + + # pending_count should be 0 and finalize should have been called once + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + def test_with_indexer_finalize_not_called_after_main_kv_only(self): + """With indexer: finalize NOT called when only main KV completes (req 3.1, 4.1).""" + op = _make_op() + # Simulate _assign_op_to_worker incrementing pending_count before submitting to indexer + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Main KV worker completes first + self._simulate_worker_done(op, finished_ops, finalize_mock) + + # pending_count should be 1, finalize should NOT have been called + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + self.assertEqual(len(finished_ops), 0) + + def test_with_indexer_finalize_called_after_both_complete(self): + """With indexer: finalize called exactly once when both workers complete (req 3.2, 4.2).""" + op = _make_op() + # Simulate _assign_op_to_worker incrementing pending_count before submitting to indexer + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Main KV worker completes first + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + + # Indexer worker completes + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + def test_with_indexer_finalize_called_once_regardless_of_order(self): + """Finalize called exactly once even if indexer completes before main KV (req 3.2, 4.2).""" + op = _make_op() + op.pending_count += 1 # indexer registered + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + # Indexer worker completes first + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + + # Main KV worker completes + self._simulate_worker_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + +# --------------------------------------------------------------------------- +# Tests – _finalize_op method behavior +# --------------------------------------------------------------------------- + +class TestFinalizeOpMethod(unittest.TestCase): + """ + Test that _finalize_op correctly calls free_op_from_buffer, puts CompletedOp, + appends to finished_ops, and deletes from op_id_to_op. + """ + + def _make_engine_stub(self): + """Create a minimal stub of TransferEngine with the real _finalize_op method.""" + from flexkv.transfer.transfer_engine import TransferEngine, free_op_from_buffer + + engine = object.__new__(TransferEngine) + engine.op_id_to_op = {} + engine.completed_queue = MagicMock() + engine.pin_buffer = MagicMock() + engine.cache_config = MagicMock() + engine.cache_config.tokens_per_block = 16 + engine.model_config = MagicMock() + engine.model_config.token_size_in_bytes = 2 + return engine + + def test_finalize_op_releases_buffer_and_notifies(self): + """_finalize_op SHALL call free_op_from_buffer and put CompletedOp (req 3.2, 4.2).""" + from flexkv.transfer.transfer_engine import TransferEngine, free_op_from_buffer + + engine = self._make_engine_stub() + op = _make_op() + engine.op_id_to_op[op.op_id] = op + + finished_ops: List[TransferOp] = [] + + with patch('flexkv.transfer.transfer_engine.free_op_from_buffer') as mock_free: + engine._finalize_op(op, finished_ops) + + # free_op_from_buffer called once + mock_free.assert_called_once_with(op, engine.pin_buffer) + # CompletedOp put to completed_queue once + engine.completed_queue.put.assert_called_once() + completed_op_arg = engine.completed_queue.put.call_args[0][0] + self.assertIsInstance(completed_op_arg, CompletedOp) + self.assertEqual(completed_op_arg.graph_id, op.graph_id) + self.assertEqual(completed_op_arg.op_id, op.op_id) + # op appended to finished_ops + self.assertIn(op, finished_ops) + # op removed from op_id_to_op + self.assertNotIn(op.op_id, engine.op_id_to_op) + + def test_finalize_op_removes_op_from_tracking_dict(self): + """_finalize_op SHALL delete op from op_id_to_op (req 3.2 - no double free).""" + engine = self._make_engine_stub() + op = _make_op() + engine.op_id_to_op[op.op_id] = op + + finished_ops: List[TransferOp] = [] + + with patch('flexkv.transfer.transfer_engine.free_op_from_buffer'): + engine._finalize_op(op, finished_ops) + + self.assertNotIn(op.op_id, engine.op_id_to_op) + + def test_finalize_op_not_called_twice(self): + """op_id_to_op deletion prevents double finalization (req 3.2 - exactly once).""" + engine = self._make_engine_stub() + op = _make_op() + engine.op_id_to_op[op.op_id] = op + + finished_ops: List[TransferOp] = [] + + with patch('flexkv.transfer.transfer_engine.free_op_from_buffer'): + engine._finalize_op(op, finished_ops) + # Second call should raise KeyError since op was already removed + with self.assertRaises(KeyError): + engine._finalize_op(op, finished_ops) + + +# --------------------------------------------------------------------------- +# Tests – Indexer Layerwise Worker initialization and op dispatch +# --------------------------------------------------------------------------- + +class TestIndexerLayerwiseWorkerInit(unittest.TestCase): + """ + Tests for indexer LayerwiseTransferWorker initialization and LAYERWISE op dispatch. + Verifies requirements 1.1, 1.3, 2.1, 2.2, 5.1, 5.3. + """ + + def _make_engine_stub_with_indexer(self, enable_layerwise: bool = True): + """ + Create a minimal TransferEngine stub with _has_indexer=True and + a pre-populated _indexer_worker_map (simulating post-_init_workers state). + """ + from flexkv.transfer.transfer_engine import TransferEngine + + engine = object.__new__(TransferEngine) + engine._has_indexer = True + engine._worker_map = {} + engine._indexer_worker_map = {} + engine._indexer_page_size = 1 + engine._indexer_op_to_parent_op = {} + engine.op_id_to_op = {} + engine.op_id_to_nvtx_range = {} + engine.completed_queue = MagicMock() + engine.pin_buffer = MagicMock() + engine.cache_config = MagicMock() + engine.cache_config.tokens_per_block = 16 + engine.model_config = MagicMock() + engine.model_config.token_size_in_bytes = 2 + + # Create mock workers for main KV + main_layerwise_worker = MagicMock() + engine._worker_map[TransferType.H2D] = [MagicMock()] + engine._worker_map[TransferType.D2H] = [MagicMock()] + if enable_layerwise: + engine._worker_map[TransferType.LAYERWISE] = [main_layerwise_worker] + + # Create mock workers for indexer + indexer_h2d_worker = MagicMock() + indexer_layerwise_worker = MagicMock() + engine._indexer_worker_map[TransferType.H2D] = [indexer_h2d_worker] + engine._indexer_worker_map[TransferType.D2H] = [MagicMock()] + if enable_layerwise: + engine._indexer_worker_map[TransferType.LAYERWISE] = [indexer_layerwise_worker] + + return engine, main_layerwise_worker, indexer_layerwise_worker + + def test_indexer_worker_map_contains_layerwise_when_enabled(self): + """ + WHEN enable_layerwise_transfer=True AND indexer handles exist + THEN _indexer_worker_map SHALL contain TransferType.LAYERWISE (req 1.1). + """ + engine, _, _ = self._make_engine_stub_with_indexer(enable_layerwise=True) + self.assertIn(TransferType.LAYERWISE, engine._indexer_worker_map) + + def test_indexer_worker_map_no_layerwise_when_disabled(self): + """ + IF enable_layerwise_transfer=False + THEN _indexer_worker_map SHALL NOT contain TransferType.LAYERWISE (req 5.1). + """ + engine, _, _ = self._make_engine_stub_with_indexer(enable_layerwise=False) + self.assertNotIn(TransferType.LAYERWISE, engine._indexer_worker_map) + + def test_layerwise_op_pending_count_incremented_for_indexer(self): + """ + WHEN _assign_op_to_worker processes a LAYERWISE op with _has_indexer=True + THEN op.pending_count SHALL be incremented by 1 before submitting to indexer (req 2.2). + """ + from flexkv.transfer.transfer_engine import register_op_to_buffer + import nvtx + + engine, main_worker, indexer_worker = self._make_engine_stub_with_indexer(enable_layerwise=True) + + op = _make_op(TransferType.LAYERWISE) + op.dp_id = 0 + engine.op_id_to_op[op.op_id] = op + + initial_pending_count = op.pending_count # should be 1 + + with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ + patch('nvtx.start_range', return_value=MagicMock()): + engine._assign_op_to_worker(op) + + # pending_count should have been incremented by 1 (for indexer) before submission + # After _assign_op_to_worker: pending_count = initial + 1 = 2 + self.assertEqual(op.pending_count, initial_pending_count + 1) + + def test_layerwise_op_submitted_to_both_main_and_indexer_workers(self): + """ + WHEN _assign_op_to_worker processes a LAYERWISE op with _has_indexer=True + THEN op SHALL be submitted to main KV worker, and a separate indexer_op + SHALL be submitted to the indexer layerwise worker (req 2.1). + """ + from flexkv.transfer.transfer_engine import register_op_to_buffer + + engine, main_worker, indexer_worker = self._make_engine_stub_with_indexer(enable_layerwise=True) + + op = _make_op(TransferType.LAYERWISE) + op.dp_id = 0 + engine.op_id_to_op[op.op_id] = op + + with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ + patch('nvtx.start_range', return_value=MagicMock()): + engine._assign_op_to_worker(op) + + # Main KV worker should have received the original op + main_worker.submit_transfer.assert_called_once_with(op) + # Indexer worker should have received a separate indexer_op (not the same op object) + indexer_worker.submit_transfer.assert_called_once() + indexer_op = indexer_worker.submit_transfer.call_args[0][0] + self.assertIsNot(indexer_op, op, "Indexer worker must receive a separate op, not the original") + self.assertEqual(indexer_op.graph_id, op.graph_id) + self.assertEqual(indexer_op.transfer_type, op.transfer_type) + + def test_layerwise_op_no_indexer_pending_count_stays_one(self): + """ + WHEN no indexer exists and LAYERWISE op is dispatched + THEN pending_count SHALL remain 1 (req 5.3). + """ + from flexkv.transfer.transfer_engine import TransferEngine + + engine = object.__new__(TransferEngine) + engine._has_indexer = False + engine._worker_map = {} + engine._indexer_worker_map = {} + engine.op_id_to_op = {} + engine.op_id_to_nvtx_range = {} + + main_layerwise_worker = MagicMock() + engine._worker_map[TransferType.LAYERWISE] = [main_layerwise_worker] + + op = _make_op(TransferType.LAYERWISE) + op.dp_id = 0 + engine.op_id_to_op[op.op_id] = op + + initial_pending_count = op.pending_count # should be 1 + + with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ + patch('nvtx.start_range', return_value=MagicMock()): + engine._assign_op_to_worker(op) + + # pending_count should remain 1 (no indexer to increment for) + self.assertEqual(op.pending_count, initial_pending_count) + main_layerwise_worker.submit_transfer.assert_called_once_with(op) + + def test_finalize_called_after_both_layerwise_workers_complete(self): + """ + WHEN both main KV and indexer layerwise workers complete + THEN _finalize_op SHALL be called exactly once (req 3.2, 4.2). + """ + op = _make_op(TransferType.LAYERWISE) + # Simulate _assign_op_to_worker incrementing pending_count for indexer + op.pending_count += 1 + self.assertEqual(op.pending_count, 2) + + finalize_mock = MagicMock() + finished_ops: List[TransferOp] = [] + + def simulate_done(o, fo, fn): + o.pending_count -= 1 + if o.pending_count == 0: + fn(o, fo) + + # Main KV layerwise worker completes + simulate_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 1) + finalize_mock.assert_not_called() + + # Indexer layerwise worker completes + simulate_done(op, finished_ops, finalize_mock) + self.assertEqual(op.pending_count, 0) + finalize_mock.assert_called_once_with(op, finished_ops) + + +if __name__ == "__main__": + unittest.main() From 9b27128e55438f792a9b658158514ea2fe955b72 Mon Sep 17 00:00:00 2001 From: zittozhang Date: Tue, 7 Apr 2026 10:32:56 +0800 Subject: [PATCH 32/59] fix page align bug --- flexkv/common/config.py | 1 - flexkv/integration/config.py | 18 +++--- flexkv/storage/storage_engine.py | 53 ++++++------------ flexkv/transfer/transfer_engine.py | 55 ++----------------- tests/test_kvmanager.py | 5 +- tests/test_transfer_engine_atomic_eviction.py | 2 +- 6 files changed, 36 insertions(+), 98 deletions(-) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 46b044a1ab..650e7ea347 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -19,7 +19,6 @@ class IndexerCacheConfig: head_size: int = 0 # qk_rope_head_dim for DSA/NSA models num_kv_heads: int = 1 # typically 1 for MLA-style indexer dtype: torch.dtype = torch.uint8 # indexer storage dtype (fp8 quantized) - page_size: int = 1 @dataclass diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 5bbdb04685..c857050b62 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -40,7 +40,7 @@ def __post_init__(self): if self.gpu_register_port == "": self.gpu_register_port = self.server_recv_port + "_gpu_register" - def _detect_indexer_config_from_hf(self, hf_config, source: str = "", page_size: int = 1) -> None: + def _detect_indexer_config_from_hf(self, hf_config, source: str = "") -> None: if hf_config is None: return @@ -49,16 +49,19 @@ def _detect_indexer_config_from_hf(self, hf_config, source: str = "", page_size: if qk_rope_head_dim is None or qk_rope_head_dim <= 0: return + # tokens_per_block is already set to sglang page_size before this + # call, so each FlexKV block = 1 sglang page. The indexer maps + # 1:1 with blocks — no extra page_size grouping is needed. self.cache_config.indexer = IndexerCacheConfig( head_size=qk_rope_head_dim, num_kv_heads=1, dtype=torch.uint8, - page_size=page_size, ) source_label = f" ({source})" if source else "" logger.info( f"Detected sparse attention indexer config{source_label}: " - f"head_size={qk_rope_head_dim}, dtype=uint8, page_size={page_size}") + f"head_size={qk_rope_head_dim}, dtype=uint8, " + f"tokens_per_block={self.cache_config.tokens_per_block}") except Exception as e: logger.debug(f"Could not detect indexer config ({source}): {e}") @@ -122,8 +125,10 @@ def post_init_from_sglang_config( pp_size: pipeline parallel size (default 1, no PP) pp_rank: pipeline parallel rank (default 0) """ - # cache config - self.cache_config.tokens_per_block = 1 + # cache config: use page_size as tokens_per_block so that FlexKV's + # CPU radix tree manages blocks at page granularity, ensuring that + # hash generation, matching, insertion and eviction are all page-aligned. + self.cache_config.tokens_per_block = page_size total_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) self.model_config.num_layers = int(num_local_layers) if num_local_layers > 0 else total_layers @@ -160,12 +165,11 @@ def post_init_from_sglang_config( update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) hf_config = getattr(sglang_config, 'hf_config', None) - self._detect_indexer_config_from_hf(hf_config, source="sglang", page_size=page_size) + self._detect_indexer_config_from_hf(hf_config, source="sglang") if self.cache_config.indexer is not None: logger.info( f"[FlexKV] Complete indexer config (sglang): " - f"page_size={self.cache_config.indexer.page_size}, " f"head_size={self.cache_config.indexer.head_size}, " f"dtype={self.cache_config.indexer.dtype}, " f"num_layers={self.model_config.num_layers}, " diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 640e100220..9d99c05303 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -40,17 +40,14 @@ def __init__(self, dtype=self._model_config.dtype, ) if self._indexer_config is not None: - indexer_page_size = self._indexer_config.page_size - indexer_num_cpu_blocks = ( - self._cache_config.num_cpu_blocks // indexer_page_size - if indexer_page_size > 1 - else self._cache_config.num_cpu_blocks - ) + # Indexer maps 1:1 with main KV blocks (each block = 1 page), + # so indexer num_blocks equals main KV num_blocks and + # tokens_per_block is 1 (one indexer entry per page). indexer_cpu_layout = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, num_layer=self._model_config.num_layers, - num_block=indexer_num_cpu_blocks, - tokens_per_block=self._indexer_config.page_size, + num_block=self._cache_config.num_cpu_blocks, + tokens_per_block=1, num_head=self._indexer_config.num_kv_heads, head_size=self._indexer_config.head_size, is_mla=True @@ -82,17 +79,11 @@ def __init__(self, max_file_size_gb=GLOBAL_CONFIG_FROM_ENV.max_file_size_gb ) if self._indexer_config is not None: - indexer_page_size = self._indexer_config.page_size - indexer_num_ssd_blocks = ( - self._cache_config.num_ssd_blocks // indexer_page_size - if indexer_page_size > 1 - else self._cache_config.num_ssd_blocks - ) indexer_ssd_layout = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.ssd_layout_type, num_layer=self._model_config.num_layers, - num_block=indexer_num_ssd_blocks, - tokens_per_block=self._indexer_config.page_size, + num_block=self._cache_config.num_ssd_blocks, + tokens_per_block=1, num_head=self._indexer_config.num_kv_heads, head_size=self._indexer_config.head_size, is_mla=True @@ -126,17 +117,11 @@ def __init__(self, remote_config_custom = self._cache_config.remote_config_custom ) if self._indexer_config is not None: - indexer_page_size = self._indexer_config.page_size - indexer_num_remote_blocks = ( - self._cache_config.num_remote_blocks // indexer_page_size - if indexer_page_size > 1 - else self._cache_config.num_remote_blocks - ) indexer_remote_layout = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.remote_layout_type, num_layer=self._model_config.num_layers, - num_block=indexer_num_remote_blocks, - tokens_per_block=self._indexer_config.page_size, + num_block=self._cache_config.num_remote_blocks, + tokens_per_block=1, num_head=self._indexer_config.num_kv_heads, head_size=self._indexer_config.head_size, is_mla=True @@ -176,26 +161,20 @@ def register_gpu_blocks(self, raw_data=gpu_blocks ) if indexer_gpu_blocks is not None: - # Log indexer GPU registration parameters - indexer_page_size = self._indexer_config.page_size if self._indexer_config else 1 + # Indexer maps 1:1 with main KV blocks; validate consistency. flexkv_logger.info( f"[StorageEngine] Registering indexer GPU buffer: " f"num_block={indexer_gpu_layout.num_block}, " - f"page_size={indexer_page_size}, " f"head_size={indexer_gpu_layout.head_size}, " f"num_head={indexer_gpu_layout.num_head}, " f"dtype={indexer_dtype}" ) - # Validate indexer num_block vs main KV num_block - if indexer_page_size > 1: - expected_indexer_blocks = gpu_layout.num_block // indexer_page_size - if indexer_gpu_layout.num_block != expected_indexer_blocks: - flexkv_logger.warning( - f"[StorageEngine] Indexer GPU num_block mismatch: " - f"indexer_num_block={indexer_gpu_layout.num_block}, " - f"expected={expected_indexer_blocks} " - f"(main_kv_num_block={gpu_layout.num_block} // page_size={indexer_page_size})" - ) + if indexer_gpu_layout.num_block != gpu_layout.num_block: + flexkv_logger.warning( + f"[StorageEngine] Indexer GPU num_block mismatch: " + f"indexer_num_block={indexer_gpu_layout.num_block}, " + f"expected={gpu_layout.num_block} (1:1 with main KV blocks)" + ) self.allocate( device_type=DeviceType.GPU, layout=indexer_gpu_layout, diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index f2c3f4cc04..a687720296 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -146,10 +146,8 @@ def __init__(self, self._running = False self._has_indexer = False - self._indexer_page_size = 1 - if cache_config.indexer is not None: - self._indexer_page_size = cache_config.indexer.page_size self._indexer_op_to_parent_op: Dict[int, int] = {} + self._indexer_op_map: Dict[int, TransferOp] = {} def _init_workers(self) -> None: if self._running: @@ -725,7 +723,7 @@ def _scheduler_loop(self) -> None: f"_indexer_op_to_parent_op. All indexer ops must be " f"registered with a parent op." ) - indexer_op = self.op_id_to_op.pop(op_id) + indexer_op = self._indexer_op_map.pop(op_id) free_op_from_buffer(indexer_op, self.pin_buffer) parent_op_id = self._indexer_op_to_parent_op.pop(op_id) parent_op = self.op_id_to_op[parent_op_id] @@ -792,38 +790,6 @@ def _finalize_op(self, op: TransferOp, finished_ops: List[TransferOp]) -> None: finished_ops.append(op) del self.op_id_to_op[op.op_id] - def _convert_to_page_level_block_ids(self, block_ids: np.ndarray) -> np.ndarray: - """Convert block-level IDs to page-level IDs by dividing by page_size. - - Input block_ids are guaranteed to be page-aligned by the sglang caller - (e.g. flexkv_radix_cache), so we only do a defensive assert here. - """ - page_size = self._indexer_page_size - if block_ids.size == 0: - return block_ids.copy() - - assert block_ids.size % page_size == 0, ( - f"[TransferEngine] block_ids size {block_ids.size} is not a multiple " - f"of indexer page_size {page_size}. " - f"Caller must page-align block_ids before reaching transfer engine." - ) - - reshaped = block_ids.reshape(-1, page_size) - page_block_ids = reshaped[:, 0] // page_size - - # Validate: all block_ids in each group must map to the same indexer page. - # sglang's PagedTokenToKVPoolAllocator guarantees page-aligned allocation, - # so this should never fail in practice. - last_in_group = reshaped[:, -1] // page_size - assert np.array_equal(page_block_ids, last_in_group), ( - f"[TransferEngine] Indexer page group(s) have block_ids spanning multiple pages " - f"(page_size={page_size}). This indicates slot indices are not page-aligned. " - f"First mismatch: first_block={reshaped[page_block_ids != last_in_group][0, 0]}, " - f"last_block={reshaped[page_block_ids != last_in_group][0, -1]}" - ) - - return page_block_ids.astype(np.int64) - def _assign_op_to_worker(self, op: TransferOp) -> None: self.op_id_to_nvtx_range[op.op_id] = nvtx.start_range(f"schedule {op.transfer_type.name} " f"op_id: {op.op_id}, " @@ -837,17 +803,9 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: raise ValueError(f"Unsupported transfer type: {op.transfer_type}") if self._has_indexer and op.transfer_type in self._indexer_worker_map: - # Convert block_ids to indexer page-level IDs. - # _convert_to_page_level_block_ids handles all page_size values uniformly - src_page_ids = self._convert_to_page_level_block_ids(op.src_block_ids) - dst_page_ids = self._convert_to_page_level_block_ids(op.dst_block_ids) - - # Ensure both sides have the same number of pages after conversion - assert src_page_ids.size == dst_page_ids.size, ( - f"[TransferEngine] src_page_ids size {src_page_ids.size} != " - f"dst_page_ids size {dst_page_ids.size} for op {op.op_id}. " - f"Both sides must have the same number of pages." - ) + # Indexer maps 1:1 with main KV blocks, use block_ids directly. + src_page_ids = op.src_block_ids + dst_page_ids = op.dst_block_ids num_pages = src_page_ids.size if num_pages > 0: @@ -864,13 +822,12 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: ) register_op_to_buffer(indexer_op, self.pin_buffer) self._indexer_op_to_parent_op[indexer_op.op_id] = op.op_id - self.op_id_to_op[indexer_op.op_id] = indexer_op + self._indexer_op_map[indexer_op.op_id] = indexer_op op.pending_count += 1 flexkv_logger.debug( f"[TransferEngine] Created indexer op {indexer_op.op_id} " f"for parent op {op.op_id}: {num_pages} pages, " - f"page_size={self._indexer_page_size}, " f"type={op.transfer_type.name}" ) diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index e0c5f21d54..74f7f1e0be 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -624,7 +624,7 @@ def run_tp_client_with_indexer(dp_client_id, tokens_per_block=indexer_tokens_per_block, num_head=indexer_cfg.num_kv_heads, head_size=indexer_cfg.head_size, - is_mla=indexer_cfg.use_mla, + is_mla=True, ) # Use KVTPClient directly with indexer buffers (shadow transfer mode) @@ -692,7 +692,6 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay cache_config.indexer = IndexerCacheConfig( head_size=64, num_kv_heads=1, - use_mla=True, dtype=torch.uint8, ) @@ -763,7 +762,7 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay tokens_per_block=cache_config.tokens_per_block, num_head=indexer_cfg.num_kv_heads, head_size=indexer_cfg.head_size, - is_mla=indexer_cfg.use_mla, + is_mla=True, ) indexer_kv_verifier = GPUIndexerCacheVerifier( shared_indexer_blocks=all_indexer_blocks, diff --git a/tests/test_transfer_engine_atomic_eviction.py b/tests/test_transfer_engine_atomic_eviction.py index 4310e1d1f2..d08e646081 100644 --- a/tests/test_transfer_engine_atomic_eviction.py +++ b/tests/test_transfer_engine_atomic_eviction.py @@ -245,8 +245,8 @@ def _make_engine_stub_with_indexer(self, enable_layerwise: bool = True): engine._has_indexer = True engine._worker_map = {} engine._indexer_worker_map = {} - engine._indexer_page_size = 1 engine._indexer_op_to_parent_op = {} + engine._indexer_op_map = {} engine.op_id_to_op = {} engine.op_id_to_nvtx_range = {} engine.completed_queue = MagicMock() From c5b292db37913b3d218dfc67fdce0693585a832c Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Wed, 8 Apr 2026 19:27:30 +0800 Subject: [PATCH 33/59] feat: distributed KV cache improvements - RedisMeta refactor, benchmark tooling, node TTL/heartbeat, build fixes, and default rebuild interval change to 100ms --- benchmarks/benchmark_dist_kvcache.py | 637 +++++++++++++++++ .../dist_benchmark/benchmark_dist_direct.py | 660 ++++++++++++++++++ .../dist_benchmark/benchmark_dist_kvcache.py | 637 +++++++++++++++++ .../dist_benchmark/example_dist_config.yml | 34 + .../example_dist_direct_config.yml | 40 ++ benchmarks/dist_benchmark/redis_check.py | 338 +++++++++ .../dist_benchmark/run_dist_benchmark.sh | 405 +++++++++++ .../run_dist_direct_benchmark.sh | 372 ++++++++++ benchmarks/dist_benchmark/utils.py | 131 ++++ benchmarks/example_dist_config.yml | 34 + benchmarks/redis_check.py | 338 +++++++++ benchmarks/run_dist_benchmark.sh | 405 +++++++++++ benchmarks/utils.py | 7 +- flexkv/cache/hie_cache_engine.py | 2 +- flexkv/cache/redis_meta.py | 192 ++++- flexkv/common/config.py | 11 +- flexkv/common/tracer.py | 5 +- flexkv/kvmanager.py | 35 +- flexkv/mooncakeEngineWrapper.py | 14 +- flexkv/server/server.py | 55 +- flexkv/transfer/transfer_engine.py | 3 +- flexkv/transfer/worker.py | 39 +- install.sh | 549 +++++++++++++++ requirements.txt | 7 +- setup.py | 28 +- 25 files changed, 4930 insertions(+), 48 deletions(-) create mode 100644 benchmarks/benchmark_dist_kvcache.py create mode 100644 benchmarks/dist_benchmark/benchmark_dist_direct.py create mode 100644 benchmarks/dist_benchmark/benchmark_dist_kvcache.py create mode 100644 benchmarks/dist_benchmark/example_dist_config.yml create mode 100644 benchmarks/dist_benchmark/example_dist_direct_config.yml create mode 100644 benchmarks/dist_benchmark/redis_check.py create mode 100755 benchmarks/dist_benchmark/run_dist_benchmark.sh create mode 100755 benchmarks/dist_benchmark/run_dist_direct_benchmark.sh create mode 100644 benchmarks/dist_benchmark/utils.py create mode 100644 benchmarks/example_dist_config.yml create mode 100644 benchmarks/redis_check.py create mode 100755 benchmarks/run_dist_benchmark.sh create mode 100755 install.sh diff --git a/benchmarks/benchmark_dist_kvcache.py b/benchmarks/benchmark_dist_kvcache.py new file mode 100644 index 0000000000..26e6c84426 --- /dev/null +++ b/benchmarks/benchmark_dist_kvcache.py @@ -0,0 +1,637 @@ +""" +Benchmark for FlexKV distributed KVCache in server_client_mode. + +This script tests the put/get performance of FlexKV when running in +server_client_mode with distributed KVCache sharing enabled (enable_p2p_cpu). + +Prerequisites: + - A running Redis server (default: 127.0.0.1:6379) + - At least 1 GPU available + - FlexKV built with distributed support (FLEXKV_ENABLE_P2P=1) + +Usage: + # Basic usage with default config + python benchmarks/benchmark_dist_kvcache.py --config benchmarks/example_dist_config.yml + + # Custom parameters + python benchmarks/benchmark_dist_kvcache.py \ + --config benchmarks/example_dist_config.yml \ + --batch-size 4 \ + --sequence-length 2048 \ + --cache-ratio 0.5 \ + --num-users 10 \ + --num-turns 3 + + # Multi-turn conversation benchmark only + python benchmarks/benchmark_dist_kvcache.py \\ + --config benchmarks/example_dist_config.yml \\ + --mode multiturn \\ + --num-users 20 \\ + --num-turns 5 + + # Cross-node benchmark: Node A (PUT only) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_a.yml --seed 42 --mode put-only + + # Cross-node benchmark: Node B (GET only, same seed) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_b.yml --seed 42 --mode get-only +""" +import os +import atexit +import signal +import argparse +import json +import tempfile +import time +from multiprocessing import Process +from dataclasses import dataclass + +import torch +import numpy as np + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ( + ModelConfig, CacheConfig, UserConfig, + update_default_config_from_user_config, parse_path_list, + GLOBAL_CONFIG_FROM_ENV, +) +from flexkv.common.debug import flexkv_logger +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +from utils import generate_random_multiturn + +flexkv_logger.set_level("INFO") + + +def load_dist_config(config_path: str): + """Load config with distributed KVCache support. + + Extends the standard load_config to handle distributed-specific fields: + enable_p2p_cpu, enable_p2p_ssd, enable_3rd_remote, + redis_host, redis_port, local_ip, redis_password, + server_client_mode, etc. + """ + import yaml + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(f"Loaded config: {config}") + + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + + # Model config + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + # Cache size config + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + + # Distributed KVCache config + if "enable_p2p_cpu" in config: + user_config.enable_p2p_cpu = config["enable_p2p_cpu"] + if "enable_p2p_ssd" in config: + user_config.enable_p2p_ssd = config["enable_p2p_ssd"] + if "enable_3rd_remote" in config: + user_config.enable_3rd_remote = config["enable_3rd_remote"] + + # Redis config + if "redis_host" in config: + user_config.redis_host = config["redis_host"] + if "redis_port" in config: + user_config.redis_port = config["redis_port"] + if "local_ip" in config: + user_config.local_ip = config["local_ip"] + if "redis_password" in config: + user_config.redis_password = config["redis_password"] + + # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled + if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): + if "MOONCAKE_CONFIG_PATH" not in os.environ: + mooncake_config = { + "engine_ip": config.get("mooncake_engine_ip", config.get("local_ip", "127.0.0.1")), + "engine_port": config.get("mooncake_engine_port", 5555), + "metadata_backend": config.get("mooncake_metadata_backend", "redis"), + "metadata_server": config.get("mooncake_metadata_server", + f"redis://{config.get('redis_host', '127.0.0.1')}:{config.get('redis_port', 6379)}"), + "metadata_server_auth": config.get("mooncake_metadata_server_auth", + config.get("redis_password", "")), + "protocol": config.get("mooncake_protocol", "tcp"), + "device_name": config.get("mooncake_device_name", ""), + } + # Write to a temp file that persists until process exits + mooncake_config_fd, mooncake_config_path = tempfile.mkstemp( + suffix=".json", prefix="mooncake_config_" + ) + with os.fdopen(mooncake_config_fd, "w") as f: + json.dump(mooncake_config, f, indent=2) + os.environ["MOONCAKE_CONFIG_PATH"] = mooncake_config_path + print(f"[INFO] Auto-generated mooncake config at: {mooncake_config_path}") + print(f"[INFO] Mooncake config: {json.dumps(mooncake_config, indent=2)}") + else: + mooncake_config_path = os.environ['MOONCAKE_CONFIG_PATH'] + print(f"[INFO] Using existing MOONCAKE_CONFIG_PATH: {mooncake_config_path}") + + # Store mooncake_config_path in cache_config so it survives spawn subprocesses via pickle + cache_config.mooncake_config_path = mooncake_config_path + + update_default_config_from_user_config(model_config, cache_config, user_config) + + # Handle server_client_mode from config + if config.get("server_client_mode", False): + os.environ["FLEXKV_SERVER_CLIENT_MODE"] = "1" + GLOBAL_CONFIG_FROM_ENV.server_client_mode = True + + return model_config, cache_config + + +@dataclass +class BenchmarkConfig: + # Single batch benchmark params + batch_size: int = 1 + sequence_length: int = 1024 + cache_ratio: float = 1.0 + clear_cpu_cache: bool = False + + # Multi-turn benchmark params + num_users: int = 10 + num_turns: int = 3 + system_prompt_length: int = 100 + input_length: int = 512 + output_length: int = 64 + + # General + mode: str = "all" # "single", "multiturn", "all", "put-only", "get-only" + seed: int = None # Random seed for deterministic token generation (cross-node) + + +def run_tp_client(dp_client_id, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks): + """Run tp_client process to register GPU blocks""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(gpu_register_port, dp_client_id, device_id) + + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Keep the process running + while True: + time.sleep(1) + + +def shutdown_tp_clients(tp_client_processes): + """Terminate all tp_client processes""" + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + + +def benchmark_single_batch(kvmanager, model_config, cache_config, bench_config): + """Benchmark single batch put/get with distributed KVCache""" + print("\n" + "=" * 60) + print(" Single Batch Benchmark (Distributed KVCache)") + print("=" * 60) + + sequence_length = bench_config.sequence_length + batch_size = bench_config.batch_size + cache_length = int(sequence_length * bench_config.cache_ratio) + + print(f" batch_size={batch_size}, sequence_length={sequence_length}, " + f"cache_ratio={bench_config.cache_ratio}, cache_length={cache_length}") + if bench_config.seed is not None: + print(f" seed={bench_config.seed}") + + # Generate random sequences (use seed for deterministic cross-node benchmarks) + if bench_config.seed is not None: + torch.manual_seed(bench_config.seed) + batch_sequence_tensor = [] + batch_slot_mapping = [] + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length,), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i + 1) * sequence_length, dtype=torch.int64)) + + results = {} + skip_put = (bench_config.mode == "get-only") + skip_get = (bench_config.mode == "put-only") + + # In get-only mode, wait for remote index to be refreshed from Redis + if skip_put: + rebuild_interval_ms = int(os.environ.get("FLEXKV_REBUILD_INTERVAL_MS", "100")) + # Wait at least 3x rebuild_interval to ensure at least one full refresh cycle + wait_time_s = max(rebuild_interval_ms * 3 / 1000.0, 0.5) + print(f" Waiting {wait_time_s:.2f}s for remote index refresh " + f"(FLEXKV_REBUILD_INTERVAL_MS={rebuild_interval_ms})...") + time.sleep(wait_time_s) + + # ---- Benchmark PUT ---- + if not skip_put: + print("\n--- PUT Phase ---") + start_time = time.time() + batch_put_ids = [] + if bench_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async( + batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None, + ) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put if elapsed_time_put > 0 else 0 + print(f" PUT: {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put * 1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + results.update({ + "put_tokens": put_tokens, + "put_time_ms": elapsed_time_put * 1000, + "put_bandwidth_GBs": transfer_bandwidth_put, + }) + else: + print("\n--- PUT Phase SKIPPED (get-only mode) ---") + + if bench_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + # ---- Benchmark GET ---- + if not skip_get: + print("\n--- GET Phase ---") + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get if elapsed_time_get > 0 else 0 + print(f" GET: {cached_tokens}/{all_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time * 1000:.2f}ms, " + f"e2e time: {elapsed_time_get * 1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + results.update({ + "get_cached_tokens": cached_tokens, + "get_total_tokens": all_tokens, + "get_cache_ratio": cached_tokens / all_tokens if all_tokens > 0 else 0, + "get_match_time_ms": get_match_time * 1000, + "get_e2e_time_ms": elapsed_time_get * 1000, + "get_bandwidth_GBs": transfer_bandwidth_get, + }) + else: + print("\n--- GET Phase SKIPPED (put-only mode) ---") + + return results + + +def benchmark_multiturn(kvmanager, model_config, cache_config, bench_config): + """Benchmark multi-turn conversation with distributed KVCache""" + print("\n" + "=" * 60) + print(" Multi-Turn Conversation Benchmark (Distributed KVCache)") + print("=" * 60) + print(f" num_users={bench_config.num_users}, num_turns={bench_config.num_turns}, " + f"system_prompt_length={bench_config.system_prompt_length}, " + f"input_length={bench_config.input_length}, output_length={bench_config.output_length}") + + # Generate multi-turn requests + reqs = generate_random_multiturn( + num_user_requests=bench_config.num_users, + num_turns=bench_config.num_turns, + system_prompt_length=bench_config.system_prompt_length, + input_length=bench_config.input_length, + output_length=bench_config.output_length, + seed=bench_config.seed, + ) + + total_get_requests = 0 + total_put_requests = 0 + cache_hit_ratios = [] + total_put_time = 0 + total_get_time = 0 + total_put_tokens = 0 + total_get_cached_tokens = 0 + total_get_all_tokens = 0 + + request_id = 0 + for req in reqs: + fake_slot_mapping = torch.arange(req.token_mask.sum(), dtype=torch.int64) + + if req.request_type == "get": + total_get_requests += 1 + total_get_all_tokens += req.token_mask.sum().item() + + start_time = time.time() + task_id, _ = kvmanager.get_match( + req.token_ids, + token_mask=torch.ones_like(torch.from_numpy(req.token_ids) if isinstance(req.token_ids, np.ndarray) else req.token_ids), + ) + kvmanager.launch([task_id], [fake_slot_mapping.numpy()]) + result = kvmanager.wait([task_id]) + elapsed = time.time() - start_time + total_get_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + cached = response.return_mask.sum().item() + total_get_cached_tokens += cached + ratio = cached / req.token_mask.sum().item() + cache_hit_ratios.append(ratio) + else: + cache_hit_ratios.append(0.0) + + elif req.request_type == "put": + total_put_requests += 1 + + start_time = time.time() + task_id = kvmanager.put_async( + req.token_ids, + fake_slot_mapping.numpy(), + token_mask=None, + ) + result = kvmanager.wait([task_id], completely=True) + elapsed = time.time() - start_time + total_put_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + total_put_tokens += response.return_mask.sum().item() + + request_id += 1 + + # Print results + print(f"\n--- Results ---") + print(f" Total requests: {len(reqs)} (GET: {total_get_requests}, PUT: {total_put_requests})") + print(f" PUT: {total_put_tokens} tokens, total time: {total_put_time * 1000:.2f}ms, " + f"avg time: {total_put_time * 1000 / max(total_put_requests, 1):.2f}ms/req") + print(f" GET: {total_get_cached_tokens}/{total_get_all_tokens} tokens cached, " + f"total time: {total_get_time * 1000:.2f}ms, " + f"avg time: {total_get_time * 1000 / max(total_get_requests, 1):.2f}ms/req") + + if cache_hit_ratios: + sorted_ratios = sorted(cache_hit_ratios) + avg_ratio = sum(sorted_ratios) / len(sorted_ratios) + print(f" Cache hit ratio: avg={avg_ratio * 100:.2f}%, " + f"min={sorted_ratios[0] * 100:.2f}%, " + f"median={sorted_ratios[len(sorted_ratios) // 2] * 100:.2f}%, " + f"max={sorted_ratios[-1] * 100:.2f}%") + + return { + "total_requests": len(reqs), + "get_requests": total_get_requests, + "put_requests": total_put_requests, + "put_tokens": total_put_tokens, + "put_total_time_ms": total_put_time * 1000, + "get_cached_tokens": total_get_cached_tokens, + "get_total_tokens": total_get_all_tokens, + "get_total_time_ms": total_get_time * 1000, + "avg_cache_hit_ratio": sum(cache_hit_ratios) / len(cache_hit_ratios) if cache_hit_ratios else 0, + } + + +def main(args): + # Set FLEXKV_REBUILD_INTERVAL_MS for faster cross-node index sync + # NOTE: Must set env var AND update GLOBAL_CONFIG_FROM_ENV because + # GLOBAL_CONFIG_FROM_ENV is evaluated at module import time (before main runs). + # The env var alone is not enough since the Namespace is already frozen. + if args.rebuild_interval_ms is not None: + os.environ["FLEXKV_REBUILD_INTERVAL_MS"] = str(args.rebuild_interval_ms) + GLOBAL_CONFIG_FROM_ENV.rebuild_interval_ms = args.rebuild_interval_ms + print(f"[INFO] Set FLEXKV_REBUILD_INTERVAL_MS={args.rebuild_interval_ms}") + + # Load config + model_config, cache_config = load_dist_config(args.config) + + bench_config = BenchmarkConfig( + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache, + num_users=args.num_users, + num_turns=args.num_turns, + system_prompt_length=args.system_prompt_length, + input_length=args.input_length, + output_length=args.output_length, + mode=args.mode, + seed=args.seed, + ) + + # Pad sequence length to be divisible by tokens_per_block + bench_config.sequence_length = ( + ((bench_config.sequence_length - 1) // cache_config.tokens_per_block + 1) + * cache_config.tokens_per_block + ) + + num_gpu_blocks = bench_config.sequence_length * bench_config.batch_size // cache_config.tokens_per_block + # Ensure enough GPU blocks for multi-turn mode too + if bench_config.mode in ("multiturn", "all", "put-only", "get-only"): + max_tokens_per_user = ( + bench_config.system_prompt_length + + bench_config.num_turns * (bench_config.input_length + bench_config.output_length) + ) + multiturn_blocks = max_tokens_per_user * bench_config.num_users // cache_config.tokens_per_block + num_gpu_blocks = max(num_gpu_blocks, multiturn_blocks) + # Add some extra blocks for safety + num_gpu_blocks = int(num_gpu_blocks * 1.5) + 64 + + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError( + f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} > " + f"available GPUs {torch.cuda.device_count()}" + ) + + print("=" * 60) + print(" FlexKV Distributed KVCache Benchmark (server_client_mode)") + print("=" * 60) + print(f" model_config: {model_config}") + print(f" cache_config: {cache_config}") + print(f" enable_kv_sharing: {cache_config.enable_kv_sharing}") + print(f" enable_p2p_cpu: {cache_config.enable_p2p_cpu}") + print(f" redis: {cache_config.redis_host}:{cache_config.redis_port}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" bench_config: {bench_config}") + + # Create KVManager (this will start KVServer in server_client_mode) + kvmanager = KVManager(model_config, cache_config) + kvmanager.start() + + # Start tp_client processes to register GPU blocks + tp_client_processes = [] + + # Register cleanup handler to ensure processes are terminated on exit + def _cleanup(): + shutdown_tp_clients(tp_client_processes) + try: + kvmanager.shutdown() + except Exception: + pass + atexit.register(_cleanup) + + def _signal_handler(signum, frame): + print(f"\nReceived signal {signum}, shutting down...") + _cleanup() + # Re-raise to allow default handler + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + for tp_rank in range(model_config.tp_size): + tp_process = Process( + target=run_tp_client, + args=(0, tp_rank, kvmanager.gpu_register_port, + model_config, cache_config, num_gpu_blocks), + daemon=True, + ) + tp_process.start() + tp_client_processes.append(tp_process) + + # Wait for system to be ready + print("\nWaiting for FlexKV to be ready...") + wait_start = time.time() + while not kvmanager.is_ready(): + time.sleep(1) + elapsed = time.time() - wait_start + if elapsed > 120: + print("ERROR: Timeout waiting for FlexKV to be ready (120s)") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + return + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + print(f" Still waiting... ({int(elapsed)}s)") + print(f"FlexKV is ready! (took {time.time() - wait_start:.1f}s)") + + try: + results = {} + + if bench_config.mode in ("single", "all", "put-only", "get-only"): + results["single_batch"] = benchmark_single_batch( + kvmanager, model_config, cache_config, bench_config + ) + + if bench_config.mode in ("multiturn", "all"): + results["multiturn"] = benchmark_multiturn( + kvmanager, model_config, cache_config, bench_config + ) + + # Print summary + print("\n" + "=" * 60) + print(" Benchmark Summary") + print("=" * 60) + for name, result in results.items(): + print(f"\n [{name}]") + for k, v in result.items(): + if isinstance(v, float): + print(f" {k}: {v:.4f}") + else: + print(f" {k}: {v}") + + # In put-only mode, keep the process alive so other nodes can GET the data + if bench_config.mode == "put-only": + print("\n" + "-" * 60) + print("Data published to Redis. Press Enter to shutdown " + "(keep running for other nodes to GET)...") + print("-" * 60) + try: + input() + except EOFError: + # Handle non-interactive environments + print("Non-interactive mode detected. Sleeping indefinitely (Ctrl+C to stop)...") + while True: + time.sleep(1) + + finally: + print("\nShutting down...") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + # Unregister atexit handler since we've already cleaned up + atexit.unregister(_cleanup) + print("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark FlexKV distributed KVCache in server_client_mode" + ) + parser.add_argument("--config", type=str, default="benchmarks/example_dist_config.yml", + help="Path to config YAML file") + parser.add_argument("--mode", type=str, default="all", + choices=["single", "multiturn", "all", "put-only", "get-only"], + help="Benchmark mode: single, multiturn, all, put-only, get-only") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for deterministic token generation (for cross-node benchmarks)") + + # Single batch params + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for single batch benchmark") + parser.add_argument("--sequence-length", type=int, default=1024, help="Sequence length per request") + parser.add_argument("--cache-ratio", type=float, default=1.0, help="Ratio of tokens to cache in PUT phase") + parser.add_argument("--clear-cpu-cache", action="store_true", help="Clear CPU cache between PUT and GET") + + # Multi-turn params + parser.add_argument("--num-users", type=int, default=10, help="Number of simulated users") + parser.add_argument("--num-turns", type=int, default=3, help="Number of conversation turns per user") + parser.add_argument("--system-prompt-length", type=int, default=100, help="System prompt length in tokens") + parser.add_argument("--input-length", type=int, default=512, help="Input length per turn in tokens") + parser.add_argument("--output-length", type=int, default=64, help="Output length per turn in tokens") + + # Cross-node sync params + parser.add_argument("--rebuild-interval-ms", type=int, default=None, + help="Override FLEXKV_REBUILD_INTERVAL_MS (default: use env or 100). " + "Recommended: 20 for cross-node benchmarks") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/dist_benchmark/benchmark_dist_direct.py b/benchmarks/dist_benchmark/benchmark_dist_direct.py new file mode 100644 index 0000000000..7e22e51e50 --- /dev/null +++ b/benchmarks/dist_benchmark/benchmark_dist_direct.py @@ -0,0 +1,660 @@ +""" +Benchmark for FlexKV distributed KVCache in direct mode (non-server_client_mode). + +In direct mode, KVManager creates KVTaskEngine directly in the main process +without going through KVServer/KVDPClient IPC. This is simpler and has lower +overhead, suitable for single-instance single-dp benchmarks. + +The key difference from server_client_mode: + - No KVServer subprocess is spawned + - KVTaskEngine runs directly in the main process + - KVTPClient still registers GPU blocks via ZMQ to KVTaskEngine + - RedisMeta is created directly in KVManager (not inside KVServer) + +Prerequisites: + - A running Redis server (default: 127.0.0.1:6379) + - At least 1 GPU available + - FlexKV built with distributed support (FLEXKV_ENABLE_P2P=1) + +Usage: + # Basic usage with default config + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config benchmarks/dist_benchmark/example_dist_direct_config.yml + + # Custom parameters + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config benchmarks/dist_benchmark/example_dist_direct_config.yml \\ + --batch-size 4 \\ + --sequence-length 2048 \\ + --cache-ratio 0.5 \\ + --num-users 10 \\ + --num-turns 3 + + # Multi-turn conversation benchmark only + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config benchmarks/dist_benchmark/example_dist_direct_config.yml \\ + --mode multiturn \\ + --num-users 20 \\ + --num-turns 5 + + # Cross-node benchmark: Node A (PUT only) + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config config_a.yml --seed 42 --mode put-only + + # Cross-node benchmark: Node B (GET only, same seed) + python benchmarks/dist_benchmark/benchmark_dist_direct.py \\ + --config config_b.yml --seed 42 --mode get-only +""" +import os +import sys +import atexit +import signal +import argparse +import json +import tempfile +import time +from multiprocessing import Process +from dataclasses import dataclass + +import torch +import numpy as np + +# Add parent directory to path so we can import utils +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ( + ModelConfig, CacheConfig, UserConfig, + update_default_config_from_user_config, parse_path_list, + GLOBAL_CONFIG_FROM_ENV, +) +from flexkv.common.debug import flexkv_logger +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +from utils import generate_random_multiturn + +flexkv_logger.set_level("INFO") + + +def load_dist_direct_config(config_path: str): + """Load config for direct mode (non-server_client_mode) distributed KVCache. + + This is similar to load_dist_config in benchmark_dist_kvcache.py, but + ensures server_client_mode is NOT set, so KVManager uses KVTaskEngine + directly instead of going through KVServer IPC. + """ + import yaml + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(f"Loaded config: {config}") + + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + + # Model config + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + # Cache size config + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + + # Distributed KVCache config + if "enable_p2p_cpu" in config: + user_config.enable_p2p_cpu = config["enable_p2p_cpu"] + if "enable_p2p_ssd" in config: + user_config.enable_p2p_ssd = config["enable_p2p_ssd"] + if "enable_3rd_remote" in config: + user_config.enable_3rd_remote = config["enable_3rd_remote"] + + # Redis config + if "redis_host" in config: + user_config.redis_host = config["redis_host"] + if "redis_port" in config: + user_config.redis_port = config["redis_port"] + if "local_ip" in config: + user_config.local_ip = config["local_ip"] + if "redis_password" in config: + user_config.redis_password = config["redis_password"] + + # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled + if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): + if "MOONCAKE_CONFIG_PATH" not in os.environ: + mooncake_config = { + "engine_ip": config.get("mooncake_engine_ip", config.get("local_ip", "127.0.0.1")), + "engine_port": config.get("mooncake_engine_port", 5555), + "metadata_backend": config.get("mooncake_metadata_backend", "redis"), + "metadata_server": config.get("mooncake_metadata_server", + f"redis://{config.get('redis_host', '127.0.0.1')}:{config.get('redis_port', 6379)}"), + "metadata_server_auth": config.get("mooncake_metadata_server_auth", + config.get("redis_password", "")), + "protocol": config.get("mooncake_protocol", "tcp"), + "device_name": config.get("mooncake_device_name", ""), + } + # Write to a temp file that persists until process exits + mooncake_config_fd, mooncake_config_path = tempfile.mkstemp( + suffix=".json", prefix="mooncake_config_" + ) + with os.fdopen(mooncake_config_fd, "w") as f: + json.dump(mooncake_config, f, indent=2) + os.environ["MOONCAKE_CONFIG_PATH"] = mooncake_config_path + print(f"[INFO] Auto-generated mooncake config at: {mooncake_config_path}") + print(f"[INFO] Mooncake config: {json.dumps(mooncake_config, indent=2)}") + else: + mooncake_config_path = os.environ['MOONCAKE_CONFIG_PATH'] + print(f"[INFO] Using existing MOONCAKE_CONFIG_PATH: {mooncake_config_path}") + + # Store mooncake_config_path in cache_config so it survives spawn subprocesses via pickle + cache_config.mooncake_config_path = mooncake_config_path + + update_default_config_from_user_config(model_config, cache_config, user_config) + + # IMPORTANT: Ensure server_client_mode is NOT set for direct mode + # Even if config says server_client_mode: true, we override it here + if config.get("server_client_mode", False): + print("[WARN] server_client_mode is set in config but will be IGNORED in direct mode benchmark.") + os.environ.pop("FLEXKV_SERVER_CLIENT_MODE", None) + GLOBAL_CONFIG_FROM_ENV.server_client_mode = False + + return model_config, cache_config + + +@dataclass +class BenchmarkConfig: + # Single batch benchmark params + batch_size: int = 1 + sequence_length: int = 1024 + cache_ratio: float = 1.0 + clear_cpu_cache: bool = False + + # Multi-turn benchmark params + num_users: int = 10 + num_turns: int = 3 + system_prompt_length: int = 100 + input_length: int = 512 + output_length: int = 64 + + # General + mode: str = "all" # "single", "multiturn", "all", "put-only", "get-only" + seed: int = None # Random seed for deterministic token generation (cross-node) + + +def run_tp_client(dp_client_id, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks): + """Run tp_client process to register GPU blocks. + + In direct mode, KVTPClient still communicates with KVTaskEngine via ZMQ + to register GPU memory blocks. The difference is that KVTaskEngine runs + in the main process (not in a KVServer subprocess). + """ + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(gpu_register_port, dp_client_id, device_id) + + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Keep the process running + while True: + time.sleep(1) + + +def shutdown_tp_clients(tp_client_processes): + """Terminate all tp_client processes""" + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + + +def benchmark_single_batch(kvmanager, model_config, cache_config, bench_config): + """Benchmark single batch put/get with distributed KVCache (direct mode)""" + print("\n" + "=" * 60) + print(" Single Batch Benchmark (Distributed KVCache - Direct Mode)") + print("=" * 60) + + sequence_length = bench_config.sequence_length + batch_size = bench_config.batch_size + cache_length = int(sequence_length * bench_config.cache_ratio) + + print(f" batch_size={batch_size}, sequence_length={sequence_length}, " + f"cache_ratio={bench_config.cache_ratio}, cache_length={cache_length}") + if bench_config.seed is not None: + print(f" seed={bench_config.seed}") + + # Generate random sequences (use seed for deterministic cross-node benchmarks) + if bench_config.seed is not None: + torch.manual_seed(bench_config.seed) + batch_sequence_tensor = [] + batch_slot_mapping = [] + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length,), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i + 1) * sequence_length, dtype=torch.int64)) + + results = {} + skip_put = (bench_config.mode == "get-only") + skip_get = (bench_config.mode == "put-only") + + # In get-only mode, wait for remote index to be refreshed from Redis + if skip_put: + rebuild_interval_ms = int(os.environ.get("FLEXKV_REBUILD_INTERVAL_MS", "100")) + wait_time_s = max(rebuild_interval_ms * 3 / 1000.0, 0.5) + print(f" Waiting {wait_time_s:.2f}s for remote index refresh " + f"(FLEXKV_REBUILD_INTERVAL_MS={rebuild_interval_ms})...") + time.sleep(wait_time_s) + + # ---- Benchmark PUT ---- + if not skip_put: + print("\n--- PUT Phase ---") + start_time = time.time() + batch_put_ids = [] + if bench_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async( + batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None, + ) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put if elapsed_time_put > 0 else 0 + print(f" PUT: {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put * 1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + results.update({ + "put_tokens": put_tokens, + "put_time_ms": elapsed_time_put * 1000, + "put_bandwidth_GBs": transfer_bandwidth_put, + }) + else: + print("\n--- PUT Phase SKIPPED (get-only mode) ---") + + if bench_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + # ---- Benchmark GET ---- + if not skip_get: + print("\n--- GET Phase ---") + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get if elapsed_time_get > 0 else 0 + print(f" GET: {cached_tokens}/{all_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time * 1000:.2f}ms, " + f"e2e time: {elapsed_time_get * 1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + results.update({ + "get_cached_tokens": cached_tokens, + "get_total_tokens": all_tokens, + "get_cache_ratio": cached_tokens / all_tokens if all_tokens > 0 else 0, + "get_match_time_ms": get_match_time * 1000, + "get_e2e_time_ms": elapsed_time_get * 1000, + "get_bandwidth_GBs": transfer_bandwidth_get, + }) + else: + print("\n--- GET Phase SKIPPED (put-only mode) ---") + + return results + + +def benchmark_multiturn(kvmanager, model_config, cache_config, bench_config): + """Benchmark multi-turn conversation with distributed KVCache (direct mode)""" + print("\n" + "=" * 60) + print(" Multi-Turn Conversation Benchmark (Distributed KVCache - Direct Mode)") + print("=" * 60) + print(f" num_users={bench_config.num_users}, num_turns={bench_config.num_turns}, " + f"system_prompt_length={bench_config.system_prompt_length}, " + f"input_length={bench_config.input_length}, output_length={bench_config.output_length}") + + # Generate multi-turn requests + reqs = generate_random_multiturn( + num_user_requests=bench_config.num_users, + num_turns=bench_config.num_turns, + system_prompt_length=bench_config.system_prompt_length, + input_length=bench_config.input_length, + output_length=bench_config.output_length, + seed=bench_config.seed, + ) + + total_get_requests = 0 + total_put_requests = 0 + cache_hit_ratios = [] + total_put_time = 0 + total_get_time = 0 + total_put_tokens = 0 + total_get_cached_tokens = 0 + total_get_all_tokens = 0 + + request_id = 0 + for req in reqs: + fake_slot_mapping = torch.arange(req.token_mask.sum(), dtype=torch.int64) + + if req.request_type == "get": + total_get_requests += 1 + total_get_all_tokens += req.token_mask.sum().item() + + start_time = time.time() + task_id, _ = kvmanager.get_match( + req.token_ids, + token_mask=torch.ones_like(torch.from_numpy(req.token_ids) if isinstance(req.token_ids, np.ndarray) else req.token_ids), + ) + kvmanager.launch([task_id], [fake_slot_mapping.numpy()]) + result = kvmanager.wait([task_id]) + elapsed = time.time() - start_time + total_get_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + cached = response.return_mask.sum().item() + total_get_cached_tokens += cached + ratio = cached / req.token_mask.sum().item() + cache_hit_ratios.append(ratio) + else: + cache_hit_ratios.append(0.0) + + elif req.request_type == "put": + total_put_requests += 1 + + start_time = time.time() + task_id = kvmanager.put_async( + req.token_ids, + fake_slot_mapping.numpy(), + token_mask=None, + ) + result = kvmanager.wait([task_id], completely=True) + elapsed = time.time() - start_time + total_put_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + total_put_tokens += response.return_mask.sum().item() + + request_id += 1 + + # Print results + print(f"\n--- Results ---") + print(f" Total requests: {len(reqs)} (GET: {total_get_requests}, PUT: {total_put_requests})") + print(f" PUT: {total_put_tokens} tokens, total time: {total_put_time * 1000:.2f}ms, " + f"avg time: {total_put_time * 1000 / max(total_put_requests, 1):.2f}ms/req") + print(f" GET: {total_get_cached_tokens}/{total_get_all_tokens} tokens cached, " + f"total time: {total_get_time * 1000:.2f}ms, " + f"avg time: {total_get_time * 1000 / max(total_get_requests, 1):.2f}ms/req") + + if cache_hit_ratios: + sorted_ratios = sorted(cache_hit_ratios) + avg_ratio = sum(sorted_ratios) / len(sorted_ratios) + print(f" Cache hit ratio: avg={avg_ratio * 100:.2f}%, " + f"min={sorted_ratios[0] * 100:.2f}%, " + f"median={sorted_ratios[len(sorted_ratios) // 2] * 100:.2f}%, " + f"max={sorted_ratios[-1] * 100:.2f}%") + + return { + "total_requests": len(reqs), + "get_requests": total_get_requests, + "put_requests": total_put_requests, + "put_tokens": total_put_tokens, + "put_total_time_ms": total_put_time * 1000, + "get_cached_tokens": total_get_cached_tokens, + "get_total_tokens": total_get_all_tokens, + "get_total_time_ms": total_get_time * 1000, + "avg_cache_hit_ratio": sum(cache_hit_ratios) / len(cache_hit_ratios) if cache_hit_ratios else 0, + } + + +def main(args): + # Set FLEXKV_REBUILD_INTERVAL_MS for faster cross-node index sync + if args.rebuild_interval_ms is not None: + os.environ["FLEXKV_REBUILD_INTERVAL_MS"] = str(args.rebuild_interval_ms) + GLOBAL_CONFIG_FROM_ENV.rebuild_interval_ms = args.rebuild_interval_ms + print(f"[INFO] Set FLEXKV_REBUILD_INTERVAL_MS={args.rebuild_interval_ms}") + + # Load config (ensures server_client_mode is OFF) + model_config, cache_config = load_dist_direct_config(args.config) + + bench_config = BenchmarkConfig( + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache, + num_users=args.num_users, + num_turns=args.num_turns, + system_prompt_length=args.system_prompt_length, + input_length=args.input_length, + output_length=args.output_length, + mode=args.mode, + seed=args.seed, + ) + + # Pad sequence length to be divisible by tokens_per_block + bench_config.sequence_length = ( + ((bench_config.sequence_length - 1) // cache_config.tokens_per_block + 1) + * cache_config.tokens_per_block + ) + + num_gpu_blocks = bench_config.sequence_length * bench_config.batch_size // cache_config.tokens_per_block + # Ensure enough GPU blocks for multi-turn mode too + if bench_config.mode in ("multiturn", "all", "put-only", "get-only"): + max_tokens_per_user = ( + bench_config.system_prompt_length + + bench_config.num_turns * (bench_config.input_length + bench_config.output_length) + ) + multiturn_blocks = max_tokens_per_user * bench_config.num_users // cache_config.tokens_per_block + num_gpu_blocks = max(num_gpu_blocks, multiturn_blocks) + # Add some extra blocks for safety + num_gpu_blocks = int(num_gpu_blocks * 1.5) + 64 + + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError( + f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} > " + f"available GPUs {torch.cuda.device_count()}" + ) + + print("=" * 60) + print(" FlexKV Distributed KVCache Benchmark (Direct Mode)") + print("=" * 60) + print(f" model_config: {model_config}") + print(f" cache_config: {cache_config}") + print(f" enable_kv_sharing: {cache_config.enable_kv_sharing}") + print(f" enable_p2p_cpu: {cache_config.enable_p2p_cpu}") + print(f" redis: {cache_config.redis_host}:{cache_config.redis_port}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" bench_config: {bench_config}") + print(f" server_client_mode: False (direct mode)") + + # Create KVManager in direct mode + # In direct mode, KVManager creates KVTaskEngine directly (no KVServer subprocess) + # RedisMeta is also created directly in KVManager + kvmanager = KVManager(model_config, cache_config) + kvmanager.start() + + # Verify we are indeed in direct mode + assert not kvmanager.server_client_mode, \ + "Expected direct mode (server_client_mode=False), but got server_client_mode=True. " \ + "Check your config: dp_size must be 1, instance_num must be 1, and " \ + "FLEXKV_SERVER_CLIENT_MODE env var must not be set." + + # Start tp_client processes to register GPU blocks + # Even in direct mode, GPU blocks are registered via KVTPClient -> KVTaskEngine ZMQ + tp_client_processes = [] + + # Register cleanup handler to ensure processes are terminated on exit + def _cleanup(): + shutdown_tp_clients(tp_client_processes) + try: + kvmanager.shutdown() + except Exception: + pass + atexit.register(_cleanup) + + def _signal_handler(signum, frame): + print(f"\nReceived signal {signum}, shutting down...") + _cleanup() + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + for tp_rank in range(model_config.tp_size): + tp_process = Process( + target=run_tp_client, + args=(0, tp_rank, kvmanager.gpu_register_port, + model_config, cache_config, num_gpu_blocks), + daemon=True, + ) + tp_process.start() + tp_client_processes.append(tp_process) + + # Wait for system to be ready + print("\nWaiting for FlexKV to be ready (direct mode)...") + wait_start = time.time() + while not kvmanager.is_ready(): + time.sleep(1) + elapsed = time.time() - wait_start + if elapsed > 120: + print("ERROR: Timeout waiting for FlexKV to be ready (120s)") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + return + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + print(f" Still waiting... ({int(elapsed)}s)") + print(f"FlexKV is ready! (took {time.time() - wait_start:.1f}s)") + + try: + results = {} + + if bench_config.mode in ("single", "all", "put-only", "get-only"): + results["single_batch"] = benchmark_single_batch( + kvmanager, model_config, cache_config, bench_config + ) + + if bench_config.mode in ("multiturn", "all"): + results["multiturn"] = benchmark_multiturn( + kvmanager, model_config, cache_config, bench_config + ) + + # Print summary + print("\n" + "=" * 60) + print(" Benchmark Summary (Direct Mode)") + print("=" * 60) + for name, result in results.items(): + print(f"\n [{name}]") + for k, v in result.items(): + if isinstance(v, float): + print(f" {k}: {v:.4f}") + else: + print(f" {k}: {v}") + + # In put-only mode, keep the process alive so other nodes can GET the data + if bench_config.mode == "put-only": + print("\n" + "-" * 60) + print("Data published to Redis. Press Enter to shutdown " + "(keep running for other nodes to GET)...") + print("-" * 60) + try: + input() + except EOFError: + print("Non-interactive mode detected. Sleeping indefinitely (Ctrl+C to stop)...") + while True: + time.sleep(1) + + finally: + print("\nShutting down...") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + atexit.unregister(_cleanup) + print("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark FlexKV distributed KVCache in direct mode (non-server_client_mode)" + ) + parser.add_argument("--config", type=str, + default="benchmarks/dist_benchmark/example_dist_direct_config.yml", + help="Path to config YAML file") + parser.add_argument("--mode", type=str, default="all", + choices=["single", "multiturn", "all", "put-only", "get-only"], + help="Benchmark mode: single, multiturn, all, put-only, get-only") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for deterministic token generation (for cross-node benchmarks)") + + # Single batch params + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for single batch benchmark") + parser.add_argument("--sequence-length", type=int, default=1024, help="Sequence length per request") + parser.add_argument("--cache-ratio", type=float, default=1.0, help="Ratio of tokens to cache in PUT phase") + parser.add_argument("--clear-cpu-cache", action="store_true", help="Clear CPU cache between PUT and GET") + + # Multi-turn params + parser.add_argument("--num-users", type=int, default=10, help="Number of simulated users") + parser.add_argument("--num-turns", type=int, default=3, help="Number of conversation turns per user") + parser.add_argument("--system-prompt-length", type=int, default=100, help="System prompt length in tokens") + parser.add_argument("--input-length", type=int, default=512, help="Input length per turn in tokens") + parser.add_argument("--output-length", type=int, default=64, help="Output length per turn in tokens") + + # Cross-node sync params + parser.add_argument("--rebuild-interval-ms", type=int, default=None, + help="Override FLEXKV_REBUILD_INTERVAL_MS (default: use env or 100). " + "Recommended: 20 for cross-node benchmarks") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/dist_benchmark/benchmark_dist_kvcache.py b/benchmarks/dist_benchmark/benchmark_dist_kvcache.py new file mode 100644 index 0000000000..26e6c84426 --- /dev/null +++ b/benchmarks/dist_benchmark/benchmark_dist_kvcache.py @@ -0,0 +1,637 @@ +""" +Benchmark for FlexKV distributed KVCache in server_client_mode. + +This script tests the put/get performance of FlexKV when running in +server_client_mode with distributed KVCache sharing enabled (enable_p2p_cpu). + +Prerequisites: + - A running Redis server (default: 127.0.0.1:6379) + - At least 1 GPU available + - FlexKV built with distributed support (FLEXKV_ENABLE_P2P=1) + +Usage: + # Basic usage with default config + python benchmarks/benchmark_dist_kvcache.py --config benchmarks/example_dist_config.yml + + # Custom parameters + python benchmarks/benchmark_dist_kvcache.py \ + --config benchmarks/example_dist_config.yml \ + --batch-size 4 \ + --sequence-length 2048 \ + --cache-ratio 0.5 \ + --num-users 10 \ + --num-turns 3 + + # Multi-turn conversation benchmark only + python benchmarks/benchmark_dist_kvcache.py \\ + --config benchmarks/example_dist_config.yml \\ + --mode multiturn \\ + --num-users 20 \\ + --num-turns 5 + + # Cross-node benchmark: Node A (PUT only) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_a.yml --seed 42 --mode put-only + + # Cross-node benchmark: Node B (GET only, same seed) + python benchmarks/benchmark_dist_kvcache.py \\ + --config config_b.yml --seed 42 --mode get-only +""" +import os +import atexit +import signal +import argparse +import json +import tempfile +import time +from multiprocessing import Process +from dataclasses import dataclass + +import torch +import numpy as np + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ( + ModelConfig, CacheConfig, UserConfig, + update_default_config_from_user_config, parse_path_list, + GLOBAL_CONFIG_FROM_ENV, +) +from flexkv.common.debug import flexkv_logger +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +from utils import generate_random_multiturn + +flexkv_logger.set_level("INFO") + + +def load_dist_config(config_path: str): + """Load config with distributed KVCache support. + + Extends the standard load_config to handle distributed-specific fields: + enable_p2p_cpu, enable_p2p_ssd, enable_3rd_remote, + redis_host, redis_port, local_ip, redis_password, + server_client_mode, etc. + """ + import yaml + + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(f"Loaded config: {config}") + + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + + # Model config + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + # Cache size config + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + + # Distributed KVCache config + if "enable_p2p_cpu" in config: + user_config.enable_p2p_cpu = config["enable_p2p_cpu"] + if "enable_p2p_ssd" in config: + user_config.enable_p2p_ssd = config["enable_p2p_ssd"] + if "enable_3rd_remote" in config: + user_config.enable_3rd_remote = config["enable_3rd_remote"] + + # Redis config + if "redis_host" in config: + user_config.redis_host = config["redis_host"] + if "redis_port" in config: + user_config.redis_port = config["redis_port"] + if "local_ip" in config: + user_config.local_ip = config["local_ip"] + if "redis_password" in config: + user_config.redis_password = config["redis_password"] + + # Auto-generate mooncake config JSON and set MOONCAKE_CONFIG_PATH if P2P is enabled + if config.get("enable_p2p_cpu", False) or config.get("enable_p2p_ssd", False): + if "MOONCAKE_CONFIG_PATH" not in os.environ: + mooncake_config = { + "engine_ip": config.get("mooncake_engine_ip", config.get("local_ip", "127.0.0.1")), + "engine_port": config.get("mooncake_engine_port", 5555), + "metadata_backend": config.get("mooncake_metadata_backend", "redis"), + "metadata_server": config.get("mooncake_metadata_server", + f"redis://{config.get('redis_host', '127.0.0.1')}:{config.get('redis_port', 6379)}"), + "metadata_server_auth": config.get("mooncake_metadata_server_auth", + config.get("redis_password", "")), + "protocol": config.get("mooncake_protocol", "tcp"), + "device_name": config.get("mooncake_device_name", ""), + } + # Write to a temp file that persists until process exits + mooncake_config_fd, mooncake_config_path = tempfile.mkstemp( + suffix=".json", prefix="mooncake_config_" + ) + with os.fdopen(mooncake_config_fd, "w") as f: + json.dump(mooncake_config, f, indent=2) + os.environ["MOONCAKE_CONFIG_PATH"] = mooncake_config_path + print(f"[INFO] Auto-generated mooncake config at: {mooncake_config_path}") + print(f"[INFO] Mooncake config: {json.dumps(mooncake_config, indent=2)}") + else: + mooncake_config_path = os.environ['MOONCAKE_CONFIG_PATH'] + print(f"[INFO] Using existing MOONCAKE_CONFIG_PATH: {mooncake_config_path}") + + # Store mooncake_config_path in cache_config so it survives spawn subprocesses via pickle + cache_config.mooncake_config_path = mooncake_config_path + + update_default_config_from_user_config(model_config, cache_config, user_config) + + # Handle server_client_mode from config + if config.get("server_client_mode", False): + os.environ["FLEXKV_SERVER_CLIENT_MODE"] = "1" + GLOBAL_CONFIG_FROM_ENV.server_client_mode = True + + return model_config, cache_config + + +@dataclass +class BenchmarkConfig: + # Single batch benchmark params + batch_size: int = 1 + sequence_length: int = 1024 + cache_ratio: float = 1.0 + clear_cpu_cache: bool = False + + # Multi-turn benchmark params + num_users: int = 10 + num_turns: int = 3 + system_prompt_length: int = 100 + input_length: int = 512 + output_length: int = 64 + + # General + mode: str = "all" # "single", "multiturn", "all", "put-only", "get-only" + seed: int = None # Random seed for deterministic token generation (cross-node) + + +def run_tp_client(dp_client_id, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks): + """Run tp_client process to register GPU blocks""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(gpu_register_port, dp_client_id, device_id) + + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Keep the process running + while True: + time.sleep(1) + + +def shutdown_tp_clients(tp_client_processes): + """Terminate all tp_client processes""" + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + + +def benchmark_single_batch(kvmanager, model_config, cache_config, bench_config): + """Benchmark single batch put/get with distributed KVCache""" + print("\n" + "=" * 60) + print(" Single Batch Benchmark (Distributed KVCache)") + print("=" * 60) + + sequence_length = bench_config.sequence_length + batch_size = bench_config.batch_size + cache_length = int(sequence_length * bench_config.cache_ratio) + + print(f" batch_size={batch_size}, sequence_length={sequence_length}, " + f"cache_ratio={bench_config.cache_ratio}, cache_length={cache_length}") + if bench_config.seed is not None: + print(f" seed={bench_config.seed}") + + # Generate random sequences (use seed for deterministic cross-node benchmarks) + if bench_config.seed is not None: + torch.manual_seed(bench_config.seed) + batch_sequence_tensor = [] + batch_slot_mapping = [] + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length,), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i + 1) * sequence_length, dtype=torch.int64)) + + results = {} + skip_put = (bench_config.mode == "get-only") + skip_get = (bench_config.mode == "put-only") + + # In get-only mode, wait for remote index to be refreshed from Redis + if skip_put: + rebuild_interval_ms = int(os.environ.get("FLEXKV_REBUILD_INTERVAL_MS", "100")) + # Wait at least 3x rebuild_interval to ensure at least one full refresh cycle + wait_time_s = max(rebuild_interval_ms * 3 / 1000.0, 0.5) + print(f" Waiting {wait_time_s:.2f}s for remote index refresh " + f"(FLEXKV_REBUILD_INTERVAL_MS={rebuild_interval_ms})...") + time.sleep(wait_time_s) + + # ---- Benchmark PUT ---- + if not skip_put: + print("\n--- PUT Phase ---") + start_time = time.time() + batch_put_ids = [] + if bench_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async( + batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None, + ) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put if elapsed_time_put > 0 else 0 + print(f" PUT: {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put * 1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + results.update({ + "put_tokens": put_tokens, + "put_time_ms": elapsed_time_put * 1000, + "put_bandwidth_GBs": transfer_bandwidth_put, + }) + else: + print("\n--- PUT Phase SKIPPED (get-only mode) ---") + + if bench_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + # ---- Benchmark GET ---- + if not skip_get: + print("\n--- GET Phase ---") + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + + kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=False) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / (1024 ** 3) + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get if elapsed_time_get > 0 else 0 + print(f" GET: {cached_tokens}/{all_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time * 1000:.2f}ms, " + f"e2e time: {elapsed_time_get * 1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + results.update({ + "get_cached_tokens": cached_tokens, + "get_total_tokens": all_tokens, + "get_cache_ratio": cached_tokens / all_tokens if all_tokens > 0 else 0, + "get_match_time_ms": get_match_time * 1000, + "get_e2e_time_ms": elapsed_time_get * 1000, + "get_bandwidth_GBs": transfer_bandwidth_get, + }) + else: + print("\n--- GET Phase SKIPPED (put-only mode) ---") + + return results + + +def benchmark_multiturn(kvmanager, model_config, cache_config, bench_config): + """Benchmark multi-turn conversation with distributed KVCache""" + print("\n" + "=" * 60) + print(" Multi-Turn Conversation Benchmark (Distributed KVCache)") + print("=" * 60) + print(f" num_users={bench_config.num_users}, num_turns={bench_config.num_turns}, " + f"system_prompt_length={bench_config.system_prompt_length}, " + f"input_length={bench_config.input_length}, output_length={bench_config.output_length}") + + # Generate multi-turn requests + reqs = generate_random_multiturn( + num_user_requests=bench_config.num_users, + num_turns=bench_config.num_turns, + system_prompt_length=bench_config.system_prompt_length, + input_length=bench_config.input_length, + output_length=bench_config.output_length, + seed=bench_config.seed, + ) + + total_get_requests = 0 + total_put_requests = 0 + cache_hit_ratios = [] + total_put_time = 0 + total_get_time = 0 + total_put_tokens = 0 + total_get_cached_tokens = 0 + total_get_all_tokens = 0 + + request_id = 0 + for req in reqs: + fake_slot_mapping = torch.arange(req.token_mask.sum(), dtype=torch.int64) + + if req.request_type == "get": + total_get_requests += 1 + total_get_all_tokens += req.token_mask.sum().item() + + start_time = time.time() + task_id, _ = kvmanager.get_match( + req.token_ids, + token_mask=torch.ones_like(torch.from_numpy(req.token_ids) if isinstance(req.token_ids, np.ndarray) else req.token_ids), + ) + kvmanager.launch([task_id], [fake_slot_mapping.numpy()]) + result = kvmanager.wait([task_id]) + elapsed = time.time() - start_time + total_get_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + cached = response.return_mask.sum().item() + total_get_cached_tokens += cached + ratio = cached / req.token_mask.sum().item() + cache_hit_ratios.append(ratio) + else: + cache_hit_ratios.append(0.0) + + elif req.request_type == "put": + total_put_requests += 1 + + start_time = time.time() + task_id = kvmanager.put_async( + req.token_ids, + fake_slot_mapping.numpy(), + token_mask=None, + ) + result = kvmanager.wait([task_id], completely=True) + elapsed = time.time() - start_time + total_put_time += elapsed + + for _, response in result.items(): + if response.status == KVResponseStatus.SUCCESS and response.return_mask is not None: + total_put_tokens += response.return_mask.sum().item() + + request_id += 1 + + # Print results + print(f"\n--- Results ---") + print(f" Total requests: {len(reqs)} (GET: {total_get_requests}, PUT: {total_put_requests})") + print(f" PUT: {total_put_tokens} tokens, total time: {total_put_time * 1000:.2f}ms, " + f"avg time: {total_put_time * 1000 / max(total_put_requests, 1):.2f}ms/req") + print(f" GET: {total_get_cached_tokens}/{total_get_all_tokens} tokens cached, " + f"total time: {total_get_time * 1000:.2f}ms, " + f"avg time: {total_get_time * 1000 / max(total_get_requests, 1):.2f}ms/req") + + if cache_hit_ratios: + sorted_ratios = sorted(cache_hit_ratios) + avg_ratio = sum(sorted_ratios) / len(sorted_ratios) + print(f" Cache hit ratio: avg={avg_ratio * 100:.2f}%, " + f"min={sorted_ratios[0] * 100:.2f}%, " + f"median={sorted_ratios[len(sorted_ratios) // 2] * 100:.2f}%, " + f"max={sorted_ratios[-1] * 100:.2f}%") + + return { + "total_requests": len(reqs), + "get_requests": total_get_requests, + "put_requests": total_put_requests, + "put_tokens": total_put_tokens, + "put_total_time_ms": total_put_time * 1000, + "get_cached_tokens": total_get_cached_tokens, + "get_total_tokens": total_get_all_tokens, + "get_total_time_ms": total_get_time * 1000, + "avg_cache_hit_ratio": sum(cache_hit_ratios) / len(cache_hit_ratios) if cache_hit_ratios else 0, + } + + +def main(args): + # Set FLEXKV_REBUILD_INTERVAL_MS for faster cross-node index sync + # NOTE: Must set env var AND update GLOBAL_CONFIG_FROM_ENV because + # GLOBAL_CONFIG_FROM_ENV is evaluated at module import time (before main runs). + # The env var alone is not enough since the Namespace is already frozen. + if args.rebuild_interval_ms is not None: + os.environ["FLEXKV_REBUILD_INTERVAL_MS"] = str(args.rebuild_interval_ms) + GLOBAL_CONFIG_FROM_ENV.rebuild_interval_ms = args.rebuild_interval_ms + print(f"[INFO] Set FLEXKV_REBUILD_INTERVAL_MS={args.rebuild_interval_ms}") + + # Load config + model_config, cache_config = load_dist_config(args.config) + + bench_config = BenchmarkConfig( + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache, + num_users=args.num_users, + num_turns=args.num_turns, + system_prompt_length=args.system_prompt_length, + input_length=args.input_length, + output_length=args.output_length, + mode=args.mode, + seed=args.seed, + ) + + # Pad sequence length to be divisible by tokens_per_block + bench_config.sequence_length = ( + ((bench_config.sequence_length - 1) // cache_config.tokens_per_block + 1) + * cache_config.tokens_per_block + ) + + num_gpu_blocks = bench_config.sequence_length * bench_config.batch_size // cache_config.tokens_per_block + # Ensure enough GPU blocks for multi-turn mode too + if bench_config.mode in ("multiturn", "all", "put-only", "get-only"): + max_tokens_per_user = ( + bench_config.system_prompt_length + + bench_config.num_turns * (bench_config.input_length + bench_config.output_length) + ) + multiturn_blocks = max_tokens_per_user * bench_config.num_users // cache_config.tokens_per_block + num_gpu_blocks = max(num_gpu_blocks, multiturn_blocks) + # Add some extra blocks for safety + num_gpu_blocks = int(num_gpu_blocks * 1.5) + 64 + + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError( + f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} > " + f"available GPUs {torch.cuda.device_count()}" + ) + + print("=" * 60) + print(" FlexKV Distributed KVCache Benchmark (server_client_mode)") + print("=" * 60) + print(f" model_config: {model_config}") + print(f" cache_config: {cache_config}") + print(f" enable_kv_sharing: {cache_config.enable_kv_sharing}") + print(f" enable_p2p_cpu: {cache_config.enable_p2p_cpu}") + print(f" redis: {cache_config.redis_host}:{cache_config.redis_port}") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" bench_config: {bench_config}") + + # Create KVManager (this will start KVServer in server_client_mode) + kvmanager = KVManager(model_config, cache_config) + kvmanager.start() + + # Start tp_client processes to register GPU blocks + tp_client_processes = [] + + # Register cleanup handler to ensure processes are terminated on exit + def _cleanup(): + shutdown_tp_clients(tp_client_processes) + try: + kvmanager.shutdown() + except Exception: + pass + atexit.register(_cleanup) + + def _signal_handler(signum, frame): + print(f"\nReceived signal {signum}, shutting down...") + _cleanup() + # Re-raise to allow default handler + signal.signal(signum, signal.SIG_DFL) + os.kill(os.getpid(), signum) + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + for tp_rank in range(model_config.tp_size): + tp_process = Process( + target=run_tp_client, + args=(0, tp_rank, kvmanager.gpu_register_port, + model_config, cache_config, num_gpu_blocks), + daemon=True, + ) + tp_process.start() + tp_client_processes.append(tp_process) + + # Wait for system to be ready + print("\nWaiting for FlexKV to be ready...") + wait_start = time.time() + while not kvmanager.is_ready(): + time.sleep(1) + elapsed = time.time() - wait_start + if elapsed > 120: + print("ERROR: Timeout waiting for FlexKV to be ready (120s)") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + return + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + print(f" Still waiting... ({int(elapsed)}s)") + print(f"FlexKV is ready! (took {time.time() - wait_start:.1f}s)") + + try: + results = {} + + if bench_config.mode in ("single", "all", "put-only", "get-only"): + results["single_batch"] = benchmark_single_batch( + kvmanager, model_config, cache_config, bench_config + ) + + if bench_config.mode in ("multiturn", "all"): + results["multiturn"] = benchmark_multiturn( + kvmanager, model_config, cache_config, bench_config + ) + + # Print summary + print("\n" + "=" * 60) + print(" Benchmark Summary") + print("=" * 60) + for name, result in results.items(): + print(f"\n [{name}]") + for k, v in result.items(): + if isinstance(v, float): + print(f" {k}: {v:.4f}") + else: + print(f" {k}: {v}") + + # In put-only mode, keep the process alive so other nodes can GET the data + if bench_config.mode == "put-only": + print("\n" + "-" * 60) + print("Data published to Redis. Press Enter to shutdown " + "(keep running for other nodes to GET)...") + print("-" * 60) + try: + input() + except EOFError: + # Handle non-interactive environments + print("Non-interactive mode detected. Sleeping indefinitely (Ctrl+C to stop)...") + while True: + time.sleep(1) + + finally: + print("\nShutting down...") + shutdown_tp_clients(tp_client_processes) + kvmanager.shutdown() + # Unregister atexit handler since we've already cleaned up + atexit.unregister(_cleanup) + print("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark FlexKV distributed KVCache in server_client_mode" + ) + parser.add_argument("--config", type=str, default="benchmarks/example_dist_config.yml", + help="Path to config YAML file") + parser.add_argument("--mode", type=str, default="all", + choices=["single", "multiturn", "all", "put-only", "get-only"], + help="Benchmark mode: single, multiturn, all, put-only, get-only") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for deterministic token generation (for cross-node benchmarks)") + + # Single batch params + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for single batch benchmark") + parser.add_argument("--sequence-length", type=int, default=1024, help="Sequence length per request") + parser.add_argument("--cache-ratio", type=float, default=1.0, help="Ratio of tokens to cache in PUT phase") + parser.add_argument("--clear-cpu-cache", action="store_true", help="Clear CPU cache between PUT and GET") + + # Multi-turn params + parser.add_argument("--num-users", type=int, default=10, help="Number of simulated users") + parser.add_argument("--num-turns", type=int, default=3, help="Number of conversation turns per user") + parser.add_argument("--system-prompt-length", type=int, default=100, help="System prompt length in tokens") + parser.add_argument("--input-length", type=int, default=512, help="Input length per turn in tokens") + parser.add_argument("--output-length", type=int, default=64, help="Output length per turn in tokens") + + # Cross-node sync params + parser.add_argument("--rebuild-interval-ms", type=int, default=None, + help="Override FLEXKV_REBUILD_INTERVAL_MS (default: use env or 100). " + "Recommended: 20 for cross-node benchmarks") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/dist_benchmark/example_dist_config.yml b/benchmarks/dist_benchmark/example_dist_config.yml new file mode 100644 index 0000000000..0f7a175344 --- /dev/null +++ b/benchmarks/dist_benchmark/example_dist_config.yml @@ -0,0 +1,34 @@ +# Distributed KVCache benchmark config (server_client_mode) +# Model config +num_layers: 4 +num_kv_heads: 8 +head_size: 128 +dtype: bfloat16 +use_mla: false +tp_size: 1 +dp_size: 1 +tokens_per_block: 16 + +# Cache config +cpu_cache_gb: 4 +ssd_cache_gb: 0 + +# Distributed KVCache config +enable_p2p_cpu: true + +# Redis config (for KV sharing metadata) +redis_host: "10.135.1.175" +redis_port: 6379 +redis_password: "123456" +local_ip: "10.135.1.176" + +# Mooncake Transfer Engine config (required for P2P) +mooncake_engine_ip: "10.135.1.176" +mooncake_engine_port: 5555 +mooncake_metadata_backend: "redis" +mooncake_metadata_server: "redis://10.135.1.175:6379" +mooncake_metadata_server_auth: "123456" +mooncake_protocol: "rdma" # "tcp" or "rdma" +mooncake_device_name: "mlx5_0,mlx5_1,mlx5_4,mlx5_5" # RDMA device name, e.g. "mlx5_0"; leave empty for tcp +# Force server_client_mode +server_client_mode: true diff --git a/benchmarks/dist_benchmark/example_dist_direct_config.yml b/benchmarks/dist_benchmark/example_dist_direct_config.yml new file mode 100644 index 0000000000..ae03da9855 --- /dev/null +++ b/benchmarks/dist_benchmark/example_dist_direct_config.yml @@ -0,0 +1,40 @@ +# Distributed KVCache benchmark config (direct mode, non-server_client_mode) +# In direct mode, KVManager creates KVTaskEngine directly in the main process +# without going through KVServer/KVDPClient IPC. This is simpler and has +# lower overhead, suitable for single-instance single-dp benchmarks. + +# Model config +num_layers: 4 +num_kv_heads: 8 +head_size: 128 +dtype: bfloat16 +use_mla: false +tp_size: 1 +dp_size: 1 +tokens_per_block: 16 + +# Cache config +cpu_cache_gb: 4 +ssd_cache_gb: 0 + +# Distributed KVCache config +enable_p2p_cpu: true + +# Redis config (for KV sharing metadata) +redis_host: "10.135.1.175" +redis_port: 6379 +redis_password: "123456" +local_ip: "10.135.1.176" + +# Mooncake Transfer Engine config (required for P2P) +mooncake_engine_ip: "10.135.1.176" +mooncake_engine_port: 5555 +mooncake_metadata_backend: "redis" +mooncake_metadata_server: "redis://10.135.1.175:6379" +mooncake_metadata_server_auth: "123456" +mooncake_protocol: "rdma" # "tcp" or "rdma" +mooncake_device_name: "mlx5_0,mlx5_1,mlx5_4,mlx5_5" # RDMA device name, e.g. "mlx5_0"; leave empty for tcp + +# Direct mode (non-server_client_mode) +# In direct mode, KVTaskEngine runs in the main process +server_client_mode: false diff --git a/benchmarks/dist_benchmark/redis_check.py b/benchmarks/dist_benchmark/redis_check.py new file mode 100644 index 0000000000..2c702056f6 --- /dev/null +++ b/benchmarks/dist_benchmark/redis_check.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +FlexKV Redis Data Inspector + +Check what data the put-only node has pushed to Redis. +This script inspects all FlexKV-related keys in Redis including: + - global:node_id (global node ID counter) + - node: (registered node info) + - meta: (node meta: mooncake engine addr, buffer ptrs) + - buffer::* (RDMA memory region registrations) + - CPUB:: (CPU KVCache block metadata - the actual cached data index) + - SSDB:: (SSD KVCache block metadata) + - PCFSB:: (PCFS remote KVCache block metadata) + - pcfs: (PCFS file node IDs) + - mooncake/* (Mooncake Transfer Engine metadata) + +Usage: + python benchmarks/redis_check.py [--host HOST] [--port PORT] [--password PWD] + + # With defaults from example_dist_config.yml: + python benchmarks/redis_check.py --host 10.135.1.175 --port 6379 --password 123456 +""" + +import argparse +import sys + +try: + import redis +except ImportError: + print("ERROR: redis-py is required. Install with: pip install redis") + sys.exit(1) + + +def connect_redis(host, port, password): + """Connect to Redis and verify connectivity.""" + r = redis.Redis( + host=host, port=port, + password=password if password else None, + decode_responses=True, + socket_connect_timeout=5, + ) + try: + r.ping() + print(f"✅ Connected to Redis at {host}:{port}") + except redis.ConnectionError as e: + print(f"❌ Failed to connect to Redis at {host}:{port}: {e}") + sys.exit(1) + return r + + +def scan_keys(r, pattern, count=1000): + """Scan Redis keys matching pattern (non-blocking).""" + keys = [] + cursor = 0 + while True: + cursor, batch = r.scan(cursor=cursor, match=pattern, count=count) + keys.extend(batch) + if cursor == 0: + break + return sorted(keys) + + +def check_global_node_id(r): + """Check the global node ID counter.""" + print("\n" + "=" * 60) + print(" 1. Global Node ID Counter") + print("=" * 60) + val = r.get("global:node_id") + if val is not None: + print(f" global:node_id = {val}") + print(f" → {val} node(s) have been registered in total") + else: + print(" ⚠️ global:node_id not found (no nodes registered yet)") + + +def check_registered_nodes(r): + """Check registered node information.""" + print("\n" + "=" * 60) + print(" 2. Registered Nodes (node:*)") + print("=" * 60) + keys = scan_keys(r, "node:*") + if not keys: + print(" ⚠️ No registered nodes found") + return + + print(f" Found {len(keys)} registered node(s):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + print(f" {field}: {value}") + print() + + +def check_node_meta(r): + """Check node meta information (mooncake engine addr, buffer ptrs).""" + print("\n" + "=" * 60) + print(" 3. Node Meta (meta:*)") + print("=" * 60) + keys = scan_keys(r, "meta:*") + if not keys: + print(" ⚠️ No node meta found") + print(" → This means PEER2CPUTransferWorker hasn't registered yet,") + print(" or mooncake transfer engine initialization failed.") + return + + print(f" Found {len(keys)} node meta entry(ies):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + # Format large integers (pointers) in hex for readability + if field in ("cpu_buffer_ptr", "ssd_buffer_ptr"): + try: + int_val = int(value) + print(f" {field}: {value} (0x{int_val:x})") + except (ValueError, TypeError): + print(f" {field}: {value}") + else: + print(f" {field}: {value}") + print() + + +def check_buffer_registrations(r): + """Check RDMA buffer registrations.""" + print("\n" + "=" * 60) + print(" 4. RDMA Buffer Registrations (buffer:*)") + print("=" * 60) + keys = scan_keys(r, "buffer:*") + if not keys: + print(" ⚠️ No RDMA buffer registrations found") + return + + print(f" Found {len(keys)} buffer registration(s):\n") + for key in keys: + data = r.hgetall(key) + buf_size = data.get("buffer_size", "?") + try: + size_mb = int(buf_size) / (1024 * 1024) + print(f" 📌 {key}: size={buf_size} bytes ({size_mb:.2f} MB)") + except (ValueError, TypeError): + print(f" 📌 {key}: size={buf_size}") + + +def check_block_metadata(r): + """Check KVCache block metadata - this is the core data from put operations. + + FlexKV uses different key prefixes for different device types: + - CPUB:: — CPU block metadata (P2P CPU sharing) + - SSDB:: — SSD block metadata (P2P SSD sharing) + - PCFSB:: — PCFS remote block metadata + Each key is a Redis hash with fields: ph, pb, nid, hash, lt, state. + """ + print("\n" + "=" * 60) + print(" 5. KVCache Block Metadata (CPUB/SSDB/PCFSB)") + print("=" * 60) + + # FlexKV actual block key prefixes (set in hie_cache_engine.py) + block_prefixes = { + "CPUB": "CPU", + "SSDB": "SSD", + "PCFSB": "PCFS (Remote)", + } + + grand_total = 0 + for prefix, label in block_prefixes.items(): + keys = scan_keys(r, f"{prefix}:*") + if not keys: + print(f"\n [{label}] {prefix}:* — no entries found") + continue + + grand_total += len(keys) + + # Group by node_id: key format is PREFIX:: + node_blocks = {} + for key in keys: + parts = key.split(":") + if len(parts) >= 2: + node_id = parts[1] + if node_id not in node_blocks: + node_blocks[node_id] = [] + node_blocks[node_id].append(key) + + print(f"\n [{label}] {prefix}:* — {len(keys)} block(s) across {len(node_blocks)} node(s):") + + for node_id, block_keys in sorted(node_blocks.items(), key=lambda x: int(x[0]) if x[0].isdigit() else 0): + print(f" 📌 Node {node_id}: {len(block_keys)} block(s)") + + # Show first few blocks as samples + sample_count = min(3, len(block_keys)) + for key in block_keys[:sample_count]: + data = r.hgetall(key) + if data: + # BlockMeta fields: ph (physical hash), pb (physical block), + # nid (node id), hash, lt (lease time), state + ph = data.get("ph", "?") + pb = data.get("pb", "?") + nid = data.get("nid", "?") + hash_val = data.get("hash", "?") + lt = data.get("lt", "?") + state = data.get("state", "?") + print(f" {key}: ph={ph}, pb={pb}, nid={nid}, hash={hash_val}, lt={lt}, state={state}") + else: + key_type = r.type(key) + print(f" {key}: type={key_type}, (empty hash)") + + if len(block_keys) > sample_count: + print(f" ... and {len(block_keys) - sample_count} more block(s)") + + if grand_total == 0: + print("\n ⚠️ No block metadata found in any prefix (CPUB/SSDB/PCFSB)") + print(" → This means no KVCache data has been published to Redis yet.") + print(" The put-only node may still be uploading, or the upload") + print(" interval (rebuild_interval_ms) hasn't elapsed yet.") + else: + print(f"\n ✅ Total block metadata entries: {grand_total}") + + +def check_pcfs_data(r): + """Check PCFS file node IDs.""" + print("\n" + "=" * 60) + print(" 6. PCFS File Node IDs (pcfs:*)") + print("=" * 60) + keys = scan_keys(r, "pcfs:*") + if not keys: + print(" (none found - this is normal if PCFS sharing is not used)") + return + + print(f" Found {len(keys)} PCFS entry(ies):\n") + for key in keys: + values = r.lrange(key, 0, -1) + print(f" 📌 {key}: {len(values)} file node ID(s)") + if values: + sample = values[:10] + print(f" sample: {sample}") + if len(values) > 10: + print(f" ... and {len(values) - 10} more") + + +def check_mooncake_keys(r): + """Check Mooncake Transfer Engine related keys.""" + print("\n" + "=" * 60) + print(" 7. Mooncake Transfer Engine Keys") + print("=" * 60) + # Mooncake uses Redis as metadata backend, keys may vary + # Common patterns: segment info, endpoint info + patterns = ["mooncake/*", "mooncake:*", "segment:*", "endpoint:*", "mc:*"] + found_any = False + for pattern in patterns: + keys = scan_keys(r, pattern) + if keys: + found_any = True + print(f"\n Pattern '{pattern}': {len(keys)} key(s)") + for key in keys[:10]: + key_type = r.type(key) + if key_type == "hash": + data = r.hgetall(key) + print(f" 📌 {key} (hash): {data}") + elif key_type == "string": + val = r.get(key) + if val and len(val) > 200: + print(f" 📌 {key} (string): {val[:200]}...") + else: + print(f" 📌 {key} (string): {val}") + elif key_type == "set": + members = r.smembers(key) + print(f" 📌 {key} (set): {members}") + elif key_type == "list": + vals = r.lrange(key, 0, 9) + print(f" 📌 {key} (list): {vals}") + else: + print(f" 📌 {key} (type={key_type})") + if len(keys) > 10: + print(f" ... and {len(keys) - 10} more") + + if not found_any: + print(" (no mooncake-specific keys found)") + + +def check_all_keys_summary(r): + """Show a summary of ALL keys in Redis grouped by prefix.""" + print("\n" + "=" * 60) + print(" 8. All Keys Summary") + print("=" * 60) + all_keys = scan_keys(r, "*") + if not all_keys: + print(" ⚠️ Redis is completely empty!") + return + + print(f" Total keys in Redis: {len(all_keys)}\n") + + # Group by prefix (first part before ':') + prefix_counts = {} + for key in all_keys: + prefix = key.split(":")[0] if ":" in key else key + prefix_counts[prefix] = prefix_counts.get(prefix, 0) + 1 + + print(f" {'Prefix':<30} {'Count':>8}") + print(f" {'-'*30} {'-'*8}") + for prefix, count in sorted(prefix_counts.items(), key=lambda x: -x[1]): + print(f" {prefix:<30} {count:>8}") + + +def main(): + parser = argparse.ArgumentParser( + description="FlexKV Redis Data Inspector - Check put-only node data" + ) + parser.add_argument("--host", type=str, default="10.135.1.175", + help="Redis host (default: 10.135.1.175)") + parser.add_argument("--port", type=int, default=6379, + help="Redis port (default: 6379)") + parser.add_argument("--password", type=str, default="123456", + help="Redis password (default: 123456)") + args = parser.parse_args() + + print("=" * 60) + print(" FlexKV Redis Data Inspector") + print("=" * 60) + print(f" Target: {args.host}:{args.port}") + + r = connect_redis(args.host, args.port, args.password) + + check_global_node_id(r) + check_registered_nodes(r) + check_node_meta(r) + check_buffer_registrations(r) + check_block_metadata(r) + check_pcfs_data(r) + check_mooncake_keys(r) + check_all_keys_summary(r) + + print("\n" + "=" * 60) + print(" Inspection Complete") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dist_benchmark/run_dist_benchmark.sh b/benchmarks/dist_benchmark/run_dist_benchmark.sh new file mode 100755 index 0000000000..8e64fdac22 --- /dev/null +++ b/benchmarks/dist_benchmark/run_dist_benchmark.sh @@ -0,0 +1,405 @@ +#!/bin/bash +# ============================================================================= +# FlexKV Distributed KVCache Benchmark - One-Click Launch Script +# +# This script handles: +# 1. Check and start Redis server if not running +# 2. Set up environment variables +# 3. Run the distributed KVCache benchmark +# +# Usage: +# bash benchmarks/run_dist_benchmark.sh [options] +# +# Options (passed through to benchmark_dist_kvcache.py): +# --config Config YAML file (default: benchmarks/example_dist_config.yml) +# --mode Benchmark mode: single, multiturn, or all (default: all) +# --batch-size Batch size (default: 1) +# --sequence-length Sequence length (default: 1024) +# --num-users Number of simulated users (default: 10) +# --num-turns Number of conversation turns (default: 3) +# --clean-redis Clean up FlexKV & Mooncake residual data in Redis before running benchmark +# (removes node:*, meta:*, CPUB:block:*, SSDB:block:*, PCFSB:block:*, +# mooncake/*, mooncake:*, segment:*, endpoint:*, mc:* keys) +# --clean-redis-only Clean up FlexKV & Mooncake residual data in Redis and exit (no benchmark) +# +# Examples: +# # Run with defaults +# bash benchmarks/run_dist_benchmark.sh +# +# # Custom parameters +# bash benchmarks/run_dist_benchmark.sh --batch-size 4 --sequence-length 2048 +# +# # Multi-turn only +# bash benchmarks/run_dist_benchmark.sh --mode multiturn --num-users 20 --num-turns 5 +# +# # Clean Redis residual data before benchmark +# bash benchmarks/run_dist_benchmark.sh --clean-redis +# +# # Only clean Redis residual data (no benchmark) +# bash benchmarks/run_dist_benchmark.sh --clean-redis-only +# ============================================================================= + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Default config file +CONFIG_FILE="${SCRIPT_DIR}/example_dist_config.yml" +REDIS_STARTED_BY_US=false +CLEAN_REDIS=false +CLEAN_REDIS_ONLY=false + +# Parse script-specific arguments and --config, pass the rest through to benchmark +BENCH_ARGS=() +prev_arg="" +for arg in "$@"; do + if [[ "$prev_arg" == "--config" ]]; then + CONFIG_FILE="$arg" + BENCH_ARGS+=("$arg") + prev_arg="$arg" + continue + fi + case "$arg" in + --clean-redis) + CLEAN_REDIS=true + ;; + --clean-redis-only) + CLEAN_REDIS=true + CLEAN_REDIS_ONLY=true + ;; + *) + BENCH_ARGS+=("$arg") + ;; + esac + prev_arg="$arg" +done + +# ============================================ +# Step 1: Parse Redis config from YAML +# ============================================ +info "============================================" +info "Step 1: Parsing configuration" +info "============================================" + +# Helper function to parse a YAML value using Python (handles comments, quotes, etc. correctly) +# Usage: parse_yaml_value [default] +parse_yaml_value() { + local key="$1" file="$2" default="${3:-}" + local val + val=$(python3 -c " +import yaml, sys +with open('$file') as f: + d = yaml.safe_load(f) +v = d.get('$key') +if v is None: + print('$default') +else: + print(v) +" 2>/dev/null) || val="$default" + echo "$val" +} + +# Simple YAML parser for redis config +REDIS_HOST=$(parse_yaml_value "redis_host" "$CONFIG_FILE" "127.0.0.1") +REDIS_PORT=$(parse_yaml_value "redis_port" "$CONFIG_FILE" "6379") +REDIS_PASSWORD=$(parse_yaml_value "redis_password" "$CONFIG_FILE" "") + +info "Config file: ${CONFIG_FILE}" +info "Redis: ${REDIS_HOST}:${REDIS_PORT}" + +# ============================================ +# Step 2: Check and start Redis +# ============================================ +info "============================================" +info "Step 2: Checking Redis server" +info "============================================" + +check_redis() { + local auth_args="" + if [[ -n "$REDIS_PASSWORD" ]]; then + auth_args="-a $REDIS_PASSWORD" + fi + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $auth_args ping 2>/dev/null | grep -q "PONG" +} + +# Build redis-cli auth arguments (reused across the script) +REDIS_AUTH_ARGS="" +if [[ -n "$REDIS_PASSWORD" ]]; then + REDIS_AUTH_ARGS="-a $REDIS_PASSWORD" +fi + +if check_redis; then + ok "Redis is already running at ${REDIS_HOST}:${REDIS_PORT}" +else + warn "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + + # Only try to start Redis if it's localhost + if [[ "$REDIS_HOST" == "127.0.0.1" ]] || [[ "$REDIS_HOST" == "localhost" ]]; then + if command -v redis-server &>/dev/null; then + info "Starting Redis server on port ${REDIS_PORT}..." + redis-server --port "$REDIS_PORT" --daemonize yes --save "" --appendonly no \ + --protected-mode no --loglevel warning + sleep 1 + + if check_redis; then + ok "Redis server started successfully" + REDIS_STARTED_BY_US=true + else + error "Failed to start Redis server" + error "Please install Redis: sudo apt install redis-server" + exit 1 + fi + else + error "redis-server not found. Please install Redis:" + error " sudo apt install redis-server" + exit 1 + fi + else + error "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + error "Please start Redis on the remote host first." + exit 1 + fi +fi + +# ============================================ +# Step 2.5: Clean FlexKV residual data in Redis (if requested) +# ============================================ +if [[ "$CLEAN_REDIS" == "true" ]]; then + info "============================================" + info "Cleaning FlexKV residual data in Redis" + info "============================================" + + clean_redis_keys() { + local pattern="$1" + local count=0 + local cursor=0 + while true; do + local result + result=$(redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS SCAN $cursor MATCH "$pattern" COUNT 500 2>/dev/null) + cursor=$(echo "$result" | head -1) + local keys + keys=$(echo "$result" | tail -n +2) + if [[ -n "$keys" ]]; then + local batch_keys + batch_keys=$(echo "$keys" | tr '\n' ' ') + if [[ -n "$batch_keys" ]]; then + local deleted + deleted=$(echo "$batch_keys" | xargs redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS DEL 2>/dev/null) + count=$((count + deleted)) + fi + fi + if [[ "$cursor" == "0" ]]; then + break + fi + done + echo "$count" + } + + total_deleted=0 + + # Clean node:* keys + n=$(clean_redis_keys "node:*") + [[ $n -gt 0 ]] && info "Deleted $n node:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean meta:* keys + n=$(clean_redis_keys "meta:*") + [[ $n -gt 0 ]] && info "Deleted $n meta:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean CPUB:block:* keys + n=$(clean_redis_keys "CPUB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n CPUB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean SSDB:block:* keys + n=$(clean_redis_keys "SSDB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n SSDB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean PCFSB:block:* keys + n=$(clean_redis_keys "PCFSB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n PCFSB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean Mooncake Transfer Engine residual keys + # Mooncake uses Redis as metadata backend to store segment/endpoint info + for mc_pattern in "mooncake/*" "mooncake:*" "segment:*" "endpoint:*" "mc:*"; do + n=$(clean_redis_keys "$mc_pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${mc_pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + if [[ $total_deleted -gt 0 ]]; then + ok "Cleaned $total_deleted FlexKV & Mooncake residual key(s) from Redis" + else + ok "No FlexKV residual data found in Redis" + fi + + if [[ "$CLEAN_REDIS_ONLY" == "true" ]]; then + ok "Clean-only mode, exiting." + exit 0 + fi +fi + +# ============================================ +# Step 3: Set up environment +# ============================================ +info "============================================" +info "Step 3: Setting up environment" +info "============================================" + +# Detect Python (prefer virtual env) +if [[ -n "$VIRTUAL_ENV" ]]; then + # Prefer 'which python3' to get the actual resolved path in the activated venv, + # because $VIRTUAL_ENV may point to a path that doesn't match the real filesystem + # (e.g. symlinks, home dir aliases like ~ vs /data1/home). + if command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + else + PYTHON="$VIRTUAL_ENV/bin/python3" + fi + if [[ ! -x "$PYTHON" ]]; then + error "Python3 not found at $PYTHON (VIRTUAL_ENV=$VIRTUAL_ENV)" + exit 1 + fi + info "Using virtual env Python: $PYTHON" +elif command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + info "Using system Python: $PYTHON" +else + error "Python3 not found!" + exit 1 +fi + +# Set PYTHONPATH to include project root +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# Set LD_LIBRARY_PATH for C++ libraries +if [[ -d "${PROJECT_ROOT}/build" ]]; then + export LD_LIBRARY_PATH="${PROJECT_ROOT}/build:${LD_LIBRARY_PATH:-}" +fi + +info "PYTHONPATH=${PYTHONPATH}" +info "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" + +# Generate mooncake config JSON and export MOONCAKE_CONFIG_PATH if P2P is enabled +ENABLE_P2P_CPU=$(grep -E "^enable_p2p_cpu:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") +ENABLE_P2P_SSD=$(grep -E "^enable_p2p_ssd:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") + +if [[ "$ENABLE_P2P_CPU" == "true" ]] || [[ "$ENABLE_P2P_SSD" == "true" ]]; then + if [[ -z "${MOONCAKE_CONFIG_PATH:-}" ]]; then + info "P2P enabled, generating mooncake config..." + + # Parse mooncake config from YAML using helper function + MC_ENGINE_IP=$(parse_yaml_value "mooncake_engine_ip" "$CONFIG_FILE") + MC_ENGINE_PORT=$(parse_yaml_value "mooncake_engine_port" "$CONFIG_FILE") + MC_METADATA_BACKEND=$(parse_yaml_value "mooncake_metadata_backend" "$CONFIG_FILE") + MC_METADATA_SERVER=$(parse_yaml_value "mooncake_metadata_server" "$CONFIG_FILE") + MC_METADATA_SERVER_AUTH=$(parse_yaml_value "mooncake_metadata_server_auth" "$CONFIG_FILE") + MC_PROTOCOL=$(parse_yaml_value "mooncake_protocol" "$CONFIG_FILE") + MC_DEVICE_NAME=$(parse_yaml_value "mooncake_device_name" "$CONFIG_FILE") + LOCAL_IP=$(parse_yaml_value "local_ip" "$CONFIG_FILE" "127.0.0.1") + + # Use defaults if not specified + MC_ENGINE_IP="${MC_ENGINE_IP:-$LOCAL_IP}" + MC_ENGINE_PORT="${MC_ENGINE_PORT:-5555}" + MC_METADATA_BACKEND="${MC_METADATA_BACKEND:-redis}" + MC_METADATA_SERVER="${MC_METADATA_SERVER:-redis://${REDIS_HOST}:${REDIS_PORT}}" + MC_PROTOCOL="${MC_PROTOCOL:-tcp}" + MC_DEVICE_NAME="${MC_DEVICE_NAME:-}" + + # Generate JSON config file + MOONCAKE_CONFIG_FILE=$(mktemp /tmp/mooncake_config_XXXXXX.json) + cat > "$MOONCAKE_CONFIG_FILE" </dev/null || true + ok "Redis stopped." +fi + +if [[ $BENCH_EXIT_CODE -eq 0 ]]; then + echo "" + ok "Benchmark completed successfully!" +else + echo "" + error "Benchmark failed with exit code: $BENCH_EXIT_CODE" +fi + +exit $BENCH_EXIT_CODE diff --git a/benchmarks/dist_benchmark/run_dist_direct_benchmark.sh b/benchmarks/dist_benchmark/run_dist_direct_benchmark.sh new file mode 100755 index 0000000000..6a9bd5887a --- /dev/null +++ b/benchmarks/dist_benchmark/run_dist_direct_benchmark.sh @@ -0,0 +1,372 @@ +#!/bin/bash +# ============================================================================= +# FlexKV Distributed KVCache Benchmark (Direct Mode) - One-Click Launch Script +# +# This script runs the distributed KVCache benchmark in direct mode +# (non-server_client_mode), where KVManager creates KVTaskEngine directly +# in the main process without going through KVServer/KVDPClient IPC. +# +# Usage: +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh [options] +# +# Options (passed through to benchmark_dist_direct.py): +# --config Config YAML file (default: benchmarks/dist_benchmark/example_dist_direct_config.yml) +# --mode Benchmark mode: single, multiturn, or all (default: all) +# --batch-size Batch size (default: 1) +# --sequence-length Sequence length (default: 1024) +# --num-users Number of simulated users (default: 10) +# --num-turns Number of conversation turns (default: 3) +# --clean-redis Clean up FlexKV & Mooncake residual data in Redis before running benchmark +# --clean-redis-only Clean up FlexKV & Mooncake residual data in Redis and exit (no benchmark) +# +# Examples: +# # Run with defaults (direct mode) +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh +# +# # Custom parameters +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --batch-size 4 --sequence-length 2048 +# +# # Multi-turn only +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --mode multiturn --num-users 20 --num-turns 5 +# +# # Clean Redis residual data before benchmark +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --clean-redis +# +# # Only clean Redis residual data (no benchmark) +# bash benchmarks/dist_benchmark/run_dist_direct_benchmark.sh --clean-redis-only +# ============================================================================= + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Default config file (direct mode config) +CONFIG_FILE="${SCRIPT_DIR}/example_dist_direct_config.yml" +REDIS_STARTED_BY_US=false +CLEAN_REDIS=false +CLEAN_REDIS_ONLY=false + +# Parse script-specific arguments and --config, pass the rest through to benchmark +BENCH_ARGS=() +prev_arg="" +for arg in "$@"; do + if [[ "$prev_arg" == "--config" ]]; then + CONFIG_FILE="$arg" + BENCH_ARGS+=("$arg") + prev_arg="$arg" + continue + fi + case "$arg" in + --clean-redis) + CLEAN_REDIS=true + ;; + --clean-redis-only) + CLEAN_REDIS=true + CLEAN_REDIS_ONLY=true + ;; + *) + BENCH_ARGS+=("$arg") + ;; + esac + prev_arg="$arg" +done + +# ============================================ +# Step 1: Parse Redis config from YAML +# ============================================ +info "============================================" +info "Step 1: Parsing configuration" +info "============================================" + +parse_yaml_value() { + local key="$1" file="$2" default="${3:-}" + local val + val=$(python3 -c " +import yaml, sys +with open('$file') as f: + d = yaml.safe_load(f) +v = d.get('$key') +if v is None: + print('$default') +else: + print(v) +" 2>/dev/null) || val="$default" + echo "$val" +} + +REDIS_HOST=$(parse_yaml_value "redis_host" "$CONFIG_FILE" "127.0.0.1") +REDIS_PORT=$(parse_yaml_value "redis_port" "$CONFIG_FILE" "6379") +REDIS_PASSWORD=$(parse_yaml_value "redis_password" "$CONFIG_FILE" "") + +info "Config file: ${CONFIG_FILE}" +info "Redis: ${REDIS_HOST}:${REDIS_PORT}" +info "Mode: Direct (non-server_client_mode)" + +# ============================================ +# Step 2: Check and start Redis +# ============================================ +info "============================================" +info "Step 2: Checking Redis server" +info "============================================" + +check_redis() { + local auth_args="" + if [[ -n "$REDIS_PASSWORD" ]]; then + auth_args="-a $REDIS_PASSWORD" + fi + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $auth_args ping 2>/dev/null | grep -q "PONG" +} + +REDIS_AUTH_ARGS="" +if [[ -n "$REDIS_PASSWORD" ]]; then + REDIS_AUTH_ARGS="-a $REDIS_PASSWORD" +fi + +if check_redis; then + ok "Redis is already running at ${REDIS_HOST}:${REDIS_PORT}" +else + warn "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + + if [[ "$REDIS_HOST" == "127.0.0.1" ]] || [[ "$REDIS_HOST" == "localhost" ]]; then + if command -v redis-server &>/dev/null; then + info "Starting Redis server on port ${REDIS_PORT}..." + redis-server --port "$REDIS_PORT" --daemonize yes --save "" --appendonly no \ + --protected-mode no --loglevel warning + sleep 1 + + if check_redis; then + ok "Redis server started successfully" + REDIS_STARTED_BY_US=true + else + error "Failed to start Redis server" + error "Please install Redis: sudo apt install redis-server" + exit 1 + fi + else + error "redis-server not found. Please install Redis:" + error " sudo apt install redis-server" + exit 1 + fi + else + error "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + error "Please start Redis on the remote host first." + exit 1 + fi +fi + +# ============================================ +# Step 2.5: Clean FlexKV & Mooncake residual data in Redis (if requested) +# ============================================ +if [[ "$CLEAN_REDIS" == "true" ]]; then + info "============================================" + info "Cleaning FlexKV & Mooncake residual data in Redis" + info "============================================" + + clean_redis_keys() { + local pattern="$1" + local count=0 + local cursor=0 + while true; do + local result + result=$(redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS SCAN $cursor MATCH "$pattern" COUNT 500 2>/dev/null) + cursor=$(echo "$result" | head -1) + local keys + keys=$(echo "$result" | tail -n +2) + if [[ -n "$keys" ]]; then + local batch_keys + batch_keys=$(echo "$keys" | tr '\n' ' ') + if [[ -n "$batch_keys" ]]; then + local deleted + deleted=$(echo "$batch_keys" | xargs redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS DEL 2>/dev/null) + count=$((count + deleted)) + fi + fi + if [[ "$cursor" == "0" ]]; then + break + fi + done + echo "$count" + } + + total_deleted=0 + + # Clean FlexKV keys + for pattern in "node:*" "meta:*" "CPUB:block:*" "SSDB:block:*" "PCFSB:block:*"; do + n=$(clean_redis_keys "$pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + # Clean Mooncake Transfer Engine residual keys + for mc_pattern in "mooncake/*" "mooncake:*" "segment:*" "endpoint:*" "mc:*"; do + n=$(clean_redis_keys "$mc_pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${mc_pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + if [[ $total_deleted -gt 0 ]]; then + ok "Cleaned $total_deleted FlexKV & Mooncake residual key(s) from Redis" + else + ok "No FlexKV residual data found in Redis" + fi + + if [[ "$CLEAN_REDIS_ONLY" == "true" ]]; then + ok "Clean-only mode, exiting." + exit 0 + fi +fi + +# ============================================ +# Step 3: Set up environment +# ============================================ +info "============================================" +info "Step 3: Setting up environment" +info "============================================" + +if [[ -n "$VIRTUAL_ENV" ]]; then + # Prefer 'which python3' to get the actual resolved path in the activated venv, + # because $VIRTUAL_ENV may point to a path that doesn't match the real filesystem + # (e.g. symlinks, home dir aliases like ~ vs /data1/home). + if command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + else + PYTHON="$VIRTUAL_ENV/bin/python3" + fi + if [[ ! -x "$PYTHON" ]]; then + error "Python3 not found at $PYTHON (VIRTUAL_ENV=$VIRTUAL_ENV)" + exit 1 + fi + info "Using virtual env Python: $PYTHON" +elif command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + info "Using system Python: $PYTHON" +else + error "Python3 not found!" + exit 1 +fi + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +if [[ -d "${PROJECT_ROOT}/build" ]]; then + export LD_LIBRARY_PATH="${PROJECT_ROOT}/build:${LD_LIBRARY_PATH:-}" +fi + +# IMPORTANT: Ensure server_client_mode is NOT set for direct mode +unset FLEXKV_SERVER_CLIENT_MODE + +info "PYTHONPATH=${PYTHONPATH}" +info "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" +info "FLEXKV_SERVER_CLIENT_MODE= (direct mode)" + +# Generate mooncake config JSON if P2P is enabled +ENABLE_P2P_CPU=$(grep -E "^enable_p2p_cpu:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") +ENABLE_P2P_SSD=$(grep -E "^enable_p2p_ssd:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") + +if [[ "$ENABLE_P2P_CPU" == "true" ]] || [[ "$ENABLE_P2P_SSD" == "true" ]]; then + if [[ -z "${MOONCAKE_CONFIG_PATH:-}" ]]; then + info "P2P enabled, generating mooncake config..." + + MC_ENGINE_IP=$(parse_yaml_value "mooncake_engine_ip" "$CONFIG_FILE") + MC_ENGINE_PORT=$(parse_yaml_value "mooncake_engine_port" "$CONFIG_FILE") + MC_METADATA_BACKEND=$(parse_yaml_value "mooncake_metadata_backend" "$CONFIG_FILE") + MC_METADATA_SERVER=$(parse_yaml_value "mooncake_metadata_server" "$CONFIG_FILE") + MC_METADATA_SERVER_AUTH=$(parse_yaml_value "mooncake_metadata_server_auth" "$CONFIG_FILE") + MC_PROTOCOL=$(parse_yaml_value "mooncake_protocol" "$CONFIG_FILE") + MC_DEVICE_NAME=$(parse_yaml_value "mooncake_device_name" "$CONFIG_FILE") + LOCAL_IP=$(parse_yaml_value "local_ip" "$CONFIG_FILE" "127.0.0.1") + + MC_ENGINE_IP="${MC_ENGINE_IP:-$LOCAL_IP}" + MC_ENGINE_PORT="${MC_ENGINE_PORT:-5555}" + MC_METADATA_BACKEND="${MC_METADATA_BACKEND:-redis}" + MC_METADATA_SERVER="${MC_METADATA_SERVER:-redis://${REDIS_HOST}:${REDIS_PORT}}" + MC_PROTOCOL="${MC_PROTOCOL:-tcp}" + MC_DEVICE_NAME="${MC_DEVICE_NAME:-}" + + MOONCAKE_CONFIG_FILE=$(mktemp /tmp/mooncake_config_XXXXXX.json) + cat > "$MOONCAKE_CONFIG_FILE" </dev/null || true + ok "Redis stopped." +fi + +if [[ $BENCH_EXIT_CODE -eq 0 ]]; then + echo "" + ok "Benchmark (Direct Mode) completed successfully!" +else + echo "" + error "Benchmark failed with exit code: $BENCH_EXIT_CODE" +fi + +exit $BENCH_EXIT_CODE diff --git a/benchmarks/dist_benchmark/utils.py b/benchmarks/dist_benchmark/utils.py new file mode 100644 index 0000000000..d979a27a07 --- /dev/null +++ b/benchmarks/dist_benchmark/utils.py @@ -0,0 +1,131 @@ +import asyncio +import random +import time +from dataclasses import dataclass, field +from typing import Optional, List, Tuple, Any +import yaml + +import torch +import numpy as np +from tqdm import tqdm + +from flexkv.common.config import * +from flexkv.common.storage import KVCacheLayoutType + + +@dataclass +class KVRequest: + user_id: int + turn_id: int + request_type: str # "get" or "put" + token_ids: np.ndarray + token_mask: np.ndarray + slot_mapping: Optional[np.ndarray] = None + + request_id: int = field(init=False) + _request_id_counter: int = field(init=False, default=0) + + def __post_init__(self): + self.request_id = KVRequest._request_id_counter + KVRequest._request_id_counter += 1 + + if isinstance(self.token_ids, torch.Tensor): + self.token_ids = self.token_ids.numpy().astype(np.int64) + if isinstance(self.token_mask, torch.Tensor): + self.token_mask = self.token_mask.numpy().astype(np.int64) + if isinstance(self.slot_mapping, torch.Tensor): + self.slot_mapping = self.slot_mapping.numpy().astype(np.int64) + +def generate_random_multiturn(num_user_requests: int, + num_turns: int, + system_prompt_length: int, + input_length: int, + output_length: int, + num_turns_ratio: float = 0.5, + input_length_ratio: float = 0.5, + output_length_ratio: float = 0.5, + seed: int = None) -> List[KVRequest]: + all_requests = [] + token_id_range = 10000 + # Set seed for deterministic generation (useful for cross-node benchmarks) + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) + system_prompt = torch.randint(0, token_id_range, (system_prompt_length,)) + for i in range(num_user_requests): + user_requests = [] + user_num_turns = max(random.randint(int(num_turns_ratio * num_turns), num_turns), 1) + history = system_prompt.clone() + for j in range(user_num_turns): + turn_input_length = random.randint(int(input_length_ratio * input_length), input_length) + turn_output_length = random.randint(int(output_length_ratio * output_length), output_length) + input_tokens = torch.randint(0, token_id_range, (turn_input_length,)) + output_tokens = torch.randint(0, token_id_range, (turn_output_length,)) + request = dict( + user_id=i, + turn_id=j, + input=torch.cat([history, input_tokens], dim=0), + output=output_tokens, + ) + history = torch.cat([history, input_tokens, output_tokens], dim=0) + user_requests.append(request) + all_requests.append(user_requests) + indices = [0] * num_user_requests + kv_requests = [] + while True: + available_lists = [ + i for i in range(num_user_requests) + if indices[i] < len(all_requests[i]) + ] + if not available_lists: + break + user_id = random.choice(available_lists) + request = all_requests[user_id][indices[user_id]] + indices[user_id] += 1 + kv_requests.append(KVRequest( + user_id=request["user_id"], + turn_id=request["turn_id"], + request_type="get", + token_ids=request["input"], + token_mask=torch.ones_like(request["input"]), + )) + kv_requests.append(KVRequest( + user_id=request["user_id"], + turn_id=request["turn_id"], + request_type="put", + token_ids=torch.cat([request["input"], request["output"]], dim=0), + token_mask=torch.ones_like(torch.cat([request["input"], request["output"]], dim=0)), + )) + return kv_requests + +def load_config(config_path: str) -> Tuple[ModelConfig, CacheConfig]: + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + print(config) + model_config = ModelConfig() + cache_config = CacheConfig() + user_config = UserConfig() + model_config.num_layers = config["num_layers"] + model_config.num_kv_heads = config["num_kv_heads"] + model_config.head_size = config["head_size"] + model_config.dtype = eval(f"torch.{config['dtype']}") + model_config.use_mla = config["use_mla"] + model_config.tp_size = config["tp_size"] + model_config.dp_size = config["dp_size"] + cache_config.tokens_per_block = config["tokens_per_block"] + + if "cpu_cache_gb" in config: + user_config.cpu_cache_gb = config["cpu_cache_gb"] + if "ssd_cache_gb" in config: + user_config.ssd_cache_gb = config["ssd_cache_gb"] + if "ssd_cache_dir" in config: + user_config.ssd_cache_dir = parse_path_list(config["ssd_cache_dir"]) + if "enable_gds" in config: + user_config.enable_gds = config["enable_gds"] + update_default_config_from_user_config(model_config, cache_config, user_config) + return model_config, cache_config + +if __name__ == "__main__": + model_config, cache_config = load_config("./benchmarks/example_config.yml") + print(model_config) + print(cache_config) diff --git a/benchmarks/example_dist_config.yml b/benchmarks/example_dist_config.yml new file mode 100644 index 0000000000..0f7a175344 --- /dev/null +++ b/benchmarks/example_dist_config.yml @@ -0,0 +1,34 @@ +# Distributed KVCache benchmark config (server_client_mode) +# Model config +num_layers: 4 +num_kv_heads: 8 +head_size: 128 +dtype: bfloat16 +use_mla: false +tp_size: 1 +dp_size: 1 +tokens_per_block: 16 + +# Cache config +cpu_cache_gb: 4 +ssd_cache_gb: 0 + +# Distributed KVCache config +enable_p2p_cpu: true + +# Redis config (for KV sharing metadata) +redis_host: "10.135.1.175" +redis_port: 6379 +redis_password: "123456" +local_ip: "10.135.1.176" + +# Mooncake Transfer Engine config (required for P2P) +mooncake_engine_ip: "10.135.1.176" +mooncake_engine_port: 5555 +mooncake_metadata_backend: "redis" +mooncake_metadata_server: "redis://10.135.1.175:6379" +mooncake_metadata_server_auth: "123456" +mooncake_protocol: "rdma" # "tcp" or "rdma" +mooncake_device_name: "mlx5_0,mlx5_1,mlx5_4,mlx5_5" # RDMA device name, e.g. "mlx5_0"; leave empty for tcp +# Force server_client_mode +server_client_mode: true diff --git a/benchmarks/redis_check.py b/benchmarks/redis_check.py new file mode 100644 index 0000000000..2c702056f6 --- /dev/null +++ b/benchmarks/redis_check.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +FlexKV Redis Data Inspector + +Check what data the put-only node has pushed to Redis. +This script inspects all FlexKV-related keys in Redis including: + - global:node_id (global node ID counter) + - node: (registered node info) + - meta: (node meta: mooncake engine addr, buffer ptrs) + - buffer::* (RDMA memory region registrations) + - CPUB:: (CPU KVCache block metadata - the actual cached data index) + - SSDB:: (SSD KVCache block metadata) + - PCFSB:: (PCFS remote KVCache block metadata) + - pcfs: (PCFS file node IDs) + - mooncake/* (Mooncake Transfer Engine metadata) + +Usage: + python benchmarks/redis_check.py [--host HOST] [--port PORT] [--password PWD] + + # With defaults from example_dist_config.yml: + python benchmarks/redis_check.py --host 10.135.1.175 --port 6379 --password 123456 +""" + +import argparse +import sys + +try: + import redis +except ImportError: + print("ERROR: redis-py is required. Install with: pip install redis") + sys.exit(1) + + +def connect_redis(host, port, password): + """Connect to Redis and verify connectivity.""" + r = redis.Redis( + host=host, port=port, + password=password if password else None, + decode_responses=True, + socket_connect_timeout=5, + ) + try: + r.ping() + print(f"✅ Connected to Redis at {host}:{port}") + except redis.ConnectionError as e: + print(f"❌ Failed to connect to Redis at {host}:{port}: {e}") + sys.exit(1) + return r + + +def scan_keys(r, pattern, count=1000): + """Scan Redis keys matching pattern (non-blocking).""" + keys = [] + cursor = 0 + while True: + cursor, batch = r.scan(cursor=cursor, match=pattern, count=count) + keys.extend(batch) + if cursor == 0: + break + return sorted(keys) + + +def check_global_node_id(r): + """Check the global node ID counter.""" + print("\n" + "=" * 60) + print(" 1. Global Node ID Counter") + print("=" * 60) + val = r.get("global:node_id") + if val is not None: + print(f" global:node_id = {val}") + print(f" → {val} node(s) have been registered in total") + else: + print(" ⚠️ global:node_id not found (no nodes registered yet)") + + +def check_registered_nodes(r): + """Check registered node information.""" + print("\n" + "=" * 60) + print(" 2. Registered Nodes (node:*)") + print("=" * 60) + keys = scan_keys(r, "node:*") + if not keys: + print(" ⚠️ No registered nodes found") + return + + print(f" Found {len(keys)} registered node(s):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + print(f" {field}: {value}") + print() + + +def check_node_meta(r): + """Check node meta information (mooncake engine addr, buffer ptrs).""" + print("\n" + "=" * 60) + print(" 3. Node Meta (meta:*)") + print("=" * 60) + keys = scan_keys(r, "meta:*") + if not keys: + print(" ⚠️ No node meta found") + print(" → This means PEER2CPUTransferWorker hasn't registered yet,") + print(" or mooncake transfer engine initialization failed.") + return + + print(f" Found {len(keys)} node meta entry(ies):\n") + for key in keys: + data = r.hgetall(key) + print(f" 📌 {key}:") + for field, value in sorted(data.items()): + # Format large integers (pointers) in hex for readability + if field in ("cpu_buffer_ptr", "ssd_buffer_ptr"): + try: + int_val = int(value) + print(f" {field}: {value} (0x{int_val:x})") + except (ValueError, TypeError): + print(f" {field}: {value}") + else: + print(f" {field}: {value}") + print() + + +def check_buffer_registrations(r): + """Check RDMA buffer registrations.""" + print("\n" + "=" * 60) + print(" 4. RDMA Buffer Registrations (buffer:*)") + print("=" * 60) + keys = scan_keys(r, "buffer:*") + if not keys: + print(" ⚠️ No RDMA buffer registrations found") + return + + print(f" Found {len(keys)} buffer registration(s):\n") + for key in keys: + data = r.hgetall(key) + buf_size = data.get("buffer_size", "?") + try: + size_mb = int(buf_size) / (1024 * 1024) + print(f" 📌 {key}: size={buf_size} bytes ({size_mb:.2f} MB)") + except (ValueError, TypeError): + print(f" 📌 {key}: size={buf_size}") + + +def check_block_metadata(r): + """Check KVCache block metadata - this is the core data from put operations. + + FlexKV uses different key prefixes for different device types: + - CPUB:: — CPU block metadata (P2P CPU sharing) + - SSDB:: — SSD block metadata (P2P SSD sharing) + - PCFSB:: — PCFS remote block metadata + Each key is a Redis hash with fields: ph, pb, nid, hash, lt, state. + """ + print("\n" + "=" * 60) + print(" 5. KVCache Block Metadata (CPUB/SSDB/PCFSB)") + print("=" * 60) + + # FlexKV actual block key prefixes (set in hie_cache_engine.py) + block_prefixes = { + "CPUB": "CPU", + "SSDB": "SSD", + "PCFSB": "PCFS (Remote)", + } + + grand_total = 0 + for prefix, label in block_prefixes.items(): + keys = scan_keys(r, f"{prefix}:*") + if not keys: + print(f"\n [{label}] {prefix}:* — no entries found") + continue + + grand_total += len(keys) + + # Group by node_id: key format is PREFIX:: + node_blocks = {} + for key in keys: + parts = key.split(":") + if len(parts) >= 2: + node_id = parts[1] + if node_id not in node_blocks: + node_blocks[node_id] = [] + node_blocks[node_id].append(key) + + print(f"\n [{label}] {prefix}:* — {len(keys)} block(s) across {len(node_blocks)} node(s):") + + for node_id, block_keys in sorted(node_blocks.items(), key=lambda x: int(x[0]) if x[0].isdigit() else 0): + print(f" 📌 Node {node_id}: {len(block_keys)} block(s)") + + # Show first few blocks as samples + sample_count = min(3, len(block_keys)) + for key in block_keys[:sample_count]: + data = r.hgetall(key) + if data: + # BlockMeta fields: ph (physical hash), pb (physical block), + # nid (node id), hash, lt (lease time), state + ph = data.get("ph", "?") + pb = data.get("pb", "?") + nid = data.get("nid", "?") + hash_val = data.get("hash", "?") + lt = data.get("lt", "?") + state = data.get("state", "?") + print(f" {key}: ph={ph}, pb={pb}, nid={nid}, hash={hash_val}, lt={lt}, state={state}") + else: + key_type = r.type(key) + print(f" {key}: type={key_type}, (empty hash)") + + if len(block_keys) > sample_count: + print(f" ... and {len(block_keys) - sample_count} more block(s)") + + if grand_total == 0: + print("\n ⚠️ No block metadata found in any prefix (CPUB/SSDB/PCFSB)") + print(" → This means no KVCache data has been published to Redis yet.") + print(" The put-only node may still be uploading, or the upload") + print(" interval (rebuild_interval_ms) hasn't elapsed yet.") + else: + print(f"\n ✅ Total block metadata entries: {grand_total}") + + +def check_pcfs_data(r): + """Check PCFS file node IDs.""" + print("\n" + "=" * 60) + print(" 6. PCFS File Node IDs (pcfs:*)") + print("=" * 60) + keys = scan_keys(r, "pcfs:*") + if not keys: + print(" (none found - this is normal if PCFS sharing is not used)") + return + + print(f" Found {len(keys)} PCFS entry(ies):\n") + for key in keys: + values = r.lrange(key, 0, -1) + print(f" 📌 {key}: {len(values)} file node ID(s)") + if values: + sample = values[:10] + print(f" sample: {sample}") + if len(values) > 10: + print(f" ... and {len(values) - 10} more") + + +def check_mooncake_keys(r): + """Check Mooncake Transfer Engine related keys.""" + print("\n" + "=" * 60) + print(" 7. Mooncake Transfer Engine Keys") + print("=" * 60) + # Mooncake uses Redis as metadata backend, keys may vary + # Common patterns: segment info, endpoint info + patterns = ["mooncake/*", "mooncake:*", "segment:*", "endpoint:*", "mc:*"] + found_any = False + for pattern in patterns: + keys = scan_keys(r, pattern) + if keys: + found_any = True + print(f"\n Pattern '{pattern}': {len(keys)} key(s)") + for key in keys[:10]: + key_type = r.type(key) + if key_type == "hash": + data = r.hgetall(key) + print(f" 📌 {key} (hash): {data}") + elif key_type == "string": + val = r.get(key) + if val and len(val) > 200: + print(f" 📌 {key} (string): {val[:200]}...") + else: + print(f" 📌 {key} (string): {val}") + elif key_type == "set": + members = r.smembers(key) + print(f" 📌 {key} (set): {members}") + elif key_type == "list": + vals = r.lrange(key, 0, 9) + print(f" 📌 {key} (list): {vals}") + else: + print(f" 📌 {key} (type={key_type})") + if len(keys) > 10: + print(f" ... and {len(keys) - 10} more") + + if not found_any: + print(" (no mooncake-specific keys found)") + + +def check_all_keys_summary(r): + """Show a summary of ALL keys in Redis grouped by prefix.""" + print("\n" + "=" * 60) + print(" 8. All Keys Summary") + print("=" * 60) + all_keys = scan_keys(r, "*") + if not all_keys: + print(" ⚠️ Redis is completely empty!") + return + + print(f" Total keys in Redis: {len(all_keys)}\n") + + # Group by prefix (first part before ':') + prefix_counts = {} + for key in all_keys: + prefix = key.split(":")[0] if ":" in key else key + prefix_counts[prefix] = prefix_counts.get(prefix, 0) + 1 + + print(f" {'Prefix':<30} {'Count':>8}") + print(f" {'-'*30} {'-'*8}") + for prefix, count in sorted(prefix_counts.items(), key=lambda x: -x[1]): + print(f" {prefix:<30} {count:>8}") + + +def main(): + parser = argparse.ArgumentParser( + description="FlexKV Redis Data Inspector - Check put-only node data" + ) + parser.add_argument("--host", type=str, default="10.135.1.175", + help="Redis host (default: 10.135.1.175)") + parser.add_argument("--port", type=int, default=6379, + help="Redis port (default: 6379)") + parser.add_argument("--password", type=str, default="123456", + help="Redis password (default: 123456)") + args = parser.parse_args() + + print("=" * 60) + print(" FlexKV Redis Data Inspector") + print("=" * 60) + print(f" Target: {args.host}:{args.port}") + + r = connect_redis(args.host, args.port, args.password) + + check_global_node_id(r) + check_registered_nodes(r) + check_node_meta(r) + check_buffer_registrations(r) + check_block_metadata(r) + check_pcfs_data(r) + check_mooncake_keys(r) + check_all_keys_summary(r) + + print("\n" + "=" * 60) + print(" Inspection Complete") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_dist_benchmark.sh b/benchmarks/run_dist_benchmark.sh new file mode 100755 index 0000000000..8e64fdac22 --- /dev/null +++ b/benchmarks/run_dist_benchmark.sh @@ -0,0 +1,405 @@ +#!/bin/bash +# ============================================================================= +# FlexKV Distributed KVCache Benchmark - One-Click Launch Script +# +# This script handles: +# 1. Check and start Redis server if not running +# 2. Set up environment variables +# 3. Run the distributed KVCache benchmark +# +# Usage: +# bash benchmarks/run_dist_benchmark.sh [options] +# +# Options (passed through to benchmark_dist_kvcache.py): +# --config Config YAML file (default: benchmarks/example_dist_config.yml) +# --mode Benchmark mode: single, multiturn, or all (default: all) +# --batch-size Batch size (default: 1) +# --sequence-length Sequence length (default: 1024) +# --num-users Number of simulated users (default: 10) +# --num-turns Number of conversation turns (default: 3) +# --clean-redis Clean up FlexKV & Mooncake residual data in Redis before running benchmark +# (removes node:*, meta:*, CPUB:block:*, SSDB:block:*, PCFSB:block:*, +# mooncake/*, mooncake:*, segment:*, endpoint:*, mc:* keys) +# --clean-redis-only Clean up FlexKV & Mooncake residual data in Redis and exit (no benchmark) +# +# Examples: +# # Run with defaults +# bash benchmarks/run_dist_benchmark.sh +# +# # Custom parameters +# bash benchmarks/run_dist_benchmark.sh --batch-size 4 --sequence-length 2048 +# +# # Multi-turn only +# bash benchmarks/run_dist_benchmark.sh --mode multiturn --num-users 20 --num-turns 5 +# +# # Clean Redis residual data before benchmark +# bash benchmarks/run_dist_benchmark.sh --clean-redis +# +# # Only clean Redis residual data (no benchmark) +# bash benchmarks/run_dist_benchmark.sh --clean-redis-only +# ============================================================================= + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +ok() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; } + +# Default config file +CONFIG_FILE="${SCRIPT_DIR}/example_dist_config.yml" +REDIS_STARTED_BY_US=false +CLEAN_REDIS=false +CLEAN_REDIS_ONLY=false + +# Parse script-specific arguments and --config, pass the rest through to benchmark +BENCH_ARGS=() +prev_arg="" +for arg in "$@"; do + if [[ "$prev_arg" == "--config" ]]; then + CONFIG_FILE="$arg" + BENCH_ARGS+=("$arg") + prev_arg="$arg" + continue + fi + case "$arg" in + --clean-redis) + CLEAN_REDIS=true + ;; + --clean-redis-only) + CLEAN_REDIS=true + CLEAN_REDIS_ONLY=true + ;; + *) + BENCH_ARGS+=("$arg") + ;; + esac + prev_arg="$arg" +done + +# ============================================ +# Step 1: Parse Redis config from YAML +# ============================================ +info "============================================" +info "Step 1: Parsing configuration" +info "============================================" + +# Helper function to parse a YAML value using Python (handles comments, quotes, etc. correctly) +# Usage: parse_yaml_value [default] +parse_yaml_value() { + local key="$1" file="$2" default="${3:-}" + local val + val=$(python3 -c " +import yaml, sys +with open('$file') as f: + d = yaml.safe_load(f) +v = d.get('$key') +if v is None: + print('$default') +else: + print(v) +" 2>/dev/null) || val="$default" + echo "$val" +} + +# Simple YAML parser for redis config +REDIS_HOST=$(parse_yaml_value "redis_host" "$CONFIG_FILE" "127.0.0.1") +REDIS_PORT=$(parse_yaml_value "redis_port" "$CONFIG_FILE" "6379") +REDIS_PASSWORD=$(parse_yaml_value "redis_password" "$CONFIG_FILE" "") + +info "Config file: ${CONFIG_FILE}" +info "Redis: ${REDIS_HOST}:${REDIS_PORT}" + +# ============================================ +# Step 2: Check and start Redis +# ============================================ +info "============================================" +info "Step 2: Checking Redis server" +info "============================================" + +check_redis() { + local auth_args="" + if [[ -n "$REDIS_PASSWORD" ]]; then + auth_args="-a $REDIS_PASSWORD" + fi + redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $auth_args ping 2>/dev/null | grep -q "PONG" +} + +# Build redis-cli auth arguments (reused across the script) +REDIS_AUTH_ARGS="" +if [[ -n "$REDIS_PASSWORD" ]]; then + REDIS_AUTH_ARGS="-a $REDIS_PASSWORD" +fi + +if check_redis; then + ok "Redis is already running at ${REDIS_HOST}:${REDIS_PORT}" +else + warn "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + + # Only try to start Redis if it's localhost + if [[ "$REDIS_HOST" == "127.0.0.1" ]] || [[ "$REDIS_HOST" == "localhost" ]]; then + if command -v redis-server &>/dev/null; then + info "Starting Redis server on port ${REDIS_PORT}..." + redis-server --port "$REDIS_PORT" --daemonize yes --save "" --appendonly no \ + --protected-mode no --loglevel warning + sleep 1 + + if check_redis; then + ok "Redis server started successfully" + REDIS_STARTED_BY_US=true + else + error "Failed to start Redis server" + error "Please install Redis: sudo apt install redis-server" + exit 1 + fi + else + error "redis-server not found. Please install Redis:" + error " sudo apt install redis-server" + exit 1 + fi + else + error "Redis is not running at ${REDIS_HOST}:${REDIS_PORT}" + error "Please start Redis on the remote host first." + exit 1 + fi +fi + +# ============================================ +# Step 2.5: Clean FlexKV residual data in Redis (if requested) +# ============================================ +if [[ "$CLEAN_REDIS" == "true" ]]; then + info "============================================" + info "Cleaning FlexKV residual data in Redis" + info "============================================" + + clean_redis_keys() { + local pattern="$1" + local count=0 + local cursor=0 + while true; do + local result + result=$(redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS SCAN $cursor MATCH "$pattern" COUNT 500 2>/dev/null) + cursor=$(echo "$result" | head -1) + local keys + keys=$(echo "$result" | tail -n +2) + if [[ -n "$keys" ]]; then + local batch_keys + batch_keys=$(echo "$keys" | tr '\n' ' ') + if [[ -n "$batch_keys" ]]; then + local deleted + deleted=$(echo "$batch_keys" | xargs redis-cli -h "$REDIS_HOST" -p "$REDIS_PORT" $REDIS_AUTH_ARGS DEL 2>/dev/null) + count=$((count + deleted)) + fi + fi + if [[ "$cursor" == "0" ]]; then + break + fi + done + echo "$count" + } + + total_deleted=0 + + # Clean node:* keys + n=$(clean_redis_keys "node:*") + [[ $n -gt 0 ]] && info "Deleted $n node:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean meta:* keys + n=$(clean_redis_keys "meta:*") + [[ $n -gt 0 ]] && info "Deleted $n meta:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean CPUB:block:* keys + n=$(clean_redis_keys "CPUB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n CPUB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean SSDB:block:* keys + n=$(clean_redis_keys "SSDB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n SSDB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean PCFSB:block:* keys + n=$(clean_redis_keys "PCFSB:block:*") + [[ $n -gt 0 ]] && info "Deleted $n PCFSB:block:* key(s)" + total_deleted=$((total_deleted + n)) + + # Clean Mooncake Transfer Engine residual keys + # Mooncake uses Redis as metadata backend to store segment/endpoint info + for mc_pattern in "mooncake/*" "mooncake:*" "segment:*" "endpoint:*" "mc:*"; do + n=$(clean_redis_keys "$mc_pattern") + [[ $n -gt 0 ]] && info "Deleted $n ${mc_pattern} key(s)" + total_deleted=$((total_deleted + n)) + done + + if [[ $total_deleted -gt 0 ]]; then + ok "Cleaned $total_deleted FlexKV & Mooncake residual key(s) from Redis" + else + ok "No FlexKV residual data found in Redis" + fi + + if [[ "$CLEAN_REDIS_ONLY" == "true" ]]; then + ok "Clean-only mode, exiting." + exit 0 + fi +fi + +# ============================================ +# Step 3: Set up environment +# ============================================ +info "============================================" +info "Step 3: Setting up environment" +info "============================================" + +# Detect Python (prefer virtual env) +if [[ -n "$VIRTUAL_ENV" ]]; then + # Prefer 'which python3' to get the actual resolved path in the activated venv, + # because $VIRTUAL_ENV may point to a path that doesn't match the real filesystem + # (e.g. symlinks, home dir aliases like ~ vs /data1/home). + if command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + else + PYTHON="$VIRTUAL_ENV/bin/python3" + fi + if [[ ! -x "$PYTHON" ]]; then + error "Python3 not found at $PYTHON (VIRTUAL_ENV=$VIRTUAL_ENV)" + exit 1 + fi + info "Using virtual env Python: $PYTHON" +elif command -v python3 &>/dev/null; then + PYTHON="$(which python3)" + info "Using system Python: $PYTHON" +else + error "Python3 not found!" + exit 1 +fi + +# Set PYTHONPATH to include project root +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# Set LD_LIBRARY_PATH for C++ libraries +if [[ -d "${PROJECT_ROOT}/build" ]]; then + export LD_LIBRARY_PATH="${PROJECT_ROOT}/build:${LD_LIBRARY_PATH:-}" +fi + +info "PYTHONPATH=${PYTHONPATH}" +info "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" + +# Generate mooncake config JSON and export MOONCAKE_CONFIG_PATH if P2P is enabled +ENABLE_P2P_CPU=$(grep -E "^enable_p2p_cpu:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") +ENABLE_P2P_SSD=$(grep -E "^enable_p2p_ssd:" "$CONFIG_FILE" 2>/dev/null | awk '{print $2}' || echo "false") + +if [[ "$ENABLE_P2P_CPU" == "true" ]] || [[ "$ENABLE_P2P_SSD" == "true" ]]; then + if [[ -z "${MOONCAKE_CONFIG_PATH:-}" ]]; then + info "P2P enabled, generating mooncake config..." + + # Parse mooncake config from YAML using helper function + MC_ENGINE_IP=$(parse_yaml_value "mooncake_engine_ip" "$CONFIG_FILE") + MC_ENGINE_PORT=$(parse_yaml_value "mooncake_engine_port" "$CONFIG_FILE") + MC_METADATA_BACKEND=$(parse_yaml_value "mooncake_metadata_backend" "$CONFIG_FILE") + MC_METADATA_SERVER=$(parse_yaml_value "mooncake_metadata_server" "$CONFIG_FILE") + MC_METADATA_SERVER_AUTH=$(parse_yaml_value "mooncake_metadata_server_auth" "$CONFIG_FILE") + MC_PROTOCOL=$(parse_yaml_value "mooncake_protocol" "$CONFIG_FILE") + MC_DEVICE_NAME=$(parse_yaml_value "mooncake_device_name" "$CONFIG_FILE") + LOCAL_IP=$(parse_yaml_value "local_ip" "$CONFIG_FILE" "127.0.0.1") + + # Use defaults if not specified + MC_ENGINE_IP="${MC_ENGINE_IP:-$LOCAL_IP}" + MC_ENGINE_PORT="${MC_ENGINE_PORT:-5555}" + MC_METADATA_BACKEND="${MC_METADATA_BACKEND:-redis}" + MC_METADATA_SERVER="${MC_METADATA_SERVER:-redis://${REDIS_HOST}:${REDIS_PORT}}" + MC_PROTOCOL="${MC_PROTOCOL:-tcp}" + MC_DEVICE_NAME="${MC_DEVICE_NAME:-}" + + # Generate JSON config file + MOONCAKE_CONFIG_FILE=$(mktemp /tmp/mooncake_config_XXXXXX.json) + cat > "$MOONCAKE_CONFIG_FILE" </dev/null || true + ok "Redis stopped." +fi + +if [[ $BENCH_EXIT_CODE -eq 0 ]]; then + echo "" + ok "Benchmark completed successfully!" +else + echo "" + error "Benchmark failed with exit code: $BENCH_EXIT_CODE" +fi + +exit $BENCH_EXIT_CODE diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 1ebabc402e..d979a27a07 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -43,9 +43,14 @@ def generate_random_multiturn(num_user_requests: int, output_length: int, num_turns_ratio: float = 0.5, input_length_ratio: float = 0.5, - output_length_ratio: float = 0.5) -> List[KVRequest]: + output_length_ratio: float = 0.5, + seed: int = None) -> List[KVRequest]: all_requests = [] token_id_range = 10000 + # Set seed for deterministic generation (useful for cross-node benchmarks) + if seed is not None: + random.seed(seed) + torch.manual_seed(seed) system_prompt = torch.randint(0, token_id_range, (system_prompt_length,)) for i in range(num_user_requests): user_requests = [] diff --git a/flexkv/cache/hie_cache_engine.py b/flexkv/cache/hie_cache_engine.py index 838663fdc8..a4acafdb26 100644 --- a/flexkv/cache/hie_cache_engine.py +++ b/flexkv/cache/hie_cache_engine.py @@ -31,7 +31,7 @@ def __init__(self, remote_max_num_blocks: int = 4000000, redis_node_id: int = 0, remote_refresh_batch_size: int = 1000, - remote_rebuild_interval_ms: int = 10000, + remote_rebuild_interval_ms: int = 100, remote_idle_sleep_ms: int = 10, local_safety_ttl_ms: int = 100, evict_start_threshold: float = 1.0, diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index 8a2c9ace6b..536d58c3ab 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -137,8 +137,12 @@ def delete_blockmeta_batch(self, node_id: int, hashes: Iterable[int], batch_size class RedisNodeInfo: """Redis node information management class implemented in Python""" + + # Default TTL for node: key in seconds. Active nodes renew before expiry. + # If a process crashes (kill -9), the key auto-expires after this period. + DEFAULT_NODE_TTL_SECONDS: int = 30 - def __init__(self, host: str, port: int, local_ip: str, password: str = "") -> None: + def __init__(self, host: str, port: int, local_ip: str, password: str = "", node_ttl_seconds: int = 0) -> None: if _redis is None: raise ImportError("redis-py is required: pip install redis") self.host = host @@ -146,9 +150,14 @@ def __init__(self, host: str, port: int, local_ip: str, password: str = "") -> N self.local_ip = str(local_ip) self.password = str(password) self.uuid = str(uuid1()) + # Use provided TTL or fall back to default + self.node_ttl_seconds: int = node_ttl_seconds if node_ttl_seconds > 0 else self.DEFAULT_NODE_TTL_SECONDS + # Heartbeat interval – renew TTL at roughly 1/3 of the TTL period + self.heartbeat_interval_seconds: float = max(1.0, self.node_ttl_seconds / 3.0) self._node_id: Optional[int] = None self._running = False self._listener_thread: Optional[threading.Thread] = None + self._heartbeat_thread: Optional[threading.Thread] = None self.current_node_id_set: set = set() self._client: Optional["_redis.Redis"] = None self._sub_client: Optional["_redis.Redis"] = None @@ -179,7 +188,7 @@ def _get_client(self) -> "_redis.Redis": ) def connect(self) -> bool: - """Connect to Redis and start listener thread""" + """Connect to Redis and start listener + heartbeat threads""" try: self._client = self._get_client() # Test connection @@ -193,17 +202,29 @@ def connect(self) -> bool: daemon=True ) self._listener_thread.start() + + # Start heartbeat thread for TTL renewal + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_worker, + name="redis-node-heartbeat", + daemon=True + ) + self._heartbeat_thread.start() return True except Exception: return False def disconnect(self) -> None: - """Disconnect from Redis and stop listener thread""" + """Disconnect from Redis and stop listener + heartbeat threads""" self._running = False if self._listener_thread and self._listener_thread.is_alive(): self._listener_thread.join(timeout=2.0) self._listener_thread = None + + if self._heartbeat_thread and self._heartbeat_thread.is_alive(): + self._heartbeat_thread.join(timeout=2.0) + self._heartbeat_thread = None if self._client: self._client.close() @@ -241,11 +262,14 @@ def _cleanup(self) -> None: pass def register_node(self) -> Optional[int]: - """Register a new node and get node_id""" + """Register a new node and get node_id, with TTL for automatic expiry on crash""" if not self._client: return None try: + # Clean up stale nodes from the same IP before registering + self._cleanup_stale_nodes_by_ip() + # Atomically increment global:node_id to get new node_id node_id = self._client.incr("global:node_id") self._node_id = node_id @@ -262,6 +286,9 @@ def register_node(self) -> Optional[int]: "pp_rank": str(getattr(self, 'pp_rank', 0)), "pp_size": str(getattr(self, 'pp_size', 1)), }) + + # Set TTL so the key auto-expires if the process crashes + self._client.expire(node_key, self.node_ttl_seconds) # Publish node update event self._client.publish("flexkv_node_id_updated", str(node_id)) @@ -271,17 +298,22 @@ def register_node(self) -> Optional[int]: return None def unregister_node(self) -> bool: - """Unregister current node""" + """Unregister current node and clean up associated meta/block data""" if not self._client or self._node_id is None: return False try: + node_id = self._node_id + # Delete node:node_id key - node_key = f"node:{self._node_id}" + node_key = f"node:{node_id}" self._client.delete(node_key) + + # Also clean up meta: to prevent stale RDMA addresses + self._cleanup_node_data(node_id) # Publish node update event - self._client.publish("flexkv_node_id_updated", str(self._node_id)) + self._client.publish("flexkv_node_id_updated", str(node_id)) self._node_id = None return True @@ -305,6 +337,48 @@ def is_node_active(self, node_id: int) -> bool: """Check if a node_id is active - lock-free RCU check""" return node_id in self.current_node_id_set + def _heartbeat_worker(self) -> None: + """Background thread that periodically renews the TTL of node: key. + + This ensures that if the process is alive, the node key never expires. + If the process crashes (kill -9), the TTL will not be renewed and the + key will auto-expire after NODE_TTL_SECONDS, allowing other nodes to + detect the crash and stop using stale meta/block data. + """ + heartbeat_client: Optional["_redis.Redis"] = None + while self._running: + try: + if heartbeat_client is None: + heartbeat_client = self._get_client() + + if self._node_id is not None: + node_key = f"node:{self._node_id}" + # Renew TTL + heartbeat_client.expire(node_key, self.node_ttl_seconds) + # Also update the timestamp field + heartbeat_client.hset(node_key, "timestamp", str(int(time.time()))) + + except Exception: + # Connection lost, reset client so it reconnects next iteration + if heartbeat_client: + try: + heartbeat_client.close() + except Exception: + pass + heartbeat_client = None + + # Sleep in small increments so we can exit quickly when _running becomes False + for _ in range(int(self.heartbeat_interval_seconds * 10)): + if not self._running: + break + time.sleep(0.1) + + if heartbeat_client: + try: + heartbeat_client.close() + except Exception: + pass + def _listener_worker(self) -> None: """Background thread that listens for node updates""" backoff = 0.5 @@ -346,6 +420,10 @@ def scan_active_nodes(self) -> None: This method can be called externally to manually refresh the active nodes list. It uses SCAN to avoid blocking Redis server. + + Because node: keys now have a TTL (heartbeat), expired keys are + automatically removed by Redis. SCAN will only return keys that are + still alive, so stale/crashed nodes are naturally excluded. """ if not self._client: return @@ -369,6 +447,15 @@ def scan_active_nodes(self) -> None: if cursor == 0: break + # Detect nodes that disappeared (TTL expired or unregistered) + disappeared = self.current_node_id_set - new_active_nodes + if disappeared: + # Clean up meta and block data for disappeared nodes + for stale_nid in disappeared: + if stale_nid == self._node_id: + continue # Don't clean up ourselves + self._cleanup_node_data(stale_nid) + # lock-free RCU switch: atomic assignment self.current_node_id_set = new_active_nodes @@ -376,10 +463,97 @@ def scan_active_nodes(self) -> None: # If scan fails, continue with current active nodes pass + def _cleanup_stale_nodes_by_ip(self) -> None: + """Clean up stale node registrations from the same IP. + + On startup, scan all node:* keys and remove those that have the same + local_ip but a different UUID (i.e. leftover from a previous crashed process). + """ + if not self._client: + return + + try: + cursor = 0 + stale_node_ids = [] + + while True: + cursor, keys = self._client.scan(cursor=cursor, match="node:*", count=100) + for key in keys: + if not key.startswith("node:"): + continue + try: + nid = int(key[5:]) + except (ValueError, IndexError): + continue + + data = self._client.hgetall(key) + node_ip = data.get("ip", "") or data.get("local_ip", "") + node_uuid = data.get("uuid", "") + + # Same IP but different UUID → stale node from a previous process + if node_ip == self.local_ip and node_uuid != self.uuid: + stale_node_ids.append(nid) + + if cursor == 0: + break + + for stale_nid in stale_node_ids: + print(f"[RedisNodeInfo] Cleaning up stale node:{stale_nid} (same IP={self.local_ip}, different UUID)") + self._client.delete(f"node:{stale_nid}") + self._cleanup_node_data(stale_nid) + + if stale_node_ids: + # Notify other nodes about the cleanup + self._client.publish("flexkv_node_id_updated", "cleanup") + + except Exception: + pass + + def _cleanup_node_data(self, node_id: int) -> None: + """Clean up meta: and CPUB/SSDB/PCFSB block keys for a given node. + + This is called when: + 1. A node is unregistered (graceful shutdown) + 2. A stale node is detected (TTL expired / startup cleanup) + """ + if not self._client: + return + + try: + # Delete meta: (and meta::pp* for pipeline parallel) + cursor = 0 + meta_keys = [] + while True: + cursor, keys = self._client.scan(cursor=cursor, match=f"meta:{node_id}*", count=100) + meta_keys.extend(keys) + if cursor == 0: + break + if meta_keys: + self._client.delete(*meta_keys) + print(f"[RedisNodeInfo] Deleted {len(meta_keys)} meta key(s) for node {node_id}") + + # Delete CPUB:block::* / SSDB:block::* / PCFSB:block::* keys + for prefix in ("CPUB", "SSDB", "PCFSB"): + cursor = 0 + block_keys = [] + while True: + cursor, keys = self._client.scan(cursor=cursor, match=f"{prefix}:block:{node_id}:*", count=500) + block_keys.extend(keys) + if cursor == 0: + break + if block_keys: + # Delete in batches to avoid blocking Redis + batch_size = 500 + for i in range(0, len(block_keys), batch_size): + self._client.delete(*block_keys[i:i + batch_size]) + print(f"[RedisNodeInfo] Deleted {len(block_keys)} {prefix}:block key(s) for node {node_id}") + + except Exception as e: + print(f"[RedisNodeInfo] Warning: failed to clean up data for node {node_id}: {e}") class RedisMeta: - def __init__(self, host: str, port: int, password: Optional[str] = None, local_ip: str = "127.0.0.1", decode_responses: bool = True) -> None: + def __init__(self, host: str, port: int, password: Optional[str] = None, local_ip: str = "127.0.0.1", decode_responses: bool = True, node_ttl_seconds: int = 0) -> None: if _redis is None: # pragma: no cover raise ImportError("redis-py is required: pip install redis") self.host = host @@ -396,7 +570,7 @@ def __init__(self, host: str, port: int, password: Optional[str] = None, local_i self._init_error: Optional[Exception] = None # create RedisNodeInfo object - self.nodeinfo = RedisNodeInfo(host, port, local_ip, password or "") + self.nodeinfo = RedisNodeInfo(host, port, local_ip, password or "", node_ttl_seconds=node_ttl_seconds) # get UUID via nodeinfo self._uuid = self.nodeinfo.get_uuid() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 650e7ea347..c56712989f 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -87,6 +87,12 @@ class CacheConfig: redis_port: int = 6379 local_ip: str = "127.0.0.1" redis_password: Optional[str] = None + # TTL (seconds) for node: key in Redis. Active nodes renew via heartbeat. + # If a process crashes, the key auto-expires after this period. + node_ttl_seconds: int = 30 + + # Mooncake transfer engine config path (serialized via pickle to survive spawn subprocesses) + mooncake_config_path: Optional[str] = None def __post_init__(self): self.enable_kv_sharing = self.enable_p2p_cpu or \ @@ -143,7 +149,7 @@ def __post_init__(self): lt_pool_initial_capacity=int(os.getenv('FLEXKV_LT_POOL_INITIAL_CAPACITY', 10000000)), refresh_batch_size=int(os.getenv('FLEXKV_REFRESH_BATCH_SIZE', 256)), - rebuild_interval_ms=int(os.getenv('FLEXKV_REBUILD_INTERVAL_MS', 10000)), + rebuild_interval_ms=int(os.getenv('FLEXKV_REBUILD_INTERVAL_MS', 100)), idle_sleep_ms=int(os.getenv('FLEXKV_IDLE_SLEEP_MS', 10)), lease_ttl_ms=int(os.getenv('FLEXKV_LEASE_TTL_MS', 30000)), safety_ttl_ms=int(os.getenv('FLEXKV_SAFETY_TTL_MS', 100)), @@ -168,6 +174,7 @@ class UserConfig: redis_port: Optional[int] = None local_ip: Optional[str] = None redis_password: Optional[str] = None + node_ttl_seconds: Optional[int] = None kv_cache_dtype: Optional[str] = None # Override kv_cache_dtype when TRT config uses "auto". Supported values: "fp8", "float8", "e4m3", "fp16", "float16", "bf16", "bfloat16", "fp32", "float32" def __post_init__(self): @@ -267,6 +274,8 @@ def update_default_config_from_user_config(model_config: ModelConfig, cache_config.local_ip = user_config.local_ip if user_config.redis_password is not None: cache_config.redis_password = user_config.redis_password + if user_config.node_ttl_seconds is not None: + cache_config.node_ttl_seconds = user_config.node_ttl_seconds global_config_attrs = set(vars(GLOBAL_CONFIG_FROM_ENV).keys()) for attr_name in dir(user_config): diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index cd4553526b..4b054f575b 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -329,6 +329,7 @@ def flush(self): def __del__(self): """Ensure all records are flushed when tracer is destroyed""" - from contextlib import suppress - with suppress(Exception): + try: self.flush() + except Exception: + pass diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 280242f0aa..2375948cda 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -67,24 +67,10 @@ def __init__(self, flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") self.redis_meta_client = None - if self.cache_config.enable_kv_sharing: - flexkv_logger.info(f"[kv manager] initializing RedisMeta and connection to \ - {self.cache_config.redis_host}:{self.cache_config.redis_port}") - # initialize redis Meta obj - self.redis_meta_client = RedisMeta( - self.cache_config.redis_host, - self.cache_config.redis_port, - self.cache_config.redis_password, - self.cache_config.local_ip, - ) - self.redis_meta_client.init_meta() - # update distributed_node_id - self.cache_config.distributed_node_id = self.redis_meta_client.get_node_id() # update distributed_node_id of current node - - self.enable_mps = GLOBAL_CONFIG_FROM_ENV.enable_mps if self.server_client_mode: + # In server_client_mode, RedisMeta is created and initialized inside KVServer # Server should only be created once across all instances and dp ranks if self.instance_id == 0 and dp_client_id == 0: total_clients = self.instance_num * model_config.dp_size @@ -99,6 +85,21 @@ def __init__(self, self.server_handle = None self.dp_client = KVDPClient(self.server_recv_port, self.model_config, self.global_client_id) else: + # In non-server_client_mode, create RedisMeta here and pass to KVTaskEngine + if self.cache_config.enable_kv_sharing: + flexkv_logger.info(f"[kv manager] initializing RedisMeta and connection to " + f"{self.cache_config.redis_host}:{self.cache_config.redis_port}") + self.redis_meta_client = RedisMeta( + self.cache_config.redis_host, + self.cache_config.redis_port, + self.cache_config.redis_password, + self.cache_config.local_ip, + node_ttl_seconds=self.cache_config.node_ttl_seconds, + ) + self.redis_meta_client.init_meta() + # update distributed_node_id + self.cache_config.distributed_node_id = self.redis_meta_client.get_node_id() + self.server_handle = None self.kv_task_engine = KVTaskEngine(self.model_config, self.cache_config, self.gpu_register_port, redis_meta=self.redis_meta_client, event_collector=event_collector) @@ -127,6 +128,10 @@ def is_ready(self) -> bool: def shutdown(self) -> None: if self.server_client_mode: self.dp_client.shutdown() + # Wait for the server process to exit after sending shutdown request + if self.server_handle is not None: + self.server_handle.shutdown() + self.server_handle = None else: self.kv_task_engine.shutdown() diff --git a/flexkv/mooncakeEngineWrapper.py b/flexkv/mooncakeEngineWrapper.py index bcb080f44b..bc08255e68 100644 --- a/flexkv/mooncakeEngineWrapper.py +++ b/flexkv/mooncakeEngineWrapper.py @@ -31,7 +31,12 @@ def __init__( ) if config is None: - mooncake_config_path = os.environ["MOONCAKE_CONFIG_PATH"] + mooncake_config_path = os.environ.get("MOONCAKE_CONFIG_PATH") + if mooncake_config_path is None: + raise RuntimeError( + "MOONCAKE_CONFIG_PATH is not set. Please set the MOONCAKE_CONFIG_PATH " + "environment variable or pass a MooncakeTransferEngineConfig object." + ) self.config = MooncakeTransferEngineConfig.from_file(mooncake_config_path) else: self.config = config @@ -51,13 +56,18 @@ def __init__( # transfer engine initialize self.engine = TransferEngine() + # Set Redis auth env vars for mooncake engine (it reads MC_REDIS_PASSWORD internally) + if self.config.metadata_server_auth: + os.environ["MC_REDIS_PASSWORD"] = self.config.metadata_server_auth + flexkv_logger.info("Set MC_REDIS_PASSWORD environment variable for mooncake Redis authentication") + self.engine.initialize_ext( self.mooncake_addr, self.config.metadata_server, self.config.protocol, self.config.device_name, self.metadata_backend, - ) + ) # mooncake operations def regist_buffer(self, buffer_ptr: int, buffer_size: int) -> int: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 2991ce4e3d..e5f808c697 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -16,6 +16,7 @@ from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger +from flexkv.cache.redis_meta import RedisMeta from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.kvtask import KVTaskEngine @@ -111,15 +112,31 @@ class KVServerHandle: def __init__(self, process: Union[mp.Process, 'subprocess.Popen']): self.process = process + def _is_alive(self) -> bool: + """Check if the process is still running (compatible with both Process and Popen).""" + if isinstance(self.process, subprocess.Popen): + return self.process.poll() is None + return self.process.is_alive() + + def _join(self, timeout: float = None) -> None: + """Wait for the process to finish (compatible with both Process and Popen).""" + if isinstance(self.process, subprocess.Popen): + try: + self.process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + pass + else: + self.process.join(timeout=timeout) + def shutdown(self) -> None: - self.process.join(timeout=5) - if self.process.is_alive(): + self._join(timeout=5) + if self._is_alive(): flexkv_logger.info("force terminate the server process") self.process.terminate() - self.process.join() + self._join() def __del__(self) -> None: - if self.process.is_alive(): + if self._is_alive(): self.shutdown() class KVServer: @@ -141,7 +158,24 @@ def __init__( # Use total_clients if provided (multi-instance mode), otherwise use dp_size max_clients = total_clients if total_clients > 0 else model_config.dp_size self.client_manager = ClientManager(max_num_dp_client=max_clients) - self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port) + + # Initialize RedisMeta in KVServer for server_client_mode + self.redis_meta_client = None + if cache_config.enable_kv_sharing: + flexkv_logger.info(f"[kv server] initializing RedisMeta and connection to " + f"{cache_config.redis_host}:{cache_config.redis_port}") + self.redis_meta_client = RedisMeta( + cache_config.redis_host, + cache_config.redis_port, + cache_config.redis_password, + cache_config.local_ip, + node_ttl_seconds=cache_config.node_ttl_seconds, + ) + self.redis_meta_client.init_meta() + # update distributed_node_id + cache_config.distributed_node_id = self.redis_meta_client.get_node_id() + + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, redis_meta=self.redis_meta_client) self.req_counter = 0 self._is_ready = False @@ -209,6 +243,12 @@ def create_server(cls, env.update(child_env) else: env = child_env or {} + # Always propagate FLEXKV_* env vars to child process so that + # runtime config overrides (e.g. FLEXKV_REBUILD_INTERVAL_MS) + # are visible when config.py is re-imported in the subprocess. + for key, val in os.environ.items(): + if key.startswith("FLEXKV_") and key not in env: + env[key] = val # Remove CUDA_VISIBLE_DEVICES so server can see all GPUs env.pop('CUDA_VISIBLE_DEVICES', None) @@ -285,8 +325,8 @@ def run(self) -> None: # Cleanup after shutdown flexkv_logger.info("Server shutting down, cleaning up...") - if hasattr(self, 'kvmanager'): - self.kvmanager.shutdown() + if hasattr(self, 'kv_task_engine'): + self.kv_task_engine.shutdown() flexkv_logger.info("Server shutdown complete") @@ -295,7 +335,6 @@ def _verify_model_config(self, model_config: ModelConfig) -> None: for field in fields(ModelConfig): client_val = getattr(model_config, field.name) server_val = getattr(self.model_config, field.name) - print(f"ModelConfig.{field.name} mismatch: client={client_val}, server={server_val}") assert client_val == server_val, \ f"ModelConfig.{field.name} mismatch: client={client_val}, server={server_val}" diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index a687720296..649caaf80f 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -371,7 +371,8 @@ def _init_workers(self) -> None: cache_config = self.cache_config, ssd_kv_layout = self._ssd_handle.kv_layout if self._ssd_handle else None, ssd_files = self._ssd_handle.get_file_list() if self._ssd_handle else None, - num_blocks_per_file = self._ssd_handle.num_blocks_per_file if self._ssd_handle else None + num_blocks_per_file = self._ssd_handle.num_blocks_per_file if self._ssd_handle else None, + mooncake_config_path = getattr(self.cache_config, 'mooncake_config_path', None) or os.environ.get("MOONCAKE_CONFIG_PATH"), ) # NOTE: now peerH2H and peerSSD2H op use the same worker if self.cache_config.enable_p2p_cpu: diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 13980932cf..0933efce6a 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -1329,6 +1329,7 @@ def __init__(self, ssd_kv_layout: KVCacheLayout = None, ssd_files: Dict[int, List[str]] = None, # ssd_device_id -> file_paths num_blocks_per_file: int = 0, + mooncake_config_path: str = None, ): super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) @@ -1361,6 +1362,7 @@ def __init__(self, self.cache_config.redis_port, self.cache_config.redis_password, self.cache_config.local_ip, + node_ttl_seconds=self.cache_config.node_ttl_seconds, ) self.redis_meta_client.set_node_id(self.cache_config.distributed_node_id) @@ -1372,8 +1374,18 @@ def __init__(self, # step2: initialize mooncake transfer engine for the whole flexkv - # NOTE:now we read the config file by env paras - mooncake_config_path = os.environ["MOONCAKE_CONFIG_PATH"] + # NOTE: prefer explicit parameter > cache_config > env variable + # (spawn subprocesses may lose env vars, but cache_config is pickle-serialized) + if mooncake_config_path is None: + mooncake_config_path = getattr(self.cache_config, 'mooncake_config_path', None) + if mooncake_config_path is None: + mooncake_config_path = os.environ.get("MOONCAKE_CONFIG_PATH") + if mooncake_config_path is None: + raise RuntimeError( + "MOONCAKE_CONFIG_PATH is not set. Please either pass mooncake_config_path " + "parameter, set cache_config.mooncake_config_path, or set the " + "MOONCAKE_CONFIG_PATH environment variable." + ) self.mooncake_config = MooncakeTransferEngineConfig.from_file( mooncake_config_path ) @@ -2114,8 +2126,27 @@ def unregist_node_meta(self, node_id: int = None) -> None: flexkv_logger.info(f"Unregistered node {self.redis_meta_client.get_node_id()} from Redis.") def get_node_meta(self, node_id: int) -> Optional[NodeMetaInfo]: - # TODO: how to remove the invalid node meta info in node_metas - """Get the node meta info by node id.""" + """Get the node meta info by node id. + + Before returning cached or freshly-fetched meta, we verify that the + node is still active (its node: key exists in Redis and has not + expired). This prevents RDMA transfers to stale addresses after a + remote node has crashed. + """ + # ===== Active-node validation (Scheme 4) ===== + if not self.redis_meta_client.is_node_active(node_id): + # Node is no longer active – purge cached meta if any + if node_id in self.node_metas: + del self.node_metas[node_id] + flexkv_logger.warning( + f"Node {node_id} is no longer active, removed cached meta." + ) + else: + flexkv_logger.warning( + f"Node {node_id} is not active, skipping meta fetch." + ) + return None + if node_id not in self.node_metas: ## fetch from redis node_redis_data = self.redis_meta_client.get_node_meta(node_id) diff --git a/install.sh b/install.sh new file mode 100755 index 0000000000..4f5d126b03 --- /dev/null +++ b/install.sh @@ -0,0 +1,549 @@ +#!/bin/bash +# ============================================================================= +# FlexKV One-Click Install Script +# ============================================================================= +# Usage: +# bash install.sh [OPTIONS] +# +# Options: +# --venv PATH Specify virtual environment path (default: ./venv) +# --no-venv Skip virtual environment creation, install directly +# --release Build in release mode (with Cython compilation) +# --debug Build in debug mode (default, no Cython) +# --enable-metrics Enable Prometheus monitoring support +# --enable-p2p Enable distributed P2P/Redis support (default: enabled) +# --disable-p2p Disable distributed P2P/Redis support +# --mooncake-version VER Mooncake release tag to build from source (default: latest main branch) +# --enable-gds Enable GDS support +# --enable-cfs Enable CFS support +# --skip-deps Skip system dependency installation +# --clean Clean all build artifacts and exit +# -h, --help Show this help message +# ============================================================================= +set -e + +# ======================== Default Configuration ======================== +VENV_PATH="./venv" +USE_VENV=1 +BUILD_TYPE="debug" +ENABLE_METRICS=0 +ENABLE_P2P=1 +ENABLE_GDS=0 +ENABLE_CFS=0 +SKIP_DEPS=0 +CLEAN_ONLY=0 +MOONCAKE_VERSION="" + +# Use sudo only if not running as root +if [ "$(id -u)" -eq 0 ]; then + SUDO="" +else + SUDO="sudo" +fi + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# ======================== Helper Functions ======================== +info() { echo -e "${BLUE}[INFO]${NC} $*"; } +success() { echo -e "${GREEN}[OK]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; exit 1; } + +usage() { + head -n 17 "$0" | tail -n 14 | sed 's/^# \?//' + exit 0 +} + +# ======================== Parse Arguments ======================== +while [[ $# -gt 0 ]]; do + case "$1" in + --venv) + VENV_PATH="$2" + USE_VENV=1 + shift 2 + ;; + --no-venv) + USE_VENV=0 + shift + ;; + --release) + BUILD_TYPE="release" + shift + ;; + --debug) + BUILD_TYPE="debug" + shift + ;; + --enable-metrics) + ENABLE_METRICS=1 + shift + ;; + --enable-p2p) + ENABLE_P2P=1 + shift + ;; + --disable-p2p) + ENABLE_P2P=0 + shift + ;; + --mooncake-version) + MOONCAKE_VERSION="$2" + shift 2 + ;; + --enable-gds) + ENABLE_GDS=1 + shift + ;; + --enable-cfs) + ENABLE_CFS=1 + shift + ;; + --skip-deps) + SKIP_DEPS=1 + shift + ;; + --clean) + CLEAN_ONLY=1 + shift + ;; + -h|--help) + usage + ;; + *) + warn "Unknown option: $1" + shift + ;; + esac +done + +# ======================== Project Root ======================== +PROJECT_ROOT="$(cd "$(dirname "$0")" && pwd)" +cd "$PROJECT_ROOT" +info "Project root: $PROJECT_ROOT" + +# ======================== Clean Mode ======================== +if [ "$CLEAN_ONLY" -eq 1 ]; then + info "Cleaning all build artifacts..." + bash build.sh --clean + if [ -d "$VENV_PATH" ]; then + rm -rf "$VENV_PATH" + info "Removed virtual environment: $VENV_PATH" + fi + success "Clean completed." + exit 0 +fi + +# ======================== Step 1: Check System Dependencies ======================== +info "============================================" +info "Step 1: Checking system dependencies" +info "============================================" + +check_command() { + if command -v "$1" &>/dev/null; then + success "$1 found: $(command -v "$1")" + return 0 + else + warn "$1 not found" + return 1 + fi +} + +MISSING_CMDS=() +MISSING_PKGS=() + +# Check essential commands +check_command python3 || MISSING_CMDS+=("python3") +check_command cmake || { MISSING_CMDS+=("cmake"); MISSING_PKGS+=("cmake"); } +check_command git || { MISSING_CMDS+=("git"); MISSING_PKGS+=("git"); } +check_command gcc || { MISSING_CMDS+=("gcc"); MISSING_PKGS+=("build-essential"); } +check_command g++ || { MISSING_CMDS+=("g++"); MISSING_PKGS+=("build-essential"); } + +# Check python3-venv availability (test with a real temporary venv to catch missing ensurepip) +if [ "$USE_VENV" -eq 1 ]; then + _VENV_TEST_DIR=$(mktemp -d) + if ! python3 -m venv "$_VENV_TEST_DIR/test_venv" &>/dev/null 2>&1; then + rm -rf "$_VENV_TEST_DIR" + warn "python3-venv not available (ensurepip missing)" + PY_MINOR=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + MISSING_PKGS+=("python3.${PY_MINOR#3.}-venv" "python3-venv" "python3-full") + else + rm -rf "$_VENV_TEST_DIR" + fi +fi + +# Check for liburing-dev (required by setup.py: -luring) +if ! dpkg -s liburing-dev &>/dev/null 2>&1 && ! rpm -q liburing-devel &>/dev/null 2>&1; then + if [ -f /etc/debian_version ]; then + warn "liburing-dev not found" + MISSING_PKGS+=("liburing-dev") + elif [ -f /etc/redhat-release ]; then + warn "liburing-devel not found" + MISSING_PKGS+=("liburing-devel") + fi +fi + +# Check for hiredis if P2P enabled +if [ "$ENABLE_P2P" -eq 1 ]; then + if ! dpkg -s libhiredis-dev &>/dev/null 2>&1 && ! rpm -q hiredis-devel &>/dev/null 2>&1; then + if [ -f /etc/debian_version ]; then + MISSING_PKGS+=("libhiredis-dev") + elif [ -f /etc/redhat-release ]; then + MISSING_PKGS+=("hiredis-devel") + fi + fi +fi + +# Install missing packages +if [ ${#MISSING_PKGS[@]} -gt 0 ] && [ "$SKIP_DEPS" -eq 0 ]; then + # Deduplicate + UNIQUE_PKGS=($(echo "${MISSING_PKGS[@]}" | tr ' ' '\n' | sort -u | tr '\n' ' ')) + info "Installing missing packages: ${UNIQUE_PKGS[*]}" + + if command -v apt-get &>/dev/null; then + $SUDO apt-get update -qq + $SUDO apt-get install -y -qq "${UNIQUE_PKGS[@]}" + elif command -v yum &>/dev/null; then + $SUDO yum install -y "${UNIQUE_PKGS[@]}" + elif command -v dnf &>/dev/null; then + $SUDO dnf install -y "${UNIQUE_PKGS[@]}" + else + error "Cannot auto-install packages. Please manually install: ${UNIQUE_PKGS[*]}" + fi + success "System packages installed." +elif [ ${#MISSING_PKGS[@]} -gt 0 ] && [ "$SKIP_DEPS" -eq 1 ]; then + warn "Skipping dependency installation (--skip-deps). Missing: ${MISSING_PKGS[*]}" +fi + +# Final check for critical commands +for cmd in python3 cmake git gcc g++; do + command -v "$cmd" &>/dev/null || error "$cmd is still not available. Please install it manually." +done + +# Check NVIDIA CUDA toolkit +if ! command -v nvcc &>/dev/null; then + warn "nvcc not found. CUDA toolkit is required for building FlexKV." + warn "Please install CUDA toolkit from: https://developer.nvidia.com/cuda-downloads" + warn "Or load it via: module load cuda" +fi + +success "System dependencies check passed." + +# ======================== Step 2: Setup Python Virtual Environment ======================== +info "============================================" +info "Step 2: Setting up Python environment" +info "============================================" + +if [ "$USE_VENV" -eq 1 ]; then + if [ -d "$VENV_PATH" ] && [ -f "$VENV_PATH/bin/activate" ]; then + info "Using existing virtual environment: $VENV_PATH" + else + info "Creating virtual environment at: $VENV_PATH" + python3 -m venv "$VENV_PATH" + success "Virtual environment created." + fi + + # Activate virtual environment + source "$VENV_PATH/bin/activate" + success "Virtual environment activated: $(which python3)" + + # Upgrade pip + info "Upgrading pip..." + pip install --upgrade pip -q +else + warn "Skipping virtual environment (--no-venv). Installing to system Python." + warn "If you encounter 'externally-managed-environment' error, use --venv instead." +fi + +# Install Python build dependencies +info "Installing Python build dependencies..." +pip install -q setuptools wheel +if [ "$BUILD_TYPE" = "release" ]; then + pip install -q "Cython>=3.0.10" +fi + +# Check if torch is installed +if ! python3 -c "import torch" &>/dev/null 2>&1; then + warn "PyTorch not found. Installing PyTorch..." + warn "If you need a specific CUDA version, please install PyTorch manually first." + pip install torch +fi +success "Python environment ready." + +# ======================== Step 3: Initialize Git Submodules ======================== +info "============================================" +info "Step 3: Initializing git submodules" +info "============================================" + +if [ "$ENABLE_METRICS" -eq 1 ]; then + info "Metrics enabled: initializing all submodules (including prometheus-cpp)..." + git submodule update --init --recursive +else + info "Metrics disabled: initializing only xxHash submodule..." + git submodule update --init --recursive third_party/xxHash +fi +success "Git submodules initialized." + +# ======================== Step 4: Build C++ Libraries ======================== +info "============================================" +info "Step 4: Building C++ libraries (CMake)" +info "============================================" + +mkdir -p build +cd build + +CMAKE_ARGS="" +if [ "$ENABLE_METRICS" -eq 0 ]; then + CMAKE_ARGS="-DFLEXKV_ENABLE_MONITORING=OFF" +fi + +info "Running CMake configuration..." +cmake .. $CMAKE_ARGS + +info "Building C++ libraries..." +cmake --build . -j"$(nproc)" + +BUILD_LIB_PATH="$(pwd)/lib" +cd "$PROJECT_ROOT" + +# Set LD_LIBRARY_PATH +export LD_LIBRARY_PATH="$BUILD_LIB_PATH:$LD_LIBRARY_PATH" + +# Copy shared libraries to package directory +info "Copying shared libraries to package directory..." +PACKAGE_LIB_DIR="flexkv/lib" +mkdir -p "$PACKAGE_LIB_DIR" +if [ -d "$BUILD_LIB_PATH" ]; then + for lib_file in "$BUILD_LIB_PATH"/*.so*; do + if [ -f "$lib_file" ]; then + cp "$lib_file" "$PACKAGE_LIB_DIR/" + fi + done +fi +success "C++ libraries built successfully." + +# ======================== Step 4.5: Install Python Runtime Dependencies ======================== +info "============================================" +info "Step 4.5: Installing Python runtime dependencies" +info "============================================" + +# Core runtime dependencies (always needed) +RUNTIME_DEPS="numpy pyzmq psutil nvtx pyyaml expiring-dict" + +# Additional dependencies for P2P/distributed mode +if [ "$ENABLE_P2P" -eq 1 ]; then + RUNTIME_DEPS="$RUNTIME_DEPS redis" + info "mooncake-transfer-engine will be built from source in Step 4.6" +fi + +info "Installing runtime dependencies: $RUNTIME_DEPS" +pip install -q $RUNTIME_DEPS +success "Python runtime dependencies installed." + +# ======================== Step 4.6: Build Mooncake from Source ======================== +if [ "$ENABLE_P2P" -eq 1 ]; then + info "============================================" + info "Step 4.6: Building mooncake-transfer-engine from source" + info "============================================" + if [ -n "$MOONCAKE_VERSION" ]; then + info "Target version: $MOONCAKE_VERSION" + else + info "Target version: latest (main branch)" + fi + + MOONCAKE_BUILD_DIR="${PROJECT_ROOT}/.mooncake-build" + + # Clone or update mooncake source + if [ -d "$MOONCAKE_BUILD_DIR" ] && [ -d "$MOONCAKE_BUILD_DIR/.git" ]; then + info "Found existing mooncake source at $MOONCAKE_BUILD_DIR, updating..." + cd "$MOONCAKE_BUILD_DIR" + git fetch --tags + else + info "Cloning mooncake source to $MOONCAKE_BUILD_DIR..." + rm -rf "$MOONCAKE_BUILD_DIR" + git clone --recurse-submodules https://github.com/kvcache-ai/Mooncake.git "$MOONCAKE_BUILD_DIR" + cd "$MOONCAKE_BUILD_DIR" + fi + + # Checkout target version if specified + if [ -n "$MOONCAKE_VERSION" ]; then + info "Checking out $MOONCAKE_VERSION..." + git checkout "$MOONCAKE_VERSION" + else + info "Using latest main branch..." + git checkout main 2>/dev/null || git checkout master 2>/dev/null || true + git pull --ff-only 2>/dev/null || true + fi + git submodule sync --recursive + git submodule update --init --recursive + + # Install mooncake system dependencies + if [ "$SKIP_DEPS" -eq 0 ]; then + info "Installing mooncake system dependencies..." + $SUDO bash -x dependencies.sh -y + else + warn "Skipping mooncake dependency installation (--skip-deps)" + fi + + # Configure: only build transfer-engine with Redis support + info "Configuring mooncake-transfer-engine with Redis metadata backend support..." + mkdir -p build && cd build + + # Detect CUDA stubs path + CUDA_STUBS_PATH="" + if [ -d "/usr/local/cuda/lib64/stubs" ]; then + CUDA_STUBS_PATH="/usr/local/cuda/lib64/stubs" + elif [ -n "$CUDA_HOME" ] && [ -d "$CUDA_HOME/lib64/stubs" ]; then + CUDA_STUBS_PATH="$CUDA_HOME/lib64/stubs" + fi + + CMAKE_EXTRA_FLAGS="" + if [ -n "$CUDA_STUBS_PATH" ]; then + CMAKE_EXTRA_FLAGS="-DCMAKE_EXE_LINKER_FLAGS=-L${CUDA_STUBS_PATH}" + fi + + cmake -G Ninja .. \ + -DWITH_TE=ON \ + -DUSE_REDIS=ON \ + -DUSE_HTTP=ON \ + -DUSE_ETCD=OFF \ + -DUSE_CUDA=ON \ + -DWITH_STORE=OFF \ + -DWITH_P2P_STORE=OFF \ + -DWITH_EP=OFF \ + -DWITH_METRICS=OFF \ + -DBUILD_UNIT_TESTS=OFF \ + -DBUILD_EXAMPLES=ON \ + -DCMAKE_BUILD_TYPE=Release \ + $CMAKE_EXTRA_FLAGS + + # Build + info "Building mooncake-transfer-engine (this may take a while)..." + if [ -n "$CUDA_STUBS_PATH" ]; then + export LD_LIBRARY_PATH="${CUDA_STUBS_PATH}:$LD_LIBRARY_PATH" + export LIBRARY_PATH="${CUDA_STUBS_PATH}:$LIBRARY_PATH" + fi + cmake --build . -j"$(nproc)" + $SUDO cmake --install . + + # Build and install Python wheel + info "Building mooncake-transfer-engine Python wheel..." + cd "$MOONCAKE_BUILD_DIR" + + # Uninstall any existing mooncake pip package + pip uninstall -y mooncake-transfer-engine mooncake-transfer-engine-cuda13 2>/dev/null || true + + # Detect if CUDA 13 build + CUDA_MAJOR_VERSION="" + if command -v nvcc &>/dev/null; then + CUDA_MAJOR_VERSION=$(nvcc --version | grep -oP 'release \K[0-9]+') + elif [ -n "$CUDA_HOME" ] && [ -f "$CUDA_HOME/bin/nvcc" ]; then + CUDA_MAJOR_VERSION=$("$CUDA_HOME/bin/nvcc" --version | grep -oP 'release \K[0-9]+') + fi + + MOONCAKE_BUILD_ENV="" + if [ -n "$CUDA_MAJOR_VERSION" ] && [ "$CUDA_MAJOR_VERSION" -ge 13 ] 2>/dev/null; then + MOONCAKE_BUILD_ENV="CU13_BUILD=1" + fi + + eval $MOONCAKE_BUILD_ENV OUTPUT_DIR=dist ./scripts/build_wheel.sh + + # build_wheel.sh outputs wheel to mooncake-wheel/dist/ + MOONCAKE_WHEEL=$(ls mooncake-wheel/dist/*.whl 2>/dev/null | head -n 1) + if [ -z "$MOONCAKE_WHEEL" ]; then + error "mooncake-transfer-engine wheel not found in mooncake-wheel/dist/" + fi + pip install "$MOONCAKE_WHEEL" + + cd "$PROJECT_ROOT" + success "mooncake-transfer-engine built from source with Redis support!" + + # Verify Redis metadata backend support + info "Verifying mooncake Redis metadata backend support..." + python3 -c " +from mooncake import engine +e = engine.TransferEngine() +print('mooncake-transfer-engine loaded successfully (built from source with Redis support)') +" && success "mooncake verification passed!" || warn "mooncake verification had warnings, see above." +fi + +# ======================== Step 5: Install Python Package ======================== +info "============================================" +info "Step 5: Installing FlexKV Python package" +info "============================================" + +# Set environment variables for build +export FLEXKV_ENABLE_METRICS="$ENABLE_METRICS" +export FLEXKV_ENABLE_P2P="$ENABLE_P2P" +export FLEXKV_ENABLE_GDS="$ENABLE_GDS" +export FLEXKV_ENABLE_CFS="$ENABLE_CFS" + +if [ "$BUILD_TYPE" = "debug" ]; then + export FLEXKV_DEBUG=1 + info "Installing in debug mode (editable, no Cython)..." + pip install -v --no-build-isolation -e . +elif [ "$BUILD_TYPE" = "release" ]; then + export FLEXKV_DEBUG=0 + info "Building release wheel..." + python3 setup.py bdist_wheel -v + # Install the built wheel + WHEEL_FILE=$(ls dist/flexkv-*.whl 2>/dev/null | head -n 1) + if [ -n "$WHEEL_FILE" ]; then + pip install "$WHEEL_FILE" + else + error "Wheel file not found in dist/" + fi +fi +success "FlexKV Python package installed." + +# ======================== Step 6: Verify Installation ======================== +info "============================================" +info "Step 6: Verifying installation" +info "============================================" + +python3 -c " +import flexkv +print('FlexKV imported successfully') +try: + print(f'Version: {flexkv.__version__}') +except AttributeError: + pass +try: + from flexkv import c_ext + print('C extension loaded successfully') +except ImportError as e: + print(f'Warning: C extension not loaded: {e}') +" && success "FlexKV installation verified!" || warn "Verification had warnings, see above." + +# ======================== Summary ======================== +echo "" +info "============================================" +success "FlexKV installation completed!" +info "============================================" +echo "" +info "Build type: $BUILD_TYPE" +info "Metrics: $([ $ENABLE_METRICS -eq 1 ] && echo 'Enabled' || echo 'Disabled')" +info "P2P/Redis: $([ $ENABLE_P2P -eq 1 ] && echo 'Enabled' || echo 'Disabled')" +if [ "$ENABLE_P2P" -eq 1 ]; then + if [ -n "$MOONCAKE_VERSION" ]; then + info "Mooncake: Built from source ($MOONCAKE_VERSION) with Redis metadata backend" + else + info "Mooncake: Built from source (latest) with Redis metadata backend" + fi +fi +info "GDS: $([ $ENABLE_GDS -eq 1 ] && echo 'Enabled' || echo 'Disabled')" +info "CFS: $([ $ENABLE_CFS -eq 1 ] && echo 'Enabled' || echo 'Disabled')" + +if [ "$USE_VENV" -eq 1 ]; then + VENV_ABS_PATH="$(cd "$VENV_PATH" && pwd)" + echo "" + info "Virtual environment: $VENV_ABS_PATH" + info "To activate it in a new terminal, run:" + echo "" + echo " source $VENV_ABS_PATH/bin/activate" + echo "" +fi diff --git a/requirements.txt b/requirements.txt index 4c1ec7be69..69eb85601f 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,12 @@ setuptools>=40.0.0 torch>=1.10.0 -# nvtx==0.2.11 # Skip nvtx for now due to compatibility issues +numpy>=1.20.0 +pyzmq>=22.0.0 +psutil>=5.8.0 +nvtx>=0.2.8 +pyyaml>=5.4.0 Cython>=3.0.10 pytest>=6.0.0 pytest-benchmark>=3.0.0 expiring-dict==1.1.2 +redis>=4.0.0 diff --git a/setup.py b/setup.py index e5b618f65a..2d3813d3cc 100755 --- a/setup.py +++ b/setup.py @@ -7,6 +7,28 @@ from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension + +def detect_cuda_arch(): + """Auto-detect GPU compute capability. Returns a semicolon-separated arch list. + Falls back to a safe default when no GPU is available.""" + try: + import torch + if torch.cuda.is_available(): + archs = set() + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + archs.add(f"{major}.{minor}") + if archs: + arch_list = ";".join(sorted(archs)) + print(f"Auto-detected GPU architectures: {arch_list}") + return arch_list + except Exception as e: + print(f"GPU architecture auto-detection failed: {e}") + # Fallback: common architectures (Ampere + Hopper) + fallback = "8.0;8.6;9.0" + print(f"No GPU detected, using fallback architectures: {fallback}") + return fallback + def get_version(): import subprocess try: @@ -81,10 +103,10 @@ def get_version(): extra_link_args.extend(["-lprometheus-cpp-pull", "-lprometheus-cpp-core"]) else: print("FLEXKV_ENABLE_METRICS=0: building without Prometheus monitoring") -# If TORCH_CUDA_ARCH_LIST is not set, default to known supported archs -# to avoid auto-detection failure on newer GPUs (e.g. Blackwell sm_100) +# Auto-detect GPU architecture if TORCH_CUDA_ARCH_LIST is not explicitly set if not os.environ.get("TORCH_CUDA_ARCH_LIST"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" + os.environ["TORCH_CUDA_ARCH_LIST"] = detect_cuda_arch() +print(f"TORCH_CUDA_ARCH_LIST = {os.environ['TORCH_CUDA_ARCH_LIST']}") extra_compile_args = ["-std=c++17", "-O3"] if enable_metrics: From 1c14efd50869999884788e043f6e04b25a5cf1ca Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Thu, 9 Apr 2026 10:59:09 +0800 Subject: [PATCH 34/59] fix cpu layout problems in tp + blockwise cpu layout --- csrc/bindings.cpp | 3 ++- csrc/layerwise.cpp | 5 +++-- csrc/layerwise.h | 3 ++- flexkv/transfer/layerwise.py | 8 +++++++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 6fa33ddbf0..e9f04ed024 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -444,7 +444,8 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_kv_stride_in_bytes"), py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), - py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_cta_num"), + py::arg("cpu_chunk_size_in_bytes"), + py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_cta_num"), py::arg("use_ce_transfer"), py::arg("num_layers"), py::arg("layer_granularity"), py::arg("is_mla"), py::arg("counter_id") = 0); diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index f3a38eec42..44a8442a06 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -225,7 +225,8 @@ void LayerwiseTransferGroup::layerwise_transfer( const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, - const int64_t cpu_chunk_size_in_bytes, const int transfer_cta_num, + const int64_t cpu_chunk_size_in_bytes, + const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, const int counter_id) { @@ -318,7 +319,7 @@ void LayerwiseTransferGroup::layerwise_transfer( for (int i = 0; i < num_gpus_; ++i) { cudaSetDevice(gpu_device_ids_[i]); - int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; + int64_t cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; if (is_mla) { cpu_startoff_inside_chunks = 0; } diff --git a/csrc/layerwise.h b/csrc/layerwise.h index eea6561cf9..187bc31327 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -47,7 +47,8 @@ class LayerwiseTransferGroup { const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, - const int64_t cpu_chunk_size_in_bytes, const int transfer_cta_num, + const int64_t cpu_chunk_size_in_bytes, + const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, const int counter_id = 0); // Counter set index for triple buffering diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index bef78366ce..90a332a85b 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -115,9 +115,14 @@ def __init__(self, self.cpu_blocks = cpu_blocks self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize - self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize + # tp has effect on the layout of the cpu tensor + # the tp dim should always be right after the block dim + if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: + cpu_kv_layout = cpu_kv_layout.div_head(self.tp_group_size) self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize + self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize + self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h @@ -279,6 +284,7 @@ def _transfer_impl(self, self.cpu_layer_stride_in_bytes, self.cpu_block_stride_in_bytes, self.cpu_chunk_size_in_bytes, + self.cpu_tp_stride_in_bytes, self.h2d_cta_num, self.use_ce_transfer_h2d, self.num_layers, From cc835f103f62fefccd8fdf613072b4d5cbcb95c3 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Thu, 9 Apr 2026 12:56:28 +0800 Subject: [PATCH 35/59] split cpu stride for ssd and for gpu --- csrc/bindings.cpp | 2 ++ csrc/layerwise.cpp | 8 +++++--- csrc/layerwise.h | 2 ++ flexkv/transfer/layerwise.py | 16 +++++++++++----- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index e9f04ed024..3bb718a992 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -445,6 +445,8 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), py::arg("cpu_chunk_size_in_bytes"), + py::arg("h2d_cpu_kv_stride_in_bytes"), + py::arg("h2d_cpu_layer_stride_in_bytes"), py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_cta_num"), py::arg("use_ce_transfer"), py::arg("num_layers"), py::arg("layer_granularity"), py::arg("is_mla"), diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 44a8442a06..ffaa57ea8d 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -226,6 +226,8 @@ void LayerwiseTransferGroup::layerwise_transfer( const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, const int64_t cpu_chunk_size_in_bytes, + const int64_t h2d_cpu_kv_stride_in_bytes, + const int64_t h2d_cpu_layer_stride_in_bytes, const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, @@ -331,7 +333,7 @@ void LayerwiseTransferGroup::layerwise_transfer( flexkv::transfer_kv_blocks( num_blocks, start_layer, layers_this_batch, gpu_block_ids, gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, - cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_ptr, h2d_cpu_kv_stride_in_bytes, h2d_cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; @@ -339,7 +341,7 @@ void LayerwiseTransferGroup::layerwise_transfer( flexkv::transfer_kv_blocks( num_blocks, start_layer, layers_this_batch, gpu_block_ids, gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, - cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_ptr, h2d_cpu_kv_stride_in_bytes, h2d_cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; @@ -347,7 +349,7 @@ void LayerwiseTransferGroup::layerwise_transfer( flexkv::transfer_kv_blocks( num_blocks, start_layer, layers_this_batch, gpu_block_ids, gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, - cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_ptr, h2d_cpu_kv_stride_in_bytes, h2d_cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; diff --git a/csrc/layerwise.h b/csrc/layerwise.h index 187bc31327..0a3fd7fcae 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -48,6 +48,8 @@ class LayerwiseTransferGroup { const int64_t cpu_layer_stride_in_bytes, const int64_t cpu_block_stride_in_bytes, const int64_t cpu_chunk_size_in_bytes, + const int64_t h2d_cpu_kv_stride_in_bytes, + const int64_t h2d_cpu_layer_stride_in_bytes, const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 90a332a85b..4ffce16b7b 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -116,12 +116,16 @@ def __init__(self, self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize - # tp has effect on the layout of the cpu tensor - # the tp dim should always be right after the block dim - if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: - cpu_kv_layout = cpu_kv_layout.div_head(self.tp_group_size) - self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize + # Full CPU strides (for SSD->CPU, which transfers all TP ranks' data) self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize + self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize + # TP-divided CPU strides (for CPU->GPU, each rank reads its own portion) + if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: + cpu_kv_layout_tp = cpu_kv_layout.div_head(self.tp_group_size) + else: + cpu_kv_layout_tp = cpu_kv_layout + self.h2d_cpu_kv_stride_in_bytes = cpu_kv_layout_tp.get_kv_stride() * self.dtype.itemsize + self.h2d_cpu_layer_stride_in_bytes = cpu_kv_layout_tp.get_layer_stride() * self.dtype.itemsize self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size self.use_ce_transfer_h2d = use_ce_transfer_h2d @@ -284,6 +288,8 @@ def _transfer_impl(self, self.cpu_layer_stride_in_bytes, self.cpu_block_stride_in_bytes, self.cpu_chunk_size_in_bytes, + self.h2d_cpu_kv_stride_in_bytes, + self.h2d_cpu_layer_stride_in_bytes, self.cpu_tp_stride_in_bytes, self.h2d_cta_num, self.use_ce_transfer_h2d, From c63ffeedc05bcc19f924666402b528488a0112cd Mon Sep 17 00:00:00 2001 From: zittozhang Date: Thu, 9 Apr 2026 15:26:37 +0800 Subject: [PATCH 36/59] fix: comprehensive TP/PP/DP dimension fixes and observability improvements - Add dp_rank to ModelConfig and propagate dp_size/dp_rank through config chain - Generate per-PP-rank IPC port suffixes to avoid ZMQ endpoint collisions - Generate per-PP/DP-rank eventfd socket paths for LayerwiseTransferWorker - Use global device_id (dp_rank * tp_size + tp_rank) to avoid GPU registration conflicts - Skip dp_rank in model config verification (DP ranks share same KVServer) - Fix recvmsg_into flags parameter from anc_buf_size to 0 - Handle zmq.Again in KVTPClient registration with blocking fallback - Add periodic GPU registration wait diagnostics in TransferManager - Add comprehensive IPC port/socket logging throughout initialization --- flexkv/common/config.py | 1 + flexkv/integration/config.py | 35 +++++++++++- flexkv/kvmanager.py | 5 ++ flexkv/server/client.py | 15 ++++- flexkv/server/server.py | 7 +++ flexkv/transfer/layerwise.py | 88 ++++++++++++++++++++++++++---- flexkv/transfer/transfer_engine.py | 8 +++ flexkv/transfer_manager.py | 11 ++++ 8 files changed, 154 insertions(+), 16 deletions(-) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index c56712989f..7a64b6f7a1 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -32,6 +32,7 @@ class ModelConfig: # parallel configs tp_size: int = 1 dp_size: int = 1 + dp_rank: int = 0 pp_size: int = 1 pp_rank: int = 0 diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index c857050b62..5d9946d430 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -114,6 +114,8 @@ def post_init_from_sglang_config( num_local_layers: int = 0, pp_size: int = 1, pp_rank: int = 0, + dp_size: int = 1, + dp_rank: int = 0, ): """ Initialize FlexKVConfig fields from sglang config. @@ -124,6 +126,8 @@ def post_init_from_sglang_config( num_local_layers: number of layers on this PP rank (0 means no PP, use total layers) pp_size: pipeline parallel size (default 1, no PP) pp_rank: pipeline parallel rank (default 0) + dp_size: data parallel size (default 1, no DP) + dp_rank: data parallel rank (default 0) """ # cache config: use page_size as tokens_per_block so that FlexKV's # CPU radix tree manages blocks at page granularity, ensuring that @@ -159,11 +163,38 @@ def post_init_from_sglang_config( self.model_config.use_mla = use_mla self.model_config.tp_size = int(tp_size) - self.model_config.dp_size = int(getattr(sglang_config, "dp_size", 1)) + self.model_config.dp_size = int(dp_size if dp_size is not None else 1) + self.model_config.dp_rank = int(dp_rank if dp_rank is not None else 0) self.model_config.pp_size = int(pp_size) self.model_config.pp_rank = int(pp_rank) update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) - + + # Each PP rank needs its own IPC ports so that their + # KVManager / TransferManager instances do not collide on the same + # ZMQ endpoint. DP ranks share the same KVServer (only DP0 creates + # it), so they must use the same IPC port. + _dp_rank = int(dp_rank if dp_rank is not None else 0) + port_suffix = "" + if int(pp_size) > 1: + port_suffix += f"_pp{int(pp_rank)}" + if port_suffix: + self.server_recv_port = f"{self.server_recv_port}{port_suffix}" + self.gpu_register_port = f"{self.server_recv_port}_gpu_register" + + rank_parts = [] + if int(tp_size) > 1: + rank_parts.append(f"tp_rank=0") + if int(pp_size) > 1: + rank_parts.append(f"pp_rank={int(pp_rank)}") + if int(self.model_config.dp_size) > 1: + rank_parts.append(f"dp_rank={_dp_rank}") + rank_label = f" [{', '.join(rank_parts)}]" if rank_parts else "" + logger.info( + f"[FlexKV] IPC ports configured{rank_label}: " + f"server_recv_port={self.server_recv_port}, " + f"gpu_register_port={self.gpu_register_port}" + ) + hf_config = getattr(sglang_config, 'hf_config', None) self._detect_indexer_config_from_hf(hf_config, source="sglang") diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 2375948cda..b5c91e297a 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -55,6 +55,11 @@ def __init__(self, else: self.gpu_register_port = self.server_recv_port + "_gpu_register" + flexkv_logger.info( + f"[KVManager] IPC ports: server_recv_port={self.server_recv_port}, " + f"gpu_register_port={self.gpu_register_port}" + ) + # Multi-instance mode also requires server_client_mode self.server_client_mode = (model_config.dp_size > 1 or self.instance_num > 1 or diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 4563e601e3..f367b1d67b 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -249,7 +249,8 @@ def __init__( self.dp_client_id = dp_client_id self.device_id = device_id - flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized!") + flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized! " + f"(gpu_register_port={gpu_register_port})") def register_to_server( self, @@ -286,7 +287,17 @@ def register_to_server( indexer_gpu_layout=indexer_layout, ) - self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) + try: + self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) + flexkv_logger.info( + f"KVTPClient {device_id}: registration message sent " + f"(dp_client_id={self.dp_client_id}, num_kv_caches={len(kv_caches)})") + except zmq.Again: + flexkv_logger.error( + f"KVTPClient {device_id}: zmq.Again when sending registration " + f"(send buffer full or no connection). Retrying with blocking send...") + self.send_to_server.send_pyobj(register_req) + flexkv_logger.info(f"KVTPClient {device_id}: registration message sent (blocking retry)") if __name__ == "__main__": diff --git a/flexkv/server/server.py b/flexkv/server/server.py index e5f808c697..6b64866b63 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -154,6 +154,10 @@ def __init__( self.context = zmq.Context(2) self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, server_recv_port, True) + flexkv_logger.info( + f"[KVServer] IPC ports bound: server_recv_port={server_recv_port}, " + f"gpu_register_port={gpu_register_port}" + ) # Use total_clients if provided (multi-instance mode), otherwise use dp_size max_clients = total_clients if total_clients > 0 else model_config.dp_size @@ -332,7 +336,10 @@ def run(self) -> None: def _verify_model_config(self, model_config: ModelConfig) -> None: """Verify that client's model config matches server's config.""" + skip_fields = {"dp_rank"} for field in fields(ModelConfig): + if field.name in skip_fields: + continue client_val = getattr(model_config, field.name) server_val = getattr(self.model_config, field.name) assert client_val == server_val, \ diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 4ffce16b7b..a97793fe08 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -36,7 +36,7 @@ def _recv_fds(sock: socket.socket, num_fds: int) -> Tuple[List[int], bytes]: data_buf = bytearray(256) anc_buf_size = socket.CMSG_SPACE(num_fds * struct.calcsize("i")) - nbytes, ancdata, flags, addr = sock.recvmsg_into([data_buf], anc_buf_size, anc_buf_size) + nbytes, ancdata, flags, addr = sock.recvmsg_into([data_buf], anc_buf_size, 0) data = bytes(data_buf[:nbytes]) fds = [] @@ -64,12 +64,22 @@ def __init__(self, dtype: torch.dtype, tp_group_size: int, dp_group_id: int, + pp_rank: int, + pp_size: int, + dp_size: int, + dp_rank: int, num_blocks_per_file: int, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, h2d_cta_num: int = 4, d2h_cta_num: int = 4, enable_eventfd: bool = True) -> None: + flexkv_logger.debug( + f"[LayerwiseWorker] __init__ started: worker_id={worker_id}, " + f"tp_group_size={tp_group_size}, dp_group_id={dp_group_id}, " + f"pp_rank={pp_rank}, pp_size={pp_size}, " + f"enable_eventfd={enable_eventfd}, " + f"num_gpu_blocks={[len(b) for b in gpu_blocks]}") super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" imported_gpu_blocks = [] @@ -84,7 +94,11 @@ def __init__(self, self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size + self.pp_rank = pp_rank + self.pp_size = pp_size if pp_size > 0 else 1 self.dp_group_id = dp_group_id + self.dp_size = dp_size if dp_size > 0 else 1 + self.dp_rank = dp_rank # initialize GPU storage self.num_layers = gpu_kv_layouts[0].num_layer @@ -109,9 +123,10 @@ def __init__(self, raise ValueError(f"Invalid GPU block type: {num_blocks_first_gpu}") # initialize CPU storage - flexkv_logger.info(f"Pinning CPU Memory: " + flexkv_logger.info(f"[LayerwiseWorker] Pinning CPU Memory: " f"{cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) + flexkv_logger.debug(f"[LayerwiseWorker] CPU memory pinned successfully") self.cpu_blocks = cpu_blocks self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize @@ -157,12 +172,15 @@ def __init__(self, gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) + flexkv_logger.debug(f"[LayerwiseWorker] About to receive eventfds, enable_eventfd={enable_eventfd}") if enable_eventfd: layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) else: layer_eventfds_tensor = torch.empty(0, dtype=torch.int32) + flexkv_logger.debug(f"[LayerwiseWorker] Eventfds received, tensor shape={layer_eventfds_tensor.shape}") # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers + flexkv_logger.debug(f"[LayerwiseWorker] Creating LayerwiseTransferGroup...") self.layerwise_transfer_group = LayerwiseTransferGroup( self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, dp_group_id, self.num_layers, @@ -171,12 +189,32 @@ def __init__(self, GLOBAL_CONFIG_FROM_ENV.iouring_entries, GLOBAL_CONFIG_FROM_ENV.iouring_flags, layer_eventfds_tensor, tp_group_size) + flexkv_logger.info(f"[LayerwiseWorker] __init__ completed successfully, worker_id={worker_id}") def _receive_eventfds_from_sglang(self, tp_group_size: int, max_retries: int = 180, retry_interval: float = 1.0) -> torch.Tensor: """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" - socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock') + base_socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock') + sock_suffix = "" + if int(self.pp_size) > 1: + sock_suffix += f"_pp{int(self.pp_rank)}" + if int(self.dp_size) > 1: + sock_suffix += f"_dp{int(self.dp_rank)}" + if sock_suffix: + root, ext = os.path.splitext(base_socket_path) + socket_path = f"{root}{sock_suffix}{ext}" + else: + socket_path = base_socket_path + + rank_parts = [] + if int(self.tp_group_size) > 1: + rank_parts.append("tp_rank=0") + if int(self.pp_size) > 1: + rank_parts.append(f"pp_rank={int(self.pp_rank)}") + if int(self.dp_size) > 1: + rank_parts.append(f"dp_rank={int(self.dp_rank)}") + rank_label = f" [{', '.join(rank_parts)}]" if rank_parts else "" def cleanup_socket(): try: @@ -193,9 +231,12 @@ def cleanup_socket(): server_sock.bind(socket_path) server_sock.listen(tp_group_size) os.chmod(socket_path, 0o777) - flexkv_logger.info(f"[LayerwiseWorker] Listening on {socket_path}, waiting for {tp_group_size} connections") + flexkv_logger.info( + f"[LayerwiseWorker] Eventfd server created{rank_label}: " + f"socket={socket_path}, waiting for {tp_group_size} connection(s)") except Exception as e: - flexkv_logger.error(f"[LayerwiseWorker] Failed to bind/listen: {e}") + flexkv_logger.error( + f"[LayerwiseWorker] Failed to bind/listen on {socket_path}{rank_label}: {e}") server_sock.close() return torch.empty(0, dtype=torch.int32) @@ -205,38 +246,60 @@ def cleanup_socket(): try: for conn_idx in range(tp_group_size): + flexkv_logger.debug( + f"[LayerwiseWorker] Waiting for connection " + f"{conn_idx + 1}/{tp_group_size} on {socket_path}...") try: conn, _ = server_sock.accept() + flexkv_logger.info( + f"[LayerwiseWorker] Accepted connection " + f"{conn_idx + 1}/{tp_group_size} on {socket_path}{rank_label}") except socket.timeout: - flexkv_logger.warning(f"[LayerwiseWorker] Timeout, received {conn_idx}/{tp_group_size}") + flexkv_logger.error( + f"[LayerwiseWorker] Timeout waiting for connection on {socket_path}{rank_label}, " + f"received {conn_idx}/{tp_group_size}") break with conn: metadata = conn.recv(16) if len(metadata) < 16: - flexkv_logger.error(f"[LayerwiseWorker] Incomplete metadata: {len(metadata)} bytes") + flexkv_logger.error( + f"[LayerwiseWorker] Incomplete metadata on {socket_path}{rank_label}: " + f"{len(metadata)} bytes") continue tp_rank, _, recv_num_layers, recv_num_counters = struct.unpack("iiii", metadata) if conn_idx == 0: num_layers, num_counters = recv_num_layers, recv_num_counters + flexkv_logger.debug( + f"[LayerwiseWorker] Connection {conn_idx + 1}: " + f"tp_rank={tp_rank}, num_layers={recv_num_layers}, " + f"num_counters={recv_num_counters}") + rank_eventfds = {} for _ in range(recv_num_counters): fds, extra_data = _recv_fds(conn, recv_num_layers) counter_id = struct.unpack("i", extra_data[:4])[0] rank_eventfds[counter_id] = fds + flexkv_logger.debug( + f"[LayerwiseWorker] Received counter_id={counter_id}, " + f"num_fds={len(fds)} from tp_rank={tp_rank}") all_rank_eventfds[tp_rank] = rank_eventfds - flexkv_logger.info(f"[LayerwiseWorker] Received eventfds from tp_rank={tp_rank}") + flexkv_logger.info( + f"[LayerwiseWorker] Received all eventfds from tp_rank={tp_rank} " + f"on {socket_path}") except Exception as e: - flexkv_logger.error(f"[LayerwiseWorker] Error in accept loop: {e}") + flexkv_logger.error( + f"[LayerwiseWorker] Error in accept loop on {socket_path}{rank_label}: {e}") finally: server_sock.close() cleanup_socket() if not all_rank_eventfds: - flexkv_logger.warning("[LayerwiseWorker] No connections received") + flexkv_logger.warning( + f"[LayerwiseWorker] No connections received on {socket_path}{rank_label}") return torch.empty(0, dtype=torch.int32) # Build tensor: [num_counters, tp_size, num_layers] @@ -248,8 +311,9 @@ def cleanup_socket(): tensor = torch.tensor(eventfds_list, dtype=torch.int32) flexkv_logger.info( - f"[LayerwiseWorker] Eventfds tensor: {tensor.shape}, " - f"counters={num_counters}, tp={tp_group_size}, layers={num_layers}" + f"[LayerwiseWorker] Eventfd setup complete{rank_label}: " + f"socket={socket_path}, tensor_shape={tensor.shape}, " + f"counters={num_counters}, tp_size={tp_group_size}, layers={num_layers}" ) return tensor diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 649caaf80f..8490781bf7 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -343,6 +343,10 @@ def _init_workers(self) -> None: dtype=gpu_handles[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, + pp_rank=self.model_config.pp_rank, + pp_size=self.model_config.pp_size, + dp_size=self.model_config.dp_size, + dp_rank=dp_client_id, num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, @@ -595,6 +599,10 @@ def _init_workers(self) -> None: dtype=indexer_gpu_handles_list[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, + pp_rank=self.model_config.pp_rank, + pp_size=self.model_config.pp_size, + dp_size=self.model_config.dp_size, + dp_rank=dp_client_id, num_blocks_per_file=indexer_num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 49166f3aea..dd9e2a3292 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -88,11 +88,22 @@ def _register_gpu_blocks_via_socket(self) -> None: f"expected {self.expected_gpus} GPUs to register " f"(instance_num={self.instance_num}, tp={self.model_config.tp_size}, " f"dp={self.model_config.dp_size})") + last_log_time = time.time() while len(self.all_gpu_blocks) < self.expected_gpus: try: # Recv from: flexkv.server.client.KVTPClient.register_to_server req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) except zmq.Again: + # Periodically log waiting status for debugging + now = time.time() + if now - last_log_time >= 5.0: + registered_ids = sorted(self.all_gpu_blocks.keys()) + flexkv_logger.info( + f"Still waiting for GPU registrations: " + f"{len(self.all_gpu_blocks)}/{self.expected_gpus} registered " + f"(registered_device_ids={registered_ids}, " + f"port={self.gpu_register_port})") + last_log_time = now time.sleep(0.001) continue From a5eb20f201590eb752d4814f1e57101ddaa41fba Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Thu, 9 Apr 2026 17:32:05 +0800 Subject: [PATCH 37/59] fix ssd read when blockwise + tp + layerwise --- csrc/layerwise.cpp | 58 ++++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index ffaa57ea8d..151671a507 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -280,42 +280,44 @@ void LayerwiseTransferGroup::layerwise_transfer( h2d_range_ids[0] = nvtxRangeStartA(h2d_range_names[0].c_str()); } + // Step 0: SSD -> CPU transfer for ALL layers at once (before layerwise loop). + // This is required because the CPU memory uses TP-divided layout where each rank's + // data occupies a contiguous region [rank*tp_stride, (rank+1)*tp_stride). Per-layer-batch + // SSD reads with full strides would land at wrong CPU positions for TP > 1. + if (enable_ssd_ && ssd_block_ids.numel() > 0) { + int num_ssd_blocks = ssd_block_ids.numel(); + int64_t ssd_bytes = cpu_chunk_size_in_bytes * 2 * num_layers * num_ssd_blocks; + double ssd_mb = ssd_bytes / (1024.0 * 1024.0); + char ssd_range_name[128]; + snprintf(ssd_range_name, sizeof(ssd_range_name), + "SSD->CPU AllLayers[0,%d) %.2fMB", num_layers, ssd_mb); + nvtxRangePushA(ssd_range_name); + + torch::Tensor all_layer_ids = + torch::arange(0, num_layers, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *ioctx_, all_layer_ids, reinterpret_cast(cpu_blocks_), + ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, + cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, + ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, + cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + num_blocks_per_file, round_robin, num_threads_per_device, is_mla); + + nvtxRangePop(); + } + int batch_idx = 0; for (int start_layer = 0; start_layer < num_layers; start_layer += layer_granularity) { int layers_this_batch = std::min(layer_granularity, num_layers - start_layer); - + batch_start_layers[batch_idx] = start_layer; batch_layers_count[batch_idx] = layers_this_batch; - - // Step 1: SSD -> CPU transfer - if (enable_ssd_ && ssd_block_ids.numel() > 0) { - // Calculate SSD->CPU data size: cpu_chunk_size * 2 (K+V) * layers * num_ssd_blocks - int num_ssd_blocks = ssd_block_ids.numel(); - int64_t ssd_bytes = cpu_chunk_size_in_bytes * 2 * layers_this_batch * num_ssd_blocks; - double ssd_mb = ssd_bytes / (1024.0 * 1024.0); - char ssd_range_name[128]; - snprintf(ssd_range_name, sizeof(ssd_range_name), - "SSD->CPU Layer[%d,%d) %.2fMB", start_layer, start_layer + layers_this_batch, ssd_mb); - nvtxRangePushA(ssd_range_name); - - torch::Tensor layer_id_list = - torch::arange(start_layer, start_layer + layers_this_batch, - torch::TensorOptions().dtype(torch::kInt32)); - transfer_kv_blocks_ssd( - *ioctx_, layer_id_list, reinterpret_cast(cpu_blocks_), - ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, - cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, - ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, - cpu_block_stride_in_bytes, - true, // is_read: SSD -> CPU - num_blocks_per_file, round_robin, num_threads_per_device, is_mla); - - nvtxRangePop(); - } - // Step 2: CPU -> GPU transfer + // Step 1: CPU -> GPU transfer // NVTX range for this batch was already started (by main thread for first batch, // or by previous batch's callback for subsequent batches) From 5a706b100a6e6972fb2a5806dc0cdca306fce77d Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 10 Apr 2026 07:33:03 +0000 Subject: [PATCH 38/59] dont sync prefetch --- flexkv/kvtask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 88038df389..b1be68bc93 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -496,7 +496,7 @@ def get_async(self, dp_id: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: - self._sync_prefetch(token_ids, namespace) + # self._sync_prefetch(token_ids, namespace) task_id, return_mask = self._get_match_impl(token_ids, slot_mapping, is_fake_slot_mapping=False, @@ -649,7 +649,7 @@ def get_match(self, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) - self._sync_prefetch(token_ids, namespace) + # self._sync_prefetch(token_ids, namespace) if token_mask is None: token_mask = np.ones_like(token_ids, dtype=bool) fake_slot_mapping = np.zeros_like(token_ids[token_mask]) From 9e72cf142de97e24899650b44a9806a86e555e8c Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 10 Apr 2026 07:42:07 +0000 Subject: [PATCH 39/59] support cpuonly match for prefetch --- flexkv/cache/cache_engine.py | 2 ++ flexkv/kvmanager.py | 3 +++ flexkv/kvtask.py | 11 ++++++++++- flexkv/server/client.py | 2 ++ flexkv/server/request.py | 1 + flexkv/server/server.py | 1 + 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index a0571aa59e..a37f0d2d7b 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -351,6 +351,8 @@ class CacheStrategy: DEFAULT_CACHE_STRATEGY = CacheStrategy() +CPUONLY_CACHE_STRATEGY = CacheStrategy(ignore_gpu=False, ignore_ssd=True, ignore_remote=True, ignore_gds=True) + class GlobalCacheEngine: def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_meta: RedisMeta = None, event_collector: Optional[KVEventCollector] = None): diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index b5c91e297a..d6cc8c2cd8 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -180,6 +180,7 @@ def get_match(self, token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, dp_id: int = 0, + cpu_only: bool = False, namespace: Optional[List[str]] = None, ) -> Tuple[int, np.ndarray]: if isinstance(token_ids, torch.Tensor): @@ -190,12 +191,14 @@ def get_match(self, task_id, mask = self.dp_client.get_match(token_ids, token_mask, layer_granularity, + cpu_only=cpu_only, namespace=namespace) else: task_id, mask = self.kv_task_engine.get_match(token_ids, token_mask, layer_granularity, dp_id, + cpu_only=cpu_only, namespace=namespace) return task_id, mask diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index b1be68bc93..47128b5309 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -17,7 +17,7 @@ from flexkv.common.block import hash_token from flexkv.common.transfer import TransferOpGraph, merge_to_batch_graph, get_nvtx_default_color, CompletedOp from flexkv.common.tracer import FlexKVTracer -from flexkv.cache.cache_engine import GlobalCacheEngine, DEFAULT_CACHE_STRATEGY +from flexkv.cache.cache_engine import GlobalCacheEngine, DEFAULT_CACHE_STRATEGY, CPUONLY_CACHE_STRATEGY from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.transfer_manager import ( @@ -208,6 +208,7 @@ def create_get_task(self, layer_granularity: int = -1, dp_id: int = 0, is_fake_slot_mapping: bool = False, + temp_cache_strategy=DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None, ) -> None: if task_id in self.tasks: @@ -220,6 +221,7 @@ def create_get_task(self, layer_num=self.model_config.num_layers, layer_granularity=layer_granularity, dp_id=dp_id, + temp_cache_strategy=temp_cache_strategy, namespace=namespace) self.tasks[task_id] = KVTask( task_id=task_id, @@ -646,6 +648,7 @@ def get_match(self, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, dp_id: int = 0, + cpu_only: bool = False, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) @@ -659,6 +662,7 @@ def get_match(self, token_mask=token_mask, layer_granularity=layer_granularity, dp_id=dp_id, + cpu_only=cpu_only, task_id=task_id, namespace=namespace) # trace get match request @@ -681,6 +685,7 @@ def _get_match_impl(self, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, dp_id: int = 0, + cpu_only: bool = False, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: if token_mask is None: @@ -689,6 +694,9 @@ def _get_match_impl(self, layer_granularity = self.model_config.num_layers if task_id == -1: task_id = self._gen_task_id() + temp_cache_strategy = DEFAULT_CACHE_STRATEGY + if cpu_only: + temp_cache_strategy = CPUONLY_CACHE_STRATEGY nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) self.create_get_task(task_id, token_ids, @@ -697,6 +705,7 @@ def _get_match_impl(self, layer_granularity, dp_id, is_fake_slot_mapping=is_fake_slot_mapping, + temp_cache_strategy=temp_cache_strategy, namespace=namespace) self._process_empty_graph(task_id) nvtx.pop_range() diff --git a/flexkv/server/client.py b/flexkv/server/client.py index f367b1d67b..526efaf372 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -156,12 +156,14 @@ def get_match( token_ids: np.ndarray, token_mask: Optional[np.ndarray], layer_granularity: int, + cpu_only: bool = False, namespace: Optional[List[str]] = None, ) -> Optional[Tuple[int, np.ndarray]]: req = GetMatchRequest(self.dp_client_id, token_ids, token_mask if token_mask is not None else None, layer_granularity, + cpu_only, self._get_task_id(), namespace) self.send_to_server.send_pyobj(req) diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 757b9e0056..ace0658ab7 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -71,6 +71,7 @@ class GetMatchRequest: token_ids: np.ndarray token_mask: Optional[np.ndarray] layer_granularity: int + cpu_only: bool = False task_id: int = -1 namespace: Optional[List[str]] = None diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 6b64866b63..60ca78d5af 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -398,6 +398,7 @@ def _handle_get_match_request(self, req: GetMatchRequest) -> None: token_mask=req.token_mask, layer_granularity=req.layer_granularity, dp_id=req.dp_client_id, + cpu_only=req.cpu_only, task_id=req.task_id, namespace=req.namespace, ) From d3ebb4bff8a19c7fae8d76ed013e7390d563584d Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Sat, 11 Apr 2026 00:20:04 +0800 Subject: [PATCH 40/59] fix: preload shared libraries via ctypes.CDLL in __init__.py Modifying LD_LIBRARY_PATH at runtime does NOT affect the current process's dynamic linker (ld.so reads it only at startup). Use ctypes.CDLL with RTLD_GLOBAL to pre-load libxxhash.so etc. so that c_ext (loaded via dlopen) can resolve them without requiring LD_LIBRARY_PATH to be set before process start. --- flexkv/__init__.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/flexkv/__init__.py b/flexkv/__init__.py index 936f51cb3f..c8ece6cd74 100644 --- a/flexkv/__init__.py +++ b/flexkv/__init__.py @@ -1,15 +1,24 @@ +import ctypes +import glob import os import sys # Add package lib directory to system library path def _setup_library_path() -> None: - """Setup library path to find shared libraries in the package""" + """Setup library path to find shared libraries in the package. + + Note: Modifying LD_LIBRARY_PATH at runtime does NOT affect the current + process's dynamic linker (ld.so reads it only at startup). We still set it + for child processes, but for the current process we must pre-load required + shared libraries via ctypes.CDLL with RTLD_GLOBAL so that subsequent + dlopen() calls (e.g. when importing c_ext) can resolve them. + """ package_dir = os.path.dirname(os.path.abspath(__file__)) lib_dir = os.path.join(package_dir, "lib") if os.path.exists(lib_dir): - # Add to LD_LIBRARY_PATH for Linux + # Set LD_LIBRARY_PATH for child processes if sys.platform.startswith('linux'): current_ld_path = os.environ.get('LD_LIBRARY_PATH', '') if lib_dir not in current_ld_path: @@ -18,6 +27,14 @@ def _setup_library_path() -> None: else: os.environ['LD_LIBRARY_PATH'] = lib_dir + # Pre-load shared libraries into the current process so that + # c_ext (loaded via dlopen) can find them. + for so_file in sorted(glob.glob(os.path.join(lib_dir, "*.so*"))): + try: + ctypes.CDLL(so_file, mode=ctypes.RTLD_GLOBAL) + except OSError: + pass # non-critical: library may not be needed + # Add to sys.path for loading if lib_dir not in sys.path: sys.path.insert(0, lib_dir) From 5a5f7da82ac07191a40c67cc58ce438d9e490a36 Mon Sep 17 00:00:00 2001 From: leolingli Date: Sat, 11 Apr 2026 19:36:11 +0800 Subject: [PATCH 41/59] add kv_cache_dtype to sglang --- flexkv/integration/config.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 5d9946d430..1b7d7ce66a 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -152,7 +152,41 @@ def post_init_from_sglang_config( self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0)) - self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16) + # Determine KV cache dtype: prioritize user_config.kv_cache_dtype (from + # flexkv_config.yaml or FLEXKV_KV_CACHE_DTYPE env var), then fall back + # to the sglang model dtype. sglang's ModelConfig.dtype is the *model + # weight* dtype (e.g. bfloat16), which may differ from the KV cache + # dtype (e.g. fp8_e4m3 when --kv-cache-dtype fp8_e4m3 is used). + def _parse_dtype_str(dtype_str: str) -> torch.dtype: + dtype_map = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float8": torch.float8_e4m3fn, + "e4m3": torch.float8_e4m3fn, + "fp8_e4m3": torch.float8_e4m3fn, + } + return dtype_map.get(dtype_str.lower(), torch.bfloat16) + + user_dtype_str = self.user_config.kv_cache_dtype + if user_dtype_str is not None: + self.model_config.dtype = _parse_dtype_str(user_dtype_str) + logger.info( + f"[FlexKV] Using kv_cache_dtype from user_config: " + f"'{user_dtype_str}' -> {self.model_config.dtype}" + ) + else: + self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16) + logger.warning( + f"[FlexKV] No kv_cache_dtype in user_config, falling back to sglang " + f"model dtype: {self.model_config.dtype}. If your KV cache uses a " + f"different dtype (e.g. fp8), add 'kv_cache_dtype: fp8' to your " + f"flexkv_config.yaml or set FLEXKV_KV_CACHE_DTYPE=fp8 environment variable." + ) attn_arch = getattr(sglang_config, "attention_arch", None) use_mla = False From ba6c000ea4dad4272d51b96669a3a4371986c712 Mon Sep 17 00:00:00 2001 From: leolingli Date: Sat, 11 Apr 2026 21:47:37 +0800 Subject: [PATCH 42/59] add some log to info the malloc --- flexkv/storage/allocator.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index ccb99e5b32..507e08e76d 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -172,6 +172,12 @@ def allocate(cls, real_file_size = num_blocks_per_file * block_size ssd_files: Dict[int, List[str]] = {} + total_num_files = num_files_per_device * num_ssd_devices + real_total_size = total_num_files * real_file_size + flexkv_logger.info(f"SSD allocator creating {total_num_files} files in {cache_dir}, " + f"each file {real_file_size/1024/1024/1024:.2f} GB, " + f"total {real_total_size/1024/1024/1024:.2f} GB") + file_count = 0 for i in range(num_ssd_devices): ssd_files[i] = [] for j in range(num_files_per_device): @@ -179,9 +185,13 @@ def allocate(cls, with open(file_path, "wb+", buffering=0) as file: cls._create_file(file, real_file_size) ssd_files[i].append(file_path) - total_num_files = num_files_per_device * num_ssd_devices - real_total_size = total_num_files * real_file_size - flexkv_logger.info(f"SSD allocator create total {total_num_files} files in {cache_dir}, " + file_count += 1 + if file_count % max(1, total_num_files // 10) == 0 or file_count == total_num_files: + flexkv_logger.info( + f"SSD allocator progress: {file_count}/{total_num_files} files created " + f"({file_count * 100 // total_num_files}%)" + ) + flexkv_logger.info(f"SSD allocator done: {total_num_files} files in {cache_dir}, " f"each file has {real_file_size/1024/1024/1024:.2f} GB, total size {real_total_size/1024/1024/1024:.2f} GB") return StorageHandle( handle_type=AccessHandleType.FILE, From 7806055cba0b3007023039eedc990892459c05b2 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 14 Apr 2026 17:35:17 +0000 Subject: [PATCH 43/59] support cp+layerwise --- csrc/bindings.cpp | 320 ++++++++++++++--------------- csrc/tp_transfer_thread_group.cpp | 12 +- csrc/tp_transfer_thread_group.h | 3 +- flexkv/common/config.py | 5 + flexkv/integration/config.py | 28 ++- flexkv/transfer/layerwise.py | 63 ++++-- flexkv/transfer/transfer_engine.py | 15 +- flexkv/transfer/worker.py | 8 +- 8 files changed, 254 insertions(+), 200 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 3bb718a992..823412da85 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -1,8 +1,8 @@ #include #include +#include #include #include -#include #include "transfer.cuh" #include @@ -59,7 +59,7 @@ void transfer_kv_blocks_binding( void *cpu_ptr = static_cast(cpu_tensor.data_ptr()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + // Determine backend type from gpu_block_type parameter flexkv::BackendType backend_type; if (gpu_block_type == 0) { @@ -69,19 +69,16 @@ void transfer_kv_blocks_binding( } else if (gpu_block_type == 2) { backend_type = flexkv::BackendType::SGLANG; } else { - throw std::runtime_error("Unsupported gpu_block_type: " + std::to_string(gpu_block_type)); + throw std::runtime_error("Unsupported gpu_block_type: " + + std::to_string(gpu_block_type)); } - + // Create GTensorHandler flexkv::GTensorHandler handler( - backend_type, - reinterpret_cast(gpu_tensor_ptrs), - num_layers, - gpu_kv_stride_in_bytes, - gpu_block_stride_in_bytes, - gpu_layer_stride_in_bytes - ); - + backend_type, reinterpret_cast(gpu_tensor_ptrs), num_layers, + gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, + gpu_layer_stride_in_bytes); + // Dispatch to appropriate template instantiation switch (backend_type) { case flexkv::BackendType::VLLM: @@ -109,7 +106,7 @@ void transfer_kv_blocks_binding( use_ce_transfer, is_mla, sync); break; } - + cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { throw std::runtime_error(cudaGetErrorString(err)); @@ -117,22 +114,21 @@ void transfer_kv_blocks_binding( } void transfer_kv_blocks_ssd_binding( - flexkv::SSDIOCTX &ioctx, - const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, - const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids, - int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, - int64_t ssd_layer_stride_in_bytes, int64_t ssd_kv_stride_in_bytes, - int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read, - int num_blocks_per_file, int round_robin = 1, - int num_threads_per_device = 8, bool is_mla = false) { + flexkv::SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, + int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids, + const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes, + int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, + int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes, + int64_t block_stride_in_bytes, bool is_read, int num_blocks_per_file, + int round_robin = 1, int num_threads_per_device = 8, bool is_mla = false) { TORCH_CHECK(ssd_block_ids.dtype() == torch::kInt64, "ssd_block_ids must be int64"); TORCH_CHECK(cpu_block_ids.dtype() == torch::kInt64, "cpu_block_ids must be int64"); flexkv::transfer_kv_blocks_ssd( - ioctx, cpu_layer_id_list, cpu_tensor_ptr, ssd_block_ids, - cpu_block_ids, cpu_layer_stride_in_bytes, cpu_kv_stride_in_bytes, + ioctx, cpu_layer_id_list, cpu_tensor_ptr, ssd_block_ids, cpu_block_ids, + cpu_layer_stride_in_bytes, cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, ssd_kv_stride_in_bytes, chunk_size_in_bytes, block_stride_in_bytes, is_read, num_blocks_per_file, round_robin, num_threads_per_device, is_mla); @@ -281,112 +277,111 @@ void transfer_kv_blocks_gds_binding( } // GDS Manager Python bindings -py::list gds_batch_write_binding(GDSManager& manager, +py::list gds_batch_write_binding(GDSManager &manager, py::list operations_list) { - size_t batch_size = operations_list.size(); - std::vector operations(batch_size); - std::vector results(batch_size); - - for (size_t i = 0; i < batch_size; ++i) { - py::dict op_dict = operations_list[i].cast(); - operations[i].filename = op_dict["filename"].cast().c_str(); - operations[i].gpu_data = op_dict["gpu_data"].cast().data_ptr(); - operations[i].size = op_dict["size"].cast(); - operations[i].file_offset = op_dict["file_offset"].cast(); - operations[i].result = &results[i]; - } - - int batch_id = manager.batch_write(operations.data(), batch_size); - - py::list result_list; - result_list.append(batch_id); - for (size_t i = 0; i < batch_size; ++i) { - result_list.append(results[i]); - } - - return result_list; + size_t batch_size = operations_list.size(); + std::vector operations(batch_size); + std::vector results(batch_size); + + for (size_t i = 0; i < batch_size; ++i) { + py::dict op_dict = operations_list[i].cast(); + operations[i].filename = op_dict["filename"].cast().c_str(); + operations[i].gpu_data = + op_dict["gpu_data"].cast().data_ptr(); + operations[i].size = op_dict["size"].cast(); + operations[i].file_offset = op_dict["file_offset"].cast(); + operations[i].result = &results[i]; + } + + int batch_id = manager.batch_write(operations.data(), batch_size); + + py::list result_list; + result_list.append(batch_id); + for (size_t i = 0; i < batch_size; ++i) { + result_list.append(results[i]); + } + + return result_list; } -py::list gds_batch_read_binding(GDSManager& manager, - py::list operations_list) { - size_t batch_size = operations_list.size(); - std::vector operations(batch_size); - std::vector results(batch_size); - - for (size_t i = 0; i < batch_size; ++i) { - py::dict op_dict = operations_list[i].cast(); - operations[i].filename = op_dict["filename"].cast().c_str(); - operations[i].gpu_buffer = op_dict["gpu_buffer"].cast().data_ptr(); - operations[i].size = op_dict["size"].cast(); - operations[i].file_offset = op_dict["file_offset"].cast(); - operations[i].result = &results[i]; - } - - int batch_id = manager.batch_read(operations.data(), batch_size); - - py::list result_list; - result_list.append(batch_id); - for (size_t i = 0; i < batch_size; ++i) { - result_list.append(results[i]); - } - - return result_list; +py::list gds_batch_read_binding(GDSManager &manager, py::list operations_list) { + size_t batch_size = operations_list.size(); + std::vector operations(batch_size); + std::vector results(batch_size); + + for (size_t i = 0; i < batch_size; ++i) { + py::dict op_dict = operations_list[i].cast(); + operations[i].filename = op_dict["filename"].cast().c_str(); + operations[i].gpu_buffer = + op_dict["gpu_buffer"].cast().data_ptr(); + operations[i].size = op_dict["size"].cast(); + operations[i].file_offset = op_dict["file_offset"].cast(); + operations[i].result = &results[i]; + } + + int batch_id = manager.batch_read(operations.data(), batch_size); + + py::list result_list; + result_list.append(batch_id); + for (size_t i = 0; i < batch_size; ++i) { + result_list.append(results[i]); + } + + return result_list; } -ssize_t gds_write_binding(GDSManager& manager, - const std::string& filename, - torch::Tensor gpu_data, - size_t file_offset = 0) { - return manager.write(filename.c_str(), gpu_data.data_ptr(), - gpu_data.numel() * gpu_data.element_size(), file_offset); +ssize_t gds_write_binding(GDSManager &manager, const std::string &filename, + torch::Tensor gpu_data, size_t file_offset = 0) { + return manager.write(filename.c_str(), gpu_data.data_ptr(), + gpu_data.numel() * gpu_data.element_size(), file_offset); } -ssize_t gds_read_binding(GDSManager& manager, - const std::string& filename, - torch::Tensor gpu_buffer, - size_t file_offset = 0) { - return manager.read(filename.c_str(), gpu_buffer.data_ptr(), - gpu_buffer.numel() * gpu_buffer.element_size(), file_offset); +ssize_t gds_read_binding(GDSManager &manager, const std::string &filename, + torch::Tensor gpu_buffer, size_t file_offset = 0) { + return manager.read(filename.c_str(), gpu_buffer.data_ptr(), + gpu_buffer.numel() * gpu_buffer.element_size(), + file_offset); } -ssize_t gds_write_async_binding(GDSManager& manager, - const std::string& filename, - torch::Tensor gpu_data, - size_t file_offset = 0) { - return manager.write_async(filename.c_str(), gpu_data.data_ptr(), - gpu_data.numel() * gpu_data.element_size(), file_offset); +ssize_t gds_write_async_binding(GDSManager &manager, + const std::string &filename, + torch::Tensor gpu_data, + size_t file_offset = 0) { + return manager.write_async(filename.c_str(), gpu_data.data_ptr(), + gpu_data.numel() * gpu_data.element_size(), + file_offset); } -ssize_t gds_read_async_binding(GDSManager& manager, - const std::string& filename, - torch::Tensor gpu_buffer, - size_t file_offset = 0) { - return manager.read_async(filename.c_str(), gpu_buffer.data_ptr(), - gpu_buffer.numel() * gpu_buffer.element_size(), file_offset); +ssize_t gds_read_async_binding(GDSManager &manager, const std::string &filename, + torch::Tensor gpu_buffer, + size_t file_offset = 0) { + return manager.read_async(filename.c_str(), gpu_buffer.data_ptr(), + gpu_buffer.numel() * gpu_buffer.element_size(), + file_offset); } // Helper function to create and initialize a GDS file with specified size -bool create_gds_file_binding(GDSManager& manager, - const std::string& filename, +bool create_gds_file_binding(GDSManager &manager, const std::string &filename, size_t file_size) { - // First create/truncate the file to the desired size - int fd = open(filename.c_str(), O_CREAT | O_RDWR | O_TRUNC, 0644); - if (fd < 0) { - return false; - } - - // Pre-allocate the file to the specified size - if (ftruncate(fd, file_size) != 0) { - close(fd); - return false; - } - - // Ensure data is written to disk - fsync(fd); + // First create/truncate the file to the desired size + int fd = open(filename.c_str(), O_CREAT | O_RDWR | O_TRUNC, 0644); + if (fd < 0) { + return false; + } + + // Pre-allocate the file to the specified size + if (ftruncate(fd, file_size) != 0) { close(fd); - - // Now add the file to GDS manager (this will open it with O_DIRECT and register with cuFile) - return manager.add_file(filename.c_str()); + return false; + } + + // Ensure data is written to disk + fsync(fd); + close(fd); + + // Now add the file to GDS manager (this will open it with O_DIRECT and + // register with cuFile) + return manager.add_file(filename.c_str()); } #endif @@ -422,35 +417,35 @@ PYBIND11_MODULE(c_ext, m) { py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false); py::class_(m, "LayerwiseTransferGroup") - .def(py::init> &, - torch::Tensor &, std::map> &, - int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, - torch::Tensor &, int, int, torch::Tensor &, int>(), - py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), - py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), - py::arg("gpu_kv_strides_tensor"), - py::arg("gpu_block_strides_tensor"), - py::arg("gpu_layer_strides_tensor"), - py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), - py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"), - py::arg("tp_size")) - .def("layerwise_transfer", - &flexkv::LayerwiseTransferGroup::layerwise_transfer, - py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), - py::arg("ssd_layer_stride_in_bytes"), - py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), - py::arg("round_robin"), py::arg("num_threads_per_device"), - py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), - py::arg("cpu_kv_stride_in_bytes"), - py::arg("cpu_layer_stride_in_bytes"), - py::arg("cpu_block_stride_in_bytes"), - py::arg("cpu_chunk_size_in_bytes"), - py::arg("h2d_cpu_kv_stride_in_bytes"), - py::arg("h2d_cpu_layer_stride_in_bytes"), - py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_cta_num"), - py::arg("use_ce_transfer"), py::arg("num_layers"), - py::arg("layer_granularity"), py::arg("is_mla"), - py::arg("counter_id") = 0); + .def(py::init> &, + torch::Tensor &, std::map> &, + int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, + torch::Tensor &, int, int, torch::Tensor &, int>(), + py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), + py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("gpu_kv_strides_tensor"), + py::arg("gpu_block_strides_tensor"), + py::arg("gpu_layer_strides_tensor"), + py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), + py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"), + py::arg("tp_size")) + .def("layerwise_transfer", + &flexkv::LayerwiseTransferGroup::layerwise_transfer, + py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), + py::arg("ssd_layer_stride_in_bytes"), + py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), + py::arg("round_robin"), py::arg("num_threads_per_device"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), + py::arg("cpu_kv_stride_in_bytes"), + py::arg("cpu_layer_stride_in_bytes"), + py::arg("cpu_block_stride_in_bytes"), + py::arg("cpu_chunk_size_in_bytes"), + py::arg("h2d_cpu_kv_stride_in_bytes"), + py::arg("h2d_cpu_layer_stride_in_bytes"), + py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_cta_num"), + py::arg("use_ce_transfer"), py::arg("num_layers"), + py::arg("layer_granularity"), py::arg("is_mla"), + py::arg("counter_id") = 0); #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, "Transfer KV blocks between remote and CPU memory", @@ -488,7 +483,8 @@ PYBIND11_MODULE(c_ext, m) { py::arg("block_hashes")); py::class_(m, "SSDIOCTX") - .def(py::init> &, int, int, int>()); + .def( + py::init> &, int, int, int>()); py::class_(m, "TPTransferThreadGroup") .def(py::init &, int, int64_t, int, int, @@ -510,8 +506,8 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_block_stride_in_bytes"), py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_num_cta"), py::arg("is_host_to_device"), py::arg("use_ce_transfer"), - py::arg("layer_id"), py::arg("layer_granularity"), - py::arg("is_mla")); + py::arg("layer_id"), py::arg("layer_granularity"), py::arg("is_mla"), + py::arg("use_sharded_d2h") = false); #ifdef FLEXKV_ENABLE_GDS py::class_(m, "TPGDSTransferThreadGroup") @@ -622,7 +618,8 @@ PYBIND11_MODULE(c_ext, m) { py::arg("evicted_blocks"), py::arg("evicted_block_hashes"), py::arg("num_evicted"), py::call_guard()) .def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks) - .def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks) + .def("total_unready_blocks", + &flexkv::CRadixTreeIndex::total_unready_blocks) .def("total_ready_blocks", &flexkv::CRadixTreeIndex::total_ready_blocks) .def("match_prefix", &flexkv::CRadixTreeIndex::match_prefix, py::arg("block_hashes"), py::arg("num_blocks"), @@ -654,33 +651,35 @@ PYBIND11_MODULE(c_ext, m) { #ifdef FLEXKV_ENABLE_GDS // Add GDS Manager class binding py::class_(m, "GDSManager") - .def(py::init>&, int, int>(), + .def(py::init> &, int, int>(), "Initialize GDS Manager with device-organized files", - py::arg("ssd_files"), py::arg("num_devices"), py::arg("round_robin") = 1) + py::arg("ssd_files"), py::arg("num_devices"), + py::arg("round_robin") = 1) .def("is_ready", &GDSManager::is_ready, "Check if GDS manager is ready for operations") .def("get_last_error", &GDSManager::get_last_error, "Get the last error message") .def("add_file", &GDSManager::add_file, - "Add and register a file with GDS (creates with O_DIRECT)", py::arg("filename")) + "Add and register a file with GDS (creates with O_DIRECT)", + py::arg("filename")) .def("remove_file", &GDSManager::remove_file, "Remove and unregister a file from GDS", py::arg("filename")) - .def("write", &gds_write_binding, - "Write data from GPU memory to file", + .def("write", &gds_write_binding, "Write data from GPU memory to file", py::arg("filename"), py::arg("gpu_data"), py::arg("file_offset") = 0) - .def("read", &gds_read_binding, - "Read data from file to GPU memory", - py::arg("filename"), py::arg("gpu_buffer"), py::arg("file_offset") = 0) + .def("read", &gds_read_binding, "Read data from file to GPU memory", + py::arg("filename"), py::arg("gpu_buffer"), + py::arg("file_offset") = 0) .def("write_async", &gds_write_async_binding, "Write data from GPU memory to file asynchronously", py::arg("filename"), py::arg("gpu_data"), py::arg("file_offset") = 0) .def("read_async", &gds_read_async_binding, "Read data from file to GPU memory asynchronously", - py::arg("filename"), py::arg("gpu_buffer"), py::arg("file_offset") = 0) - .def("batch_write", &gds_batch_write_binding, - "Batch write operations", py::arg("operations")) - .def("batch_read", &gds_batch_read_binding, - "Batch read operations", py::arg("operations")) + py::arg("filename"), py::arg("gpu_buffer"), + py::arg("file_offset") = 0) + .def("batch_write", &gds_batch_write_binding, "Batch write operations", + py::arg("operations")) + .def("batch_read", &gds_batch_read_binding, "Batch read operations", + py::arg("operations")) .def("batch_synchronize", &GDSManager::batch_synchronize, "Wait for batch operations to complete", py::arg("batch_id")) .def("synchronize", &GDSManager::synchronize, @@ -694,8 +693,7 @@ PYBIND11_MODULE(c_ext, m) { .def("get_round_robin", &GDSManager::get_round_robin, "Get round-robin granularity") .def("get_file_paths", &GDSManager::get_file_paths, - "Get file paths for a specific device", - py::arg("device_id")) + "Get file paths for a specific device", py::arg("device_id")) .def("create_gds_file", &create_gds_file_binding, "Create and register a GDS file with specified size", py::arg("filename"), py::arg("file_size")); diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 349cc62cc8..9d0886d72b 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -156,7 +156,8 @@ void TPTransferThreadGroup::tp_group_transfer( const int64_t cpu_block_stride_in_bytes, const int64_t cpu_tp_stride_in_bytes, const int transfer_num_cta, const bool is_host_to_device, const bool use_ce_transfer, - const int layer_id, const int layer_granularity, const bool is_mla) { + const int layer_id, const int layer_granularity, const bool is_mla, + const bool use_sharded_d2h) { std::atomic failed{false}; std::string error_msg; @@ -177,20 +178,21 @@ void TPTransferThreadGroup::tp_group_transfer( int64_t *cpu_block_ids = static_cast(cpu_block_id_tensor.data_ptr()); void *cpu_ptr = cpu_blocks_; + bool should_use_sharded_offsets = is_mla || use_sharded_d2h; int64_t cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; - if (is_mla && !is_host_to_device) { + if (should_use_sharded_offsets && !is_host_to_device) { cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_; - } else if (is_mla && is_host_to_device) { + } else if (should_use_sharded_offsets && is_host_to_device) { cpu_startoff_inside_chunks = 0; } int64_t gpu_startoff_inside_chunks = - is_mla && !is_host_to_device + should_use_sharded_offsets && !is_host_to_device ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : 0; // we assume that the chunk size is the same for all gpus, // even if they have different number of gpu_blocks - int64_t chunk_size = is_mla && !is_host_to_device + int64_t chunk_size = should_use_sharded_offsets && !is_host_to_device ? gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i]; diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 0aceaf2a68..08bf2846ff 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -56,7 +56,8 @@ class TPTransferThreadGroup { const int transfer_num_cta, const bool is_host_to_device, const bool use_ce_transfer, const int layer_id, - const int layer_granularity, const bool is_mla); + const int layer_granularity, const bool is_mla, + const bool use_sharded_d2h); private: using Task = std::function; diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 7a64b6f7a1..e406e51f33 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -36,6 +36,11 @@ class ModelConfig: pp_size: int = 1 pp_rank: int = 0 + # NSA context parallelism: when True, layerwise transfer sends full + # (unpartitioned) KV cache to every rank instead of head-sliced data. + is_nsa_cp: bool = False + cp_size: int = 1 + @property def token_size_in_bytes(self) -> int: kv_dim = 1 if self.use_mla else 2 diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 1b7d7ce66a..051c4c622f 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -116,6 +116,9 @@ def post_init_from_sglang_config( pp_rank: int = 0, dp_size: int = 1, dp_rank: int = 0, + is_nsa_cp: bool = False, + cp_size: int = 1, + cp_rank: int = 0, ): """ Initialize FlexKVConfig fields from sglang config. @@ -128,6 +131,9 @@ def post_init_from_sglang_config( pp_rank: pipeline parallel rank (default 0) dp_size: data parallel size (default 1, no DP) dp_rank: data parallel rank (default 0) + is_nsa_cp: whether NSA context parallelism is enabled + cp_size: context parallel size (default 1, no CP) + cp_rank: context parallel rank (default 0) """ # cache config: use page_size as tokens_per_block so that FlexKV's # CPU radix tree manages blocks at page granularity, ensuring that @@ -201,6 +207,8 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.dp_rank = int(dp_rank if dp_rank is not None else 0) self.model_config.pp_size = int(pp_size) self.model_config.pp_rank = int(pp_rank) + self.model_config.is_nsa_cp = is_nsa_cp + self.model_config.cp_size = int(cp_size if cp_size is not None else 1) update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) # Each PP rank needs its own IPC ports so that their @@ -217,7 +225,7 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: rank_parts = [] if int(tp_size) > 1: - rank_parts.append(f"tp_rank=0") + rank_parts.append("tp_rank=0") if int(pp_size) > 1: rank_parts.append(f"pp_rank={int(pp_rank)}") if int(self.model_config.dp_size) > 1: @@ -249,7 +257,7 @@ def post_init_from_trt_config( # Convert dtype string to torch.dtype dtype_str = config.pytorch_backend_config.kv_cache_dtype flexkv_logger.info(f"[FlexKVConfig] dtype_str from TRT config: {dtype_str}") - + # Helper function to convert dtype string to torch.dtype def _parse_dtype_str(dtype_str: str) -> torch.dtype: dtype_map = { @@ -259,12 +267,12 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: "fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, + "fp8": torch.float8_e4m3fn, "float8": torch.float8_e4m3fn, - "e4m3": torch.float8_e4m3fn, + "e4m3": torch.float8_e4m3fn, } return dtype_map.get(dtype_str.lower(), torch.bfloat16) - + if dtype_str == "auto": # When dtype_str is "auto", try to get kv_cache_dtype from user_config first # This allows users to specify kv_cache_dtype in flexkv_config.json or via environment variable @@ -287,7 +295,7 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.dtype = _parse_dtype_str(dtype_str) else: self.model_config.dtype = dtype_str - + # Set model config (parallel configs part) if config.mapping.enable_attention_dp: self.model_config.tp_size = 1 @@ -297,19 +305,19 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.dp_size = 1 self.model_config.pp_size = getattr(config.mapping, 'pp_size', 1) self.model_config.pp_rank = getattr(config.mapping, 'pp_rank', 0) - + # self.model_config (model configs part) try: model_path = getattr(config, 'hf_model_dir', None) from transformers import AutoConfig as HFAutoConfig hf_config = HFAutoConfig.from_pretrained( - str(model_path), + str(model_path), trust_remote_code=True ) self.model_config.num_layers = hf_config.num_hidden_layers - self.model_config.use_mla = (hasattr(hf_config, 'kv_lora_rank') and + self.model_config.use_mla = (hasattr(hf_config, 'kv_lora_rank') and hf_config.kv_lora_rank is not None and - hasattr(hf_config, 'qk_rope_head_dim') and + hasattr(hf_config, 'qk_rope_head_dim') and hf_config.qk_rope_head_dim is not None) if self.model_config.use_mla: self.model_config.head_size = hf_config.kv_lora_rank + hf_config.qk_rope_head_dim diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index a97793fe08..8d945a899a 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -73,7 +73,8 @@ def __init__(self, use_ce_transfer_d2h: bool = False, h2d_cta_num: int = 4, d2h_cta_num: int = 4, - enable_eventfd: bool = True) -> None: + enable_eventfd: bool = True, + is_nsa_cp: bool = False) -> None: flexkv_logger.debug( f"[LayerwiseWorker] __init__ started: worker_id={worker_id}, " f"tp_group_size={tp_group_size}, dp_group_id={dp_group_id}, " @@ -99,6 +100,7 @@ def __init__(self, self.dp_group_id = dp_group_id self.dp_size = dp_size if dp_size > 0 else 1 self.dp_rank = dp_rank + self.is_nsa_cp = is_nsa_cp # initialize GPU storage self.num_layers = gpu_kv_layouts[0].num_layer @@ -122,11 +124,18 @@ def __init__(self, else: raise ValueError(f"Invalid GPU block type: {num_blocks_first_gpu}") + flexkv_logger.debug(f"[LayerwiseWorker] About to receive eventfds, enable_eventfd={enable_eventfd}") + if enable_eventfd: + layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) + else: + layer_eventfds_tensor = torch.empty(0, dtype=torch.int32) + flexkv_logger.debug(f"[LayerwiseWorker] Eventfds received, tensor shape={layer_eventfds_tensor.shape}") + # initialize CPU storage flexkv_logger.info(f"[LayerwiseWorker] Pinning CPU Memory: " f"{cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) - flexkv_logger.debug(f"[LayerwiseWorker] CPU memory pinned successfully") + flexkv_logger.debug("[LayerwiseWorker] CPU memory pinned successfully") self.cpu_blocks = cpu_blocks self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize @@ -135,13 +144,19 @@ def __init__(self, self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize # TP-divided CPU strides (for CPU->GPU, each rank reads its own portion) - if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: - cpu_kv_layout_tp = cpu_kv_layout.div_head(self.tp_group_size) - else: + if self.is_nsa_cp: + # CP: no head partitioning, every rank gets the full KV cache cpu_kv_layout_tp = cpu_kv_layout + self.cpu_tp_stride_in_bytes = 0 + else: + # TP: partition by heads, each rank reads a different head slice + if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: + cpu_kv_layout_tp = cpu_kv_layout.div_head(self.tp_group_size) + else: + cpu_kv_layout_tp = cpu_kv_layout + self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size self.h2d_cpu_kv_stride_in_bytes = cpu_kv_layout_tp.get_kv_stride() * self.dtype.itemsize self.h2d_cpu_layer_stride_in_bytes = cpu_kv_layout_tp.get_layer_stride() * self.dtype.itemsize - self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h @@ -172,15 +187,8 @@ def __init__(self, gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) - flexkv_logger.debug(f"[LayerwiseWorker] About to receive eventfds, enable_eventfd={enable_eventfd}") - if enable_eventfd: - layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) - else: - layer_eventfds_tensor = torch.empty(0, dtype=torch.int32) - flexkv_logger.debug(f"[LayerwiseWorker] Eventfds received, tensor shape={layer_eventfds_tensor.shape}") - # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers - flexkv_logger.debug(f"[LayerwiseWorker] Creating LayerwiseTransferGroup...") + flexkv_logger.debug("[LayerwiseWorker] Creating LayerwiseTransferGroup...") self.layerwise_transfer_group = LayerwiseTransferGroup( self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, dp_group_id, self.num_layers, @@ -261,20 +269,33 @@ def cleanup_socket(): break with conn: - metadata = conn.recv(16) + # Accept both 16-byte (legacy: tp_rank, tp_size, num_layers, num_counters) + # and 24-byte (new: tp_rank, tp_size, cp_rank, cp_size, num_layers, num_counters) + metadata = conn.recv(24) if len(metadata) < 16: flexkv_logger.error( f"[LayerwiseWorker] Incomplete metadata on {socket_path}{rank_label}: " f"{len(metadata)} bytes") continue - tp_rank, _, recv_num_layers, recv_num_counters = struct.unpack("iiii", metadata) + if len(metadata) >= 24: + tp_rank, _, cp_rank, cp_size, recv_num_layers, recv_num_counters = \ + struct.unpack("iiiiii", metadata[:24]) + else: + tp_rank, _, recv_num_layers, recv_num_counters = \ + struct.unpack("iiii", metadata[:16]) + cp_rank, cp_size = 0, 1 + + # Use cp_rank as the connection key when CP is active, + # otherwise use tp_rank + rank_key = cp_rank if cp_size > 1 else tp_rank if conn_idx == 0: num_layers, num_counters = recv_num_layers, recv_num_counters flexkv_logger.debug( f"[LayerwiseWorker] Connection {conn_idx + 1}: " - f"tp_rank={tp_rank}, num_layers={recv_num_layers}, " + f"tp_rank={tp_rank}, cp_rank={cp_rank}, cp_size={cp_size}, " + f"num_layers={recv_num_layers}, " f"num_counters={recv_num_counters}") rank_eventfds = {} @@ -284,12 +305,12 @@ def cleanup_socket(): rank_eventfds[counter_id] = fds flexkv_logger.debug( f"[LayerwiseWorker] Received counter_id={counter_id}, " - f"num_fds={len(fds)} from tp_rank={tp_rank}") + f"num_fds={len(fds)} from rank_key={rank_key}") - all_rank_eventfds[tp_rank] = rank_eventfds + all_rank_eventfds[rank_key] = rank_eventfds flexkv_logger.info( - f"[LayerwiseWorker] Received all eventfds from tp_rank={tp_rank} " - f"on {socket_path}") + f"[LayerwiseWorker] Received all eventfds from rank_key={rank_key} " + f"(tp_rank={tp_rank}, cp_rank={cp_rank}) on {socket_path}") except Exception as e: flexkv_logger.error( f"[LayerwiseWorker] Error in accept loop on {socket_path}{rank_label}: {e}") diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 8490781bf7..1c345b0082 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -207,6 +207,8 @@ def _init_workers(self) -> None: dtype=gpu_handles[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, + is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), + cp_size=getattr(self.model_config, "cp_size", 1), use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, @@ -226,6 +228,8 @@ def _init_workers(self) -> None: dtype=gpu_handles[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, + is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), + cp_size=getattr(self.model_config, "cp_size", 1), use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, @@ -329,6 +333,10 @@ def _init_workers(self) -> None: ssd_files = {} if self._ssd_handle is None else self._ssd_handle.get_file_list() ssd_kv_layout = None if self._ssd_handle is None else self._ssd_handle.kv_layout num_blocks_per_file = 0 if self._ssd_handle is None else self._ssd_handle.num_blocks_per_file + _is_nsa_cp = getattr(self.model_config, 'is_nsa_cp', False) + _cp_size = getattr(self.model_config, 'cp_size', 1) + # For CP, each CP rank connects via eventfd; for TP, each TP rank connects. + _eventfd_group_size = _cp_size if _is_nsa_cp and _cp_size > 1 else self.tp_size self.layerwise_workers = [ LayerwiseTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -341,7 +349,7 @@ def _init_workers(self) -> None: cpu_kv_layout=self._cpu_handle.kv_layout, ssd_kv_layout=ssd_kv_layout, dtype=gpu_handles[0].dtype, - tp_group_size=self.tp_size, + tp_group_size=_eventfd_group_size, dp_group_id=dp_client_id, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size, @@ -352,6 +360,7 @@ def _init_workers(self) -> None: use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + is_nsa_cp=_is_nsa_cp, ) for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] @@ -439,6 +448,8 @@ def _init_workers(self) -> None: dtype=indexer_gpu_handles_list[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, + is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), + cp_size=getattr(self.model_config, "cp_size", 1), use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, @@ -458,6 +469,8 @@ def _init_workers(self) -> None: dtype=indexer_gpu_handles_list[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, + is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), + cp_size=getattr(self.model_config, "cp_size", 1), use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 0933efce6a..3fa30fcef7 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -433,6 +433,8 @@ def __init__(self, dtype: torch.dtype, tp_group_size: int, dp_group_id: int, + is_nsa_cp: bool = False, + cp_size: int = 1, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, transfer_num_cta_h2d: int = 4, @@ -454,6 +456,7 @@ def __init__(self, self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size self.dp_group_id = dp_group_id + self.use_cp_nsa_d2h_shard = bool(is_nsa_cp and cp_size > 1) flexkv_logger.info(f"Pinning CPU Memory: {cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) @@ -498,6 +501,8 @@ def __init__(self, gpu_device_ids = [self.gpu_blocks[i][0].device.index for i in range(self.num_gpus)] num_tensors_per_gpu = len(self.gpu_blocks[0]) + flexkv_logger.info(f"num_tensors_per_gpu: {num_tensors_per_gpu}") + self.tp_transfer_thread_group = TPTransferThreadGroup( self.num_gpus, gpu_block_ptrs_flat, @@ -557,6 +562,7 @@ def _transfer_impl(self, layer_id, layer_granularity, self.is_mla, + self.use_cp_nsa_d2h_shard, ) @@ -2127,7 +2133,7 @@ def unregist_node_meta(self, node_id: int = None) -> None: def get_node_meta(self, node_id: int) -> Optional[NodeMetaInfo]: """Get the node meta info by node id. - + Before returning cached or freshly-fetched meta, we verify that the node is still active (its node: key exists in Redis and has not expired). This prevents RDMA transfers to stale addresses after a From d56c4049364bfbd2498b2437566e8ca4ee20106c Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 15 Apr 2026 03:05:45 +0000 Subject: [PATCH 44/59] fix empty token mask --- flexkv/cache/cache_engine.py | 59 ++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index a37f0d2d7b..292523bb48 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -76,7 +76,7 @@ def __init__(self, self.num_total_blocks = num_total_blocks self.evict_ratio = evict_ratio self.evict_start_threshold = evict_start_threshold - + self.event_collector = event_collector self._metrics_collector = metrics_collector @@ -156,25 +156,25 @@ def take(self, strict: bool = True) -> torch.Tensor: # Calculate current utilization utilization = (self.mempool.num_total_blocks - self.mempool.num_free_blocks) / self.mempool.num_total_blocks if self.mempool.num_total_blocks > 0 else 0 - + # Proactive eviction: trigger when utilization exceeds threshold OR when blocks are needed should_evict = (utilization >= self.evict_start_threshold) or (num_required_blocks > self.mempool.num_free_blocks) - + if should_evict: if protected_node is not None: self.index.lock(protected_node) - + # Calculate how many blocks to evict # Goal: maintain free blocks above (1 - evict_start_threshold) ratio target_free_blocks = int(self.mempool.num_total_blocks * (1.0 - self.evict_start_threshold)) evict_to_reach_target = max(0, target_free_blocks - self.mempool.num_free_blocks) - + evict_block_num = max( num_required_blocks - self.mempool.num_free_blocks, # At least meet current demand evict_to_reach_target, # Or reach target free ratio int(self.mempool.num_total_blocks * self.evict_ratio) if self.evict_ratio > 0 else 0 # Or minimum evict_ratio ) - + if evict_block_num > 0: target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) evicted_block_hashes = torch.zeros(evict_block_num, dtype=torch.int64) @@ -196,18 +196,18 @@ def take(self, ) if protected_node is not None: self.index.unlock(protected_node) - + if strict and num_required_blocks > self.mempool.num_free_blocks: raise RuntimeError(f"Not enough free blocks to take, " f"required: {num_required_blocks}, " f"available: {self.mempool.num_free_blocks}") num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) allocated_blocks = self.mempool.allocate_blocks(num_allocated_blocks) - + # Record allocation metrics if self._metrics_collector is not None and num_allocated_blocks > 0: self._metrics_collector.record_allocation(DEVICE_TYPE[self.device_type].lower(), num_allocated_blocks) - + return allocated_blocks def recycle(self, physical_blocks: np.ndarray) -> None: @@ -290,19 +290,19 @@ def take(self, strict: bool = True) -> np.ndarray: # Calculate current utilization utilization = (self.mempool.num_total_blocks - self.mempool.num_free_blocks) / self.mempool.num_total_blocks if self.mempool.num_total_blocks > 0 else 0 - + # Proactive eviction: trigger when utilization exceeds threshold OR when blocks are needed should_evict = (utilization >= self.evict_start_threshold) or (num_required_blocks > self.mempool.num_free_blocks) - + if should_evict: if protected_node is not None: self.index.lock(protected_node) - + # Calculate how many blocks to evict # Goal: maintain free blocks above (1 - evict_start_threshold) ratio target_free_blocks = int(self.mempool.num_total_blocks * (1.0 - self.evict_start_threshold)) evict_to_reach_target = max(0, target_free_blocks - self.mempool.num_free_blocks) - + evict_block_num = max( num_required_blocks - self.mempool.num_free_blocks, # At least meet current demand evict_to_reach_target, # Or reach target free ratio @@ -311,28 +311,28 @@ def take(self, if evict_block_num > 0: evicted_blocks, evicted_block_hashes = self.index.evict(evict_block_num) self.mempool.recycle_blocks(evicted_blocks) - + # Record eviction metrics if self._metrics_collector is not None and len(evicted_blocks) > 0: self._metrics_collector.record_eviction(DEVICE_TYPE[self.device_type].lower(), len(evicted_blocks)) - + if self.event_collector is not None: self.event_collector.publish_removed(block_hashes=evicted_block_hashes, medium=DEVICE_TYPE[self.device_type]) if protected_node is not None: self.index.unlock(protected_node) - + if strict and num_required_blocks > self.mempool.num_free_blocks: raise RuntimeError("Not enough free blocks to take, ", f"required: {num_required_blocks}, " f"available: {self.mempool.num_free_blocks}") num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) allocated_blocks = self.mempool.allocate_blocks(num_allocated_blocks) - + # Record allocation metrics if self._metrics_collector is not None and num_allocated_blocks > 0: self._metrics_collector.record_allocation(DEVICE_TYPE[self.device_type].lower(), num_allocated_blocks) - + return allocated_blocks def recycle(self, physical_blocks: np.ndarray) -> None: @@ -477,7 +477,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0) self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]] = \ lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0, 0) - + # Update initial mempool stats self._update_mempool_metrics() @@ -509,10 +509,10 @@ def _update_mempool_metrics(self) -> None: engine.mempool.num_total_blocks, engine.mempool.num_free_blocks ) - + def _record_transfer_ops(self, transfer_graph: TransferOpGraph, operation: str) -> None: """Record metrics for all transfer operations in the graph. - + Args: transfer_graph: The transfer operation graph operation: Operation type ("get" or "put") @@ -559,6 +559,13 @@ def get(self, aligned_token_ids = token_ids[:aligned_length] token_mask[aligned_length:] = False + if aligned_length == 0 or not token_mask.any(): + transfer_graph = TransferOpGraph.create_empty_graph() + transfer_graph.bind_to_dp_group(dp_id) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) + callback = partial(self._transfer_callback, node_to_unlock={}, buffer_to_free={}) + return transfer_graph, return_mask, callback, {}, -1 + block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, @@ -624,12 +631,12 @@ def get(self, device_type=op_node_to_ready[op_id][0], node_to_ready=op_node_to_ready[op_id][1], ready_length=op_node_to_ready[op_id][2]) - + # Record metrics for GET operation if self._metrics_collector is not None: self._record_transfer_ops(transfer_graph, "get") self._update_mempool_metrics() - + return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id def _get_impl_global(self, @@ -967,7 +974,7 @@ def _get_impl_local(self, transfer_graph.add_transfer_op(op_peerh2h) #TODO here we dont combine peer cpu or local cpu match results, so we can safely add remote results to local cpu #TODO here assume all matched blocks are ready blocks for peer cpu - if (cpu_matched_result.insert_to_local_cpu_index and + if (cpu_matched_result.insert_to_local_cpu_index and cpu_matched_result.num_ready_matched_blocks >= block_mask_start and cpu_matched_result.num_ready_matched_blocks == cpu_matched_result.num_matched_blocks): cpu_node_to_unlock = self.cpu_cache_engine.insert(sequence_meta, @@ -1361,11 +1368,11 @@ def _put_impl_local(self, :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ :ssd_matched_result.num_matched_blocks][block_mask_start:block_mask_end] - + #if len(cpu_matched_blocks) > len(ssd_matched_blocks): # print(f"[PUT_LOCAL] CPU matched blocks are greater than SSD matched blocks, skipping") # return self._empty_put_return(request_id) - + num_skipped_blocks = len(cpu_matched_blocks) fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks From 930c2a984d16e79addcc8ca03280aed81a505d00 Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Wed, 15 Apr 2026 15:25:43 +0800 Subject: [PATCH 45/59] fix: make eventfd accept loop resilient to per-connection failures - Replace fixed-count for loop with while loop that continues until all ranks have registered or deadline is reached - Add per-connection try-except so a single SCM_RIGHTS failure does not abort the entire accept loop - Increase listen backlog to tp_group_size*3 to accommodate client retries on failed connections - Use per-connection timeout with overall deadline instead of a single global timeout --- flexkv/transfer/layerwise.py | 129 +++++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 53 deletions(-) diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 8d945a899a..7700232a70 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -237,7 +237,8 @@ def cleanup_socket(): try: server_sock.bind(socket_path) - server_sock.listen(tp_group_size) + # Use a larger backlog to accommodate client retries on failed connections + server_sock.listen(tp_group_size * 3) os.chmod(socket_path, 0o777) flexkv_logger.info( f"[LayerwiseWorker] Eventfd server created{rank_label}: " @@ -248,72 +249,94 @@ def cleanup_socket(): server_sock.close() return torch.empty(0, dtype=torch.int32) - server_sock.settimeout(max_retries * retry_interval) + # Use a per-connection timeout instead of a global one so that + # failed connections can be retried by the client without the server + # giving up too early. The total deadline is still bounded. + per_conn_timeout = 30 # seconds per accept() call + total_deadline = time.time() + max_retries * retry_interval + server_sock.settimeout(per_conn_timeout) all_rank_eventfds: Dict[int, Dict[int, List[int]]] = {} num_layers, num_counters = self.num_layers, 3 + conn_idx = 0 try: - for conn_idx in range(tp_group_size): - flexkv_logger.debug( - f"[LayerwiseWorker] Waiting for connection " - f"{conn_idx + 1}/{tp_group_size} on {socket_path}...") + # Keep accepting until we have eventfds from all ranks or deadline. + while len(all_rank_eventfds) < tp_group_size: + if time.time() > total_deadline: + flexkv_logger.error( + f"[LayerwiseWorker] Deadline exceeded on {socket_path}{rank_label}, " + f"received {len(all_rank_eventfds)}/{tp_group_size} ranks") + break + + remaining = total_deadline - time.time() + server_sock.settimeout(min(per_conn_timeout, max(remaining, 1))) + try: conn, _ = server_sock.accept() + conn_idx += 1 flexkv_logger.info( f"[LayerwiseWorker] Accepted connection " - f"{conn_idx + 1}/{tp_group_size} on {socket_path}{rank_label}") + f"{conn_idx} (registered {len(all_rank_eventfds)}/{tp_group_size}) " + f"on {socket_path}{rank_label}") except socket.timeout: - flexkv_logger.error( + flexkv_logger.warning( f"[LayerwiseWorker] Timeout waiting for connection on {socket_path}{rank_label}, " - f"received {conn_idx}/{tp_group_size}") - break + f"registered {len(all_rank_eventfds)}/{tp_group_size}, retrying...") + continue - with conn: - # Accept both 16-byte (legacy: tp_rank, tp_size, num_layers, num_counters) - # and 24-byte (new: tp_rank, tp_size, cp_rank, cp_size, num_layers, num_counters) - metadata = conn.recv(24) - if len(metadata) < 16: - flexkv_logger.error( - f"[LayerwiseWorker] Incomplete metadata on {socket_path}{rank_label}: " - f"{len(metadata)} bytes") - continue - - if len(metadata) >= 24: - tp_rank, _, cp_rank, cp_size, recv_num_layers, recv_num_counters = \ - struct.unpack("iiiiii", metadata[:24]) - else: - tp_rank, _, recv_num_layers, recv_num_counters = \ - struct.unpack("iiii", metadata[:16]) - cp_rank, cp_size = 0, 1 - - # Use cp_rank as the connection key when CP is active, - # otherwise use tp_rank - rank_key = cp_rank if cp_size > 1 else tp_rank - if conn_idx == 0: - num_layers, num_counters = recv_num_layers, recv_num_counters - - flexkv_logger.debug( - f"[LayerwiseWorker] Connection {conn_idx + 1}: " - f"tp_rank={tp_rank}, cp_rank={cp_rank}, cp_size={cp_size}, " - f"num_layers={recv_num_layers}, " - f"num_counters={recv_num_counters}") - - rank_eventfds = {} - for _ in range(recv_num_counters): - fds, extra_data = _recv_fds(conn, recv_num_layers) - counter_id = struct.unpack("i", extra_data[:4])[0] - rank_eventfds[counter_id] = fds - flexkv_logger.debug( - f"[LayerwiseWorker] Received counter_id={counter_id}, " - f"num_fds={len(fds)} from rank_key={rank_key}") + try: + with conn: + # Accept both 16-byte (legacy: tp_rank, tp_size, num_layers, num_counters) + # and 24-byte (new: tp_rank, tp_size, cp_rank, cp_size, num_layers, num_counters) + metadata = conn.recv(24) + if len(metadata) < 16: + flexkv_logger.error( + f"[LayerwiseWorker] Incomplete metadata on {socket_path}{rank_label}: " + f"{len(metadata)} bytes") + continue + + if len(metadata) >= 24: + tp_rank, _, cp_rank, cp_size, recv_num_layers, recv_num_counters = \ + struct.unpack("iiiiii", metadata[:24]) + else: + tp_rank, _, recv_num_layers, recv_num_counters = \ + struct.unpack("iiii", metadata[:16]) + cp_rank, cp_size = 0, 1 + + # Use cp_rank as the connection key when CP is active, + # otherwise use tp_rank + rank_key = cp_rank if cp_size > 1 else tp_rank + if not all_rank_eventfds: + num_layers, num_counters = recv_num_layers, recv_num_counters - all_rank_eventfds[rank_key] = rank_eventfds - flexkv_logger.info( - f"[LayerwiseWorker] Received all eventfds from rank_key={rank_key} " - f"(tp_rank={tp_rank}, cp_rank={cp_rank}) on {socket_path}") + flexkv_logger.debug( + f"[LayerwiseWorker] Connection {conn_idx}: " + f"tp_rank={tp_rank}, cp_rank={cp_rank}, cp_size={cp_size}, " + f"num_layers={recv_num_layers}, " + f"num_counters={recv_num_counters}") + + rank_eventfds = {} + for _ in range(recv_num_counters): + fds, extra_data = _recv_fds(conn, recv_num_layers) + counter_id = struct.unpack("i", extra_data[:4])[0] + rank_eventfds[counter_id] = fds + flexkv_logger.debug( + f"[LayerwiseWorker] Received counter_id={counter_id}, " + f"num_fds={len(fds)} from rank_key={rank_key}") + + all_rank_eventfds[rank_key] = rank_eventfds + flexkv_logger.info( + f"[LayerwiseWorker] Received all eventfds from rank_key={rank_key} " + f"(tp_rank={tp_rank}, cp_rank={cp_rank}) on {socket_path}") + except Exception as e: + flexkv_logger.warning( + f"[LayerwiseWorker] Failed to receive eventfds from connection {conn_idx} " + f"on {socket_path}{rank_label}: {e}. " + f"Client will retry, continuing accept loop...") + continue except Exception as e: flexkv_logger.error( - f"[LayerwiseWorker] Error in accept loop on {socket_path}{rank_label}: {e}") + f"[LayerwiseWorker] Fatal error in accept loop on {socket_path}{rank_label}: {e}") finally: server_sock.close() cleanup_socket() From b427e7ef92d03a2b7e4079dfdd930a4743c8a42b Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Wed, 15 Apr 2026 16:34:40 +0800 Subject: [PATCH 46/59] fix: add ACK handshake for layerwise eventfd transfer to prevent race condition --- flexkv/transfer/layerwise.py | 10 ++++++++++ flexkv/transfer_manager.py | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 7700232a70..ea5fd57af7 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -325,10 +325,20 @@ def cleanup_socket(): f"num_fds={len(fds)} from rank_key={rank_key}") all_rank_eventfds[rank_key] = rank_eventfds + # Send ACK to client so it knows the fds were received + try: + conn.sendall(b"\x01") + except Exception: + pass flexkv_logger.info( f"[LayerwiseWorker] Received all eventfds from rank_key={rank_key} " f"(tp_rank={tp_rank}, cp_rank={cp_rank}) on {socket_path}") except Exception as e: + # Send NACK so client knows to retry + try: + conn.sendall(b"\x00") + except Exception: + pass flexkv_logger.warning( f"[LayerwiseWorker] Failed to receive eventfds from connection {conn_idx} " f"on {socket_path}{rank_label}: {e}. " diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index dd9e2a3292..8d63baffd1 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -1,5 +1,6 @@ import os import multiprocessing as mp +import signal import time import queue import selectors @@ -645,6 +646,19 @@ def _process_worker(self, gpu_register_port: str, ready_event, start_event) -> None: + # Automatically reap child processes (daemon transfer workers) to + # prevent zombie accumulation. Use a handler that calls waitpid() + # with WNOHANG so that multiprocessing.Process.join() still works + # correctly (SIG_IGN would cause join() to raise ChildProcessError). + def _reap_children(signum, frame): + while True: + try: + pid, _ = os.waitpid(-1, os.WNOHANG) + if pid == 0: + break + except ChildProcessError: + break + signal.signal(signal.SIGCHLD, _reap_children) try: start_event.set() os.environ['MPI4PY_RC_INITIALIZE'] = 'false' @@ -721,6 +735,13 @@ def _process_worker(self, except Exception as e: flexkv_logger.error(f"Error closing selector: {e}") + # Gracefully shut down transfer engine and its worker subprocesses + if 'transfer_manager' in locals(): + try: + transfer_manager.shutdown() + except Exception as e: + flexkv_logger.error(f"Error shutting down transfer manager: {e}") + command_conn.close() result_conn.close() From d91bed1c5426de3f2ef11d8c193cfba4753d6f8d Mon Sep 17 00:00:00 2001 From: staryxchen Date: Wed, 15 Apr 2026 17:27:22 +0800 Subject: [PATCH 47/59] fix: correct variable name and config reference in MoonCakeTransferEngineWrapper - Fix typo: change`engien_port`to`engine_port` - Use`self.config`for attribute assignment instead of`config Signed-off-by: staryxchen --- flexkv/mooncakeEngineWrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flexkv/mooncakeEngineWrapper.py b/flexkv/mooncakeEngineWrapper.py index bc08255e68..2b4d45113c 100644 --- a/flexkv/mooncakeEngineWrapper.py +++ b/flexkv/mooncakeEngineWrapper.py @@ -40,9 +40,9 @@ def __init__( self.config = MooncakeTransferEngineConfig.from_file(mooncake_config_path) else: self.config = config - self.engine_ip = config.engine_ip - self.engien_port = config.engine_port - self.mooncake_addr = f"{self.engine_ip}:{self.engien_port}" + self.engine_ip = self.config.engine_ip + self.engine_port = self.config.engine_port + self.mooncake_addr = f"{self.engine_ip}:{self.engine_port}" flexkv_logger.info(f"Mooncake listen on: {self.mooncake_addr}") supported_backend = ["redis"] From dfe21d4eabfb5a98b45141e55bbf2f034fd211de Mon Sep 17 00:00:00 2001 From: staryxchen Date: Wed, 15 Apr 2026 17:27:58 +0800 Subject: [PATCH 48/59] fix(mooncakeEngineWrapper): change unregist_buffer return type from None to int - Update method signature to match underlying engine's return value - Return engine's status code (0 for success, -1 for error) Signed-off-by: staryxchen --- flexkv/mooncakeEngineWrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexkv/mooncakeEngineWrapper.py b/flexkv/mooncakeEngineWrapper.py index 2b4d45113c..e954f90c56 100644 --- a/flexkv/mooncakeEngineWrapper.py +++ b/flexkv/mooncakeEngineWrapper.py @@ -75,7 +75,7 @@ def regist_buffer(self, buffer_ptr: int, buffer_size: int) -> int: ret = self.engine.register_memory(buffer_ptr, buffer_size) return ret if ret == 0 else -1 - def unregist_buffer(self, buffer_ptr: int) -> None: + def unregist_buffer(self, buffer_ptr: int) -> int: """Unregister the buffer to the mooncake engine.""" ret = self.engine.unregister_memory(buffer_ptr) return ret if ret == 0 else -1 From 3a36508eb0adc1a3a960a8777e8c9ec8ccdbc691 Mon Sep 17 00:00:00 2001 From: staryxchen Date: Wed, 15 Apr 2026 17:34:16 +0800 Subject: [PATCH 49/59] fix(mooncakeEngineWrapper): add return type annotation to transfer_sync_write_with_notify method - Added`-> int`return type annotation to the method signature. Signed-off-by: staryxchen --- flexkv/mooncakeEngineWrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexkv/mooncakeEngineWrapper.py b/flexkv/mooncakeEngineWrapper.py index e954f90c56..b0f24f3eb1 100644 --- a/flexkv/mooncakeEngineWrapper.py +++ b/flexkv/mooncakeEngineWrapper.py @@ -102,7 +102,7 @@ def batch_transfer_sync_write(self, peer_engine_addr: str, src_ptr_list: List[in ret = self.engine.batch_transfer_sync_write(peer_engine_addr, src_ptr_list, dst_ptr_list, data_size_list) return ret if ret == 0 else -1 - def transfer_sync_write_with_notify(self, peer_engine_addr: str, src_ptr: int, dst_ptr: int, data_size: int, notify_name: str, msg : NotifyMsg): + def transfer_sync_write_with_notify(self, peer_engine_addr: str, src_ptr: int, dst_ptr: int, data_size: int, notify_name: str, msg : NotifyMsg) -> int: if not MOONCAKE_AVAILABLE: raise RuntimeError("Mooncake engine is not available") notify = engine.TransferNotify(notify_name, msg.to_string()) From 7148c7e73824770ab43c894984b260627126186e Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 15 Apr 2026 15:28:11 +0000 Subject: [PATCH 50/59] fix d2h issue for glm5+cp8 --- csrc/bindings.cpp | 2 +- csrc/tp_transfer_thread_group.cpp | 22 ++++++++++------------ csrc/tp_transfer_thread_group.h | 2 +- flexkv/transfer/layerwise.py | 5 ++++- flexkv/transfer/transfer_engine.py | 2 +- flexkv/transfer/worker.py | 5 +++-- 6 files changed, 20 insertions(+), 18 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 823412da85..6268d569dd 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -507,7 +507,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_num_cta"), py::arg("is_host_to_device"), py::arg("use_ce_transfer"), py::arg("layer_id"), py::arg("layer_granularity"), py::arg("is_mla"), - py::arg("use_sharded_d2h") = false); + py::arg("is_nsa_cp") = false); #ifdef FLEXKV_ENABLE_GDS py::class_(m, "TPGDSTransferThreadGroup") diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 9d0886d72b..77ac01eabd 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -157,7 +157,7 @@ void TPTransferThreadGroup::tp_group_transfer( const int64_t cpu_tp_stride_in_bytes, const int transfer_num_cta, const bool is_host_to_device, const bool use_ce_transfer, const int layer_id, const int layer_granularity, const bool is_mla, - const bool use_sharded_d2h) { + const bool is_nsa_cp) { std::atomic failed{false}; std::string error_msg; @@ -168,6 +168,8 @@ void TPTransferThreadGroup::tp_group_transfer( std::vector> futures; futures.reserve(num_gpus_); + bool enable_sharded_d2h = is_mla && !is_host_to_device; + for (int i = 0; i < num_gpus_; ++i) { futures.emplace_back(enqueue_for_gpu(i, [&, i]() { try { @@ -178,24 +180,20 @@ void TPTransferThreadGroup::tp_group_transfer( int64_t *cpu_block_ids = static_cast(cpu_block_id_tensor.data_ptr()); void *cpu_ptr = cpu_blocks_; - bool should_use_sharded_offsets = is_mla || use_sharded_d2h; - int64_t cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; - if (should_use_sharded_offsets && !is_host_to_device) { + int64_t cpu_startoff_inside_chunks = 0; + if (enable_sharded_d2h) cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_; - } else if (should_use_sharded_offsets && is_host_to_device) { - cpu_startoff_inside_chunks = 0; - } + else if (!is_mla) + cpu_startoff_inside_chunks = i * cpu_tp_stride_in_bytes; int64_t gpu_startoff_inside_chunks = - should_use_sharded_offsets && !is_host_to_device - ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ - : 0; + enable_sharded_d2h ? i * gpu_chunk_sizes_in_bytes_[i] / num_gpus_ + : 0; // we assume that the chunk size is the same for all gpus, // even if they have different number of gpu_blocks - int64_t chunk_size = should_use_sharded_offsets && !is_host_to_device + int64_t chunk_size = enable_sharded_d2h ? gpu_chunk_sizes_in_bytes_[i] / num_gpus_ : gpu_chunk_sizes_in_bytes_[i]; - // Dispatch to the appropriate template based on backend type switch (backend_type_) { case BackendType::VLLM: diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 08bf2846ff..551910c0c8 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -57,7 +57,7 @@ class TPTransferThreadGroup { const bool is_host_to_device, const bool use_ce_transfer, const int layer_id, const int layer_granularity, const bool is_mla, - const bool use_sharded_d2h); + const bool is_nsa_cp); private: using Task = std::function; diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index ea5fd57af7..e103498c1e 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -446,7 +446,10 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * num_h2d_blocks * kv_dim + transfer_size = self.cpu_chunk_size_in_bytes * self.num_layers * num_h2d_blocks * kv_dim + + if self.is_nsa_cp or self.is_mla: + transfer_size *= self.tp_group_size self._log_transfer_performance( transfer_op, diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 1c345b0082..0a6bcb05bf 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -27,7 +27,7 @@ from flexkv.common.debug import flexkv_logger from flexkv.common.storage import StorageHandle -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp +from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp, LayerwiseTransferOp from flexkv.common.transfer import get_nvtx_range_color from flexkv.common.storage import KVCacheLayoutType from flexkv.transfer.scheduler import TransferScheduler diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 3fa30fcef7..17c9a363c6 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -452,11 +452,12 @@ def __init__(self, self.gpu_blocks = imported_gpu_blocks self.dtype = dtype # note this should be quantized data type self.is_mla = gpu_kv_layouts[0].is_mla + self.is_nsa_cp = is_nsa_cp + self.cp_size = cp_size self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size self.dp_group_id = dp_group_id - self.use_cp_nsa_d2h_shard = bool(is_nsa_cp and cp_size > 1) flexkv_logger.info(f"Pinning CPU Memory: {cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) @@ -562,7 +563,7 @@ def _transfer_impl(self, layer_id, layer_granularity, self.is_mla, - self.use_cp_nsa_d2h_shard, + self.is_nsa_cp and self.cp_size > 1, ) From d54f9303074b09d774932066c4b5b369e5defbce Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 15 Apr 2026 15:29:22 +0000 Subject: [PATCH 51/59] fix sglang config issue --- flexkv/integration/config.py | 77 +++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 051c4c622f..c0c0d2c2ee 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -49,18 +49,29 @@ def _detect_indexer_config_from_hf(self, hf_config, source: str = "") -> None: if qk_rope_head_dim is None or qk_rope_head_dim <= 0: return + index_head_dim = getattr(hf_config, 'index_head_dim', None) + if index_head_dim is not None and index_head_dim > 0: + quant_block_size = 128 + head_size = self.cache_config.tokens_per_block * ( + index_head_dim + index_head_dim // quant_block_size * 4 + ) + else: + head_size = qk_rope_head_dim + # tokens_per_block is already set to sglang page_size before this # call, so each FlexKV block = 1 sglang page. The indexer maps - # 1:1 with blocks — no extra page_size grouping is needed. + # 1:1 with blocks — no extra page_size grouping is needed. For + # NSA/DSA models, head_size stores the packed per-page buffer width + # so the CPU layout matches the GPU indexer tensor shape. self.cache_config.indexer = IndexerCacheConfig( - head_size=qk_rope_head_dim, + head_size=head_size, num_kv_heads=1, dtype=torch.uint8, ) source_label = f" ({source})" if source else "" logger.info( f"Detected sparse attention indexer config{source_label}: " - f"head_size={qk_rope_head_dim}, dtype=uint8, " + f"head_size={head_size}, dtype=uint8, " f"tokens_per_block={self.cache_config.tokens_per_block}") except Exception as e: logger.debug(f"Could not detect indexer config ({source}): {e}") @@ -143,20 +154,34 @@ def post_init_from_sglang_config( total_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) self.model_config.num_layers = int(num_local_layers) if num_local_layers > 0 else total_layers - if hasattr(sglang_config, "get_total_num_kv_heads"): - try: - self.model_config.num_kv_heads = int(sglang_config.get_total_num_kv_heads()) - except Exception: - self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) - elif hasattr(sglang_config, "get_num_kv_heads"): - try: - per_rank = int(sglang_config.get_num_kv_heads(tp_size)) - self.model_config.num_kv_heads = per_rank * tp_size - except Exception: - self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + attn_arch = getattr(sglang_config, "attention_arch", None) + use_mla = False + if hasattr(attn_arch, "name"): + use_mla = (attn_arch.name.upper() == "MLA") + elif isinstance(attn_arch, str): + use_mla = (attn_arch.upper() == "MLA") + + if use_mla: + kv_lora_rank = int(getattr(sglang_config, "kv_lora_rank", 0)) + qk_rope_head_dim = int(getattr(sglang_config, "qk_rope_head_dim", 0)) + mla_head_size = kv_lora_rank + qk_rope_head_dim + self.model_config.num_kv_heads = 1 + self.model_config.head_size = int(mla_head_size) else: - self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) - self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0)) + if hasattr(sglang_config, "get_total_num_kv_heads"): + try: + self.model_config.num_kv_heads = int(sglang_config.get_total_num_kv_heads()) + except Exception: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + elif hasattr(sglang_config, "get_num_kv_heads"): + try: + per_rank = int(sglang_config.get_num_kv_heads(tp_size)) + self.model_config.num_kv_heads = per_rank * tp_size + except Exception: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + else: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0)) # Determine KV cache dtype: prioritize user_config.kv_cache_dtype (from # flexkv_config.yaml or FLEXKV_KV_CACHE_DTYPE env var), then fall back @@ -194,12 +219,20 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: f"flexkv_config.yaml or set FLEXKV_KV_CACHE_DTYPE=fp8 environment variable." ) - attn_arch = getattr(sglang_config, "attention_arch", None) - use_mla = False - if hasattr(attn_arch, "name"): - use_mla = (attn_arch.name.upper() == "MLA") - elif isinstance(attn_arch, str): - use_mla = (attn_arch.upper() == "MLA") + if use_mla and getattr(sglang_config, "index_head_dim", None) is not None: + kv_lora_rank = int(getattr(sglang_config, "kv_lora_rank", 0)) + qk_rope_head_dim = int(getattr(sglang_config, "qk_rope_head_dim", 0)) + if self.model_config.dtype == torch.float8_e4m3fn: + assert kv_lora_rank % 128 == 0, ( + f"kv_lora_rank {kv_lora_rank} must be multiple of 128 " + "for NSA FP8 KV cache layout" + ) + self.model_config.head_size = int( + kv_lora_rank + + kv_lora_rank // 128 * 4 + + qk_rope_head_dim * torch.bfloat16.itemsize + ) + self.model_config.use_mla = use_mla self.model_config.tp_size = int(tp_size) From 7ec580f00c0ac9c416c6357a7c4e95261c17b0bf Mon Sep 17 00:00:00 2001 From: zitto Date: Fri, 17 Apr 2026 15:02:25 +0800 Subject: [PATCH 52/59] fix(layerwise): fuse indexer DISK2H/H2D into layerwise worker (#149) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Fuse indexer into layerwise transfer pipeline: - LayerwiseTransferOp/WorkerLayerwiseTransferOp: add indexer_src/dst_block_ids fields - merge_to_batch_graph: populate indexer block_ids (1:1 with main KV) in LAYERWISE op - C++ LayerwiseTransferGroup: accept indexer_ssd_files, init indexer_ioctx_ - C++ layerwise_transfer: add Step 0.5 — indexer SSD→CPU after main KV SSD→CPU - Python LayerwiseTransferWorker: accept indexer SSD params, compute indexer SSD strides, pass indexer SSD block_ids and H2D block_ids to C++ layerwise_transfer 2. Skip redundant worker creation in layerwise mode: - Main KV: skip h2d_workers and cpussd_read_worker creation & registration - Indexer: skip _indexer_h2d_workers and _indexer_disk2h_worker creation & registration - D2H / H2DISK workers preserved (layerwise does not support PUT direction) 3. Add startup assertions in layerwise mode: - Assert _worker_map does not contain H2D or DISK2H - Assert _worker_map contains LAYERWISE Co-authored-by: zittozhang --- csrc/bindings.cpp | 29 ++- csrc/layerwise.cpp | 232 ++++++++++++++++++++- csrc/layerwise.h | 39 +++- flexkv/common/transfer.py | 18 +- flexkv/transfer/layerwise.py | 155 +++++++++++++- flexkv/transfer/transfer_engine.py | 323 ++++++++++++++++------------- flexkv/transfer/worker_op.py | 5 + 7 files changed, 640 insertions(+), 161 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 6268d569dd..05ccb88b87 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -420,7 +420,11 @@ PYBIND11_MODULE(c_ext, m) { .def(py::init> &, torch::Tensor &, std::map> &, int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, - torch::Tensor &, int, int, torch::Tensor &, int>(), + torch::Tensor &, int, int, torch::Tensor &, int, + const std::vector> &, + torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor, + std::map>>(), py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), py::arg("gpu_kv_strides_tensor"), @@ -428,7 +432,14 @@ PYBIND11_MODULE(c_ext, m) { py::arg("gpu_layer_strides_tensor"), py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"), - py::arg("tp_size")) + py::arg("tp_size"), + py::arg("indexer_gpu_blocks") = std::vector>{}, + py::arg("indexer_cpu_blocks") = torch::Tensor(), + py::arg("indexer_gpu_kv_strides_tensor") = torch::Tensor(), + py::arg("indexer_gpu_block_strides_tensor") = torch::Tensor(), + py::arg("indexer_gpu_layer_strides_tensor") = torch::Tensor(), + py::arg("indexer_gpu_chunk_sizes_tensor") = torch::Tensor(), + py::arg("indexer_ssd_files") = std::map>{}) .def("layerwise_transfer", &flexkv::LayerwiseTransferGroup::layerwise_transfer, py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), @@ -445,7 +456,19 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_cta_num"), py::arg("use_ce_transfer"), py::arg("num_layers"), py::arg("layer_granularity"), py::arg("is_mla"), - py::arg("counter_id") = 0); + py::arg("counter_id") = 0, + py::arg("indexer_gpu_block_id_tensor") = torch::Tensor(), + py::arg("indexer_cpu_block_id_tensor") = torch::Tensor(), + py::arg("indexer_cpu_block_stride_in_bytes") = 0, + py::arg("indexer_cpu_layer_stride_in_bytes") = 0, + py::arg("indexer_h2d_cpu_kv_stride_in_bytes") = 0, + py::arg("indexer_h2d_cpu_layer_stride_in_bytes") = 0, + py::arg("indexer_ssd_block_ids") = torch::Tensor(), + py::arg("indexer_cpu_block_ids_d2h") = torch::Tensor(), + py::arg("indexer_ssd_layer_stride_in_bytes") = 0, + py::arg("indexer_ssd_kv_stride_in_bytes") = 0, + py::arg("indexer_cpu_chunk_size_in_bytes") = 0, + py::arg("indexer_num_blocks_per_file") = 0); #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, "Transfer KV blocks between remote and CPU memory", diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 151671a507..139a91eff9 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -66,7 +66,14 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( torch::Tensor &gpu_block_strides_tensor, torch::Tensor &gpu_layer_strides_tensor, torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, - int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size) { + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size, + const std::vector> &indexer_gpu_blocks, + torch::Tensor indexer_cpu_blocks, + torch::Tensor indexer_gpu_kv_strides_tensor, + torch::Tensor indexer_gpu_block_strides_tensor, + torch::Tensor indexer_gpu_layer_strides_tensor, + torch::Tensor indexer_gpu_chunk_sizes_tensor, + std::map> indexer_ssd_files) { num_gpus_ = num_gpus; num_layers_ = num_layers; @@ -168,6 +175,77 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( ioctx_ = std::make_unique(ssd_files, ssd_files.size(), iouring_entries, iouring_flags); } + + // Initialize indexer fuse support + enable_indexer_ = !indexer_gpu_blocks.empty(); + if (enable_indexer_) { + indexer_num_tensors_per_gpu_ = indexer_gpu_blocks[0].size(); + cudaMallocHost((void **)&indexer_gpu_blocks_, + num_gpus_ * indexer_num_tensors_per_gpu_ * sizeof(void *)); + for (int i = 0; i < num_gpus_; ++i) { + for (int j = 0; j < indexer_num_tensors_per_gpu_; ++j) { + indexer_gpu_blocks_[i * indexer_num_tensors_per_gpu_ + j] = + indexer_gpu_blocks[i][j].data_ptr(); + } + } + + indexer_cpu_blocks_ = indexer_cpu_blocks.data_ptr(); + + indexer_gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + indexer_gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + indexer_gpu_layer_strides_in_bytes_ = new int64_t[num_gpus]; + indexer_gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t *idx_kv_strides_ptr = indexer_gpu_kv_strides_tensor.data_ptr(); + int64_t *idx_block_strides_ptr = indexer_gpu_block_strides_tensor.data_ptr(); + int64_t *idx_layer_strides_ptr = indexer_gpu_layer_strides_tensor.data_ptr(); + int64_t *idx_chunk_sizes_ptr = indexer_gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + indexer_gpu_kv_strides_in_bytes_[i] = idx_kv_strides_ptr[i]; + indexer_gpu_block_strides_in_bytes_[i] = idx_block_strides_ptr[i]; + indexer_gpu_layer_strides_in_bytes_[i] = idx_layer_strides_ptr[i]; + indexer_gpu_chunk_sizes_in_bytes_[i] = idx_chunk_sizes_ptr[i]; + } + + // Determine indexer backend type from tensor count (symmetric with main KV) + if (indexer_num_tensors_per_gpu_ == 1) { + indexer_backend_type_ = BackendType::TRTLLM; + } else if (indexer_num_tensors_per_gpu_ == num_layers) { + indexer_backend_type_ = BackendType::VLLM; + } else if (indexer_num_tensors_per_gpu_ == num_layers * 2) { + indexer_backend_type_ = BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported indexer GPU block type: " + + std::to_string(indexer_num_tensors_per_gpu_)); + } + + // Build GTensorHandlers for indexer (symmetric with main KV) + indexer_gpu_tensor_handlers_.reserve(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { + int64_t **idx_gpu_blocks_ptr = reinterpret_cast( + indexer_gpu_blocks_ + i * indexer_num_tensors_per_gpu_); + indexer_gpu_tensor_handlers_.emplace_back( + indexer_backend_type_, idx_gpu_blocks_ptr, num_layers, + indexer_gpu_kv_strides_in_bytes_[i], + indexer_gpu_block_strides_in_bytes_[i], + indexer_gpu_layer_strides_in_bytes_[i]); + } + + fprintf(stderr, "[LayerwiseTransferGroup] Indexer fuse: enabled=true, " + "num_tensors_per_gpu=%d, chunk_size=%ld bytes, backend=%s\n", + indexer_num_tensors_per_gpu_, indexer_gpu_chunk_sizes_in_bytes_[0], + indexer_backend_type_ == BackendType::SGLANG ? "SGLANG" : + indexer_backend_type_ == BackendType::VLLM ? "VLLM" : "TRTLLM"); + } + + // Initialize indexer SSD IO context if indexer_ssd_files is not empty + enable_indexer_ssd_ = !indexer_ssd_files.empty(); + if (enable_indexer_ssd_) { + indexer_ioctx_ = std::make_unique( + indexer_ssd_files, indexer_ssd_files.size(), + iouring_entries, iouring_flags); + } } LayerwiseTransferGroup::~LayerwiseTransferGroup() { @@ -184,6 +262,16 @@ LayerwiseTransferGroup::~LayerwiseTransferGroup() { delete[] gpu_block_strides_in_bytes_; delete[] gpu_layer_strides_in_bytes_; delete[] gpu_chunk_sizes_in_bytes_; + + // Clean up indexer resources + if (enable_indexer_) { + cudaFreeHost(indexer_gpu_blocks_); + indexer_gpu_tensor_handlers_.clear(); + delete[] indexer_gpu_kv_strides_in_bytes_; + delete[] indexer_gpu_block_strides_in_bytes_; + delete[] indexer_gpu_layer_strides_in_bytes_; + delete[] indexer_gpu_chunk_sizes_in_bytes_; + } } void LayerwiseTransferGroup::layer_done_callback(int start_layer, @@ -231,7 +319,19 @@ void LayerwiseTransferGroup::layerwise_transfer( const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, - const int counter_id) { + const int counter_id, + const torch::Tensor &indexer_gpu_block_id_tensor, + const torch::Tensor &indexer_cpu_block_id_tensor, + const int64_t indexer_cpu_block_stride_in_bytes, + const int64_t indexer_cpu_layer_stride_in_bytes, + const int64_t indexer_h2d_cpu_kv_stride_in_bytes, + const int64_t indexer_h2d_cpu_layer_stride_in_bytes, + const torch::Tensor &indexer_ssd_block_ids, + const torch::Tensor &indexer_cpu_block_ids_d2h, + const int64_t indexer_ssd_layer_stride_in_bytes, + const int64_t indexer_ssd_kv_stride_in_bytes, + const int64_t indexer_cpu_chunk_size_in_bytes, + const int indexer_num_blocks_per_file) { // Set current counter ID for eventfd notification current_counter_id_ = counter_id; @@ -243,6 +343,21 @@ void LayerwiseTransferGroup::layerwise_transfer( static_cast(cpu_block_id_tensor.data_ptr()); void *cpu_ptr = cpu_blocks_; + // Indexer block ids (may be empty if indexer is not enabled or not provided) + bool do_indexer_transfer = enable_indexer_ && + indexer_gpu_block_id_tensor.defined() && + indexer_gpu_block_id_tensor.numel() > 0; + int num_indexer_blocks = 0; + int64_t *indexer_gpu_block_ids = nullptr; + int64_t *indexer_cpu_block_ids = nullptr; + if (do_indexer_transfer) { + num_indexer_blocks = indexer_gpu_block_id_tensor.numel(); + indexer_gpu_block_ids = + static_cast(indexer_gpu_block_id_tensor.data_ptr()); + indexer_cpu_block_ids = + static_cast(indexer_cpu_block_id_tensor.data_ptr()); + } + // Create CUDA events for timing each layer batch (on GPU 0) int num_batches = (num_layers + layer_granularity - 1) / layer_granularity; std::vector timing_events(num_batches + 1); // +1 for start event @@ -269,9 +384,23 @@ void LayerwiseTransferGroup::layerwise_transfer( for (int g = 0; g < num_gpus_; ++g) { bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * ltb * num_blocks; } - double mb_this_batch = bytes_this_batch / (1024.0 * 1024.0); - char name[128]; - snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) %.2fMB", sl, sl + ltb, mb_this_batch); + // Add indexer bytes if applicable + int64_t indexer_bytes_this_batch = 0; + if (do_indexer_transfer) { + for (int g = 0; g < num_gpus_; ++g) { + indexer_bytes_this_batch += indexer_gpu_chunk_sizes_in_bytes_[g] * ltb * num_indexer_blocks; + } + } + double mb_this_batch = (bytes_this_batch + indexer_bytes_this_batch) / (1024.0 * 1024.0); + char name[256]; + if (do_indexer_transfer) { + snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) KV:%.2fMB+Idx:%.2fMB", + sl, sl + ltb, bytes_this_batch / (1024.0 * 1024.0), + indexer_bytes_this_batch / (1024.0 * 1024.0)); + } else { + snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) %.2fMB", sl, sl + ltb, + bytes_this_batch / (1024.0 * 1024.0)); + } h2d_range_names[b] = name; } @@ -308,6 +437,37 @@ void LayerwiseTransferGroup::layerwise_transfer( nvtxRangePop(); } + // Indexer SSD -> CPU transfer for ALL layers at once. + if (enable_indexer_ssd_ && indexer_ssd_block_ids.defined() && + indexer_ssd_block_ids.numel() > 0) { + int num_indexer_ssd_blocks = indexer_ssd_block_ids.numel(); + int64_t indexer_ssd_bytes = indexer_cpu_chunk_size_in_bytes * num_layers * num_indexer_ssd_blocks; + double indexer_ssd_mb = indexer_ssd_bytes / (1024.0 * 1024.0); + char idx_ssd_range_name[128]; + snprintf(idx_ssd_range_name, sizeof(idx_ssd_range_name), + "Indexer SSD->CPU AllLayers[0,%d) %.2fMB", num_layers, indexer_ssd_mb); + nvtxRangePushA(idx_ssd_range_name); + + torch::Tensor all_layer_ids = + torch::arange(0, num_layers, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *indexer_ioctx_, all_layer_ids, + reinterpret_cast(indexer_cpu_blocks_), + indexer_ssd_block_ids, indexer_cpu_block_ids_d2h, + indexer_cpu_layer_stride_in_bytes, + indexer_ssd_kv_stride_in_bytes, + indexer_ssd_layer_stride_in_bytes, + indexer_ssd_kv_stride_in_bytes, + indexer_cpu_chunk_size_in_bytes, + indexer_cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + indexer_num_blocks_per_file, round_robin, num_threads_per_device, + true /* is_mla: indexer always MLA */); + + nvtxRangePop(); + } + int batch_idx = 0; for (int start_layer = 0; start_layer < num_layers; start_layer += layer_granularity) { @@ -356,6 +516,60 @@ void LayerwiseTransferGroup::layerwise_transfer( streams_[i], transfer_cta_num, true, use_ce_transfer, is_mla, false); break; } + + // Fused indexer CPU -> GPU transfer on the same stream + // Uses transfer_kv_blocks (symmetric with main KV Step 2) instead of + // hand-written cudaMemcpyAsync loops for backend-agnostic support. + // Note: indexer uses ReplicatedLinear weights with 1 head (is_mla=true), + // so all TP ranks hold identical data. No TP head-partitioning needed, + // cpu_startoff is always 0 (unlike main KV which may offset by tp_stride). + if (do_indexer_transfer) { + int64_t idx_chunk_size = indexer_gpu_chunk_sizes_in_bytes_[i]; + // idx_cpu_startoff = 0: indexer data is not partitioned across TP ranks + int64_t idx_cpu_startoff = 0; + + switch (indexer_backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_indexer_blocks, start_layer, layers_this_batch, + indexer_gpu_block_ids, indexer_gpu_tensor_handlers_[i], + 0 /* gpu_startoff */, indexer_cpu_block_ids, + indexer_cpu_blocks_, + indexer_h2d_cpu_kv_stride_in_bytes, + indexer_h2d_cpu_layer_stride_in_bytes, + indexer_cpu_block_stride_in_bytes, + idx_cpu_startoff, idx_chunk_size, + streams_[i], transfer_cta_num, true /* h2d */, + use_ce_transfer, true /* is_mla */, false /* sync */); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_indexer_blocks, start_layer, layers_this_batch, + indexer_gpu_block_ids, indexer_gpu_tensor_handlers_[i], + 0 /* gpu_startoff */, indexer_cpu_block_ids, + indexer_cpu_blocks_, + indexer_h2d_cpu_kv_stride_in_bytes, + indexer_h2d_cpu_layer_stride_in_bytes, + indexer_cpu_block_stride_in_bytes, + idx_cpu_startoff, idx_chunk_size, + streams_[i], transfer_cta_num, true /* h2d */, + use_ce_transfer, true /* is_mla */, false /* sync */); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_indexer_blocks, start_layer, layers_this_batch, + indexer_gpu_block_ids, indexer_gpu_tensor_handlers_[i], + 0 /* gpu_startoff */, indexer_cpu_block_ids, + indexer_cpu_blocks_, + indexer_h2d_cpu_kv_stride_in_bytes, + indexer_h2d_cpu_layer_stride_in_bytes, + indexer_cpu_block_stride_in_bytes, + idx_cpu_startoff, idx_chunk_size, + streams_[i], transfer_cta_num, true /* h2d */, + use_ce_transfer, true /* is_mla */, false /* sync */); + break; + } + } } // Record event after this batch on GPU 0 @@ -397,6 +611,14 @@ void LayerwiseTransferGroup::layerwise_transfer( for (int g = 0; g < num_gpus_; ++g) { bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * batch_layers_count[i] * num_blocks; } + // Include indexer bytes + int64_t indexer_bytes_batch = 0; + if (do_indexer_transfer) { + for (int g = 0; g < num_gpus_; ++g) { + indexer_bytes_batch += indexer_gpu_chunk_sizes_in_bytes_[g] * batch_layers_count[i] * num_indexer_blocks; + } + bytes_this_batch += indexer_bytes_batch; + } double bandwidth_gbps = (bytes_this_batch / (1024.0 * 1024.0 * 1024.0)) / (elapsed_ms / 1000.0); diff --git a/csrc/layerwise.h b/csrc/layerwise.h index 0a3fd7fcae..95f81f70bf 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -27,7 +27,14 @@ class LayerwiseTransferGroup { torch::Tensor &gpu_block_strides_tensor, torch::Tensor &gpu_layer_strides_tensor, torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, - int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size); + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size, + const std::vector> &indexer_gpu_blocks = {}, + torch::Tensor indexer_cpu_blocks = torch::Tensor(), + torch::Tensor indexer_gpu_kv_strides_tensor = torch::Tensor(), + torch::Tensor indexer_gpu_block_strides_tensor = torch::Tensor(), + torch::Tensor indexer_gpu_layer_strides_tensor = torch::Tensor(), + torch::Tensor indexer_gpu_chunk_sizes_tensor = torch::Tensor(), + std::map> indexer_ssd_files = {}); ~LayerwiseTransferGroup(); @@ -53,7 +60,19 @@ class LayerwiseTransferGroup { const int64_t cpu_tp_stride_in_bytes, const int transfer_cta_num, const bool use_ce_transfer, const int num_layers, const int layer_granularity, const bool is_mla, - const int counter_id = 0); // Counter set index for triple buffering + const int counter_id = 0, + const torch::Tensor &indexer_gpu_block_id_tensor = torch::Tensor(), + const torch::Tensor &indexer_cpu_block_id_tensor = torch::Tensor(), + const int64_t indexer_cpu_block_stride_in_bytes = 0, + const int64_t indexer_cpu_layer_stride_in_bytes = 0, + const int64_t indexer_h2d_cpu_kv_stride_in_bytes = 0, + const int64_t indexer_h2d_cpu_layer_stride_in_bytes = 0, + const torch::Tensor &indexer_ssd_block_ids = torch::Tensor(), + const torch::Tensor &indexer_cpu_block_ids_d2h = torch::Tensor(), + const int64_t indexer_ssd_layer_stride_in_bytes = 0, + const int64_t indexer_ssd_kv_stride_in_bytes = 0, + const int64_t indexer_cpu_chunk_size_in_bytes = 0, + const int indexer_num_blocks_per_file = 0); private: int num_gpus_; @@ -77,6 +96,22 @@ class LayerwiseTransferGroup { bool enable_ssd_; std::unique_ptr ioctx_; + // Indexer fuse support + bool enable_indexer_ = false; + void **indexer_gpu_blocks_ = nullptr; + void *indexer_cpu_blocks_ = nullptr; + int indexer_num_tensors_per_gpu_ = 0; + int64_t *indexer_gpu_kv_strides_in_bytes_ = nullptr; + int64_t *indexer_gpu_block_strides_in_bytes_ = nullptr; + int64_t *indexer_gpu_layer_strides_in_bytes_ = nullptr; + int64_t *indexer_gpu_chunk_sizes_in_bytes_ = nullptr; + BackendType indexer_backend_type_ = BackendType::SGLANG; + std::vector indexer_gpu_tensor_handlers_; + + // Indexer SSD IO context + bool enable_indexer_ssd_ = false; + std::unique_ptr indexer_ioctx_; + // Layer eventfds for notification // Shape: [tp_size, num_counters, num_layers] bool enable_eventfd_; diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 3b38a03b5a..ba1da3b46b 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -125,6 +125,9 @@ class LayerwiseTransferOp(TransferOp): src_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) dst_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) counter_id: int = 0 # Counter set index for triple buffering eventfd notification + # Indexer block_ids for fused indexer transfer (1:1 with main KV block_ids) + indexer_src_block_ids: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + indexer_dst_block_ids: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) def __init__(self, graph_id: int, @@ -135,12 +138,18 @@ def __init__(self, layer_id: int = 0, layer_granularity: int = 1, dp_id: int = 0, - counter_id: int = 0) -> None: + counter_id: int = 0, + indexer_src_block_ids: Optional[np.ndarray] = None, + indexer_dst_block_ids: Optional[np.ndarray] = None) -> None: self.src_block_ids_h2d = src_block_ids_h2d self.dst_block_ids_h2d = dst_block_ids_h2d self.src_block_ids_disk2h = src_block_ids_disk2h self.dst_block_ids_disk2h = dst_block_ids_disk2h self.counter_id = counter_id + self.indexer_src_block_ids = indexer_src_block_ids if indexer_src_block_ids is not None \ + else np.array([], dtype=np.int64) + self.indexer_dst_block_ids = indexer_dst_block_ids if indexer_dst_block_ids is not None \ + else np.array([], dtype=np.int64) super().__init__( graph_id=graph_id, @@ -160,11 +169,14 @@ def __post_init__(self) -> None: self.layer_granularity = 1 assert self.src_block_ids_h2d.size == self.dst_block_ids_h2d.size assert self.src_block_ids_disk2h.size == self.dst_block_ids_disk2h.size + assert self.indexer_src_block_ids.size == self.indexer_dst_block_ids.size assert self.src_block_ids_h2d.dtype == np.int64 assert self.dst_block_ids_h2d.dtype == np.int64 assert self.src_block_ids_disk2h.dtype == np.int64 assert self.dst_block_ids_disk2h.dtype == np.int64 + assert self.indexer_src_block_ids.dtype == np.int64 + assert self.indexer_dst_block_ids.dtype == np.int64 class TransferOpGraph: _next_graph_id = 0 @@ -471,6 +483,10 @@ def merge_to_batch_graph(batch_id: int, layer_granularity=1, dp_id=ops_by_type[TransferType.H2D][0].dp_id, counter_id=counter_id, + # Indexer maps 1:1 with main KV blocks, use same block_ids + # CPU side (src) and GPU side (dst) for H2D direction + indexer_src_block_ids=merged_h2d_op.src_block_ids.copy(), + indexer_dst_block_ids=merged_h2d_op.dst_block_ids.copy(), ) merged_graph.add_transfer_op(layerwise_transfer_op) batch_end_op_id = -1 diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index e103498c1e..785846fb1e 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -74,7 +74,15 @@ def __init__(self, h2d_cta_num: int = 4, d2h_cta_num: int = 4, enable_eventfd: bool = True, - is_nsa_cp: bool = False) -> None: + is_nsa_cp: bool = False, + indexer_gpu_blocks: Optional[List[List[TensorSharedHandle]]] = None, + indexer_cpu_blocks: Optional[torch.Tensor] = None, + indexer_gpu_kv_layouts: Optional[List[KVCacheLayout]] = None, + indexer_cpu_kv_layout: Optional[KVCacheLayout] = None, + indexer_dtype: Optional[torch.dtype] = None, + indexer_ssd_files: Optional[Dict[int, List[str]]] = None, + indexer_ssd_kv_layout: Optional[KVCacheLayout] = None, + indexer_num_blocks_per_file: int = 0) -> None: flexkv_logger.debug( f"[LayerwiseWorker] __init__ started: worker_id={worker_id}, " f"tp_group_size={tp_group_size}, dp_group_id={dp_group_id}, " @@ -189,6 +197,105 @@ def __init__(self, # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers flexkv_logger.debug("[LayerwiseWorker] Creating LayerwiseTransferGroup...") + + # Initialize indexer fuse support + self.enable_indexer = (indexer_gpu_blocks is not None and indexer_cpu_blocks is not None) + indexer_constructor_kwargs = {} + if self.enable_indexer: + assert indexer_gpu_kv_layouts is not None + assert indexer_cpu_kv_layout is not None + assert indexer_dtype is not None + + # Import indexer GPU tensor handles + imported_indexer_gpu_blocks = [] + for handles_in_one_gpu in indexer_gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_indexer_gpu_blocks.append(blocks_in_one_gpu) + + # Pin indexer CPU memory + flexkv_logger.info( + f"[LayerwiseWorker] Pinning indexer CPU Memory: " + f"{indexer_cpu_blocks.numel() * indexer_cpu_blocks.element_size() / (1024 ** 3):.4f} GB") + cudaHostRegister(indexer_cpu_blocks) + + # Compute indexer GPU stride tensors + indexer_gpu_kv_strides = [layout.get_kv_stride() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + indexer_gpu_block_strides = [layout.get_block_stride() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + indexer_gpu_layer_strides = [layout.get_layer_stride() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + indexer_gpu_chunk_sizes = [layout.get_chunk_size() * indexer_dtype.itemsize + for layout in indexer_gpu_kv_layouts] + + # Compute indexer CPU strides. + # Indexer is always is_mla=True (1 head, ReplicatedLinear weights), + # so all TP ranks hold identical data and no head-partitioning is needed. + # Therefore indexer has no tp_stride — cpu_startoff is always 0. + self.indexer_cpu_block_stride_in_bytes = indexer_cpu_kv_layout.get_block_stride() * indexer_dtype.itemsize + self.indexer_cpu_layer_stride_in_bytes = indexer_cpu_kv_layout.get_layer_stride() * indexer_dtype.itemsize + self.indexer_h2d_cpu_kv_stride_in_bytes = indexer_cpu_kv_layout.get_kv_stride() * indexer_dtype.itemsize + self.indexer_h2d_cpu_layer_stride_in_bytes = indexer_cpu_kv_layout.get_layer_stride() * indexer_dtype.itemsize + + self.indexer_gpu_blocks = imported_indexer_gpu_blocks + self.indexer_cpu_blocks = indexer_cpu_blocks + self.indexer_gpu_kv_strides_tensor = torch.tensor(indexer_gpu_kv_strides, dtype=torch.int64) + self.indexer_gpu_block_strides_tensor = torch.tensor(indexer_gpu_block_strides, dtype=torch.int64) + self.indexer_gpu_layer_strides_tensor = torch.tensor(indexer_gpu_layer_strides, dtype=torch.int64) + self.indexer_gpu_chunk_sizes_tensor = torch.tensor(indexer_gpu_chunk_sizes, dtype=torch.int64) + + flexkv_logger.info( + f"[LayerwiseWorker] Indexer fuse enabled: " + f"gpu_blocks={len(imported_indexer_gpu_blocks)}, " + f"cpu_size={indexer_cpu_blocks.numel() * indexer_cpu_blocks.element_size() / (1024 ** 2):.2f} MB, " + f"chunk_size={indexer_gpu_chunk_sizes[0]} bytes, " + f"cpu_block_stride={self.indexer_cpu_block_stride_in_bytes} bytes, " + f"cpu_layer_stride={self.indexer_cpu_layer_stride_in_bytes} bytes") + else: + self.indexer_cpu_block_stride_in_bytes = 0 + self.indexer_cpu_layer_stride_in_bytes = 0 + self.indexer_h2d_cpu_kv_stride_in_bytes = 0 + self.indexer_h2d_cpu_layer_stride_in_bytes = 0 + self.indexer_gpu_blocks = [] + self.indexer_cpu_blocks = torch.Tensor() + self.indexer_gpu_kv_strides_tensor = torch.empty(0, dtype=torch.int64) + self.indexer_gpu_block_strides_tensor = torch.empty(0, dtype=torch.int64) + self.indexer_gpu_layer_strides_tensor = torch.empty(0, dtype=torch.int64) + self.indexer_gpu_chunk_sizes_tensor = torch.empty(0, dtype=torch.int64) + + # Initialize indexer SSD support + self.enable_indexer_ssd = ( + self.enable_indexer and + indexer_ssd_files is not None and len(indexer_ssd_files) > 0 and + indexer_ssd_kv_layout is not None + ) + if self.enable_indexer_ssd: + assert indexer_dtype is not None + self.indexer_ssd_files = indexer_ssd_files + self.indexer_num_blocks_per_file = indexer_num_blocks_per_file + + indexer_ssd_kv_layout_per_file = indexer_ssd_kv_layout.div_block( + sum(len(fl) for fl in indexer_ssd_files.values()), padding=True) + self.indexer_ssd_kv_stride_in_bytes = indexer_ssd_kv_layout_per_file.get_kv_stride() * indexer_dtype.itemsize + self.indexer_ssd_layer_stride_in_bytes = indexer_ssd_kv_layout_per_file.get_layer_stride() * indexer_dtype.itemsize + self.indexer_cpu_chunk_size_in_bytes = indexer_cpu_kv_layout.get_chunk_size() * indexer_dtype.itemsize + + flexkv_logger.info( + f"[LayerwiseWorker] Indexer SSD fuse enabled: " + f"num_files={sum(len(fl) for fl in indexer_ssd_files.values())}, " + f"num_blocks_per_file={indexer_num_blocks_per_file}, " + f"ssd_kv_stride={self.indexer_ssd_kv_stride_in_bytes}, " + f"ssd_layer_stride={self.indexer_ssd_layer_stride_in_bytes}, " + f"cpu_chunk_size={self.indexer_cpu_chunk_size_in_bytes}") + else: + self.indexer_ssd_files = {} + self.indexer_num_blocks_per_file = 0 + self.indexer_ssd_kv_stride_in_bytes = 0 + self.indexer_ssd_layer_stride_in_bytes = 0 + self.indexer_cpu_chunk_size_in_bytes = 0 + self.layerwise_transfer_group = LayerwiseTransferGroup( self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, dp_group_id, self.num_layers, @@ -196,7 +303,11 @@ def __init__(self, gpu_layer_strides_tensor, gpu_chunk_sizes_tensor, GLOBAL_CONFIG_FROM_ENV.iouring_entries, GLOBAL_CONFIG_FROM_ENV.iouring_flags, - layer_eventfds_tensor, tp_group_size) + layer_eventfds_tensor, tp_group_size, + self.indexer_gpu_blocks, self.indexer_cpu_blocks, + self.indexer_gpu_kv_strides_tensor, self.indexer_gpu_block_strides_tensor, + self.indexer_gpu_layer_strides_tensor, self.indexer_gpu_chunk_sizes_tensor, + self.indexer_ssd_files) flexkv_logger.info(f"[LayerwiseWorker] __init__ completed successfully, worker_id={worker_id}") def _receive_eventfds_from_sglang(self, tp_group_size: int, @@ -378,6 +489,8 @@ def _transfer_impl(self, dst_block_ids_disk2h: Optional[torch.Tensor], layer_granularity: int, counter_id: int = 0, + indexer_src_block_ids: Optional[torch.Tensor] = None, + indexer_dst_block_ids: Optional[torch.Tensor] = None, **kwargs: Any) -> None: assert src_block_ids_h2d.dtype == torch.int64 assert dst_block_ids_h2d.dtype == torch.int64 @@ -392,6 +505,21 @@ def _transfer_impl(self, cpu_block_ids_d2h = dst_block_ids_disk2h if dst_block_ids_disk2h is not None \ else torch.empty(0, dtype=torch.int64) + # Prepare indexer block_ids for fused transfer + indexer_gpu_block_id_tensor = torch.Tensor() + indexer_cpu_block_id_tensor = torch.Tensor() + if self.enable_indexer and indexer_dst_block_ids is not None and len(indexer_dst_block_ids) > 0: + indexer_gpu_block_id_tensor = indexer_dst_block_ids + indexer_cpu_block_id_tensor = indexer_src_block_ids + + # Prepare indexer SSD block_ids for fused DISK2H transfer + indexer_ssd_block_ids_tensor = torch.Tensor() + indexer_cpu_block_ids_d2h_tensor = torch.Tensor() + if self.enable_indexer_ssd and src_block_ids_disk2h is not None: + # Indexer SSD block_ids mirror main KV's DISK2H block_ids (1:1 mapping) + indexer_ssd_block_ids_tensor = ssd_block_ids + indexer_cpu_block_ids_d2h_tensor = cpu_block_ids_d2h + self.layerwise_transfer_group.layerwise_transfer( ssd_block_ids, cpu_block_ids_d2h, @@ -415,6 +543,18 @@ def _transfer_impl(self, layer_granularity, self.is_mla, counter_id, + indexer_gpu_block_id_tensor, + indexer_cpu_block_id_tensor, + self.indexer_cpu_block_stride_in_bytes, + self.indexer_cpu_layer_stride_in_bytes, + self.indexer_h2d_cpu_kv_stride_in_bytes, + self.indexer_h2d_cpu_layer_stride_in_bytes, + indexer_ssd_block_ids_tensor, + indexer_cpu_block_ids_d2h_tensor, + self.indexer_ssd_layer_stride_in_bytes, + self.indexer_ssd_kv_stride_in_bytes, + self.indexer_cpu_chunk_size_in_bytes, + self.indexer_num_blocks_per_file, ) def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: @@ -432,6 +572,15 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: src_block_ids_disk2h = None dst_block_ids_disk2h = None + # Extract indexer block_ids if available + indexer_src_block_ids = None + indexer_dst_block_ids = None + if self.enable_indexer and transfer_op.indexer_src_block_ids.size > 0: + indexer_src_block_ids = torch.from_numpy( + transfer_op.indexer_src_block_ids).to(dtype=torch.int64).pin_memory() + indexer_dst_block_ids = torch.from_numpy( + transfer_op.indexer_dst_block_ids).to(dtype=torch.int64).pin_memory() + num_h2d_blocks = len(src_block_ids_h2d) start_time = time.time() @@ -442,6 +591,8 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: dst_block_ids_disk2h, layer_granularity, transfer_op.counter_id, + indexer_src_block_ids=indexer_src_block_ids, + indexer_dst_block_ids=indexer_dst_block_ids, ) end_time = time.time() diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 0a6bcb05bf..55f291e178 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -155,27 +155,57 @@ def _init_workers(self) -> None: self._worker_map: Dict[TransferType, Union[WorkerHandle, List[WorkerHandle]]] = {} assert self._cpu_handle is not None + _enable_layerwise = GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer # Use num_gpu_groups to support multi-instance mode # Use gpu_device_id from StorageHandle for correct CUDA device selection + + # H2D worker + if not _enable_layerwise: + if self.tp_size == 1: + self.h2d_workers: List[WorkerHandle] = [ + GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=gpu_handles[0].get_tensor_handle_list(), + cpu_blocks=self._cpu_handle.get_tensor(), + gpu_kv_layout=gpu_handles[0].kv_layout, + cpu_kv_layout=self._cpu_handle.kv_layout, + dtype=gpu_handles[0].dtype, + gpu_device_id=gpu_handles[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for _, gpu_handles in self.gpu_handle_groups.items() + ] + else: + self.h2d_workers = [ + tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], + cpu_blocks=self._cpu_handle.get_tensor(), + gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], + cpu_kv_layout=self._cpu_handle.kv_layout, + dtype=gpu_handles[0].dtype, + tp_group_size=self.tp_size, + dp_group_id=dp_client_id, + is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), + cp_size=getattr(self.model_config, "cp_size", 1), + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for dp_client_id, gpu_handles in self.gpu_handle_groups.items() + ] + self._worker_map[TransferType.H2D] = self.h2d_workers + + # D2H worker if self.tp_size == 1: - self.h2d_workers: List[WorkerHandle] = [ - GPUCPUTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self.finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=gpu_handles[0].get_tensor_handle_list(), - cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layout=gpu_handles[0].kv_layout, - cpu_kv_layout=self._cpu_handle.kv_layout, - dtype=gpu_handles[0].dtype, - gpu_device_id=gpu_handles[0].gpu_device_id, - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, - ) - for _, gpu_handles in self.gpu_handle_groups.items() - ] self.d2h_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -195,7 +225,7 @@ def _init_workers(self) -> None: for _, gpu_handles in self.gpu_handle_groups.items() ] else: - self.h2d_workers = [ + self.d2h_workers = [ tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, @@ -216,43 +246,26 @@ def _init_workers(self) -> None: ) for dp_client_id, gpu_handles in self.gpu_handle_groups.items() ] - self.d2h_workers = [ - tpGPUCPUTransferWorker.create_worker( + self._worker_map[TransferType.D2H] = self.d2h_workers + + if self._ssd_handle is not None and self._cpu_handle is not None: + # DISK2H worker + if not _enable_layerwise: + self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], + ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, - dtype=gpu_handles[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), - cp_size=getattr(self.model_config, "cp_size", 1), - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ssd_kv_layout=self._ssd_handle.kv_layout, + dtype=self._cpu_handle.dtype, + num_blocks_per_file=self._ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, ) - for dp_client_id, gpu_handles in self.gpu_handle_groups.items() - ] - self._worker_map[TransferType.H2D] = self.h2d_workers - self._worker_map[TransferType.D2H] = self.d2h_workers + self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker - if self._ssd_handle is not None and self._cpu_handle is not None: - self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self.finished_ops_queue, - op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), - ssd_files=self._ssd_handle.get_file_list(), - cpu_kv_layout=self._cpu_handle.kv_layout, - ssd_kv_layout=self._ssd_handle.kv_layout, - dtype=self._cpu_handle.dtype, - num_blocks_per_file=self._ssd_handle.num_blocks_per_file, - cache_config=self._cache_config, - ) + # H2DISK worker self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, @@ -266,7 +279,6 @@ def _init_workers(self) -> None: cache_config=self._cache_config, ) self._worker_map[TransferType.H2DISK] = self.cpussd_write_worker - self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -337,8 +349,21 @@ def _init_workers(self) -> None: _cp_size = getattr(self.model_config, 'cp_size', 1) # For CP, each CP rank connects via eventfd; for TP, each TP rank connects. _eventfd_group_size = _cp_size if _is_nsa_cp and _cp_size > 1 else self.tp_size - self.layerwise_workers = [ - LayerwiseTransferWorker.create_worker( + + # Prepare indexer handles for fused layerwise transfer + has_indexer_for_layerwise = ( + self._indexer_gpu_handles is not None and + self._indexer_cpu_handle is not None + ) + + self.layerwise_workers = [] + for dp_client_id, gpu_handles in self.gpu_handle_groups.items(): + # Resolve indexer handles for this dp_client_id + idx_handles = None + if has_indexer_for_layerwise: + idx_handles = self._indexer_gpu_handles.get(dp_client_id) + + worker = LayerwiseTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -361,9 +386,22 @@ def _init_workers(self) -> None: h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, is_nsa_cp=_is_nsa_cp, + indexer_gpu_blocks=[h.get_tensor_handle_list() for h in idx_handles] if idx_handles else None, + indexer_cpu_blocks=self._indexer_cpu_handle.get_tensor() if idx_handles else None, + indexer_gpu_kv_layouts=[h.kv_layout for h in idx_handles] if idx_handles else None, + indexer_cpu_kv_layout=self._indexer_cpu_handle.kv_layout if idx_handles else None, + indexer_dtype=idx_handles[0].dtype if idx_handles else None, + indexer_ssd_files=self._indexer_ssd_handle.get_file_list() if (idx_handles and self._indexer_ssd_handle) else None, + indexer_ssd_kv_layout=self._indexer_ssd_handle.kv_layout if (idx_handles and self._indexer_ssd_handle) else None, + indexer_num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file if (idx_handles and self._indexer_ssd_handle) else 0, ) - for dp_client_id, gpu_handles in self.gpu_handle_groups.items() - ] + self.layerwise_workers.append(worker) + + flexkv_logger.debug( + f"[TransferEngine] Created layerwise worker for dp_client_id={dp_client_id}: " + f"tp_size={self.tp_size}, has_indexer={idx_handles is not None}, " + f"has_ssd={len(ssd_files) > 0}") + self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers if self.cache_config.enable_kv_sharing and self._cpu_handle is not None and (self.cache_config.enable_p2p_cpu \ @@ -398,25 +436,53 @@ def _init_workers(self) -> None: and self._indexer_cpu_handle is not None): self._indexer_finished_ops_queue = self.mp_ctx.Queue() self._indexer_worker_map: Dict[TransferType, Union[WorkerHandle, List[WorkerHandle]]] = {} + # H2D indexer worker + if not _enable_layerwise: + if self.tp_size == 1: + self._indexer_h2d_workers = [ + GPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + else: + self._indexer_h2d_workers = [ + tpGPUCPUTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + dtype=indexer_gpu_handles_list[0].dtype, + tp_group_size=self.tp_size, + dp_group_id=dp_client_id, + is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), + cp_size=getattr(self.model_config, "cp_size", 1), + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, + transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, + ) + for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + ] + self._indexer_worker_map[TransferType.H2D] = self._indexer_h2d_workers + + # D2H indexer worker if self.tp_size == 1: - self._indexer_h2d_workers = [ - GPUCPUTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self._indexer_finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), - gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, - cpu_kv_layout=self._indexer_cpu_handle.kv_layout, - dtype=indexer_gpu_handles_list[0].dtype, - gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, - ) - for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] self._indexer_d2h_workers = [ GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -436,27 +502,6 @@ def _init_workers(self) -> None: for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() ] else: - self._indexer_h2d_workers = [ - tpGPUCPUTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self._indexer_finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], - cpu_blocks=self._indexer_cpu_handle.get_tensor(), - gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], - cpu_kv_layout=self._indexer_cpu_handle.kv_layout, - dtype=indexer_gpu_handles_list[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), - cp_size=getattr(self.model_config, "cp_size", 1), - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, - ) - for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] self._indexer_d2h_workers = [ tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -478,9 +523,9 @@ def _init_workers(self) -> None: ) for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() ] - self._indexer_worker_map[TransferType.H2D] = self._indexer_h2d_workers self._indexer_worker_map[TransferType.D2H] = self._indexer_d2h_workers if self._indexer_ssd_handle is not None and self._indexer_cpu_handle is not None: + # H2DISK indexer worker self._indexer_h2disk_worker = CPUSSDDiskTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, @@ -493,20 +538,22 @@ def _init_workers(self) -> None: num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, cache_config=self._cache_config, ) - self._indexer_disk2h_worker = CPUSSDDiskTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self._indexer_finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), - ssd_files=self._indexer_ssd_handle.get_file_list(), - cpu_kv_layout=self._indexer_cpu_handle.kv_layout, - ssd_kv_layout=self._indexer_ssd_handle.kv_layout, - dtype=self._indexer_cpu_handle.dtype, - num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, - cache_config=self._cache_config, - ) self._indexer_worker_map[TransferType.H2DISK] = self._indexer_h2disk_worker - self._indexer_worker_map[TransferType.DISK2H] = self._indexer_disk2h_worker + # DISK2H indexer worker + if not _enable_layerwise: + self._indexer_disk2h_worker = CPUSSDDiskTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self._indexer_finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + cpu_blocks=self._indexer_cpu_handle.get_tensor(), + ssd_files=self._indexer_ssd_handle.get_file_list(), + cpu_kv_layout=self._indexer_cpu_handle.kv_layout, + ssd_kv_layout=self._indexer_ssd_handle.kv_layout, + dtype=self._indexer_cpu_handle.dtype, + num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file, + cache_config=self._cache_config, + ) + self._indexer_worker_map[TransferType.DISK2H] = self._indexer_disk2h_worker flexkv_logger.info("TransferEngine: indexer SSD workers initialized") if self._indexer_remote_handle is not None and self._indexer_cpu_handle is not None: self._indexer_h2remote_worker = CPURemoteTransferWorker.create_worker( @@ -594,43 +641,15 @@ def _init_workers(self) -> None: if self.cache_config.enable_p2p_ssd: self._indexer_worker_map[TransferType.PEERSSD2H] = self._indexer_cpu_remote_cpu_worker flexkv_logger.info("TransferEngine: indexer P2P workers initialized") - if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: - indexer_ssd_files = {} if self._indexer_ssd_handle is None else self._indexer_ssd_handle.get_file_list() - indexer_ssd_kv_layout = None if self._indexer_ssd_handle is None else self._indexer_ssd_handle.kv_layout - indexer_num_blocks_per_file = 0 if self._indexer_ssd_handle is None else self._indexer_ssd_handle.num_blocks_per_file - self._indexer_layerwise_workers = [ - LayerwiseTransferWorker.create_worker( - mp_ctx=self.mp_ctx, - finished_ops_queue=self._indexer_finished_ops_queue, - op_buffer_tensor=self.pin_buffer.get_buffer(), - gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], - cpu_blocks=self._indexer_cpu_handle.get_tensor(), - ssd_files=indexer_ssd_files, - gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], - cpu_kv_layout=self._indexer_cpu_handle.kv_layout, - ssd_kv_layout=indexer_ssd_kv_layout, - dtype=indexer_gpu_handles_list[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - pp_rank=self.model_config.pp_rank, - pp_size=self.model_config.pp_size, - dp_size=self.model_config.dp_size, - dp_rank=dp_client_id, - num_blocks_per_file=indexer_num_blocks_per_file, - use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, - use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, - h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, - d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, - enable_eventfd=False, - ) - for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] - self._indexer_worker_map[TransferType.LAYERWISE] = self._indexer_layerwise_workers - flexkv_logger.info("TransferEngine: indexer Layerwise workers initialized") self._has_indexer = True - flexkv_logger.info( - f"TransferEngine: indexer inline workers initialized " - f"({len(self._indexer_h2d_workers)} H2D + {len(self._indexer_d2h_workers)} D2H)") + if not _enable_layerwise: + flexkv_logger.info( + f"TransferEngine: indexer inline workers initialized " + f"({len(self._indexer_h2d_workers)} H2D + {len(self._indexer_d2h_workers)} D2H)") + else: + flexkv_logger.info( + f"TransferEngine: indexer inline workers initialized " + f"(H2D fused into layerwise, {len(self._indexer_d2h_workers)} D2H)") if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") @@ -657,6 +676,14 @@ def _init_workers(self) -> None: flexkv_logger.info(f"waiting for indexer {transfer_type.name} worker {worker.worker_id} to ready") worker.ready_event.wait() flexkv_logger.info(f"indexer {transfer_type.name} worker {worker.worker_id} is ready") + # Startup assertions: verify layerwise mode worker map consistency + if _enable_layerwise: + assert TransferType.H2D not in self._worker_map, \ + "H2D worker should not exist in layerwise mode (fused into layerwise worker)" + assert TransferType.DISK2H not in self._worker_map, \ + "DISK2H worker should not exist in layerwise mode (fused into layerwise worker)" + assert TransferType.LAYERWISE in self._worker_map, \ + "LAYERWISE worker must exist when layerwise transfer is enabled" # Start scheduler thread self._running = True self._scheduler_thread = threading.Thread(target=self._scheduler_loop) @@ -848,10 +875,10 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: op.pending_count += 1 flexkv_logger.debug( - f"[TransferEngine] Created indexer op {indexer_op.op_id} " - f"for parent op {op.op_id}: {num_pages} pages, " - f"type={op.transfer_type.name}" - ) + f"[TransferEngine] === Indexer Op Dispatched (non-layerwise) ===" + f"\n parent_op_id={op.op_id}, indexer_op_id={indexer_op.op_id}" + f"\n type={op.transfer_type.name}, dp_id={op.dp_id}" + f"\n num_pages={num_pages}, pending_count={op.pending_count}") indexer_worker = self._indexer_worker_map[op.transfer_type] if isinstance(indexer_worker, List): diff --git a/flexkv/transfer/worker_op.py b/flexkv/transfer/worker_op.py index a271435275..ecc7b29f9c 100644 --- a/flexkv/transfer/worker_op.py +++ b/flexkv/transfer/worker_op.py @@ -52,6 +52,9 @@ class WorkerLayerwiseTransferOp: src_block_ids_disk2h: np.ndarray dst_block_ids_disk2h: np.ndarray counter_id: int # Counter set index for triple buffering eventfd notification + # Indexer block_ids for fused indexer transfer + indexer_src_block_ids: np.ndarray + indexer_dst_block_ids: np.ndarray def __init__(self, transfer_op: LayerwiseTransferOp): self.transfer_op_id = transfer_op.op_id @@ -65,3 +68,5 @@ def __init__(self, transfer_op: LayerwiseTransferOp): self.src_block_ids_disk2h = transfer_op.src_block_ids_disk2h self.dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h self.counter_id = transfer_op.counter_id + self.indexer_src_block_ids = transfer_op.indexer_src_block_ids + self.indexer_dst_block_ids = transfer_op.indexer_dst_block_ids From 1bf5eca762a506ff48bfa1467a039bc841c0bf23 Mon Sep 17 00:00:00 2001 From: zitto Date: Fri, 17 Apr 2026 16:13:58 +0800 Subject: [PATCH 53/59] fix server_args kv cache dtype (#151) Co-authored-by: zittozhang --- flexkv/integration/config.py | 12 +- tests/test_kvmanager.py | 360 +++++++++++++++++++++++++++++------ 2 files changed, 311 insertions(+), 61 deletions(-) diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index c0c0d2c2ee..a288d544a5 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -3,7 +3,7 @@ import os import torch import tempfile -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from dataclasses import dataclass, field from flexkv.common.debug import flexkv_logger @@ -130,6 +130,7 @@ def post_init_from_sglang_config( is_nsa_cp: bool = False, cp_size: int = 1, cp_rank: int = 0, + kv_cache_dtype: Optional[str] = None, ): """ Initialize FlexKVConfig fields from sglang config. @@ -210,10 +211,17 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: f"[FlexKV] Using kv_cache_dtype from user_config: " f"'{user_dtype_str}' -> {self.model_config.dtype}" ) + elif kv_cache_dtype is not None and kv_cache_dtype != "auto": + # Use the kv_cache_dtype from sglang server_args (e.g. "fp8_e4m3") + self.model_config.dtype = _parse_dtype_str(kv_cache_dtype) + logger.info( + f"[FlexKV] Using kv_cache_dtype from sglang server_args: " + f"'{kv_cache_dtype}' -> {self.model_config.dtype}" + ) else: self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16) logger.warning( - f"[FlexKV] No kv_cache_dtype in user_config, falling back to sglang " + f"[FlexKV] No kv_cache_dtype in user_config or server_args, falling back to sglang " f"model dtype: {self.model_config.dtype}. If your KV cache uses a " f"different dtype (e.g. fp8), add 'kv_cache_dtype: fp8' to your " f"flexkv_config.yaml or set FLEXKV_KV_CACHE_DTYPE=fp8 environment variable." diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index 74f7f1e0be..c446dda667 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -491,9 +491,19 @@ def hash_all_values(self, layer_id, token_ids): token_hash += int(token_id) * (i + 17) return torch.tensor(((layer_id + 1) * 29 + token_hash) % 251 + 1, dtype=self.dtype).item() - def fill_gpu_blocks(self, token_ids, block_ids): - assert len(token_ids) == len(block_ids) * self.tokens_per_block - + def fill_gpu_blocks(self, block_ids, main_kv_tokens_per_block, token_ids): + """Fill indexer GPU blocks with deterministic hash values. + + Indexer uses tokens_per_block=1 on CPU/SSD side. Each indexer block + corresponds to one main-KV block (1:1 page mapping). We hash the + *entire page* of token_ids from the main KV request to produce a + single deterministic value per (layer, block). + + Args: + block_ids: block IDs to fill (same as main KV block_ids). + main_kv_tokens_per_block: tokens_per_block of main KV (e.g. 16). + token_ids: full token_ids tensor from the request. + """ if not isinstance(token_ids, torch.Tensor): token_ids = torch.tensor(token_ids, dtype=torch.int64) if not isinstance(block_ids, torch.Tensor): @@ -503,12 +513,13 @@ def fill_gpu_blocks(self, token_ids, block_ids): for layer_id in range(self.num_layers): gpu_tensor = self.gpu_blocks[tp_id][layer_id] for block_idx, block_id in enumerate(block_ids): - start_token_idx = block_idx * self.tokens_per_block - end_token_idx = start_token_idx + self.tokens_per_block + start_token_idx = block_idx * main_kv_tokens_per_block + end_token_idx = start_token_idx + main_kv_tokens_per_block hash_value = self.hash_all_values( layer_id, token_ids[start_token_idx:end_token_idx], ) + # gpu_tensor shape: (num_blocks, tokens_per_block=1, head_size) gpu_tensor[block_id, :, :] = hash_value def clear_gpu_blocks(self, block_ids): @@ -519,9 +530,14 @@ def clear_gpu_blocks(self, block_ids): for layer_id in range(self.num_layers): self.gpu_blocks[tp_id][layer_id][block_ids, :, :] = 0 - def verify_gpu_blocks(self, token_ids, block_ids) -> bool: - assert len(token_ids) == len(block_ids) * self.tokens_per_block + def verify_gpu_blocks(self, block_ids, main_kv_tokens_per_block, token_ids) -> bool: + """Verify indexer GPU blocks after round-trip transfer. + Args: + block_ids: block IDs to verify. + main_kv_tokens_per_block: tokens_per_block of main KV. + token_ids: full token_ids tensor from the request. + """ if not isinstance(token_ids, torch.Tensor): token_ids = torch.tensor(token_ids, dtype=torch.int64) if not isinstance(block_ids, torch.Tensor): @@ -534,8 +550,8 @@ def verify_gpu_blocks(self, token_ids, block_ids) -> bool: for layer_id in range(self.num_layers): gpu_tensor = self.gpu_blocks[tp_id][layer_id] for block_idx, block_id in enumerate(block_ids): - start_token_idx = block_idx * self.tokens_per_block - end_token_idx = start_token_idx + self.tokens_per_block + start_token_idx = block_idx * main_kv_tokens_per_block + end_token_idx = start_token_idx + main_kv_tokens_per_block expected_hash_value = self.hash_all_values( layer_id, token_ids[start_token_idx:end_token_idx], @@ -598,11 +614,12 @@ def run_tp_client_with_indexer(dp_client_id, raise ValueError(f"Invalid GPU layout type for indexer test: {gpu_layout_type}") # Derive indexer params from cache_config.indexer (IndexerCacheConfig). - # Shared fields (num_layers, tokens_per_block) come from main model_config / cache_config. + # Indexer uses tokens_per_block=1 (one indexer entry per page/block), + # matching the CPU/SSD layout in StorageEngine. indexer_cfg = cache_config.indexer assert indexer_cfg is not None, "cache_config.indexer must be set for indexer shadow transfer tests" - indexer_tokens_per_block = cache_config.tokens_per_block # shared with main KV - indexer_num_layers = model_config.num_layers # shared with main KV + indexer_tokens_per_block = 1 # indexer: 1 entry per page (not main KV tokens_per_block) + indexer_num_layers = model_config.num_layers # Create indexer GPU blocks (MLA-style: 3D tensors) indexer_blocks = [] @@ -661,21 +678,12 @@ def run_tp_client_with_indexer(dp_client_id, child_conn.close() -@pytest.mark.parametrize( - "model_config", - [ - {"tp_size": 1, "dp_size": 1}, - ], indirect=True, -) -@pytest.mark.parametrize("cache_config", [ - {'enable_cpu': True, 'enable_ssd': False, 'num_cpu_blocks': 1024}, -], indirect=True) -@pytest.mark.parametrize("test_config", [ - {'num_gpu_blocks': 256, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, -], indirect=True) -@pytest.mark.parametrize("gpu_layout_type", [0]) -def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_layout_type): - """Test KVManager shadow transfer mode with attached indexer buffers.""" +def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, test_label="indexer", layerwise=False): + """Core test logic for KVManager with indexer shadow transfer. + + Shared by test_kvmanager_with_indexer (non-layerwise) and + test_kvmanager_with_indexer_layerwise (layerwise mode). + """ tp_size = model_config.tp_size tokens_per_block = cache_config.tokens_per_block num_gpu_blocks = test_config["num_gpu_blocks"] @@ -685,9 +693,6 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay skip_if_insufficient_gpus(tp_size) - # Set indexer config inside cache_config.indexer (IndexerCacheConfig). - # Only indexer-unique fields are stored here; shared fields (num_layers, - # tokens_per_block, num_cpu_blocks) are read from model_config / cache_config. from flexkv.common.config import IndexerCacheConfig cache_config.indexer = IndexerCacheConfig( head_size=64, @@ -759,7 +764,7 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay type=KVCacheLayoutType.LAYERFIRST, num_layer=model_config.num_layers, num_block=num_gpu_blocks, - tokens_per_block=cache_config.tokens_per_block, + tokens_per_block=1, # indexer: 1 entry per page num_head=indexer_cfg.num_kv_heads, head_size=indexer_cfg.head_size, is_mla=True, @@ -773,19 +778,19 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay while not kvmanager.is_ready(): time.sleep(1) - flexkv_logger.info("waiting for flexkv (with indexer shadow transfer) to be ready") - print("[Test] KVManager (with indexer shadow transfer) is ready") + flexkv_logger.info(f"waiting for flexkv ({test_label}) to be ready") + print(f"[Test] KVManager ({test_label}) is ready") request_pairs = [generate_request_pair(i, block_per_request, num_gpu_blocks, tokens_per_block, 1) for i in range(num_requests)] initial_write_num = int(num_requests * initial_write_ratio) - print("[Test] Testing put flow with indexer shadow transfer...") + print(f"[Test] Testing put flow ({test_label})...") for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) if indexer_kv_verifier is not None: - indexer_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + indexer_kv_verifier.fill_gpu_blocks(block_ids, tokens_per_block, token_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), @@ -798,15 +803,18 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay gpu_kv_verifier.clear_gpu_blocks(block_ids) if indexer_kv_verifier is not None: indexer_kv_verifier.clear_gpu_blocks(block_ids) - print(f"[Test] Initial {initial_write_num} put operations completed with indexer shadow transfer") + print(f"[Test] Initial {initial_write_num} put operations completed ({test_label})") - print("[Test] Testing get flow with indexer shadow transfer...") + print(f"[Test] Testing get flow ({test_label})...") total_cache_hit = 0 total_cache_miss = 0 running_get_requests = [] req_id2block_ids = {} req_id2token_ids = {} + batch_task_ids = [] + batch_slot_mappings = [] + for i in range(min(initial_write_num, num_requests)): token_ids, block_ids, dp_id = request_pairs[i] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) @@ -816,38 +824,76 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay token_mask=None, dp_id=dp_id, ) - kvmanager.launch(request_id, slot_mapping) - running_get_requests.append(request_id) + batch_task_ids.append(request_id) + batch_slot_mappings.append(slot_mapping) req_id2block_ids[request_id] = block_ids req_id2token_ids[request_id] = token_ids - if running_get_requests: - return_results = kvmanager.wait(running_get_requests, completely=True) - for req_id, kvresponse in return_results.items(): - assert kvresponse.status == KVResponseStatus.SUCCESS - total_cache_hit += kvresponse.return_mask.sum().item() - total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() + if layerwise: + # Layerwise mode: launch all GETs as a single batch so that + # merge_to_batch_graph produces a LAYERWISE op (fused DISK2H+H2D). + returned_ids = kvmanager.launch( + task_ids=batch_task_ids, + slot_mappings=batch_slot_mappings, + as_batch=True, + layerwise_transfer=True, + ) + batch_id = returned_ids[0] + batch_results = kvmanager.wait(batch_id, completely=True) + kvresponse = batch_results[batch_id] + assert kvresponse.status == KVResponseStatus.SUCCESS, \ + f"Layerwise batch GET failed: {kvresponse.status}" + for idx, orig_req_id in enumerate(batch_task_ids): + mask = kvresponse.return_mask[idx] + total_cache_hit += mask.sum().item() + total_cache_miss += len(mask) - mask.sum().item() if gpu_kv_verifier is not None: - valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + valid_fetched_tokens = mask.sum().item() // tokens_per_block * tokens_per_block if valid_fetched_tokens > 0: assert gpu_kv_verifier.verify_kv_blocks( - req_id2token_ids[req_id][:valid_fetched_tokens], - req_id2block_ids[req_id][:valid_fetched_tokens // tokens_per_block]) + req_id2token_ids[orig_req_id][:valid_fetched_tokens], + req_id2block_ids[orig_req_id][:valid_fetched_tokens // tokens_per_block]) if indexer_kv_verifier is not None: - valid_fetched_blocks = kvresponse.return_mask.sum().item() // tokens_per_block + valid_fetched_blocks = mask.sum().item() // tokens_per_block if valid_fetched_blocks > 0: assert indexer_kv_verifier.verify_gpu_blocks( - req_id2token_ids[req_id][:valid_fetched_blocks * tokens_per_block], - req_id2block_ids[req_id][:valid_fetched_blocks]) - print(f"[Test] Get flow completed: hit={total_cache_hit}, miss={total_cache_miss}") + req_id2block_ids[orig_req_id][:valid_fetched_blocks], + tokens_per_block, + req_id2token_ids[orig_req_id][:valid_fetched_blocks * tokens_per_block]) + else: + # Non-layerwise: launch each GET individually + for req_id in batch_task_ids: + kvmanager.launch(req_id, batch_slot_mappings[batch_task_ids.index(req_id)]) + running_get_requests.append(req_id) - print("[Test] Testing try_wait flow with indexer shadow transfer...") + if running_get_requests: + return_results = kvmanager.wait(running_get_requests, completely=True) + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + total_cache_hit += kvresponse.return_mask.sum().item() + total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() + if gpu_kv_verifier is not None: + valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + if valid_fetched_tokens > 0: + assert gpu_kv_verifier.verify_kv_blocks( + req_id2token_ids[req_id][:valid_fetched_tokens], + req_id2block_ids[req_id][:valid_fetched_tokens // tokens_per_block]) + if indexer_kv_verifier is not None: + valid_fetched_blocks = kvresponse.return_mask.sum().item() // tokens_per_block + if valid_fetched_blocks > 0: + assert indexer_kv_verifier.verify_gpu_blocks( + req_id2block_ids[req_id][:valid_fetched_blocks], + tokens_per_block, + req_id2token_ids[req_id][:valid_fetched_blocks * tokens_per_block]) + print(f"[Test] Get flow completed ({test_label}): hit={total_cache_hit}, miss={total_cache_miss}") + + print(f"[Test] Testing try_wait flow ({test_label})...") if initial_write_num < num_requests: token_ids, block_ids, dp_id = request_pairs[initial_write_num] if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) if indexer_kv_verifier is not None: - indexer_kv_verifier.fill_gpu_blocks(token_ids, block_ids) + indexer_kv_verifier.fill_gpu_blocks(block_ids, tokens_per_block, token_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), @@ -866,13 +912,209 @@ def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_lay gpu_kv_verifier.clear_gpu_blocks(block_ids) if indexer_kv_verifier is not None: indexer_kv_verifier.clear_gpu_blocks(block_ids) - print("[Test] try_wait flow completed") + print(f"[Test] try_wait flow completed ({test_label})") - print("[Test] Testing shutdown with indexer shadow transfer...") - if cache_config.enable_cpu and cache_config.num_cpu_blocks >= num_gpu_blocks: + # Cache miss assertion: when total capacity >= GPU blocks, expect 0 miss + enable_cpu = cache_config.enable_cpu + enable_ssd = cache_config.enable_ssd + num_cpu_blocks = cache_config.num_cpu_blocks + num_ssd_blocks = cache_config.num_ssd_blocks + if (enable_cpu and num_cpu_blocks >= num_gpu_blocks) or \ + (enable_ssd and num_ssd_blocks >= num_gpu_blocks): assert total_cache_miss == 0, f"Expected 0 cache miss, got {total_cache_miss}" shutdown_tp_client(tp_client_processes) kvmanager.shutdown() - print("[Test] Shutdown completed successfully") - print("[Test] test_kvmanager_with_indexer PASSED") + print(f"[Test] {test_label} PASSED") + + +@pytest.mark.parametrize( + "model_config", + [ + {"tp_size": 1, "dp_size": 1}, + ], indirect=True, +) +@pytest.mark.parametrize("cache_config", [ + {'enable_cpu': True, 'enable_ssd': False, 'num_cpu_blocks': 1024}, + {'enable_cpu': True, 'enable_ssd': True, 'num_cpu_blocks': 256, 'num_ssd_blocks': 2048}, +], indirect=True) +@pytest.mark.parametrize("test_config", [ + {'num_gpu_blocks': 256, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, +], indirect=True) +@pytest.mark.parametrize("gpu_layout_type", [0]) +def test_kvmanager_with_indexer(model_config, cache_config, test_config, gpu_layout_type): + """Test KVManager with indexer: GPU↔CPU (and optionally ↔SSD) data correctness.""" + ssd_label = "+ssd" if cache_config.enable_ssd else "" + _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, + test_label=f"indexer{ssd_label}") + + +import ctypes +import socket +import struct +import threading + +# ---- Mock SGLang eventfd client for layerwise unit tests ---- + +_libc = ctypes.CDLL("libc.so.6", use_errno=True) + + +def _sys_eventfd(initval: int = 0, flags: int = 0) -> int: + """Create an eventfd file descriptor via libc.""" + fd = _libc.eventfd(ctypes.c_uint(initval), ctypes.c_int(flags)) + if fd == -1: + err = ctypes.get_errno() + raise OSError(err, f"eventfd failed: {os.strerror(err)}") + return fd + + +_EFD_SEMAPHORE = 0x1 + + +def _send_fds_via_scm(sock: socket.socket, fds: list, extra_data: bytes = b"x"): + """Send fds via SCM_RIGHTS (mirrors SGLang's send_fds).""" + fds_packed = struct.pack(f"{len(fds)}i", *fds) + ancdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds_packed)] + sock.sendmsg([extra_data], ancdata) + + +def _mock_sglang_eventfd_client(socket_path: str, + tp_rank: int, + tp_size: int, + num_layers: int, + num_counters: int = 3, + max_retries: int = 120, + retry_interval: float = 0.5): + """Simulate SGLang sending eventfds to the LayerwiseTransferWorker. + + Runs in a background thread. Creates real eventfds so the C++ + LayerwiseTransferGroup receives valid file descriptors. The eventfds + are never read by anyone in the test, but that is fine: the C++ + ``enable_eventfd_`` flag will be ``true`` and ``eventfd_write`` will + simply increment the counter without blocking. + """ + created_fds = [] + try: + # Create real eventfds + for _ in range(num_counters * num_layers): + created_fds.append(_sys_eventfd(0, _EFD_SEMAPHORE)) + + # Retry connecting until the worker process binds the socket + sock = None + for attempt in range(max_retries): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(socket_path) + print(f"[MockEventfdClient] Connected to {socket_path} " + f"(attempt {attempt + 1})") + break + except (FileNotFoundError, ConnectionRefusedError): + sock.close() + sock = None + time.sleep(retry_interval) + + if sock is None: + print(f"[MockEventfdClient] FAILED to connect to {socket_path} " + f"after {max_retries} attempts") + return + + # Send 24-byte metadata: tp_rank, tp_size, cp_rank, cp_size, + # num_layers, num_counters + metadata = struct.pack("iiiiii", + tp_rank, tp_size, + 0, 1, # cp_rank=0, cp_size=1 + num_layers, num_counters) + sock.sendall(metadata) + + # Send eventfds for each counter via SCM_RIGHTS + fd_idx = 0 + for counter_id in range(num_counters): + fds = created_fds[fd_idx:fd_idx + num_layers] + fd_idx += num_layers + _send_fds_via_scm(sock, fds, struct.pack("i", counter_id)) + + # Wait for ACK + sock.settimeout(30.0) + ack = sock.recv(1) + if ack and ack[0] == 1: + print(f"[MockEventfdClient] Eventfd handshake OK " + f"(counters={num_counters}, layers={num_layers})") + else: + print(f"[MockEventfdClient] Unexpected ACK: {ack!r}") + sock.close() + except Exception as e: + print(f"[MockEventfdClient] Error: {e}") + traceback.print_exc() + # Note: we intentionally do NOT close the eventfds here. + # They must remain valid for the lifetime of the LayerwiseTransferGroup + # in the worker subprocess. They will be cleaned up when the worker + # process exits and the OS reclaims the file descriptors. + + +@pytest.mark.parametrize( + "model_config", + [ + {"tp_size": 1, "dp_size": 1}, + ], indirect=True, +) +@pytest.mark.parametrize("cache_config", [ + {'enable_cpu': True, 'enable_ssd': False, 'num_cpu_blocks': 1024}, + {'enable_cpu': True, 'enable_ssd': True, 'num_cpu_blocks': 256, 'num_ssd_blocks': 2048}, +], indirect=True) +@pytest.mark.parametrize("test_config", [ + {'num_gpu_blocks': 256, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, +], indirect=True) +@pytest.mark.parametrize("gpu_layout_type", [0]) +def test_kvmanager_with_indexer_layerwise(model_config, cache_config, test_config, gpu_layout_type): + """Test KVManager with indexer in LAYERWISE mode. + + Validates the full round-trip: + PUT: D2H + H2DISK (non-layerwise, same as normal) + GET: LAYERWISE (fused DISK2H + H2D) + Data correctness is verified for both the main KV cache and the + indexer (DSA) KV cache after the round-trip. + + A background thread simulates the SGLang eventfd client so the + LayerwiseTransferWorker can complete its initialization handshake + without any source-code changes. + """ + from flexkv.common.config import GLOBAL_CONFIG_FROM_ENV + + # Save original values + orig_layerwise_env = os.environ.get('FLEXKV_ENABLE_LAYERWISE_TRANSFER') + orig_layerwise_flag = GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer + + # Determine the socket path that the worker will listen on. + # For tp_size=1, pp_size=1, dp_size=1, there is no suffix. + socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', + '/tmp/flexkv_layerwise_eventfd.sock') + + try: + # Enable layerwise transfer + os.environ['FLEXKV_ENABLE_LAYERWISE_TRANSFER'] = '1' + GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer = True + + # Start mock SGLang eventfd client thread BEFORE kvmanager.start() + # so it is ready to connect once the worker process binds the socket. + eventfd_thread = threading.Thread( + target=_mock_sglang_eventfd_client, + args=(socket_path, 0, 1, model_config.num_layers), + daemon=True, + ) + eventfd_thread.start() + + ssd_label = "+ssd" if cache_config.enable_ssd else "" + _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, + test_label=f"layerwise+indexer{ssd_label}", layerwise=True) + + eventfd_thread.join(timeout=10) + finally: + # Restore original environment and config + if orig_layerwise_env is None: + os.environ.pop('FLEXKV_ENABLE_LAYERWISE_TRANSFER', None) + else: + os.environ['FLEXKV_ENABLE_LAYERWISE_TRANSFER'] = orig_layerwise_env + GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer = orig_layerwise_flag + + + From b9e29ace0716a43372053ffb56debbd56b8092c9 Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Wed, 22 Apr 2026 16:03:54 +0800 Subject: [PATCH 54/59] fix: support cross-node TP - add _node suffix to eventfd socket path via FLEXKV_NODE_ID env var and divide cp_size for multi-node TP when NSA CP is active --- flexkv/kvtask.py | 5 +++++ flexkv/transfer/layerwise.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 47128b5309..d096d02ad3 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -123,6 +123,11 @@ def __init__(self, model_config_for_transfer.tp_size //= self.tp_node_count if not self.model_config.use_mla: model_config_for_transfer.num_kv_heads //= self.tp_node_count + # When NSA CP is active, cp_size mirrors tp_size and must also + # be divided so that TransferEngine's _eventfd_group_size matches + # the number of local GPUs on each node. + if model_config_for_transfer.is_nsa_cp and model_config_for_transfer.cp_size > 1: + model_config_for_transfer.cp_size //= self.tp_node_count combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1" if not combine_with_trtllm: diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 785846fb1e..7acf56127e 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -320,6 +320,12 @@ def _receive_eventfds_from_sglang(self, tp_group_size: int, sock_suffix += f"_pp{int(self.pp_rank)}" if int(self.dp_size) > 1: sock_suffix += f"_dp{int(self.dp_rank)}" + # Multi-node TP: add _node{id} suffix to match sglang connector's + # socket path. The sglang connector sets FLEXKV_NODE_ID when it + # detects cross-node TP; the subprocess inherits this env var. + _node_id_str = os.environ.get("FLEXKV_NODE_ID") + if _node_id_str is not None: + sock_suffix += f"_node{_node_id_str}" if sock_suffix: root, ext = os.path.splitext(base_socket_path) socket_path = f"{root}{sock_suffix}{ext}" From 7ece123c10887cbaaab06843da97f347f91c38f0 Mon Sep 17 00:00:00 2001 From: zittozhang Date: Fri, 24 Apr 2026 14:27:00 +0800 Subject: [PATCH 55/59] refactor: align multi-node TP topology with framework (nnodes/node_rank) and unify eventfd socket path Replace device_count()-based probing and FLEXKV_MASTER_HOST/FLEXKV_NODE_ID/FLEXKV_LOCAL_GPU_COUNT env-var plumbing with explicit ModelConfig fields driven by the framework (sglang server_args / TRT-LLM launcher), so FlexKV and its callers cannot drift on the multi-node layout. Key changes: - ModelConfig: new fields nnodes, node_rank, master_host; FlexKVConfig.post_init_from_sglang_config accepts them. - kvtask.py: derive gpus_per_node=(tp*pp)//nnodes and nnodes_per_tp_group=ceil(tp/gpus_per_node); rename self.tp_node_count->self.nnodes_per_tp_group and drop self.is_multinode_tp (use nnodes_per_tp_group>1). NSA CP division logic preserved. - transfer_manager.py: rename get_master_host_and_ports_from_env()->resolve_master_host_and_ports(master_host=None); TransferManagerOnRemote accepts master_host kwarg (env remains as fallback). - transfer/layerwise.py: new helper build_layerwise_eventfd_socket_path(model_config) as single source of truth for the UDS path (suffix uses pp_rank/dp_rank only; node_rank intentionally omitted since UDS is kernel-local). LayerwiseTransferWorker takes layerwise_eventfd_socket kwarg instead of reading env in the subprocess. - transfer/transfer_engine.py: compute the socket path once and pass it down; replace getattr(model_config, 'is_nsa_cp'/'cp_size', ...) with direct attribute access now that they are ModelConfig fields. --- flexkv/common/config.py | 18 +++++++ flexkv/integration/config.py | 15 ++++++ flexkv/kvtask.py | 67 +++++++++++++++++++------ flexkv/transfer/layerwise.py | 79 +++++++++++++++++------------- flexkv/transfer/transfer_engine.py | 37 ++++++++------ flexkv/transfer_manager.py | 38 +++++++++++--- 6 files changed, 182 insertions(+), 72 deletions(-) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index e406e51f33..383653204f 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -36,6 +36,24 @@ class ModelConfig: pp_size: int = 1 pp_rank: int = 0 + # topology configs + # nnodes : number of physical machines spanned by one replica + # (== server_args.nnodes in SGLang) + # node_rank : index of this machine within ``nnodes`` + # (== server_args.node_rank in SGLang). Used by + # KVTaskEngine's multi-node topology derivation and + # for logging; NOT embedded in the layerwise UDS + # socket path (UDS endpoints are kernel-local). + nnodes: int = 1 + node_rank: int = 0 + + # Multi-node bootstrap: master node's IP for TransferManager rendezvous. + # ``None`` falls back to ``FLEXKV_MASTER_HOST`` env var (default + # ``"localhost"``) inside ``resolve_master_host_and_ports``. Set this + # from the framework's own launch config (e.g. sglang's + # ``--dist-init-addr``) to avoid exposing an extra env knob. + master_host: Optional[str] = None + # NSA context parallelism: when True, layerwise transfer sends full # (unpartitioned) KV cache to every rank instead of head-sliced data. is_nsa_cp: bool = False diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index a288d544a5..580aab57cf 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -127,10 +127,13 @@ def post_init_from_sglang_config( pp_rank: int = 0, dp_size: int = 1, dp_rank: int = 0, + nnodes: int = 1, + node_rank: int = 0, is_nsa_cp: bool = False, cp_size: int = 1, cp_rank: int = 0, kv_cache_dtype: Optional[str] = None, + master_host: Optional[str] = None, ): """ Initialize FlexKVConfig fields from sglang config. @@ -143,9 +146,13 @@ def post_init_from_sglang_config( pp_rank: pipeline parallel rank (default 0) dp_size: data parallel size (default 1, no DP) dp_rank: data parallel rank (default 0) + nnodes: number of nodes (aligned with server_args.nnodes, default 1) + node_rank: index of this node (aligned with server_args.node_rank, default 0) is_nsa_cp: whether NSA context parallelism is enabled cp_size: context parallel size (default 1, no CP) cp_rank: context parallel rank (default 0) + kv_cache_dtype: KV cache dtype (default None, use model dtype) + master_host: master host for multi-node setup (default None, use localhost) """ # cache config: use page_size as tokens_per_block so that FlexKV's # CPU radix tree manages blocks at page granularity, ensuring that @@ -250,6 +257,14 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.pp_rank = int(pp_rank) self.model_config.is_nsa_cp = is_nsa_cp self.model_config.cp_size = int(cp_size if cp_size is not None else 1) + # Topology: nnodes + node_rank (aligned with sglang server_args). + # ``gpus_per_node`` is no longer stored on model_config; KVTaskEngine + # derives it locally as (tp_size * pp_size) // nnodes. + self.model_config.nnodes = max(1, int(nnodes)) + self.model_config.node_rank = int(node_rank) + # Multi-node bootstrap: master host (derived from sglang --dist-init-addr). + # ``None`` here falls back to FLEXKV_MASTER_HOST env var downstream. + self.model_config.master_host = master_host update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) # Each PP rank needs its own IPC ports so that their diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index d096d02ad3..3a336b8fdf 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -9,7 +9,6 @@ import os from expiring_dict import ExpiringDict import nvtx -import torch import numpy as np from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV @@ -21,7 +20,7 @@ from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.transfer_manager import ( - get_master_host_and_ports_from_env, + resolve_master_host_and_ports, get_trtllm_subprocess_host_and_ports_from_env ) from flexkv.cache.redis_meta import RedisMeta @@ -107,27 +106,61 @@ def __init__(self, self.model_config = model_config self._check_config(model_config, cache_config) - self.is_multinode_tp = False - self.tp_node_count = 1 - if self.model_config.tp_size > torch.cuda.device_count(): - if self.model_config.tp_size != torch.cuda.device_count() * 2: - raise ValueError("Only support 2 nodes TP for now") - assert self.model_config.dp_size == 1 - self.tp_node_count = self.model_config.tp_size // torch.cuda.device_count() - self.is_multinode_tp = True + # ---- Multi-node topology ---- + nnodes = self.model_config.nnodes + pp_size = self.model_config.pp_size + tp_size = self.model_config.tp_size + + total_gpus = tp_size * pp_size + if total_gpus % nnodes != 0: + raise ValueError( + f"[KVTaskEngine] cannot derive gpus_per_node: " + f"tp*pp={total_gpus} not divisible by nnodes={nnodes}" + ) + gpus_per_node = total_gpus // nnodes + + self.nnodes_per_tp_group = max( + (tp_size + gpus_per_node - 1) // gpus_per_node, 1 + ) + if self.nnodes_per_tp_group > 2: + raise ValueError( + f"Only support 2-nodes TP for now, but got " + f"nnodes_per_tp_group={self.nnodes_per_tp_group} " + f"(tp_size={tp_size}, gpus_per_node={gpus_per_node})" + ) + + if tp_size % self.nnodes_per_tp_group != 0: + raise ValueError( + f"[KVTaskEngine] tp_size={tp_size} not divisible by " + f"nnodes_per_tp_group={self.nnodes_per_tp_group}" + ) + tp_size_per_node = tp_size // self.nnodes_per_tp_group + + flexkv_logger.info( + f"[KVTaskEngine] topology: " + f"nnodes={nnodes}, " + f"node_rank={self.model_config.node_rank}, " + f"gpus_per_node={gpus_per_node}, " + f"tp_size={tp_size}, " + f"pp_size={pp_size}, " + f"dp_size={self.model_config.dp_size}, " + f"nnodes_per_tp_group={self.nnodes_per_tp_group}, " + f"tp_size_per_node={tp_size_per_node}, " + f"master_host={self.model_config.master_host!r}" + ) self.cache_engine = GlobalCacheEngine(cache_config, model_config, redis_meta, event_collector) model_config_for_transfer = copy.deepcopy(self.model_config) - if self.is_multinode_tp: - model_config_for_transfer.tp_size //= self.tp_node_count + if self.nnodes_per_tp_group > 1: + model_config_for_transfer.tp_size //= self.nnodes_per_tp_group if not self.model_config.use_mla: - model_config_for_transfer.num_kv_heads //= self.tp_node_count + model_config_for_transfer.num_kv_heads //= self.nnodes_per_tp_group # When NSA CP is active, cp_size mirrors tp_size and must also # be divided so that TransferEngine's _eventfd_group_size matches # the number of local GPUs on each node. if model_config_for_transfer.is_nsa_cp and model_config_for_transfer.cp_size > 1: - model_config_for_transfer.cp_size //= self.tp_node_count + model_config_for_transfer.cp_size //= self.nnodes_per_tp_group combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1" if not combine_with_trtllm: @@ -155,8 +188,10 @@ def __init__(self, ] self.transfer_handles[0]._handle.send_config_to_remotes() - if self.is_multinode_tp: - master_host, master_ports = get_master_host_and_ports_from_env() + if self.nnodes_per_tp_group > 1: + master_host, master_ports = resolve_master_host_and_ports( + master_host=self.model_config.master_host + ) self.transfer_handles.append(TransferManagerHandle( model_config_for_transfer, self.cache_config, diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 7acf56127e..8ebe616c11 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -1,36 +1,59 @@ -import copy -import torch.multiprocessing as mp -import threading import time import os import socket import struct -from abc import ABC, abstractmethod -from dataclasses import dataclass -from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe +from torch.multiprocessing import Queue as MPQueue from multiprocessing.connection import Connection -from threading import Thread from typing import List, Any, Dict, Union, Optional, Tuple -import ctypes -import numpy as np -import nvtx import torch -from flexkv import c_ext - from flexkv.c_ext import LayerwiseTransferGroup from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType -from flexkv.common.transfer import get_nvtx_range_color -from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.common.config import ModelConfig, GLOBAL_CONFIG_FROM_ENV -from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp +from flexkv.transfer.worker_op import WorkerLayerwiseTransferOp from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister +def build_layerwise_eventfd_socket_path(model_config: ModelConfig) -> str: + """Construct the LayerwiseWorker's UDS socket path. + + Disambiguated by ``(pp_rank, dp_rank)`` so multiple PP stages and DP + replicas on the same host each get their own endpoint. + + We deliberately do NOT embed ``node_rank`` in the path: Unix domain + sockets are kernel-local, so two FlexKV instances on different + physical hosts cannot collide even when ``/tmp`` happens to be on a + shared filesystem (NFS and friends propagate the inode, not the + socket endpoint). Deployments that stack multiple containers on one + host with a shared ``/tmp`` should disambiguate via the + ``FLEXKV_LAYERWISE_EVENTFD_SOCKET`` env var (e.g. embed ``$HOSTNAME`` + or the container id in the base path). + + Must stay in sync with the sglang-side consumer at + ``sglang.srt.mem_cache.storage.flexkv.flexkv_connector``, which + imports this helper directly so the two ends cannot drift. Both + sides derive the path from the same ``ModelConfig`` fields, so no + env-var plumbing between processes is required. + """ + base = os.environ.get( + 'FLEXKV_LAYERWISE_EVENTFD_SOCKET', + '/tmp/flexkv_layerwise_eventfd.sock', + ) + suffix = "" + if model_config.pp_size > 1: + suffix += f"_pp{model_config.pp_rank}" + if model_config.dp_size > 1: + suffix += f"_dp{model_config.dp_rank}" + if not suffix: + return base + root, ext = os.path.splitext(base) + return f"{root}{suffix}{ext}" + + def _recv_fds(sock: socket.socket, num_fds: int) -> Tuple[List[int], bytes]: """Receive multiple fds + extra_data via Unix domain socket (SCM_RIGHTS).""" data_buf = bytearray(256) @@ -68,6 +91,7 @@ def __init__(self, pp_size: int, dp_size: int, dp_rank: int, + layerwise_eventfd_socket: str, num_blocks_per_file: int, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, @@ -108,6 +132,11 @@ def __init__(self, self.dp_group_id = dp_group_id self.dp_size = dp_size if dp_size > 0 else 1 self.dp_rank = dp_rank + # Pre-computed UDS socket path. Both ends (this worker and the + # sglang connector) derive the path from the same ModelConfig + # fields (pp_rank / dp_rank / node_rank / is_multinode_tp), so no + # env-var plumbing between processes is required. + self.layerwise_eventfd_socket = layerwise_eventfd_socket self.is_nsa_cp = is_nsa_cp # initialize GPU storage @@ -314,23 +343,7 @@ def _receive_eventfds_from_sglang(self, tp_group_size: int, max_retries: int = 180, retry_interval: float = 1.0) -> torch.Tensor: """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" - base_socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock') - sock_suffix = "" - if int(self.pp_size) > 1: - sock_suffix += f"_pp{int(self.pp_rank)}" - if int(self.dp_size) > 1: - sock_suffix += f"_dp{int(self.dp_rank)}" - # Multi-node TP: add _node{id} suffix to match sglang connector's - # socket path. The sglang connector sets FLEXKV_NODE_ID when it - # detects cross-node TP; the subprocess inherits this env var. - _node_id_str = os.environ.get("FLEXKV_NODE_ID") - if _node_id_str is not None: - sock_suffix += f"_node{_node_id_str}" - if sock_suffix: - root, ext = os.path.splitext(base_socket_path) - socket_path = f"{root}{sock_suffix}{ext}" - else: - socket_path = base_socket_path + socket_path = self.layerwise_eventfd_socket rank_parts = [] if int(self.tp_group_size) > 1: diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 55f291e178..371d24db26 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -18,18 +18,17 @@ import multiprocessing as mp import selectors import os -from queue import Queue -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import contextlib import nvtx import numpy as np +import torch from flexkv.common.debug import flexkv_logger from flexkv.common.storage import StorageHandle -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp, LayerwiseTransferOp +from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp from flexkv.common.transfer import get_nvtx_range_color -from flexkv.common.storage import KVCacheLayoutType from flexkv.transfer.scheduler import TransferScheduler from flexkv.transfer.worker import ( WorkerHandle, @@ -41,7 +40,10 @@ tpGDSTransferWorker, PEER2CPUTransferWorker, ) -from flexkv.transfer.layerwise import LayerwiseTransferWorker +from flexkv.transfer.layerwise import ( + LayerwiseTransferWorker, + build_layerwise_eventfd_socket_path, +) from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.ring_buffer import SharedOpPool @@ -193,8 +195,8 @@ def _init_workers(self) -> None: dtype=gpu_handles[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, - is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), - cp_size=getattr(self.model_config, "cp_size", 1), + is_nsa_cp=self.model_config.is_nsa_cp, + cp_size=self.model_config.cp_size, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, @@ -237,8 +239,8 @@ def _init_workers(self) -> None: dtype=gpu_handles[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, - is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), - cp_size=getattr(self.model_config, "cp_size", 1), + is_nsa_cp=self.model_config.is_nsa_cp, + cp_size=self.model_config.cp_size, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, @@ -345,11 +347,15 @@ def _init_workers(self) -> None: ssd_files = {} if self._ssd_handle is None else self._ssd_handle.get_file_list() ssd_kv_layout = None if self._ssd_handle is None else self._ssd_handle.kv_layout num_blocks_per_file = 0 if self._ssd_handle is None else self._ssd_handle.num_blocks_per_file - _is_nsa_cp = getattr(self.model_config, 'is_nsa_cp', False) - _cp_size = getattr(self.model_config, 'cp_size', 1) + _is_nsa_cp = self.model_config.is_nsa_cp + _cp_size = self.model_config.cp_size # For CP, each CP rank connects via eventfd; for TP, each TP rank connects. _eventfd_group_size = _cp_size if _is_nsa_cp and _cp_size > 1 else self.tp_size + _layerwise_eventfd_socket = build_layerwise_eventfd_socket_path( + self.model_config + ) + # Prepare indexer handles for fused layerwise transfer has_indexer_for_layerwise = ( self._indexer_gpu_handles is not None and @@ -380,6 +386,7 @@ def _init_workers(self) -> None: pp_size=self.model_config.pp_size, dp_size=self.model_config.dp_size, dp_rank=dp_client_id, + layerwise_eventfd_socket=_layerwise_eventfd_socket, num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, @@ -470,8 +477,8 @@ def _init_workers(self) -> None: dtype=indexer_gpu_handles_list[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, - is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), - cp_size=getattr(self.model_config, "cp_size", 1), + is_nsa_cp=self.model_config.is_nsa_cp, + cp_size=self.model_config.cp_size, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, @@ -514,8 +521,8 @@ def _init_workers(self) -> None: dtype=indexer_gpu_handles_list[0].dtype, tp_group_size=self.tp_size, dp_group_id=dp_client_id, - is_nsa_cp=getattr(self.model_config, "is_nsa_cp", False), - cp_size=getattr(self.model_config, "cp_size", 1), + is_nsa_cp=self.model_config.is_nsa_cp, + cp_size=self.model_config.cp_size, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 8d63baffd1..5b1a55ee10 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -228,10 +228,30 @@ def shutdown(self) -> None: if hasattr(self, 'transfer_engine'): self.transfer_engine.shutdown() -def get_master_host_and_ports_from_env() -> Tuple[str, Tuple[str, str, str]]: - master_host = os.getenv("FLEXKV_MASTER_HOST", "localhost") +def resolve_master_host_and_ports( + master_host: Optional[str] = None, +) -> Tuple[str, Tuple[str, str, str]]: + """Resolve the (master_host, master_ports) tuple for multi-node transfer. + + ``master_host`` resolution order: + 1. explicit ``master_host`` argument (when provided by the caller, + e.g. via sglang ``--dist-init-addr``); + 2. ``FLEXKV_MASTER_HOST`` env var (used by framework-agnostic + launchers such as TRT-LLM's ``multi_node_launch.sh``); + 3. ``"localhost"`` default. + + ``master_ports`` always comes from ``FLEXKV_MASTER_PORTS`` (or default), + because changing ports rarely warrants a host-aware plumbing change. + """ + if master_host is None: + master_host = os.getenv("FLEXKV_MASTER_HOST", "localhost") master_ports = os.getenv("FLEXKV_MASTER_PORTS", "5556,5557,5558") master_ports = tuple(master_ports.split(",")) + flexkv_logger.info( + f"[TransferManager] resolved master endpoint: " + f"host={master_host!r} (source={'arg' if master_host is not None else 'env/default'}), " + f"ports={master_ports}" + ) return "tcp://" + master_host, master_ports def get_trtllm_subprocess_host_and_ports_from_env() -> Tuple[str, Tuple[str, str, str]]: @@ -244,9 +264,11 @@ class TransferManagerOnRemote(TransferManager): """ TransferManager for remote mode, used for multi-node tensor parallelism. """ - def __init__(self, mode: str = "Default"): + def __init__(self, mode: str = "Default", master_host: Optional[str] = None): if mode == "Default": - self.master_host, self.master_ports = get_master_host_and_ports_from_env() + self.master_host, self.master_ports = resolve_master_host_and_ports( + master_host=master_host + ) elif mode == "TrtllmSubprocess": self.master_host, self.master_ports = get_trtllm_subprocess_host_and_ports_from_env() else: @@ -842,20 +864,20 @@ def _bind_master_ports(self) -> None: try: command_addr = f"{self.master_host}:{self.master_ports[0]}" self.command_socket.bind(command_addr) - flexkv_logger.debug(f"Master bound command port at {command_addr}") + flexkv_logger.info(f"Master bound command port at {command_addr}") result_addr = f"{self.master_host}:{self.master_ports[1]}" self.result_socket.bind(result_addr) - flexkv_logger.debug(f"Master bound result port at {result_addr}") + flexkv_logger.info(f"Master bound result port at {result_addr}") query_addr = f"{self.master_host}:{self.master_ports[2]}" self.query_socket.bind(query_addr) - flexkv_logger.debug(f"Master bound query port at {query_addr}") + flexkv_logger.info(f"Master bound query port at {query_addr}") self.result_socket.setsockopt(zmq.RCVTIMEO, 0) self._connected = True - flexkv_logger.debug("Master transfer manager ready for remote connections") + flexkv_logger.info("Master transfer manager ready for remote connections") except Exception as e: flexkv_logger.error(f"Master failed to bind ports: {e}") From 50e702e89892af4d842336fdde64b2ed438f94d0 Mon Sep 17 00:00:00 2001 From: Stary Date: Thu, 30 Apr 2026 10:24:18 +0800 Subject: [PATCH 56/59] feat(hugepage): add HugePage support (#158) - Add configuration entries for HugePage allocation - Implement HugePageAllocator and allocate CPU KV cache on hugetlbfs - Support HugePage for temporary buffer in PEER2CPUTransferWorker - Add tests and documentation for HugePage feature Signed-off-by: staryxchen --- docs/flexkv_config_reference/README_zh.md | 3 + docs/hugepage/README_en.md | 299 +++++++++++ docs/hugepage/README_zh.md | 299 +++++++++++ flexkv/__init__.py | 9 + flexkv/common/config.py | 20 + flexkv/common/storage.py | 6 + flexkv/storage/allocator.py | 465 +++++++++++++++++- flexkv/storage/storage_engine.py | 28 +- flexkv/transfer/host_buffer.py | 142 ++++++ flexkv/transfer/layerwise.py | 7 +- flexkv/transfer/transfer_engine.py | 40 +- flexkv/transfer/worker.py | 64 +-- tests/hugepage/conftest.py | 35 ++ tests/hugepage/test_hugepage_transfer_e2e.py | 185 +++++++ tests/hugepage/test_hugepage_unit.py | 218 ++++++++ .../test_hugepage_worker_integration.py | 142 ++++++ tests/test_config_hugepage.py | 44 ++ 17 files changed, 1948 insertions(+), 58 deletions(-) create mode 100644 docs/hugepage/README_en.md create mode 100644 docs/hugepage/README_zh.md create mode 100644 flexkv/transfer/host_buffer.py create mode 100644 tests/hugepage/conftest.py create mode 100644 tests/hugepage/test_hugepage_transfer_e2e.py create mode 100644 tests/hugepage/test_hugepage_unit.py create mode 100644 tests/hugepage/test_hugepage_worker_integration.py create mode 100644 tests/test_config_hugepage.py diff --git a/docs/flexkv_config_reference/README_zh.md b/docs/flexkv_config_reference/README_zh.md index 562f1d8a38..e24f055784 100644 --- a/docs/flexkv_config_reference/README_zh.md +++ b/docs/flexkv_config_reference/README_zh.md @@ -47,6 +47,9 @@ enable_gds: false | `FLEXKV_SSD_CACHE_GB` | int | 0 | SSD 缓存层容量,单位为 GB。建议设置大于 `FLEXKV_CPU_CACHE_GB`并为`FLEXKV_MAX_FILE_SIZE_GB`的整数倍,若仅用CPU缓存则设为 0(此时不启用 SSD 缓存) | | `FLEXKV_SSD_CACHE_DIR` | str | "./flexkv_ssd" | SSD 缓存数据的存放目录。若有多块 SSD,可通过分号 `;` 分隔多个挂载路径。例如 `"/data0/flexkv_ssd/;/data1/flexkv_ssd/"`,以提升带宽 | | `FLEXKV_ENABLE_GDS` | bool | 0 | 是否启用 GPU Direct Storage(GDS)。如硬件和驱动支持,开启后可提升 SSD 到 GPU 的数据吞吐能力。默认关闭,开启请设为 1 | +| `FLEXKV_USE_HUGEPAGE_CPU_BUFFER` | bool | 0 | 是否为通用 CPU KV cache 启用 HugePage。默认关闭,开启请设为 1 | +| `FLEXKV_USE_HUGEPAGE_TMP_BUFFER` | bool | 0 | 是否为 `enable_p2p_ssd` 场景下的 tmp CPU staging buffer 启用 HugePage。默认关闭,开启请设为 1 | +| `FLEXKV_HUGEPAGE_SIZE_BYTES` | int | 2097152 | HugePage 大小,默认 2 MiB。如果宿主机准备的是 1 GiB HugePage,可设为 `1073741824` | --- diff --git a/docs/hugepage/README_en.md b/docs/hugepage/README_en.md new file mode 100644 index 0000000000..f097b1d98b --- /dev/null +++ b/docs/hugepage/README_en.md @@ -0,0 +1,299 @@ +# FlexKV HugePage User Guide + +## 1. Overview + +FlexKV currently exposes three HugePage-related configuration fields: + +- `use_hugepage_cpu_buffer` + Controls whether the main CPU KV cache should be allocated from HugePages. +- `use_hugepage_tmp_buffer` + Controls whether the temporary CPU staging buffer in the `enable_p2p_ssd=true` path should be allocated from HugePages. +- `hugepage_size_bytes` + Controls the HugePage size used by both allocation paths. + +The two HugePage switches are independent. They can be enabled separately or together. + +There is one important implementation constraint today: + +- `use_hugepage_cpu_buffer` for the main CPU KV cache must use a `hugetlbfs`-backed file so the same HugePage mapping can be reopened inside `spawn` workers. +- Anonymous `MAP_HUGETLB` is not sufficient for the main CPU cache path because once that tensor is sent into `TransferEngine` workers, PyTorch serializes ordinary CPU tensors into new shared-memory storage, which breaks the original HugePage backing. +- `use_hugepage_tmp_buffer` does not have that cross-process sharing requirement, so it may still succeed through either anonymous HugePages or `hugetlbfs`. + +--- + +## 2. Recommended Use Cases + +HugePage is recommended in the following situations: + +- The CPU KV cache is large and CPU-side page table or TLB overhead is non-trivial. +- `enable_p2p_ssd=true` is enabled and you want to optimize the temporary staging buffer in the `SSD -> CPU -> GPU` path. +- The host has already been prepared with reserved HugePages and a working hugetlbfs mount. + +HugePage is not recommended when the host has no HugePage reservation or when the target workload does not materially benefit from CPU cache or p2p SSD path optimization. + +--- + +## 3. Prerequisites + +### 3.1 HugePages Must Be Reserved on the Host + +Check the current HugePage status: + +```bash +grep -E 'HugePages_|Hugepagesize' /proc/meminfo +``` + +For 2 MiB HugePages, the following command reserves 4096 pages, which is about 8 GiB: + +```bash +sudo sysctl -w vm.nr_hugepages=4096 +``` + +If you plan to use 1 GiB HugePages, they usually need to be reserved through kernel boot parameters, for example: + +```text +default_hugepagesz=1G hugepagesz=1G hugepages=N +``` + +### 3.2 hugetlbfs Must Be Mounted + +FlexKV uses `/mnt/hugepages` as the default hugetlbfs mount point. + +For `use_hugepage_cpu_buffer`, this is a hard requirement rather than a recommendation. If `FLEXKV_HUGETLBFS_DIR` points to a normal filesystem, FlexKV now rejects it and falls back instead of silently treating regular 4 KiB pages as HugePages. + +Check the mount status: + +```bash +mount | grep hugetlbfs +ls -ld /mnt/hugepages +``` + +If hugetlbfs is not mounted yet: + +```bash +sudo mkdir -p /mnt/hugepages +sudo mount -t hugetlbfs none /mnt/hugepages +``` + +If your actual hugetlbfs mount point is different, set it explicitly: + +```bash +export FLEXKV_HUGETLBFS_DIR=/path/to/hugetlbfs +``` + +### 3.3 CUDA Runtime Is Required for the tmp Buffer Path + +The `use_hugepage_tmp_buffer` path performs `cudaHostRegister` after HugePage allocation succeeds. This path therefore requires: + +- A working CUDA runtime +- `libcudart.so` to be discoverable +- A sufficiently large `memlock` limit on the host or in the container + +Basic check: + +```bash +python3 - <<'PY' +import torch +print(torch.cuda.is_available()) +PY +``` + +Note: `use_hugepage_cpu_buffer` does not depend on `cudaHostRegister`. + +Additional note: although `use_hugepage_cpu_buffer` does not require CUDA runtime, it does require a writable `hugetlbfs` mount because the main CPU KV cache must be reopened from the same HugePage-backed file inside spawned workers. + +--- + +## 4. Configuration + +HugePage is now a formal user-facing configuration surface in FlexKV. It can be configured through either configuration files or environment variables. + +### 4.1 Configuration File + +YAML example: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +JSON example: + +```json +{ + "cpu_cache_gb": 32, + "ssd_cache_gb": 1024, + "ssd_cache_dir": "/data/flexkv_ssd/", + "enable_p2p_ssd": true, + "use_hugepage_cpu_buffer": true, + "use_hugepage_tmp_buffer": true, + "hugepage_size_bytes": 2097152 +} +``` + +### 4.2 Environment Variables + +```bash +export FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1 +export FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1 +export FLEXKV_HUGEPAGE_SIZE_BYTES=2097152 +``` + +Meaning: + +- `FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1` + Enables HugePage allocation for the main CPU KV cache. +- `FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1` + Enables HugePage allocation for the temporary CPU staging buffer in the p2p SSD path. +- `FLEXKV_HUGEPAGE_SIZE_BYTES=2097152` + Uses 2 MiB HugePages. + +If the host is configured for 1 GiB HugePages: + +```bash +export FLEXKV_HUGEPAGE_SIZE_BYTES=1073741824 +``` + +### 4.3 How to Choose Between the Two Switches + +- If you only want to optimize the main CPU KV cache, enable `use_hugepage_cpu_buffer`. +- If you only want to optimize the temporary staging buffer in the p2p SSD path, enable `use_hugepage_tmp_buffer` and make sure `enable_p2p_ssd=true` is set. +- If both paths matter, enable both switches. + +--- + +## 5. Recommended Enablement Order + +For a first rollout, the recommended order is: + +1. Prepare 2 MiB HugePages on the host and confirm that hugetlbfs is mounted correctly. +2. Enable `use_hugepage_cpu_buffer=true` first and verify that the main CPU KV cache works correctly. +3. If you also need to validate the p2p SSD path, then enable `use_hugepage_tmp_buffer=true`. +4. Only after the feature is stable should you evaluate switching to 1 GiB HugePages. + +For initial validation, 2 MiB HugePages are recommended because host setup is simpler and troubleshooting is more straightforward. + +--- + +## 6. How to Verify It Is Working + +### 6.1 Check the Logs + +If the tmp staging buffer successfully uses HugePages, logs will contain a message similar to: + +```text +[PEER2CPUTransferWorker] tmp_cpu_buffer allocated on HugePages: 2.000 GB +``` + +If the main CPU KV cache successfully uses HugePages, you will typically also see a log similar to: + +```text +HugePage allocate total_size: ... GB (page_size=2MiB) +``` + +If the HugePage path for the tmp staging buffer fails and falls back, logs will contain a message similar to: + +```text +[PEER2CPUTransferWorker] HugePage allocation for tmp_cpu_buffer failed (...); falling back to torch.empty(pin_memory=True). +``` + +If `use_hugepage_cpu_buffer=true` but the hugetlbfs mount is invalid, logs will contain a message similar to: + +```text +HugePage allocation failed (HugePage: /path is not a hugetlbfs mount ...); falling back to regular CPU memory. +``` + +### 6.2 Check HugePage Counters + +Before and after the service starts, run: + +```bash +grep -E 'HugePages_Total|HugePages_Free|Hugepagesize' /proc/meminfo +``` + +If HugePage allocation is active, you will typically observe: + +- `HugePages_Total` unchanged +- `HugePages_Free` decreased + +After the service exits and releases resources, `HugePages_Free` should return close to its original value. + +### 6.3 Run the Test Suite + +If the machine already satisfies the HugePage and CUDA requirements, run: + +```bash +PYTHONDONTWRITEBYTECODE=1 python3 -m pytest -q tests/hugepage -rs +``` + +This test suite validates: + +- HugePage allocation and release +- HugePage configuration flow for the CPU KV cache +- HugePage configuration flow for the tmp staging buffer +- Fallback behavior when HugePage allocation cannot be used + +--- + +## 7. Common Configuration Errors + +### 7.1 `use_hugepage_tmp_buffer` Is Enabled but Does Not Take Effect + +Check the following items in order: + +- `enable_p2p_ssd=true` is actually enabled +- The host has enough HugePages reserved +- hugetlbfs is mounted +- `FLEXKV_HUGETLBFS_DIR` points to the correct mount +- CUDA runtime is available +- `memlock` is not too small + +### 7.2 HugePage Is Enabled but There Is No Error and No Performance Gain + +This usually means the HugePage path has already fallen back to regular memory. + +FlexKV treats HugePage as a best-effort optimization with automatic fallback. Service startup success alone is not sufficient evidence that HugePage is active. You must confirm through logs and `/proc/meminfo`. + +Also distinguish the two common cases: + +- For `use_hugepage_cpu_buffer`, a writable `hugetlbfs` mount is mandatory even if the host has reserved HugePages. +- For `use_hugepage_tmp_buffer`, anonymous HugePages may still work even without `hugetlbfs`. + +### 7.3 1 GiB HugePages Do Not Work After Configuration + +The most common reason is that the host does not actually have a 1 GiB HugePage pool available. Confirm the following: + +- Kernel boot parameters are set correctly +- The host has a real 1 GiB HugePage reservation +- `hugepage_size_bytes` matches the HugePage type actually available on the machine + +For initial rollout, it is better to validate functionality with 2 MiB HugePages first and move to 1 GiB only afterward. + +--- + +## 8. Minimal Working Examples + +If you want to validate both the main CPU KV cache and the p2p SSD tmp buffer HugePage paths, the following is a minimal example: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +If you only want to validate the main CPU KV cache path: + +```yaml +cpu_cache_gb: 32 +use_hugepage_cpu_buffer: true +hugepage_size_bytes: 2097152 +``` diff --git a/docs/hugepage/README_zh.md b/docs/hugepage/README_zh.md new file mode 100644 index 0000000000..b4799f7898 --- /dev/null +++ b/docs/hugepage/README_zh.md @@ -0,0 +1,299 @@ +# FlexKV HugePage 使用指南 + +## 一、功能概述 + +FlexKV 当前支持两类 HugePage 配置项: + +- `use_hugepage_cpu_buffer` + 控制通用 CPU KV Cache 是否优先使用 HugePage 分配。 +- `use_hugepage_tmp_buffer` + 控制 `enable_p2p_ssd=true` 场景下临时 CPU staging buffer 是否优先使用 HugePage 分配。 +- `hugepage_size_bytes` + 控制上述两类内存申请时使用的 HugePage 大小。 + +两类开关可以独立启用,也可以同时启用。 + +当前实现上有一个重要限制: + +- `use_hugepage_cpu_buffer` 对主 CPU KV cache 的生效路径,必须依赖 `hugetlbfs` 挂载文件来保证在 `spawn` worker 场景下仍然保持 HugePage backing。 +- 纯匿名 `MAP_HUGETLB` 只用于单进程或不可共享场景;一旦主 CPU cache 被传给 `TransferEngine` 的子进程,PyTorch 会把普通 CPU tensor 序列化成新的 shared-memory storage,匿名映射无法继续作为跨进程共享后端,因此不能满足主 CPU cache 的目标语义。 +- `use_hugepage_tmp_buffer` 仍然可以走匿名 HugePage 或 hugetlbfs,两者都不会经过上述主 cache 的跨进程共享问题。 + +--- + +## 二、适用场景 + +建议在以下场景启用 HugePage: + +- CPU KV Cache 容量较大,希望降低页表和 TLB 开销。 +- 已启用 `enable_p2p_ssd=true`,并希望优化 `SSD -> CPU -> GPU` 数据路径中的临时 staging buffer。 +- 已完成宿主机 HugePage 预留和 hugetlbfs 挂载,具备稳定的系统运行条件。 + +如果机器没有预留 HugePage,或者当前并不依赖 CPU KV Cache / p2p SSD 路径上的性能收益,不建议启用。 + +--- + +## 三、前置条件 + +### 3.1 宿主机已预留 HugePage + +先检查系统状态: + +```bash +grep -E 'HugePages_|Hugepagesize' /proc/meminfo +``` + +以 2 MiB HugePage 为例,预留 4096 个页,即约 8 GiB: + +```bash +sudo sysctl -w vm.nr_hugepages=4096 +``` + +如果使用 1 GiB HugePage,通常需要在内核启动参数中预留,例如: + +```text +default_hugepagesz=1G hugepagesz=1G hugepages=N +``` + +### 3.2 宿主机已挂载 hugetlbfs + +FlexKV 默认使用 `/mnt/hugepages` 作为 hugetlbfs 挂载点。 + +说明:对于 `use_hugepage_cpu_buffer`,这一步不是“建议”,而是必须条件。如果 `FLEXKV_HUGETLBFS_DIR` 指向普通文件系统,FlexKV 现在会直接判定失败并回退,不再把普通 4 KiB 页误判成 HugePage 成功。 + +检查挂载状态: + +```bash +mount | grep hugetlbfs +ls -ld /mnt/hugepages +``` + +如果尚未挂载,可执行: + +```bash +sudo mkdir -p /mnt/hugepages +sudo mount -t hugetlbfs none /mnt/hugepages +``` + +如果实际挂载点不是 `/mnt/hugepages`,需要显式设置: + +```bash +export FLEXKV_HUGETLBFS_DIR=/path/to/hugetlbfs +``` + +### 3.3 tmp buffer 场景需要 CUDA 运行时 + +`use_hugepage_tmp_buffer` 对应的 staging buffer 在 HugePage 分配成功后还会执行 `cudaHostRegister`。因此这一路径要求: + +- CUDA runtime 可用 +- `libcudart.so` 可正常加载 +- 容器或宿主机的 `memlock` 限制不要过小 + +可先做基础检查: + +```bash +python3 - <<'PY' +import torch +print(torch.cuda.is_available()) +PY +``` + +说明:`use_hugepage_cpu_buffer` 不依赖 `cudaHostRegister`。 + +补充说明:`use_hugepage_cpu_buffer` 虽然不依赖 CUDA runtime,但依赖可写的 hugetlbfs 挂载点,因为主 CPU KV cache 需要通过该文件在 `spawn` worker 间重新打开同一块 HugePage-backed 映射。 + +--- + +## 四、配置方式 + +FlexKV 已将 HugePage 作为正式用户配置项,支持配置文件和环境变量两种方式。 + +### 4.1 配置文件 + +YAML 示例: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +JSON 示例: + +```json +{ + "cpu_cache_gb": 32, + "ssd_cache_gb": 1024, + "ssd_cache_dir": "/data/flexkv_ssd/", + "enable_p2p_ssd": true, + "use_hugepage_cpu_buffer": true, + "use_hugepage_tmp_buffer": true, + "hugepage_size_bytes": 2097152 +} +``` + +### 4.2 环境变量 + +```bash +export FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1 +export FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1 +export FLEXKV_HUGEPAGE_SIZE_BYTES=2097152 +``` + +说明: + +- `FLEXKV_USE_HUGEPAGE_CPU_BUFFER=1` + 为通用 CPU KV Cache 启用 HugePage。 +- `FLEXKV_USE_HUGEPAGE_TMP_BUFFER=1` + 为 p2p SSD 场景下的临时 CPU staging buffer 启用 HugePage。 +- `FLEXKV_HUGEPAGE_SIZE_BYTES=2097152` + 表示使用 2 MiB HugePage。 + +如果宿主机准备的是 1 GiB HugePage,可设置为: + +```bash +export FLEXKV_HUGEPAGE_SIZE_BYTES=1073741824 +``` + +### 4.3 两个开关的选择原则 + +- 只需要优化通用 CPU KV Cache:开启 `use_hugepage_cpu_buffer`。 +- 只需要优化 p2p SSD 的临时 staging buffer:开启 `use_hugepage_tmp_buffer`,同时确保 `enable_p2p_ssd=true`。 +- 两条路径都需要:两个开关同时开启。 + +--- + +## 五、推荐启用顺序 + +首次接入建议按以下顺序进行: + +1. 先在宿主机准备 2 MiB HugePage,并确认 hugetlbfs 挂载正常。 +2. 先只启用 `use_hugepage_cpu_buffer=true`,验证通用 CPU KV Cache 可正常工作。 +3. 如果还需要验证 p2p SSD 路径,再启用 `use_hugepage_tmp_buffer=true`。 +4. 在确认功能稳定后,再根据机器环境评估是否切换到 1 GiB HugePage。 + +推荐第一轮验证优先使用 2 MiB HugePage。它的系统准备成本更低,排障也更直接。 + +--- + +## 六、如何确认已经生效 + +### 6.1 检查日志 + +如果 tmp staging buffer 成功使用 HugePage,日志会出现类似信息: + +```text +[PEER2CPUTransferWorker] tmp_cpu_buffer allocated on HugePages: 2.000 GB +``` + +如果主 CPU KV cache 成功使用 HugePage,通常会先看到类似日志: + +```text +HugePage allocate total_size: ... GB (page_size=2MiB) +``` + +如果 tmp staging buffer 的 HugePage 路径失败并回退,日志会出现类似信息: + +```text +[PEER2CPUTransferWorker] HugePage allocation for tmp_cpu_buffer failed (...); falling back to torch.empty(pin_memory=True). +``` + +如果 `use_hugepage_cpu_buffer=true` 但 hugetlbfs 挂载不正确,日志会出现类似信息: + +```text +HugePage allocation failed (HugePage: /path is not a hugetlbfs mount ...); falling back to regular CPU memory. +``` + +### 6.2 检查 HugePage 计数 + +在服务启动前后分别执行: + +```bash +grep -E 'HugePages_Total|HugePages_Free|Hugepagesize' /proc/meminfo +``` + +如果 HugePage 分配生效,通常可以观察到: + +- `HugePages_Total` 不变 +- `HugePages_Free` 下降 + +服务退出并释放资源后,`HugePages_Free` 应恢复到接近启动前的水平。 + +### 6.3 运行测试 + +如果机器已具备 HugePage 和 CUDA 条件,可执行: + +```bash +PYTHONDONTWRITEBYTECODE=1 python3 -m pytest -q tests/hugepage -rs +``` + +该测试集可用于验证: + +- HugePage 分配与释放 +- CPU KV Cache 的 HugePage 配置路径 +- tmp staging buffer 的 HugePage 配置路径 +- HugePage 失败后的回退行为 + +--- + +## 七、常见配置错误 + +### 7.1 开启了 `use_hugepage_tmp_buffer`,但实际上没有生效 + +请依次检查: + +- 是否同时设置了 `enable_p2p_ssd=true` +- 宿主机是否预留了足够的 HugePage +- hugetlbfs 是否已挂载 +- `FLEXKV_HUGETLBFS_DIR` 是否指向正确挂载点 +- CUDA runtime 是否可用 +- `memlock` 限制是否过小 + +### 7.2 开启了 HugePage,但服务没有报错也没有性能收益 + +这通常意味着 HugePage 路径已经回退到普通内存分配。 + +FlexKV 对 HugePage 采用的是“失败自动回退”策略,因此不能仅以服务是否启动成功来判断功能是否生效,必须结合日志和 `/proc/meminfo` 一起确认。 + +另外需要区分两类原因: + +- `use_hugepage_cpu_buffer` 场景下,如果没有可写 hugetlbfs 挂载点,即使系统里预留了 HugePage,也不会被视为可用配置。 +- `use_hugepage_tmp_buffer` 场景下,如果匿名 HugePage 成功,可能不依赖 hugetlbfs 挂载。 + +### 7.3 使用 1 GiB HugePage 后启动失败或无法生效 + +最常见原因是宿主机并未真正准备 1 GiB HugePage 池。请确认: + +- 内核启动参数已正确设置 +- 宿主机已实际预留 1 GiB HugePage +- `hugepage_size_bytes` 与系统中实际可用的 HugePage 类型一致 + +如果是首次接入,建议先回到 2 MiB HugePage 完成功能验证,再切换到 1 GiB。 + +--- + +## 八、最小可用配置示例 + +如果你的目标是同时验证通用 CPU KV Cache 和 p2p SSD tmp buffer 的 HugePage 功能,可以使用以下最小配置: + +```yaml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data/flexkv_ssd/ +enable_p2p_ssd: true +use_hugepage_cpu_buffer: true +use_hugepage_tmp_buffer: true +hugepage_size_bytes: 2097152 +``` + +如果你当前只想验证通用 CPU KV Cache,则可以只保留: + +```yaml +cpu_cache_gb: 32 +use_hugepage_cpu_buffer: true +hugepage_size_bytes: 2097152 +``` diff --git a/flexkv/__init__.py b/flexkv/__init__.py index c8ece6cd74..97a6943e67 100644 --- a/flexkv/__init__.py +++ b/flexkv/__init__.py @@ -42,3 +42,12 @@ def _setup_library_path() -> None: # Call the setup function when the package is imported _setup_library_path() + +# ``flexkv.c_ext`` is a PyTorch C++ extension and dynamically links against +# ``libc10.so`` / ``libtorch*.so`` from the installed ``torch`` package. Those +# libraries live under ``/torch/lib`` and are NOT on the system +# linker search path. Importing ``torch`` here causes Python to ``dlopen`` them +# (with RTLD_GLOBAL), so any subsequent ``import flexkv.c_ext`` can resolve +# them without requiring the caller to ``import torch`` first or to set +# ``LD_LIBRARY_PATH``. +import torch # noqa: E402,F401 (side-effect import: load libtorch/libc10) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 383653204f..4adc83153f 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -80,6 +80,17 @@ class CacheConfig: distributed_node_id: int = -1 # only used when distributed cpu/ssd and only can be set when redis_meta_client initialized num_tmp_cpu_blocks: int = 500 # only used when distributed ssd p2p, it controls the number blocks of temp cpu buffer which used for copy data from ssd to cpu + # When True, the main CPU KV cache is allocated from Linux HugePages via + # ``mmap(MAP_HUGETLB)`` instead of regular CPU memory. Requires pre-reserved + # huge pages on the host (see ``/proc/sys/vm/nr_hugepages``). Falls back + # silently if allocation fails. + use_hugepage_cpu_buffer: bool = False + # When True, the temporary SSD->CPU staging buffer (used by PEER2CPUTransferWorker + # under enable_p2p_ssd) is allocated from Linux HugePages via ``mmap(MAP_HUGETLB)`` + # instead of a pinned ``torch.empty``. Requires pre-reserved huge pages on the host + # (see ``/proc/sys/vm/nr_hugepages``). Falls back silently if allocation fails. + use_hugepage_tmp_buffer: bool = False + hugepage_size_bytes: int = 2 * 1024 * 1024 # 2 MiB by default; set to 1<<30 for 1GiB # Indexer configuration @@ -186,6 +197,9 @@ class UserConfig: ssd_cache_gb: int = 0 # 0 means disable ssd ssd_cache_dir: Union[str, List[str]] = "./ssd_cache" enable_gds: bool = False + use_hugepage_cpu_buffer: bool = False + use_hugepage_tmp_buffer: bool = False + hugepage_size_bytes: int = 2 * 1024 * 1024 enable_p2p_cpu: bool = False enable_p2p_ssd: bool = False enable_3rd_remote: bool = False @@ -249,6 +263,9 @@ def load_user_config_from_env() -> UserConfig: ssd_cache_gb=int(os.getenv('FLEXKV_SSD_CACHE_GB', 0)), ssd_cache_dir=parse_path_list(os.getenv('FLEXKV_SSD_CACHE_DIR', "./flexkv_ssd")), enable_gds=bool(int(os.getenv('FLEXKV_ENABLE_GDS', 0))), + use_hugepage_cpu_buffer=bool(int(os.getenv('FLEXKV_USE_HUGEPAGE_CPU_BUFFER', 0))), + use_hugepage_tmp_buffer=bool(int(os.getenv('FLEXKV_USE_HUGEPAGE_TMP_BUFFER', 0))), + hugepage_size_bytes=int(os.getenv('FLEXKV_HUGEPAGE_SIZE_BYTES', 2 * 1024 * 1024)), kv_cache_dtype=os.getenv('FLEXKV_KV_CACHE_DTYPE', None), ) @@ -269,6 +286,9 @@ def update_default_config_from_user_config(model_config: ModelConfig, cache_config.ssd_cache_dir = user_config.ssd_cache_dir cache_config.enable_ssd = user_config.ssd_cache_gb > 0 cache_config.enable_gds = user_config.enable_gds + cache_config.use_hugepage_cpu_buffer = user_config.use_hugepage_cpu_buffer + cache_config.use_hugepage_tmp_buffer = user_config.use_hugepage_tmp_buffer + cache_config.hugepage_size_bytes = user_config.hugepage_size_bytes cache_config.enable_p2p_cpu = user_config.enable_p2p_cpu cache_config.enable_p2p_ssd = user_config.enable_p2p_ssd cache_config.enable_3rd_remote = user_config.enable_3rd_remote diff --git a/flexkv/common/storage.py b/flexkv/common/storage.py index ffd1d2cd46..bc89cb2aed 100644 --- a/flexkv/common/storage.py +++ b/flexkv/common/storage.py @@ -172,6 +172,7 @@ class StorageHandle: num_blocks_per_file: Optional[int] = None gpu_device_id: Optional[int] = None remote_config_custom: Optional[Dict[str, Any]] = None + worker_data: Optional[Any] = None def get_tensor_list(self) -> List[torch.Tensor]: assert isinstance(self.data, list) and \ @@ -195,6 +196,11 @@ def get_tensor(self) -> torch.Tensor: else: raise ValueError(f"Invalid handle type: {self.handle_type}, expected TENSOR") + def get_worker_tensor(self) -> Any: + if self.worker_data is not None: + return self.worker_data + return self.get_tensor() + def get_file_list(self) -> Union[List[str], Dict[int, List[str]]]: if self.handle_type == AccessHandleType.FILE: return self.data # type: ignore diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index 507e08e76d..8e60841541 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -1,6 +1,10 @@ +import ctypes +import mmap import os +import weakref +from dataclasses import dataclass from abc import ABC, abstractmethod -from typing import Tuple, Optional, List, Union, Dict, Any, BinaryIO +from typing import Tuple, List, Union, Dict, Any, BinaryIO try: from flexkv.c_ext import Pcfs except ImportError: @@ -128,6 +132,463 @@ def from_raw_data(cls, dtype=dtype, ) +# --------------------------------------------------------------------------- +# HugePage helpers (standalone, reusable outside of BaseStorageAllocator) +# --------------------------------------------------------------------------- +DEFAULT_HUGE_PAGE_SIZE = 2 * 1024 * 1024 # 2 MiB +DEFAULT_HUGETLBFS_DIR = "/mnt/hugepages" + +_MAP_SHARED = 0x01 +_MAP_PRIVATE = 0x02 +_MAP_ANONYMOUS = 0x20 +_MAP_HUGETLB = 0x40000 +_MAP_HUGE_SHIFT = 26 +_PROT_READ = 0x1 +_PROT_WRITE = 0x2 +_MAP_FAILED = ctypes.c_void_p(-1).value # (void*)-1 +_HUGETLBFS_MAGIC = 0x958458F6 + +_libc = ctypes.CDLL("libc.so.6", use_errno=True) +_libc.mmap.restype = ctypes.c_void_p +_libc.mmap.argtypes = [ + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_long, +] +_libc.munmap.restype = ctypes.c_int +_libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] +_libc.ftruncate.restype = ctypes.c_int +_libc.ftruncate.argtypes = [ctypes.c_int, ctypes.c_long] +_libc.close.restype = ctypes.c_int +_libc.close.argtypes = [ctypes.c_int] + + +class _StatFS(ctypes.Structure): + _fields_ = [ + ("f_type", ctypes.c_long), + ("f_bsize", ctypes.c_long), + ("f_blocks", ctypes.c_ulong), + ("f_bfree", ctypes.c_ulong), + ("f_bavail", ctypes.c_ulong), + ("f_files", ctypes.c_ulong), + ("f_ffree", ctypes.c_ulong), + ("f_fsid", ctypes.c_int * 2), + ("f_namelen", ctypes.c_long), + ("f_frsize", ctypes.c_long), + ("f_flags", ctypes.c_long), + ("f_spare", ctypes.c_long * 4), + ] + + +_libc.statfs.restype = ctypes.c_int +_libc.statfs.argtypes = [ctypes.c_char_p, ctypes.POINTER(_StatFS)] + +@dataclass +class _HugePageMapping: + finalizer: Any + aligned: int + path: str | None = None + + +_live_hugepage_mappings: "Dict[int, _HugePageMapping]" = {} + + +@dataclass(frozen=True) +class HugePageTensorHandle: + path: str + num_elements: int + dtype: torch.dtype + aligned: int + + def get_tensor(self) -> torch.Tensor: + return _materialize_shareable_hugepage_tensor( + path=self.path, + num_elements=self.num_elements, + dtype=self.dtype, + aligned=self.aligned, + ) + + +def _cleanup_hugepage_mapping(addr: int, aligned: int, fd: int, + path: str | None, data_ptr: int) -> None: + _munmap_huge(addr, aligned) + if fd >= 0: + _libc.close(fd) + if path is not None: + try: + os.unlink(path) + except FileNotFoundError: + pass + _live_hugepage_mappings.pop(data_ptr, None) + + +def _cleanup_hugepage_mmap(mm: mmap.mmap, path: str | None, data_ptr: int) -> None: + try: + mm.close() + finally: + if path is not None: + try: + os.unlink(path) + except FileNotFoundError: + pass + _live_hugepage_mappings.pop(data_ptr, None) + + +def _statfs_type(path: str) -> int: + statfs_buf = _StatFS() + if _libc.statfs(path.encode(), ctypes.byref(statfs_buf)) != 0: + err = ctypes.get_errno() + raise RuntimeError( + f"HugePage: statfs({path}) failed: {os.strerror(err)} (errno={err})" + ) + return int(statfs_buf.f_type) + + +def _ensure_hugetlbfs_mount(mnt_dir: str) -> None: + if not os.path.isdir(mnt_dir): + raise RuntimeError(f"HugePage: hugetlbfs directory does not exist: {mnt_dir}") + fs_type = _statfs_type(mnt_dir) + if fs_type != _HUGETLBFS_MAGIC: + raise RuntimeError( + f"HugePage: {mnt_dir} is not a hugetlbfs mount " + f"(f_type=0x{fs_type:x}, expected=0x{_HUGETLBFS_MAGIC:x})" + ) + + +def _create_hugetlbfs_file(aligned: int) -> tuple[str, int]: + mnt_dir = os.environ.get("FLEXKV_HUGETLBFS_DIR", DEFAULT_HUGETLBFS_DIR) + _ensure_hugetlbfs_mount(mnt_dir) + path = os.path.join(mnt_dir, f"flexkv_hugepage_{os.getpid()}_{id(object()):x}") + fd = os.open(path, os.O_CREAT | os.O_RDWR | os.O_EXCL, 0o600) + try: + ctypes.set_errno(0) + if _libc.ftruncate(fd, aligned) != 0: + err = ctypes.get_errno() + raise RuntimeError( + f"HugePage: ftruncate({path}, {aligned}) failed: " + f"{os.strerror(err)} (errno={err})" + ) + except Exception: + os.close(fd) + try: + os.unlink(path) + except FileNotFoundError: + pass + raise + return path, fd + + +def _wrap_mmap_tensor(mm: mmap.mmap, + aligned: int, + num_elements: int, + dtype: torch.dtype, + cleanup_path: str | None) -> torch.Tensor: + num_bytes = num_elements * dtype.itemsize + tensor = torch.frombuffer(mm, dtype=torch.uint8, count=num_bytes).view(dtype)[:num_elements] + ptr = tensor.data_ptr() + finalizer = weakref.finalize(tensor, _cleanup_hugepage_mmap, mm, cleanup_path, ptr) + _live_hugepage_mappings[ptr] = _HugePageMapping( + finalizer=finalizer, + aligned=aligned, + path=cleanup_path, + ) + return tensor + + +def _materialize_shareable_hugepage_tensor(path: str, + num_elements: int, + dtype: torch.dtype, + aligned: int) -> torch.Tensor: + fd = os.open(path, os.O_RDWR) + try: + mm = mmap.mmap( + fd, + aligned, + flags=mmap.MAP_SHARED, + prot=mmap.PROT_READ | mmap.PROT_WRITE, + ) + finally: + os.close(fd) + return _wrap_mmap_tensor(mm, aligned, num_elements, dtype, cleanup_path=None) + + +def materialize_worker_tensor(data: Union[torch.Tensor, HugePageTensorHandle]) -> torch.Tensor: + if isinstance(data, torch.Tensor): + return data + if isinstance(data, HugePageTensorHandle): + return data.get_tensor() + raise TypeError(f"Unsupported worker tensor type: {type(data)}") + + +def get_worker_hugepage_handle(tensor: torch.Tensor, + num_elements: int, + dtype: torch.dtype) -> HugePageTensorHandle | None: + mapping = _live_hugepage_mappings.get(tensor.data_ptr()) + if mapping is None or mapping.path is None: + return None + return HugePageTensorHandle( + path=mapping.path, + num_elements=num_elements, + dtype=dtype, + aligned=mapping.aligned, + ) + + +def _align_to_page(num_bytes: int, page_size_bytes: int) -> int: + """Round *num_bytes* up to the next multiple of *page_size_bytes*.""" + return (num_bytes + page_size_bytes - 1) & ~(page_size_bytes - 1) + + +def _read_hugepages_free(page_size_bytes: int) -> int: + """Return the number of free huge pages for *page_size_bytes*.""" + try: + size_kb = page_size_bytes // 1024 + cur_kb = 0 + free = 0 + with open("/proc/meminfo") as f: + for line in f: + if line.startswith("Hugepagesize:"): + cur_kb = int(line.split()[1]) + elif line.startswith("HugePages_Free:"): + free = int(line.split()[1]) + if cur_kb != size_kb: + return 0 # wrong page size pool + return free + except Exception: + return 0 + + +def _mmap_huge(num_bytes: int, page_size_bytes: int) -> Tuple[int, int, int]: + if num_bytes <= 0: + raise ValueError(f"HugePage: num_bytes must be > 0, got {num_bytes}") + if page_size_bytes <= 0 or (page_size_bytes & (page_size_bytes - 1)) != 0: + raise ValueError( + f"HugePage: page_size_bytes must be a power of two, got {page_size_bytes}" + ) + + aligned = _align_to_page(num_bytes, page_size_bytes) + page_shift = page_size_bytes.bit_length() - 1 + + # 1) Anonymous MAP_HUGETLB — no hugetlbfs mount needed. + free_pages = _read_hugepages_free(page_size_bytes) + if free_pages and aligned > free_pages * page_size_bytes: + flexkv_logger.warning( + f"HugePage: requested {aligned // page_size_bytes} pages " + f"({aligned / (1024**3):.3f} GiB) but only {free_pages} free " + f"(page_size={page_size_bytes // (1024*1024)} MiB). " + f"The kernel may fall back to regular pages or overcommit." + ) + + ctypes.set_errno(0) + huge_flags = _MAP_PRIVATE | _MAP_ANONYMOUS | _MAP_HUGETLB | (page_shift << _MAP_HUGE_SHIFT) + ret = _libc.mmap(None, aligned, _PROT_READ | _PROT_WRITE, huge_flags, -1, 0) + if ret is not None and ret != _MAP_FAILED: + return int(ret), aligned, -1 + + # 2) Fallback: file-backed hugetlbfs. Reject non-hugetlbfs mounts so we + # never silently succeed on regular 4 KiB pages. + fd = -1 + try: + path, fd = _create_hugetlbfs_file(aligned) + try: + os.unlink(path) + except OSError: + pass + + ctypes.set_errno(0) + ret = _libc.mmap( + None, + aligned, + _PROT_READ | _PROT_WRITE, + _MAP_SHARED, + fd, + 0, + ) + if ret is None or ret == _MAP_FAILED: + err = ctypes.get_errno() + raise RuntimeError( + f"HugePage: mmap({path}, {aligned}) failed: " + f"{os.strerror(err)} (errno={err})" + ) + return int(ret), aligned, fd + except Exception: # noqa: BLE001 + if fd >= 0: + os.close(fd) + raise + + +def _munmap_huge(addr: int, length: int) -> None: + if _libc.munmap(ctypes.c_void_p(addr), length) != 0: + err = ctypes.get_errno() + flexkv_logger.warning( + f"HugePage: munmap({hex(addr)}, {length}) failed: " + f"{os.strerror(err)} (errno={err})" + ) + + +def alloc_hugepage_tensor(num_elements: int, + dtype: torch.dtype, + page_size_bytes: int = DEFAULT_HUGE_PAGE_SIZE, + shareable: bool = False) -> torch.Tensor: + """Allocate ``num_elements`` values of ``dtype`` on HugePage-backed memory. + + Returns a 1-D ``torch.Tensor`` that zero-copy wraps the mmap'd region. + The tensor's ``data_ptr()`` can be passed to ``cudaHostRegister`` or to + other RDMA-style registration APIs. + + Use ``free_hugepage_tensor(tensor)`` to explicitly release the mapping; + otherwise it will be released when the tensor (and all references to it) + are garbage-collected. + + Raises: + RuntimeError: if the mmap fails or if the resulting VMA is not backed + by huge pages of the requested size (i.e. no silent fallback). + """ + num_bytes = num_elements * dtype.itemsize + + if shareable: + aligned = _align_to_page(num_bytes, page_size_bytes) + path, fd = _create_hugetlbfs_file(aligned) + try: + mm = mmap.mmap( + fd, + aligned, + flags=mmap.MAP_SHARED, + prot=mmap.PROT_READ | mmap.PROT_WRITE, + ) + finally: + os.close(fd) + return _wrap_mmap_tensor(mm, aligned, num_elements, dtype, cleanup_path=path) + + addr, aligned, fd = _mmap_huge(num_bytes, page_size_bytes) + + # Zero-copy wrap: build a numpy uint8 array pointing at the raw memory, + # then view it as the requested dtype via ``torch.frombuffer``. The numpy + # array keeps a reference (``_base_keepalive``) so Python's GC cannot free + # the underlying bytes while the tensor is still live. + buf_type = (ctypes.c_uint8 * aligned) + raw = buf_type.from_address(addr) + np_arr = np.frombuffer(raw, dtype=np.uint8, count=num_bytes) + tensor = torch.frombuffer(np_arr, dtype=torch.uint8, count=num_bytes) \ + .view(dtype)[:num_elements] + + ptr = tensor.data_ptr() + finalizer = weakref.finalize(tensor, _cleanup_hugepage_mapping, + addr, aligned, fd, None, ptr) + _live_hugepage_mappings[ptr] = _HugePageMapping( + finalizer=finalizer, + aligned=aligned, + path=None, + ) + return tensor + + +def free_hugepage_tensor(tensor: torch.Tensor) -> None: + """Release the HugePage mapping previously created by :func:`alloc_hugepage_tensor`. + + No-op if ``tensor`` is not known to be HugePage-backed. + The caller must ensure no other references to the tensor's memory remain + in active use (e.g. ``cudaHostUnregister`` should be called first, and + any Python reference to ``tensor`` should be dropped after this call). + """ + if not isinstance(tensor, torch.Tensor): + return + ptr = tensor.data_ptr() + mapping = _live_hugepage_mappings.pop(ptr, None) + if mapping is None: + return + mapping.finalizer() + + +class HugePageAllocator(BaseStorageAllocator): + """CPU KV-cache allocator backed by hugetlbfs HugePages. + + Unlike :class:`CPUAllocator` (which relies on ``torch.empty`` on top of 4KiB + pages), this allocator maps a hugetlbfs file and wraps the resulting buffer + into a 1-D ``torch.Tensor`` (zero-copy). + + Benefits: + * Reduced TLB pressure for large KV caches (2MiB / 1GiB pages). + * The returned tensor's ``data_ptr()`` can still be passed to + ``cudaHostRegister`` for pinned H2D/D2H transfers. + + Prerequisites: + * The kernel must have huge pages reserved, e.g. for 2MiB pages:: + + echo N > /proc/sys/vm/nr_hugepages + # or, per-size on recent kernels: + echo N > /sys/kernel/mm/hugepages/hugepages-2048kB/nr_hugepages + + For 1GiB pages the kernel usually needs ``hugepagesz=1G`` at boot + and a corresponding ``hugepages=N`` reservation. + + kwargs: + page_size_bytes (int): Huge page size in bytes. Supported values: + ``2 * 1024 * 1024`` (default) or ``1024 * 1024 * 1024``. + """ + + @classmethod + def allocate(cls, + layout: KVCacheLayout, + dtype: torch.dtype, + **kwargs: Any) -> StorageHandle: + page_size_bytes = int(kwargs.get("page_size_bytes", DEFAULT_HUGE_PAGE_SIZE)) + total_elements = layout.get_total_elements() + element_size = dtype.itemsize + + flexkv_logger.info( + f"HugePage allocate total_size: " + f"{total_elements * element_size / 1024 / 1024 / 1024:.4f} GB " + f"(page_size={page_size_bytes // (1024 * 1024)}MiB)" + ) + try: + physical_tensor = alloc_hugepage_tensor( + total_elements, + dtype, + page_size_bytes, + shareable=True, + ) + except Exception as e: # noqa: BLE001 + flexkv_logger.warning( + f"HugePage allocation failed ({e}); falling back to regular CPU memory." + ) + return CPUAllocator.allocate(layout, dtype, **kwargs) + worker_data = get_worker_hugepage_handle(physical_tensor, total_elements, dtype) + return StorageHandle( + handle_type=AccessHandleType.TENSOR, + data=physical_tensor, + kv_layout=layout, + dtype=dtype, + worker_data=worker_data, + ) + + @classmethod + def free(cls, accessible_handle: StorageHandle) -> None: + if accessible_handle.handle_type != AccessHandleType.TENSOR: + return + tensor = accessible_handle.data + if isinstance(tensor, torch.Tensor): + free_hugepage_tensor(tensor) + + @classmethod + def from_raw_data(cls, + data: torch.Tensor, # type: ignore + layout: KVCacheLayout, + dtype: torch.dtype, + **kwargs: Any) -> StorageHandle: + # We assume the caller already backs ``data`` with huge pages (or does + # not care). We do not take ownership of any mmap here. + return StorageHandle( + handle_type=AccessHandleType.TENSOR, + data=data, + kv_layout=layout, + dtype=dtype, + ) + + class SSDAllocator(BaseStorageAllocator): @classmethod def allocate(cls, @@ -138,7 +599,7 @@ def allocate(cls, file_prefix = kwargs.get("file_prefix", "flexkv_ssd_cache") cfg_max_file_size_gb = kwargs.get("max_file_size_gb", -1) cfg_max_blocks_per_file = int(1e9) - + if cache_dir is None: raise ValueError("cache_dir is required for SSD allocator") if isinstance(cache_dir, str): diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 9d99c05303..195662c121 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -10,10 +10,21 @@ from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import StorageHandle, KVCacheLayout, KVCacheLayoutType from flexkv.common.transfer import DeviceType -from flexkv.storage.allocator import CPUAllocator, GPUAllocator, SSDAllocator, RemoteAllocator +from flexkv.storage.allocator import ( + CPUAllocator, + GPUAllocator, + HugePageAllocator, + RemoteAllocator, + SSDAllocator, +) class StorageEngine: + def _cpu_allocator(self) -> type[CPUAllocator] | type[HugePageAllocator]: + if self._cache_config.use_hugepage_cpu_buffer: + return HugePageAllocator + return CPUAllocator + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): @@ -226,21 +237,28 @@ def allocate(self, storage_handle: StorageHandle if device_type == DeviceType.CPU: + cpu_allocator = self._cpu_allocator() pin_memory = kwargs.get('pin_memory', False) + page_size_bytes = kwargs.get( + 'page_size_bytes', + self._cache_config.hugepage_size_bytes, + ) if raw_data is not None: assert isinstance(raw_data, torch.Tensor), \ "raw_data for CPUAllocator must be Tensor" - storage_handle = CPUAllocator.from_raw_data( + storage_handle = cpu_allocator.from_raw_data( data=raw_data, # type: ignore layout=layout, dtype=dtype, - pin_memory=pin_memory + pin_memory=pin_memory, + page_size_bytes=page_size_bytes, ) else: - storage_handle = CPUAllocator.allocate( + storage_handle = cpu_allocator.allocate( layout=layout, dtype=dtype, - pin_memory=pin_memory + pin_memory=pin_memory, + page_size_bytes=page_size_bytes, ) elif device_type == DeviceType.GPU: num_chunks = kwargs.get('num_chunks', 1) diff --git a/flexkv/transfer/host_buffer.py b/flexkv/transfer/host_buffer.py new file mode 100644 index 0000000000..bc2b9ea44d --- /dev/null +++ b/flexkv/transfer/host_buffer.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import ctypes +from dataclasses import dataclass +from typing import Optional + +import torch + +from flexkv.common.debug import flexkv_logger +from flexkv.storage.allocator import alloc_hugepage_tensor, free_hugepage_tensor + +_cudart = None +_cudart_load_error: Optional[OSError] = None + + +def _get_cudart(): + global _cudart + global _cudart_load_error + + if _cudart is None and _cudart_load_error is None: + try: + _cudart = ctypes.CDLL("libcudart.so") + except OSError as e: + _cudart_load_error = e + + if _cudart is None: + raise RuntimeError(f"libcudart.so is unavailable: {_cudart_load_error}") + return _cudart + + +def cuda_host_registration_available() -> bool: + try: + _get_cudart() + except RuntimeError: + return False + return True + + +def cudaHostRegister(tensor: torch.Tensor) -> None: + cudart = _get_cudart() + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + ret = cudart.cudaHostRegister(ctypes.c_void_p(ptr), ctypes.c_size_t(size), 1) + if ret != 0: + raise RuntimeError(f"cudaHostRegister failed with error code {ret}") + + +def cudaHostUnregister(tensor: torch.Tensor) -> None: + cudart = _get_cudart() + ptr = tensor.data_ptr() + ret = cudart.cudaHostUnregister(ctypes.c_void_p(ptr)) + if ret != 0: + raise RuntimeError(f"cudaHostUnregister failed with error code {ret}") + + +@dataclass +class HostBufferHandle: + tensor: torch.Tensor + is_hugepage: bool = False + is_cuda_registered: bool = False + + def __post_init__(self) -> None: + if self.is_cuda_registered and not self.is_hugepage: + raise ValueError("CUDA-registered host buffer must be HugePage-backed") + + @classmethod + def pinned(cls, tensor: torch.Tensor) -> HostBufferHandle: + return cls(tensor=tensor) + + @classmethod + def hugepage(cls, tensor: torch.Tensor) -> HostBufferHandle: + return cls(tensor=tensor, is_hugepage=True, is_cuda_registered=True) + + def release(self) -> None: + if not self.is_hugepage: + return + + if self.is_cuda_registered: + try: + cudaHostUnregister(self.tensor) + except Exception as e: + flexkv_logger.warning( + f"[host_buffer] release hugepage host buffer: cuda unregister failed ({e})" + ) + self.is_cuda_registered = False + + free_hugepage_tensor(self.tensor) + flexkv_logger.info("[host_buffer] release hugepage host buffer") + self.is_hugepage = False + + +def _allocate_pinned_cpu_tensor(num_elements: int, dtype: torch.dtype) -> HostBufferHandle: + return HostBufferHandle.pinned( + torch.empty( + num_elements, + dtype=dtype, + device="cpu", + pin_memory=True, + ) + ) + + +def _fallback_to_pinned( + num_elements: int, + dtype: torch.dtype, + reason: Exception, +) -> HostBufferHandle: + flexkv_logger.warning( + f"[host_buffer] fallback to pinned host buffer ({reason})" + ) + return _allocate_pinned_cpu_tensor(num_elements, dtype) + + +def allocate_host_buffer( + num_elements: int, + dtype: torch.dtype, + use_hugepage: bool, + hugepage_size_bytes: int, +) -> HostBufferHandle: + if not use_hugepage: + return _allocate_pinned_cpu_tensor(num_elements, dtype) + + flexkv_logger.info("[host_buffer] attempt hugepage host buffer") + + hugepage_buf = None + try: + hugepage_buf = alloc_hugepage_tensor( + num_elements=num_elements, + dtype=dtype, + page_size_bytes=hugepage_size_bytes, + ) + cudaHostRegister(hugepage_buf) + except Exception as e: + if hugepage_buf is not None: + free_hugepage_tensor(hugepage_buf) + return _fallback_to_pinned(num_elements, dtype, e) + + flexkv_logger.info( + f"[host_buffer] hugepage host buffer ready: " + f"{hugepage_buf.numel() * hugepage_buf.element_size() / (1024 ** 3):.3f} GB" + ) + return HostBufferHandle.hugepage(hugepage_buf) diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index 8ebe616c11..df2b1dc25d 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -13,6 +13,7 @@ from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.config import ModelConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.storage.allocator import HugePageTensorHandle, materialize_worker_tensor from flexkv.transfer.worker_op import WorkerLayerwiseTransferOp from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister @@ -79,7 +80,7 @@ def __init__(self, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], ssd_files: Dict[int, List[str]], gpu_kv_layouts: List[KVCacheLayout], cpu_kv_layout: KVCacheLayout, @@ -100,7 +101,7 @@ def __init__(self, enable_eventfd: bool = True, is_nsa_cp: bool = False, indexer_gpu_blocks: Optional[List[List[TensorSharedHandle]]] = None, - indexer_cpu_blocks: Optional[torch.Tensor] = None, + indexer_cpu_blocks: Optional[Union[torch.Tensor, HugePageTensorHandle]] = None, indexer_gpu_kv_layouts: Optional[List[KVCacheLayout]] = None, indexer_cpu_kv_layout: Optional[KVCacheLayout] = None, indexer_dtype: Optional[torch.dtype] = None, @@ -115,6 +116,7 @@ def __init__(self, f"num_gpu_blocks={[len(b) for b in gpu_blocks]}") super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" + cpu_blocks = materialize_worker_tensor(cpu_blocks) imported_gpu_blocks = [] for handles_in_one_gpu in gpu_blocks: blocks_in_one_gpu = [] @@ -234,6 +236,7 @@ def __init__(self, assert indexer_gpu_kv_layouts is not None assert indexer_cpu_kv_layout is not None assert indexer_dtype is not None + indexer_cpu_blocks = materialize_worker_tensor(indexer_cpu_blocks) # Import indexer GPU tensor handles imported_indexer_gpu_blocks = [] diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 371d24db26..e4779932f6 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -170,7 +170,7 @@ def _init_workers(self) -> None: finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=gpu_handles[0].get_tensor_handle_list(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), gpu_kv_layout=gpu_handles[0].kv_layout, cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, @@ -189,7 +189,7 @@ def _init_workers(self) -> None: finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, @@ -214,7 +214,7 @@ def _init_workers(self) -> None: finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=gpu_handles[0].get_tensor_handle_list(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), gpu_kv_layout=gpu_handles[0].kv_layout, cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, @@ -233,7 +233,7 @@ def _init_workers(self) -> None: finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=[gpu_handle.get_tensor_handle_list() for gpu_handle in gpu_handles], - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, @@ -257,7 +257,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, ssd_kv_layout=self._ssd_handle.kv_layout, @@ -272,7 +272,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, ssd_kv_layout=self._ssd_handle.kv_layout, @@ -286,7 +286,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, remote_kv_layout=self._remote_handle.kv_layout, @@ -298,7 +298,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, remote_kv_layout=self._remote_handle.kv_layout, @@ -374,7 +374,7 @@ def _init_workers(self) -> None: finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=[handle.get_tensor_handle_list() for handle in gpu_handles], - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), ssd_files=ssd_files, gpu_kv_layouts=[handle.kv_layout for handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, @@ -394,7 +394,7 @@ def _init_workers(self) -> None: d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, is_nsa_cp=_is_nsa_cp, indexer_gpu_blocks=[h.get_tensor_handle_list() for h in idx_handles] if idx_handles else None, - indexer_cpu_blocks=self._indexer_cpu_handle.get_tensor() if idx_handles else None, + indexer_cpu_blocks=self._indexer_cpu_handle.get_worker_tensor() if idx_handles else None, indexer_gpu_kv_layouts=[h.kv_layout for h in idx_handles] if idx_handles else None, indexer_cpu_kv_layout=self._indexer_cpu_handle.kv_layout if idx_handles else None, indexer_dtype=idx_handles[0].dtype if idx_handles else None, @@ -421,7 +421,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor = self.pin_buffer.get_buffer(), - cpu_blocks=self._cpu_handle.get_tensor(), + cpu_blocks=self._cpu_handle.get_worker_tensor(), cpu_kv_layout=self._cpu_handle.kv_layout, # TODO: get remote kv_layout, now we can assume that remote kv layout is same as current node remote_kv_layout=self._cpu_handle.kv_layout, @@ -452,7 +452,7 @@ def _init_workers(self) -> None: finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, cpu_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=indexer_gpu_handles_list[0].dtype, @@ -471,7 +471,7 @@ def _init_workers(self) -> None: finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], cpu_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=indexer_gpu_handles_list[0].dtype, @@ -496,7 +496,7 @@ def _init_workers(self) -> None: finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=indexer_gpu_handles_list[0].get_tensor_handle_list(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), gpu_kv_layout=indexer_gpu_handles_list[0].kv_layout, cpu_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=indexer_gpu_handles_list[0].dtype, @@ -515,7 +515,7 @@ def _init_workers(self) -> None: finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), gpu_blocks=[h.get_tensor_handle_list() for h in indexer_gpu_handles_list], - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], cpu_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=indexer_gpu_handles_list[0].dtype, @@ -537,7 +537,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), ssd_files=self._indexer_ssd_handle.get_file_list(), cpu_kv_layout=self._indexer_cpu_handle.kv_layout, ssd_kv_layout=self._indexer_ssd_handle.kv_layout, @@ -552,7 +552,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), ssd_files=self._indexer_ssd_handle.get_file_list(), cpu_kv_layout=self._indexer_cpu_handle.kv_layout, ssd_kv_layout=self._indexer_ssd_handle.kv_layout, @@ -567,7 +567,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), remote_file=self._indexer_remote_handle.get_file_list(), cpu_kv_layout=self._indexer_cpu_handle.kv_layout, remote_kv_layout=self._indexer_remote_handle.kv_layout, @@ -579,7 +579,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), remote_file=self._indexer_remote_handle.get_file_list(), cpu_kv_layout=self._indexer_cpu_handle.kv_layout, remote_kv_layout=self._indexer_remote_handle.kv_layout, @@ -634,7 +634,7 @@ def _init_workers(self) -> None: mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), - cpu_blocks=self._indexer_cpu_handle.get_tensor(), + cpu_blocks=self._indexer_cpu_handle.get_worker_tensor(), cpu_kv_layout=self._indexer_cpu_handle.kv_layout, remote_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=self._indexer_cpu_handle.dtype, diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 17c9a363c6..534a315c7f 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -10,7 +10,6 @@ from threading import Thread from typing import List, Any, Dict, Union, Optional, Tuple -import ctypes import numpy as np import nvtx import torch @@ -34,6 +33,11 @@ from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType from flexkv.common.transfer import get_nvtx_range_color, LayerwiseTransferOp from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV, MooncakeTransferEngineConfig +from flexkv.storage.allocator import HugePageTensorHandle, materialize_worker_tensor +from flexkv.transfer.host_buffer import ( + allocate_host_buffer, + cudaHostRegister, +) from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp from flexkv.mooncakeEngineWrapper import MoonCakeTransferEngineWrapper @@ -49,24 +53,6 @@ transfer_kv_blocks_remote = None shared_transfer_kv_blocks_remote_read = None - -cudart = ctypes.CDLL('libcudart.so') - -def cudaHostRegister(tensor: torch.Tensor) -> None: - """Register a CPU tensor with CUDA for pinned memory access""" - ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() - ret = cudart.cudaHostRegister(ctypes.c_void_p(ptr), ctypes.c_size_t(size), 1) # 1 means cudaHostRegisterPortable - if ret != 0: - raise RuntimeError(f"cudaHostRegister failed with error code {ret}") - -def cudaHostUnregister(tensor: torch.Tensor) -> None: - """Unregister a CPU tensor from CUDA for pinned memory access""" - ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() - ret = cudart.cudaHostUnregister(ctypes.c_void_p(ptr)) - - class TransferWorkerBase(ABC): _worker_id_counter = 0 _worker_id_lock = threading.Lock() @@ -276,7 +262,7 @@ def __init__(self, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, gpu_blocks: List[TensorSharedHandle], - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], gpu_kv_layout: KVCacheLayout, cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, @@ -288,6 +274,7 @@ def __init__(self, # initialize worker in a new process super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) # Register CPU tensors with CUDA + cpu_blocks = materialize_worker_tensor(cpu_blocks) flexkv_logger.info(f"Pinning CPU Memory: {cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] @@ -427,7 +414,7 @@ def __init__(self, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], gpu_kv_layouts: List[KVCacheLayout], cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, @@ -442,6 +429,7 @@ def __init__(self, super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size + cpu_blocks = materialize_worker_tensor(cpu_blocks) # Handle tensor import for multi-process case imported_gpu_blocks = [] for handles_in_one_gpu in gpu_blocks: @@ -604,7 +592,7 @@ def __init__(self, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], ssd_files: Dict[int, List[str]], # ssd_device_id -> file_paths cpu_kv_layout: KVCacheLayout, ssd_kv_layout: KVCacheLayout, @@ -612,6 +600,7 @@ def __init__(self, num_blocks_per_file: int, cache_config: CacheConfig): super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + cpu_blocks = materialize_worker_tensor(cpu_blocks) self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) @@ -728,7 +717,7 @@ def __init__(self, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, - cpu_blocks: List[torch.Tensor], + cpu_blocks: Union[List[torch.Tensor], torch.Tensor, HugePageTensorHandle], remote_file: List[str], cpu_kv_layout: KVCacheLayout, remote_kv_layout: KVCacheLayout, @@ -739,6 +728,8 @@ def __init__(self, raise RuntimeError("transfer_kv_blocks_remote not available, please build with FLEXKV_ENABLE_CFS=1") super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + cpu_blocks = materialize_worker_tensor(cpu_blocks) + self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.remote_files = remote_file self.num_remote_files = len(remote_file) @@ -1328,7 +1319,7 @@ def __init__(self, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, - cpu_blocks: torch.Tensor, + cpu_blocks: Union[torch.Tensor, HugePageTensorHandle], cpu_kv_layout: KVCacheLayout, remote_kv_layout: KVCacheLayout, dtype: torch.dtype, @@ -1339,6 +1330,7 @@ def __init__(self, mooncake_config_path: str = None, ): super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + cpu_blocks = materialize_worker_tensor(cpu_blocks) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.num_layers = cpu_kv_layout.num_layer self.num_cpu_blocks = cpu_kv_layout.num_block @@ -1435,14 +1427,25 @@ def __init__(self, self.cpu_kv_layout.num_head, self.cpu_kv_layout.head_size, self.cpu_kv_layout.is_mla, - self.cpu_kv_layout._kv_shape, + self.cpu_kv_layout.kv_shape, ) - self.tmp_cpu_buffer = torch.empty( - self.tmp_cpu_buffer_layout.get_total_elements(), + # Allocate the temporary SSD->CPU staging buffer. + # + # Two backends are supported: + # (a) HugePage-backed mmap (when ``cache_config.use_hugepage_tmp_buffer`` + # is True and the kernel has huge pages reserved). We still need + # to pin it for CUDA via ``cudaHostRegister`` because the region + # is not allocated through PyTorch's pinned-memory allocator. + # (b) Pinned ``torch.empty`` (the original behavior, default). + tmp_num_elements = self.tmp_cpu_buffer_layout.get_total_elements() + self._tmp_cpu_buffer_handle = allocate_host_buffer( + num_elements=tmp_num_elements, dtype=self.dtype, - device="cpu", - pin_memory=True, + use_hugepage=self.cache_config.use_hugepage_tmp_buffer, + hugepage_size_bytes=self.cache_config.hugepage_size_bytes, ) + self.tmp_cpu_buffer = self._tmp_cpu_buffer_handle.tensor + self.mooncake_transfer_engine.regist_buffer( self.tmp_cpu_buffer.data_ptr(), self.tmp_cpu_buffer.numel() * self.tmp_cpu_buffer.element_size(), @@ -1515,6 +1518,9 @@ def shutdown(self): self.mooncake_transfer_engine.unregist_buffer(self.cpu_blocks.data_ptr()) if self.cache_config.enable_p2p_ssd: self.mooncake_transfer_engine.unregist_buffer(self.tmp_cpu_buffer.data_ptr()) + # Release CUDA pinning & HugePage mapping, if any. + if hasattr(self, "_tmp_cpu_buffer_handle"): + self._tmp_cpu_buffer_handle.release() # unregist node info from redis server self.unregist_node_meta() diff --git a/tests/hugepage/conftest.py b/tests/hugepage/conftest.py new file mode 100644 index 0000000000..88bdd3f94f --- /dev/null +++ b/tests/hugepage/conftest.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from contextlib import suppress + +import pytest +import torch + +from flexkv.storage.allocator import alloc_hugepage_tensor +from flexkv.transfer import host_buffer + + +def alloc_hugepage_or_skip( + num_elements: int, + dtype: torch.dtype, + page_size_bytes: int, +) -> torch.Tensor: + try: + return alloc_hugepage_tensor( + num_elements=num_elements, + dtype=dtype, + page_size_bytes=page_size_bytes, + ) + except Exception as e: + pytest.skip(f"hugepage allocation failed: {e}") + + +def cuda_ops_or_skip() -> tuple: + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return host_buffer.cudaHostRegister, host_buffer.cudaHostUnregister + + +def unregister_suppress(tensor: torch.Tensor) -> None: + with suppress(Exception): + host_buffer.cudaHostUnregister(tensor) diff --git a/tests/hugepage/test_hugepage_transfer_e2e.py b/tests/hugepage/test_hugepage_transfer_e2e.py new file mode 100644 index 0000000000..faa80d722f --- /dev/null +++ b/tests/hugepage/test_hugepage_transfer_e2e.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import ctypes +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch + +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.storage.allocator import ( + DEFAULT_HUGE_PAGE_SIZE, + free_hugepage_tensor, +) +from tests.hugepage.conftest import ( + alloc_hugepage_or_skip, + cuda_ops_or_skip, + unregister_suppress, +) + +PAGE = DEFAULT_HUGE_PAGE_SIZE +_NUM_LAYERS = 1 +_NUM_BLOCKS = 4 +_TOKENS_PER_BLOCK = 16 +_NUM_HEADS = 8 +_HEAD_SIZE = 128 +_DTYPE = torch.bfloat16 +_ELEM_SIZE = _DTYPE.itemsize + +_CPU_LAYOUT = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=_NUM_LAYERS, + num_block=_NUM_BLOCKS, + tokens_per_block=_TOKENS_PER_BLOCK, + num_head=_NUM_HEADS, + head_size=_HEAD_SIZE, + is_mla=False, +) + +_CHUNK = _CPU_LAYOUT.get_chunk_size() +_BLOCK_STRIDE = _CPU_LAYOUT.get_block_stride() +_KV_STRIDE = _CPU_LAYOUT.get_kv_stride() +_LAYER_STRIDE = _CPU_LAYOUT.get_layer_stride() + + +def _ensure_c_ext(): + try: + from flexkv.c_ext import SSDIOCTX, transfer_kv_blocks_ssd + except ImportError: + pytest.skip("c_ext not built or SSD support disabled") + return SSDIOCTX, transfer_kv_blocks_ssd + + +def _ssd_layout_for(num_blocks_per_file: int) -> KVCacheLayout: + return KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=_NUM_LAYERS, + num_block=num_blocks_per_file, + tokens_per_block=_TOKENS_PER_BLOCK, + num_head=_NUM_HEADS, + head_size=_HEAD_SIZE, + is_mla=False, + ) + + +def _verify_ssd_read( + buf: np.ndarray, + pattern: np.ndarray, + kv_stride_bytes: int, + layer_stride_bytes: int, + chunk_bytes: int, + num_blocks: int, + kv_dim: int, +) -> None: + pattern_u8 = pattern.view(np.uint8) + lid = 0 + for bid in range(num_blocks): + for kv in range(kv_dim): + buf_start = lid * layer_stride_bytes + kv * kv_stride_bytes + bid * chunk_bytes + pat_start = buf_start + actual = buf[buf_start:buf_start + chunk_bytes] + expected = pattern_u8[pat_start:pat_start + chunk_bytes] + assert np.array_equal(actual, expected), ( + f"block {bid} {'K' if kv == 0 else 'V'} mismatch " + f"at offset={buf_start}, size={chunk_bytes}" + ) + + +def test_hugepage_ssd_to_gpu_roundtrip() -> None: + SSDIOCTX, transfer_kv_blocks_ssd = _ensure_c_ext() + cudaHostRegister, _ = cuda_ops_or_skip() + + num_blocks_per_file = _NUM_BLOCKS + kv_dim = 2 + _ssd_layout_for(num_blocks_per_file) + + chunk_bytes = _CHUNK * _ELEM_SIZE + block_stride_bytes = _BLOCK_STRIDE * _ELEM_SIZE + kv_stride_bytes = _KV_STRIDE * _ELEM_SIZE + layer_stride_bytes = _LAYER_STRIDE * _ELEM_SIZE + cpu_chunk_bytes = chunk_bytes + cpu_kv_stride_bytes = kv_stride_bytes + cpu_layer_stride_bytes = layer_stride_bytes + ssd_kv_stride_bytes = kv_stride_bytes + ssd_layer_stride_bytes = layer_stride_bytes + ssd_chunk_bytes = chunk_bytes + ssd_block_stride_bytes = block_stride_bytes + file_size = kv_dim * num_blocks_per_file * chunk_bytes + + pattern = np.arange(file_size // 2, dtype=np.int16) + pattern_bytes = pattern.view(np.uint8) + + tmpdir = Path(tempfile.mkdtemp(prefix="flexkv_e2e_")) + ssd_path = tmpdir / "ssd_0.bin" + pattern_bytes.tofile(ssd_path) + + hugepage_tensor = alloc_hugepage_or_skip( + _CPU_LAYOUT.get_total_elements(), + _DTYPE, + PAGE, + ) + ptr = hugepage_tensor.data_ptr() + needs_unpin = False + + try: + cudaHostRegister(hugepage_tensor) + needs_unpin = True + + ioctx = SSDIOCTX({0: [str(ssd_path)]}, 1, 0, 0) + layer_ids = torch.arange(0, _NUM_LAYERS, dtype=torch.int32) + ssd_block_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64) + cpu_block_ids = torch.tensor([0, 1, 2, 3], dtype=torch.int64) + + transfer_kv_blocks_ssd( + ioctx=ioctx, + cpu_layer_id_list=layer_ids, + cpu_tensor_ptr=hugepage_tensor.data_ptr(), + ssd_block_ids=ssd_block_ids, + cpu_block_ids=cpu_block_ids, + cpu_layer_stride_in_bytes=cpu_layer_stride_bytes, + cpu_kv_stride_in_bytes=cpu_kv_stride_bytes, + ssd_layer_stride_in_bytes=ssd_layer_stride_bytes, + ssd_kv_stride_in_bytes=ssd_kv_stride_bytes, + chunk_size_in_bytes=ssd_chunk_bytes, + block_stride_in_bytes=ssd_block_stride_bytes, + is_read=True, + num_blocks_per_file=num_blocks_per_file, + round_robin=1, + num_threads_per_device=1, + is_mla=False, + ) + + buf_np = np.frombuffer( + (ctypes.c_uint8 * (hugepage_tensor.numel() * _ELEM_SIZE)).from_address(ptr), + dtype=np.uint8, + ) + _verify_ssd_read( + buf_np, + pattern, + cpu_kv_stride_bytes, + cpu_layer_stride_bytes, + cpu_chunk_bytes, + num_blocks_per_file, + kv_dim, + ) + + gpu_tensor = torch.empty_like(hugepage_tensor, device="cuda") + gpu_tensor.copy_(hugepage_tensor, non_blocking=True) + torch.cuda.synchronize() + + roundtrip = torch.empty_like(hugepage_tensor) + roundtrip.copy_(gpu_tensor, non_blocking=True) + torch.cuda.synchronize() + + assert torch.equal( + hugepage_tensor.view(torch.int16), + roundtrip.view(torch.int16), + ) + finally: + if needs_unpin: + unregister_suppress(hugepage_tensor) + free_hugepage_tensor(hugepage_tensor) + shutil.rmtree(tmpdir) diff --git a/tests/hugepage/test_hugepage_unit.py b/tests/hugepage/test_hugepage_unit.py new file mode 100644 index 0000000000..48de07092f --- /dev/null +++ b/tests/hugepage/test_hugepage_unit.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import os + +import pytest +import torch + +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.transfer import DeviceType +from flexkv.storage.storage_engine import StorageEngine +from flexkv.storage.allocator import ( + DEFAULT_HUGE_PAGE_SIZE, + HugePageTensorHandle, + HugePageAllocator, + _live_hugepage_mappings, + alloc_hugepage_tensor, + free_hugepage_tensor, + get_worker_hugepage_handle, + materialize_worker_tensor, +) +from tests.hugepage.conftest import ( + alloc_hugepage_or_skip, + cuda_ops_or_skip, + unregister_suppress, +) + +PAGE = DEFAULT_HUGE_PAGE_SIZE + + +def test_basic_alloc_free() -> None: + n_bytes = 16 * 1024 * 1024 + n_elem = n_bytes // 2 + tensor = alloc_hugepage_or_skip(n_elem, torch.bfloat16, PAGE) + addr = tensor.data_ptr() + + try: + assert isinstance(tensor, torch.Tensor) + assert tensor.numel() == n_elem + assert tensor.dtype == torch.bfloat16 + assert tensor.device.type == "cpu" + assert addr != 0 + assert addr % PAGE == 0 + assert addr in _live_hugepage_mappings + + tensor.view(torch.int16).fill_(0x5A5A) + assert int(tensor.view(torch.int16)[0].item()) == 0x5A5A + finally: + free_hugepage_tensor(tensor) + + assert addr not in _live_hugepage_mappings + + +def test_invalid_args() -> None: + with pytest.raises(ValueError): + alloc_hugepage_tensor(0, torch.float32, page_size_bytes=PAGE) + + with pytest.raises(ValueError): + alloc_hugepage_tensor(1, torch.float32, page_size_bytes=PAGE + 1) + + +def test_non_hugetlbfs_fallback_is_rejected(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: + monkeypatch.setenv("FLEXKV_HUGETLBFS_DIR", str(tmp_path)) + + with pytest.raises(RuntimeError, match="is not a hugetlbfs mount"): + alloc_hugepage_tensor(1024, torch.float32, page_size_bytes=1024 * 1024 * 1024) + + +def test_worker_hugepage_handle_round_trip() -> None: + tensor = alloc_hugepage_or_skip(1024 * 1024, torch.bfloat16, PAGE) + + try: + handle = get_worker_hugepage_handle(tensor, tensor.numel(), tensor.dtype) + if handle is None: + pytest.skip("non-shareable hugepage allocation path on this host") + + rebuilt = materialize_worker_tensor(handle) + rebuilt.view(torch.int16)[0] = 0x1234 + + assert isinstance(handle, HugePageTensorHandle) + assert int(tensor.view(torch.int16)[0].item()) == 0x1234 + free_hugepage_tensor(rebuilt) + finally: + free_hugepage_tensor(tensor) + + +def test_hugepage_allocator_fallback() -> None: + layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERFIRST, + num_layer=2, + num_block=64, + tokens_per_block=16, + num_head=8, + head_size=128, + is_mla=False, + ) + dtype = torch.bfloat16 + page_size = 1024 * 1024 * 1024 + old_dir = os.environ.get("FLEXKV_HUGETLBFS_DIR") + os.environ["FLEXKV_HUGETLBFS_DIR"] = "/nonexistent/flexkv_hugetlbfs" + try: + handle = HugePageAllocator.allocate( + layout=layout, + dtype=dtype, + page_size_bytes=page_size, + ) + finally: + if old_dir is None: + os.environ.pop("FLEXKV_HUGETLBFS_DIR", None) + else: + os.environ["FLEXKV_HUGETLBFS_DIR"] = old_dir + + tensor = handle.get_tensor() + assert isinstance(tensor, torch.Tensor) + assert tensor.numel() == layout.get_total_elements() + assert tensor.dtype == dtype + assert tensor.data_ptr() not in _live_hugepage_mappings + HugePageAllocator.free(handle) + + +def test_cuda_host_register() -> None: + cudaHostRegister, _ = cuda_ops_or_skip() + tensor = alloc_hugepage_or_skip(1024 * 1024, torch.bfloat16, PAGE) + + try: + cudaHostRegister(tensor) + + gpu_tensor = torch.empty_like(tensor, device="cuda") + tensor.fill_(1.25) + gpu_tensor.copy_(tensor, non_blocking=True) + torch.cuda.synchronize() + out = torch.empty_like(tensor) + out.copy_(gpu_tensor, non_blocking=True) + torch.cuda.synchronize() + assert torch.all(out == 1.25).item() + finally: + unregister_suppress(tensor) + free_hugepage_tensor(tensor) + + +def test_host_buffer_release_is_idempotent() -> None: + from flexkv.transfer.host_buffer import allocate_host_buffer + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available for pinned host buffer allocation") + + handle = allocate_host_buffer( + num_elements=1024, + dtype=torch.bfloat16, + use_hugepage=False, + hugepage_size_bytes=PAGE, + ) + + handle.release() + handle.release() + + assert not handle.is_hugepage + assert not handle.is_cuda_registered + + +def test_storage_engine_cpu_cache_uses_hugepage_when_enabled() -> None: + if not os.path.isdir(os.environ.get("FLEXKV_HUGETLBFS_DIR", "/mnt/hugepages")): + pytest.skip("hugetlbfs mount not available") + + model_config = ModelConfig( + num_layers=1, + num_kv_heads=1, + head_size=128, + use_mla=False, + dtype=torch.bfloat16, + ) + cache_config = CacheConfig( + enable_cpu=True, + enable_ssd=False, + num_cpu_blocks=8, + tokens_per_block=16, + use_hugepage_cpu_buffer=True, + hugepage_size_bytes=PAGE, + ) + + storage_engine = StorageEngine(model_config, cache_config) + cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) + cpu_tensor = cpu_handle.get_tensor() + + try: + if cpu_tensor.data_ptr() not in _live_hugepage_mappings: + pytest.skip("hugepage CPU cache allocation fell back on this host") + assert cpu_tensor.data_ptr() in _live_hugepage_mappings + worker_tensor = cpu_handle.get_worker_tensor() + assert isinstance(worker_tensor, HugePageTensorHandle) + finally: + HugePageAllocator.free(cpu_handle) + + +def test_storage_engine_cpu_cache_falls_back_when_hugepage_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: + model_config = ModelConfig( + num_layers=1, + num_kv_heads=1, + head_size=128, + use_mla=False, + dtype=torch.bfloat16, + ) + cache_config = CacheConfig( + enable_cpu=True, + enable_ssd=False, + num_cpu_blocks=8, + tokens_per_block=16, + use_hugepage_cpu_buffer=True, + hugepage_size_bytes=1024 * 1024 * 1024, + ) + monkeypatch.setenv("FLEXKV_HUGETLBFS_DIR", "/nonexistent/flexkv_hugetlbfs") + + storage_engine = StorageEngine(model_config, cache_config) + cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) + cpu_tensor = cpu_handle.get_tensor() + + assert cpu_tensor.data_ptr() not in _live_hugepage_mappings + HugePageAllocator.free(cpu_handle) diff --git a/tests/hugepage/test_hugepage_worker_integration.py b/tests/hugepage/test_hugepage_worker_integration.py new file mode 100644 index 0000000000..abab51821d --- /dev/null +++ b/tests/hugepage/test_hugepage_worker_integration.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import gc +from pathlib import Path +from unittest.mock import patch + +import pytest +import torch + +from flexkv.storage.allocator import ( + DEFAULT_HUGE_PAGE_SIZE, + alloc_hugepage_tensor, + free_hugepage_tensor, +) +from tests.hugepage.conftest import cuda_ops_or_skip, unregister_suppress + +PAGE = DEFAULT_HUGE_PAGE_SIZE +_NUM_PAGES = 8 +_NUM_BYTES = _NUM_PAGES * PAGE +_NUM_ELEMENTS = _NUM_BYTES // 2 + + +def _read_meminfo_hugepages() -> tuple[int, int, int]: + total = free = size_kb = 0 + with open("/proc/meminfo", encoding="utf-8") as f: + for line in f: + if line.startswith("HugePages_Total:"): + total = int(line.split()[1]) + elif line.startswith("HugePages_Free:"): + free = int(line.split()[1]) + elif line.startswith("Hugepagesize:"): + size_kb = int(line.split()[1]) + return total, free, size_kb * 1024 + + +def _require_hugepages(num_pages: int) -> tuple[int, int]: + total, free, _ = _read_meminfo_hugepages() + if total < num_pages: + pytest.skip(f"need at least {num_pages} huge pages") + return total, free + + +def _simulate_tmp_cpu_buffer_init(tmp_num_elements: int) -> tuple[torch.Tensor, bool]: + tmp_cpu_buffer = torch.empty( + tmp_num_elements, + dtype=torch.bfloat16, + device="cpu", + pin_memory=True, + ) + needs_unpin = False + hugepage_buf = None + cudaHostRegister, _ = cuda_ops_or_skip() + try: + hugepage_buf = alloc_hugepage_tensor( + num_elements=tmp_num_elements, + dtype=torch.bfloat16, + page_size_bytes=PAGE, + ) + cudaHostRegister(hugepage_buf) + except Exception: + if hugepage_buf is not None: + free_hugepage_tensor(hugepage_buf) + else: + tmp_cpu_buffer = hugepage_buf + needs_unpin = True + + return tmp_cpu_buffer, needs_unpin + + +class MockMooncakeEngine: + def __init__(self) -> None: + self.registered: set[int] = set() + + def regist_buffer(self, ptr: int, size: int) -> int: + assert ptr != 0 + assert size > 0 + self.registered.add(ptr) + return 0 + + def unregist_buffer(self, ptr: int) -> int: + assert ptr in self.registered + self.registered.discard(ptr) + return 0 + + +def test_full_lifecycle_hugepage() -> None: + _, free_before = _require_hugepages(_NUM_PAGES) + mooncake = MockMooncakeEngine() + tmp_cpu_buffer, needs_unpin = _simulate_tmp_cpu_buffer_init(_NUM_ELEMENTS) + + mooncake.regist_buffer( + tmp_cpu_buffer.data_ptr(), + tmp_cpu_buffer.numel() * tmp_cpu_buffer.element_size(), + ) + + _, free_after_alloc, _ = _read_meminfo_hugepages() + consumed = free_before - free_after_alloc + if needs_unpin: + assert consumed == _NUM_PAGES + assert len(mooncake.registered) == 1 + else: + assert consumed == 0 + + tmp_cpu_buffer.view(torch.int16).fill_(0x7B7B) + assert int(tmp_cpu_buffer.view(torch.int16)[0].item()) == 0x7B7B + + mooncake.unregist_buffer(tmp_cpu_buffer.data_ptr()) + if needs_unpin: + unregister_suppress(tmp_cpu_buffer) + free_hugepage_tensor(tmp_cpu_buffer) + + assert len(mooncake.registered) == 0 + + del tmp_cpu_buffer + gc.collect() + + _, free_after_free, _ = _read_meminfo_hugepages() + assert free_after_free == free_before + + +def test_fallback_when_cuda_host_register_fails() -> None: + _, free_before = _require_hugepages(_NUM_PAGES) + + with patch( + "flexkv.transfer.host_buffer.cudaHostRegister", + side_effect=RuntimeError("injected cudaHostRegister failure"), + ): + tmp_cpu_buffer, needs_unpin = _simulate_tmp_cpu_buffer_init(_NUM_ELEMENTS) + assert not needs_unpin + + mooncake = MockMooncakeEngine() + mooncake.regist_buffer( + tmp_cpu_buffer.data_ptr(), + tmp_cpu_buffer.numel() * tmp_cpu_buffer.element_size(), + ) + mooncake.unregist_buffer(tmp_cpu_buffer.data_ptr()) + + del tmp_cpu_buffer + gc.collect() + + _, free_after, _ = _read_meminfo_hugepages() + assert free_after == free_before diff --git a/tests/test_config_hugepage.py b/tests/test_config_hugepage.py new file mode 100644 index 0000000000..1a14291ab8 --- /dev/null +++ b/tests/test_config_hugepage.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from flexkv.common.config import ( + CacheConfig, + ModelConfig, + UserConfig, + load_user_config_from_env, + update_default_config_from_user_config, +) + + +def test_load_user_config_from_env_reads_hugepage_flags(monkeypatch) -> None: + monkeypatch.setenv("FLEXKV_USE_HUGEPAGE_CPU_BUFFER", "1") + monkeypatch.setenv("FLEXKV_USE_HUGEPAGE_TMP_BUFFER", "1") + monkeypatch.setenv("FLEXKV_HUGEPAGE_SIZE_BYTES", str(1 << 30)) + + user_config = load_user_config_from_env() + + assert user_config.use_hugepage_cpu_buffer is True + assert user_config.use_hugepage_tmp_buffer is True + assert user_config.hugepage_size_bytes == 1 << 30 + + +def test_update_default_config_from_user_config_applies_hugepage_flags() -> None: + model_config = ModelConfig( + num_layers=1, + num_kv_heads=1, + head_size=128, + use_mla=False, + ) + cache_config = CacheConfig() + user_config = UserConfig( + cpu_cache_gb=16, + ssd_cache_gb=0, + use_hugepage_cpu_buffer=True, + use_hugepage_tmp_buffer=True, + hugepage_size_bytes=1 << 30, + ) + + update_default_config_from_user_config(model_config, cache_config, user_config) + + assert cache_config.use_hugepage_cpu_buffer is True + assert cache_config.use_hugepage_tmp_buffer is True + assert cache_config.hugepage_size_bytes == 1 << 30 From 2508b112d2b0ca22c291c7b0c5fcd591b385f429 Mon Sep 17 00:00:00 2001 From: zittozhang Date: Thu, 30 Apr 2026 14:45:10 +0800 Subject: [PATCH 57/59] feat: add PP support with centralized data plane and WorkerKey abstraction Extend ModelConfig with PP-aware fields: pp_start_layer/pp_end_layer, enable_dp_attention, attn_cp_size/attn_cp_rank, and derived properties (attn_tp_size, num_layers_per_pp_stage, token_size_in_bytes_per_pp_stage). Add freeze() to prevent post-init mutation of parallel config. Introduce WorkerKey(dp_rank, pp_rank) as the unique worker identifier, replacing the flat dp_id throughout the transfer stack. This allows TransferEngine to manage multiple PP stages within a single centralized data plane instance, while each PP stage retains independent control decisions (decentralized control plane). Key changes: - TransferOp: dp_id -> (dp_rank, pp_rank) - TransferOpGraph: add clear_gpu_blocks()/set_gpu_blocks() for deferred GPU block binding (PP stages share a graph template but - Integration adapters (vllm/sglang/trt-llm): compute pp_start_layer/pp_end_layer per stage, freeze ModelConfig after init - C++ layerwise/tp_transfer_thread_group: use tp_size_per_node instead of global tp_size for correct node-local eventfd grouping --- CMakeLists.txt | 16 +- build.sh | 6 +- csrc/bindings.cpp | 16 +- csrc/gds/tp_gds_transfer_thread_group.cpp | 2 - csrc/gds/tp_gds_transfer_thread_group.h | 2 - csrc/layerwise.cpp | 4 +- csrc/layerwise.h | 2 +- csrc/tp_transfer_thread_group.cpp | 6 +- csrc/tp_transfer_thread_group.h | 6 +- flexkv/cache/cache_engine.py | 22 +- flexkv/cache/hie_cache_engine.py | 26 +- flexkv/cache/redis_meta.py | 52 +-- flexkv/cache/transfer_pattern.py | 3 +- flexkv/common/config.py | 189 ++++++++-- flexkv/common/tracer.py | 6 +- flexkv/common/transfer.py | 49 ++- flexkv/integration/config.py | 240 +++++++----- .../tensorrt_llm/trtllm_adapter.py | 24 +- flexkv/integration/vllm/vllm_v1_adapter.py | 9 +- flexkv/kvmanager.py | 35 +- flexkv/kvtask.py | 135 ++++--- flexkv/server/client.py | 53 ++- flexkv/server/request.py | 9 +- flexkv/server/server.py | 15 +- flexkv/storage/storage_engine.py | 18 +- flexkv/transfer/layerwise.py | 125 ++----- flexkv/transfer/transfer_engine.py | 342 +++++++++++------- flexkv/transfer/worker.py | 14 +- flexkv/transfer_manager.py | 178 +++++++-- tests/replay_from_tracer.py | 19 +- tests/test_kvmanager.py | 49 ++- tests/test_transfer_engine_atomic_eviction.py | 6 +- 32 files changed, 1070 insertions(+), 608 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 34c777d811..5040ed3d95 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,10 +67,20 @@ target_include_directories(xxhash PUBLIC install(FILES ${XXHASH_HEADERS} DESTINATION include) # ==================== prometheus-cpp Library ==================== -# Option to enable/disable monitoring (env FLEXKV_ENABLE_METRICS=0 or -DFLEXKV_ENABLE_MONITORING=OFF) -set(_FLEXKV_MONITORING_DEFAULT ON) +# Step 1: Auto-detect default from directory existence +if(IS_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/third_party/prometheus-cpp") + set(_FLEXKV_MONITORING_DEFAULT ON) +else() + set(_FLEXKV_MONITORING_DEFAULT OFF) + message(STATUS "third_party/prometheus-cpp not found, Prometheus monitoring defaults to OFF") +endif() + +# Step 2: Environment variable override (FLEXKV_ENABLE_METRICS=1/0) if(DEFINED ENV{FLEXKV_ENABLE_METRICS}) - if("$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "0" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "OFF" + if("$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "1" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "ON" + OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "YES" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "TRUE") + set(_FLEXKV_MONITORING_DEFAULT ON) + elseif("$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "0" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "OFF" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "NO" OR "$ENV{FLEXKV_ENABLE_METRICS}" STREQUAL "FALSE") set(_FLEXKV_MONITORING_DEFAULT OFF) endif() diff --git a/build.sh b/build.sh index 12f4c018c5..05078f1277 100755 --- a/build.sh +++ b/build.sh @@ -66,7 +66,11 @@ fi echo "=== Building in ${BUILD_TYPE} mode ===" # Install submodules -git submodule update --init --recursive +if git rev-parse --is-inside-work-tree >/dev/null 2>&1; then + git submodule update --init --recursive +else + echo "WARNING: Not a git repository, skipping submodule update. If submodules are missing, please clone the repo instead of copying." +fi mkdir -p build cd build diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 05ccb88b87..659802cdd8 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -419,14 +419,14 @@ PYBIND11_MODULE(c_ext, m) { py::class_(m, "LayerwiseTransferGroup") .def(py::init> &, torch::Tensor &, std::map> &, - int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, + int, torch::Tensor &, torch::Tensor &, torch::Tensor &, torch::Tensor &, int, int, torch::Tensor &, int, const std::vector> &, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::map>>(), py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), - py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("ssd_files"), py::arg("num_layers"), py::arg("gpu_kv_strides_tensor"), py::arg("gpu_block_strides_tensor"), py::arg("gpu_layer_strides_tensor"), @@ -510,13 +510,13 @@ PYBIND11_MODULE(c_ext, m) { py::init> &, int, int, int>()); py::class_(m, "TPTransferThreadGroup") - .def(py::init &, int, int64_t, int, int, + .def(py::init &, int, int64_t, int, const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &>(), py::arg("num_gpus"), py::arg("gpu_block_ptrs_flat"), py::arg("num_tensors_per_gpu"), py::arg("cpu_blocks_ptr"), - py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("num_layers"), py::arg("gpu_kv_strides_in_bytes"), py::arg("gpu_block_strides_in_bytes"), py::arg("gpu_layer_strides_in_bytes"), @@ -529,19 +529,19 @@ PYBIND11_MODULE(c_ext, m) { py::arg("cpu_block_stride_in_bytes"), py::arg("cpu_tp_stride_in_bytes"), py::arg("transfer_num_cta"), py::arg("is_host_to_device"), py::arg("use_ce_transfer"), - py::arg("layer_id"), py::arg("layer_granularity"), py::arg("is_mla"), - py::arg("is_nsa_cp") = false); + py::arg("layer_id"), py::arg("layer_granularity"), py::arg("is_mla") + ); #ifdef FLEXKV_ENABLE_GDS py::class_(m, "TPGDSTransferThreadGroup") .def(py::init &, int, - std::map> &, int, int, + std::map> &, int, const std::vector &, const std::vector &, const std::vector &, const std::vector &, const std::vector &>(), py::arg("num_gpus"), py::arg("gpu_block_ptrs_flat"), py::arg("num_tensors_per_gpu"), py::arg("ssd_files"), - py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("num_layers"), py::arg("gpu_kv_strides_in_bytes"), py::arg("gpu_block_strides_in_bytes"), py::arg("gpu_layer_strides_in_bytes"), diff --git a/csrc/gds/tp_gds_transfer_thread_group.cpp b/csrc/gds/tp_gds_transfer_thread_group.cpp index f75e35bfe6..30aecae204 100644 --- a/csrc/gds/tp_gds_transfer_thread_group.cpp +++ b/csrc/gds/tp_gds_transfer_thread_group.cpp @@ -9,7 +9,6 @@ TPGDSTransferThreadGroup::TPGDSTransferThreadGroup( const std::vector &gpu_block_ptrs_flat, int num_tensors_per_gpu, std::map> &ssd_files, - int dp_group_id, int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, @@ -19,7 +18,6 @@ TPGDSTransferThreadGroup::TPGDSTransferThreadGroup( num_gpus_ = num_gpus; num_tensors_per_gpu_ = num_tensors_per_gpu; - dp_group_id_ = dp_group_id; // per-GPU layout parameters gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; diff --git a/csrc/gds/tp_gds_transfer_thread_group.h b/csrc/gds/tp_gds_transfer_thread_group.h index 616c788b30..359548477a 100644 --- a/csrc/gds/tp_gds_transfer_thread_group.h +++ b/csrc/gds/tp_gds_transfer_thread_group.h @@ -27,7 +27,6 @@ class TPGDSTransferThreadGroup { const std::vector &gpu_block_ptrs_flat, int num_tensors_per_gpu, std::map> &ssd_files, - int dp_group_id, int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, @@ -54,7 +53,6 @@ class TPGDSTransferThreadGroup { std::future enqueue_for_gpu(int gpu_idx, Task task); int num_gpus_; - int dp_group_id_; std::vector gpu_device_ids_; void **gpu_blocks_; int num_tensors_per_gpu_; diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp index 139a91eff9..40a217ec25 100644 --- a/csrc/layerwise.cpp +++ b/csrc/layerwise.cpp @@ -61,7 +61,7 @@ static void CUDART_CB layer_done_host_callback(void *userData) { LayerwiseTransferGroup::LayerwiseTransferGroup( int num_gpus, const std::vector> &gpu_blocks, torch::Tensor &cpu_blocks, - std::map> &ssd_files, int dp_group_id, + std::map> &ssd_files, int num_layers, torch::Tensor &gpu_kv_strides_tensor, torch::Tensor &gpu_block_strides_tensor, torch::Tensor &gpu_layer_strides_tensor, @@ -147,8 +147,6 @@ LayerwiseTransferGroup::LayerwiseTransferGroup( cpu_blocks_ = cpu_blocks.data_ptr(); - dp_group_id_ = dp_group_id; - // Get GPU device IDs from tensors (like tp_transfer_thread_group.cpp) gpu_device_ids_.resize(num_gpus_); for (int i = 0; i < num_gpus_; ++i) { diff --git a/csrc/layerwise.h b/csrc/layerwise.h index 95f81f70bf..2de0a23550 100644 --- a/csrc/layerwise.h +++ b/csrc/layerwise.h @@ -22,7 +22,7 @@ class LayerwiseTransferGroup { LayerwiseTransferGroup( int num_gpus, const std::vector> &gpu_blocks, torch::Tensor &cpu_blocks, - std::map> &ssd_files, int dp_group_id, + std::map> &ssd_files, int num_layers, torch::Tensor &gpu_kv_strides_tensor, torch::Tensor &gpu_block_strides_tensor, torch::Tensor &gpu_layer_strides_tensor, diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 77ac01eabd..d0fa757244 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -22,7 +22,7 @@ namespace flexkv { TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector &gpu_block_ptrs_flat, - int num_tensors_per_gpu, int64_t cpu_blocks_ptr, int dp_group_id, + int num_tensors_per_gpu, int64_t cpu_blocks_ptr, int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, const std::vector &gpu_layer_strides_in_bytes, @@ -30,7 +30,6 @@ TPTransferThreadGroup::TPTransferThreadGroup( const std::vector &gpu_device_ids) { num_gpus_ = num_gpus; num_tensors_per_gpu_ = num_tensors_per_gpu; - dp_group_id_ = dp_group_id; gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; @@ -156,8 +155,7 @@ void TPTransferThreadGroup::tp_group_transfer( const int64_t cpu_block_stride_in_bytes, const int64_t cpu_tp_stride_in_bytes, const int transfer_num_cta, const bool is_host_to_device, const bool use_ce_transfer, - const int layer_id, const int layer_granularity, const bool is_mla, - const bool is_nsa_cp) { + const int layer_id, const int layer_granularity, const bool is_mla) { std::atomic failed{false}; std::string error_msg; diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 551910c0c8..4a4aacd373 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -38,7 +38,7 @@ class TPTransferThreadGroup { TPTransferThreadGroup(int num_gpus, const std::vector &gpu_block_ptrs_flat, int num_tensors_per_gpu, int64_t cpu_blocks_ptr, - int dp_group_id, int num_layers, + int num_layers, const std::vector &gpu_kv_strides_in_bytes, const std::vector &gpu_block_strides_in_bytes, const std::vector &gpu_layer_strides_in_bytes, @@ -56,15 +56,13 @@ class TPTransferThreadGroup { const int transfer_num_cta, const bool is_host_to_device, const bool use_ce_transfer, const int layer_id, - const int layer_granularity, const bool is_mla, - const bool is_nsa_cp); + const int layer_granularity, const bool is_mla); private: using Task = std::function; std::future enqueue_for_gpu(int gpu_idx, Task task); int num_gpus_; - int dp_group_id_; std::vector gpu_device_ids_; void **gpu_blocks_; void *cpu_blocks_; diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 292523bb48..39e6cba16f 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -398,7 +398,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m if cache_config.enable_cpu: if cache_config.enable_p2p_cpu: - self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) #TODO + self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta) elif self.index_accel: self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, @@ -422,7 +422,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: if cache_config.enable_p2p_ssd: - self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) #TODO + self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta) elif self.index_accel: self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, cache_config.num_ssd_blocks, @@ -447,7 +447,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m if cache_config.enable_remote: if cache_config.enable_kv_sharing: # Build PCFSCacheEngine from CacheConfig directly (replacing RemotePCFSCacheEngine) TODO - self.remote_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.REMOTE, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) + self.remote_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.REMOTE, meta=self.redis_meta) elif self.index_accel: self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, cache_config.num_remote_blocks, @@ -532,14 +532,15 @@ def get(self, slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, temp_cache_strategy: CacheStrategy = DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None) \ -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: - layer_num = self.model_config.num_layers + layer_num = self.model_config.num_layers_per_pp_stage if layer_granularity == -1: layer_granularity = layer_num @@ -561,7 +562,7 @@ def get(self, if aligned_length == 0 or not token_mask.any(): transfer_graph = TransferOpGraph.create_empty_graph() - transfer_graph.bind_to_dp_group(dp_id) + transfer_graph.bind_to_worker(dp_rank, pp_rank) return_mask = np.zeros_like(token_mask, dtype=np.bool_) callback = partial(self._transfer_callback, node_to_unlock={}, buffer_to_free={}) return transfer_graph, return_mask, callback, {}, -1 @@ -616,7 +617,7 @@ def get(self, # finished_ops_ids=finished_ops_ids, # layer_num=layer_num, # layer_granularity=layer_granularity) - transfer_graph.bind_to_dp_group(dp_id) + transfer_graph.bind_to_worker(dp_rank, pp_rank) for device_type in node_to_unlock: self.cache_engines[device_type].lock_node(node_to_unlock[device_type][0]) @@ -1073,14 +1074,15 @@ def put(self, token_mask: np.ndarray, slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, temp_cache_strategy: CacheStrategy = DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None) \ -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: - layer_num = self.model_config.num_layers + layer_num = self.model_config.num_layers_per_pp_stage # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] @@ -1131,7 +1133,7 @@ def put(self, return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[(block_start_idx + skipped_gpu_blocks)* self.tokens_per_block: (block_start_idx + skipped_gpu_blocks + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True - transfer_graph.bind_to_dp_group(dp_id) + transfer_graph.bind_to_worker(dp_rank, pp_rank) for device_type in node_to_unlock: self.cache_engines[device_type].lock_node(node_to_unlock[device_type][0]) diff --git a/flexkv/cache/hie_cache_engine.py b/flexkv/cache/hie_cache_engine.py index a4acafdb26..41f37ee801 100644 --- a/flexkv/cache/hie_cache_engine.py +++ b/flexkv/cache/hie_cache_engine.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple, TYPE_CHECKING, List, Dict +import time import numpy as np import torch @@ -37,9 +38,7 @@ def __init__(self, evict_start_threshold: float = 1.0, hit_reward_seconds: int = 0, eviction_policy: str = "lru", - meta: Optional[RedisMeta] = None, - pp_rank: int = 0, - pp_size: int = 1) -> None: + meta: Optional[RedisMeta] = None) -> None: if num_total_blocks <= 0: raise ValueError(f"Invalid num_total_blocks: {num_total_blocks}") if tokens_per_block <= 0 or (tokens_per_block & (tokens_per_block - 1)) != 0: @@ -92,8 +91,6 @@ def __init__(self, self.num_total_blocks = num_total_blocks self.evict_ratio = evict_ratio self.evict_start_threshold = evict_start_threshold - self.pp_rank = pp_rank - self.pp_size = pp_size # cumulative statistics: for analyzing distributed KV reuse benefits self._stats_total_queried_tokens = 0 # total tokens queried @@ -116,12 +113,8 @@ def start(self) -> None: else: raise ValueError(f"Invalid device type: {self.device_type}") - if self.pp_size > 1: - local_ch_block_key = f"{base_key}:pp{self.pp_rank}" - remote_ch_block_key = f"{base_key}:pp{self.pp_rank}" - else: - local_ch_block_key = base_key - remote_ch_block_key = base_key + local_ch_block_key = base_key + remote_ch_block_key = base_key self.remote_ch = self._meta.get_redis_meta_channel(remote_ch_block_key) self.local_ch = self._meta.get_redis_meta_channel(local_ch_block_key) # Load and store mapping of node_id -> file_nodeids from Redis @@ -161,7 +154,6 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> num_blocks = sequence_meta.num_blocks # Query both local and remote - import time t0 = time.perf_counter() mr_local = self.local_index.match_prefix(block_hashes_t, int(num_blocks), True) t1 = time.perf_counter() @@ -452,7 +444,7 @@ def recycle(self, physical_blocks: np.ndarray) -> None: #TODO pfcs may not work now @classmethod - def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Optional[RedisMeta] = None, pp_rank: int = 0, pp_size: int = 1) -> "HierarchyLRCacheEngine": + def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Optional[RedisMeta] = None) -> "HierarchyLRCacheEngine": """Create a PCFSCacheEngine from CacheConfig. This replaces RemotePCFSCacheEngine. It wires both local and remote @@ -531,16 +523,14 @@ def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, me local_safety_ttl_ms=int(GLOBAL_CONFIG_FROM_ENV.safety_ttl_ms), eviction_policy=GLOBAL_CONFIG_FROM_ENV.eviction_policy, meta=meta, - pp_rank=pp_rank, - pp_size=pp_size, ) #TODO is this enough for peercpu and peerssd? @classmethod - def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_type: DeviceType, meta: Optional[RedisMeta] = None, pp_rank: int = 0, pp_size: int = 1) -> "HierarchyLRCacheEngine": + def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_type: DeviceType, meta: Optional[RedisMeta] = None) -> "HierarchyLRCacheEngine": if device_type == DeviceType.REMOTE: - return cls.pcfs_ce_from_cache_config(cache_config, node_id, meta, pp_rank=pp_rank, pp_size=pp_size) + return cls.pcfs_ce_from_cache_config(cache_config, node_id, meta) else: # select correct blocks configuration based on device_type if device_type == DeviceType.CPU: @@ -574,7 +564,5 @@ def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_typ hit_reward_seconds=int(GLOBAL_CONFIG_FROM_ENV.hit_reward_seconds), eviction_policy=GLOBAL_CONFIG_FROM_ENV.eviction_policy, meta=meta, - pp_rank=pp_rank, - pp_size=pp_size, ) raise ValueError("Invalid device type: {cache_config.device_type}") diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index 536d58c3ab..1721f80f85 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -680,14 +680,13 @@ def add_node_ids(self, node_ids: Iterable[Union[int, str]]) -> int: # rpush returns the new length of the list return int(r.rpush(f"pcfs:{nid}", *values)) - def regist_buffer(self, mrs: Iterable[object], pp_rank: int = 0, pp_size: int = 1) -> int: + def regist_buffer(self, mrs: Iterable[object]) -> int: """Register RDMA memory regions in Redis. Each element in mrs can be one of: - dict with keys {"buffer_ptr": ..., "buffer_size": ...} - tuple/list (buffer_ptr, buffer_size) - Stored as hash: key = buffer:[:pp]:, field "buffer_size" = . - When pp_size > 1, pp_rank is included in the key for isolation. + Stored as hash: key = buffer::, field "buffer_size" = . Returns the number of regions processed. """ nid = self.get_node_id() @@ -704,27 +703,21 @@ def regist_buffer(self, mrs: Iterable[object], pp_rank: int = 0, pp_size: int = continue if ptr is None or size is None: continue - if pp_size > 1: - key = f"buffer:{nid}:pp{pp_rank}:{int(ptr)}" - else: - key = f"buffer:{nid}:{int(ptr)}" + key = f"buffer:{nid}:{int(ptr)}" pipe.hset(key, mapping={"buffer_size": int(size)}) processed += 1 if processed: pipe.execute() return processed - def unregist_buffer(self, buffer_ptr: Union[int, str], pp_rank: int = 0, pp_size: int = 1) -> bool: + def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool: """Unregister a previously registered RDMA memory region by buffer_ptr. - Looks up key buffer:[:pp]: and deletes it if present. + Looks up key buffer:: and deletes it if present. Returns True if the key existed and was deleted, otherwise False. """ nid = self.get_node_id() - if pp_size > 1: - key = f"buffer:{nid}:pp{pp_rank}:{int(buffer_ptr)}" - else: - key = f"buffer:{nid}:{int(buffer_ptr)}" + key = f"buffer:{nid}:{int(buffer_ptr)}" r = self._client() exists = bool(r.exists(key)) if exists: @@ -732,40 +725,31 @@ def unregist_buffer(self, buffer_ptr: Union[int, str], pp_rank: int = 0, pp_size return True return False - def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int, pp_rank: int = 0, pp_size: int = 1) -> None: + def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int) -> None: """Register node meta information as a Redis hash. - Key: meta:[:pp] - When pp_size > 1, pp_rank is included in the key for PP rank isolation. + Key: meta: Fields: node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int) """ r = self._client() - if pp_size > 1: - key = f"meta:{int(node_id)}:pp{pp_rank}" - else: - key = f"meta:{int(node_id)}" + key = f"meta:{int(node_id)}" r.hset(key, mapping={ "node_id": int(node_id), "addr": str(addr), "zmq_addr": str(zmq_addr), "cpu_buffer_ptr": int(cpu_buffer_ptr), "ssd_buffer_ptr": int(ssd_buffer_ptr), - "pp_rank": int(pp_rank), - "pp_size": int(pp_size), }) - def get_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> dict: + def get_node_meta(self, node_id: int) -> dict: """Get node meta information from Redis. - Reads key meta:[:pp] and returns a dict with fields: + Reads key meta: and returns a dict with fields: node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int). Returns empty dict if the key does not exist. """ r = self._client() - if pp_size > 1: - key = f"meta:{int(node_id)}:pp{pp_rank}" - else: - key = f"meta:{int(node_id)}" + key = f"meta:{int(node_id)}" data = r.hgetall(key) if not data: return {} @@ -780,16 +764,10 @@ def get_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> dic out["ssd_buffer_ptr"] = int(sb) if sb is not None and sb != "" else 0 return out - def unregist_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> bool: - """Unregister node meta by node_id. Returns True if deleted. - - When pp_size > 1, only deletes the key for the specified pp_rank. - """ + def unregist_node_meta(self, node_id: int) -> bool: + """Unregister node meta by node_id. Returns True if deleted.""" r = self._client() - if pp_size > 1: - key = f"meta:{int(node_id)}:pp{pp_rank}" - else: - key = f"meta:{int(node_id)}" + key = f"meta:{int(node_id)}" return bool(r.delete(key)) diff --git a/flexkv/cache/transfer_pattern.py b/flexkv/cache/transfer_pattern.py index fcf207408b..6290e3e69a 100644 --- a/flexkv/cache/transfer_pattern.py +++ b/flexkv/cache/transfer_pattern.py @@ -61,7 +61,8 @@ def convert_read_graph_to_layer_wise_graph( layer_id=i * layer_granularity, layer_granularity=layer_granularity, # Inherit these fields directly - dp_id=op.dp_id, + dp_rank=op.dp_rank, + pp_rank=op.pp_rank, ) new_graph.add_transfer_op(new_op) split_op_ids.append(new_op.op_id) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 4adc83153f..8b92e92c72 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -1,6 +1,7 @@ import os import json -from dataclasses import dataclass +import yaml +from dataclasses import dataclass, field, fields from enum import Enum from typing import Optional, List, Union, Dict, Any from argparse import Namespace @@ -29,21 +30,40 @@ class ModelConfig: use_mla: bool = False dtype: torch.dtype = torch.bfloat16 - # parallel configs + # ------------------------------------------------------------------ + # Parallel configs + # ------------------------------------------------------------------ tp_size: int = 1 - dp_size: int = 1 - dp_rank: int = 0 + tp_rank: int = 0 + pp_size: int = 1 pp_rank: int = 0 - # topology configs - # nnodes : number of physical machines spanned by one replica - # (== server_args.nnodes in SGLang) - # node_rank : index of this machine within ``nnodes`` - # (== server_args.node_rank in SGLang). Used by - # KVTaskEngine's multi-node topology derivation and - # for logging; NOT embedded in the layerwise UDS - # socket path (UDS endpoints are kernel-local). + dp_size: int = 1 + dp_rank: int = 0 + + # pp_start_layer / pp_end_layer: [start, end) layer indices for this PP stage. + pp_start_layer: int = 0 + pp_end_layer: int = -1 # -1 → lazily resolved to num_layers + + # ------------------------------------------------------------------ + # Attention-level parallel configs + # ------------------------------------------------------------------ + # enable_dp_attention: whether DP-attention is enabled (sglang + # ``--enable-dp-attention`` or TRT-LLM ``enable_attention_dp``). + # When True, the physical TP group is split into + # attn_tp × attn_cp × attn_dp. + enable_dp_attention: bool = False + + # attn_cp_size / attn_cp_rank: context-parallel size/rank. + attn_cp_size: int = 1 + attn_cp_rank: int = 0 + + # ------------------------------------------------------------------ + # Topology configs + # ------------------------------------------------------------------ + # nnodes: number of physical machines spanned by one replica + # node_rank: index of this machine within ``nnodes`` nnodes: int = 1 node_rank: int = 0 @@ -54,15 +74,144 @@ class ModelConfig: # ``--dist-init-addr``) to avoid exposing an extra env knob. master_host: Optional[str] = None - # NSA context parallelism: when True, layerwise transfer sends full - # (unpartitioned) KV cache to every rank instead of head-sliced data. + # NSA context parallelism: when True, every rank in the CP group holds + # identical (full) KV cache. FlexKV treats CP ranks the same as TP ranks + # in MLA mode. is_nsa_cp: bool = False - cp_size: int = 1 + + # ------------------------------------------------------------------ + # Freeze mechanism: after post_init, ModelConfig must not be mutated + # ------------------------------------------------------------------ + _frozen: bool = field(default=False, init=False, repr=False) + + def freeze(self) -> None: + """Lock the config so that any subsequent __setattr__ raises an error. + """ + object.__setattr__(self, '_frozen', True) + flexkv_logger.info( + f"[FlexKV] ModelConfig FROZEN — primitive vars are now immutable. " + f"Derived: attn_tp_size={self.attn_tp_size}, attn_tp_rank={self.attn_tp_rank}, " + f"tp_size_per_node={self.tp_size_per_node}, " + f"nnodes_per_pp_rank={self.nnodes_per_pp_rank}, " + f"nnodes_per_tp_group={self.nnodes_per_tp_group}, " + f"total_gpus={self.total_gpus}, " + f"gpus_per_node={self.gpus_per_node}, " + f"num_kv_heads_per_node={self.num_kv_heads_per_node}, " + f"tp_rank_per_node={self.tp_rank_per_node}, " + f"local_rank={self.local_rank}" + ) + + def __setattr__(self, name: str, value) -> None: + if name == '_frozen': + return object.__setattr__(self, name, value) + if getattr(self, '_frozen', False): + raise AttributeError( + f"ModelConfig is frozen — cannot set '{name}'. " + f"All primitive fields must be set during post_init_from_*(), " + f"after which freeze() is called. Derived fields (attn_tp_size, " + f"attn_tp_rank) are @property " + f"and cannot be set at all." + ) + object.__setattr__(self, name, value) + + # ------------------------------------------------------------------ + # Derived topology properties + # ------------------------------------------------------------------ + @property + def total_gpus(self) -> int: + """Total GPUs across all nodes for one FlexKV instance.""" + return self.dp_size * self.tp_size * self.pp_size + + @property + def gpus_per_node(self) -> int: + """Total GPUs on this node (across all DP, PP stages and TP groups).""" + return self.total_gpus // self.nnodes + + @property + def nnodes_per_pp_rank(self) -> int: + """Number of nodes spanned by one PP stage.""" + return max(self.nnodes // self.pp_size, 1) + + @property + def nnodes_per_tp_group(self) -> int: + """Number of nodes spanned by one TP group.""" + return self.nnodes_per_pp_rank + + @property + def tp_size_per_node(self) -> int: + """Number of TP ranks on this node within one TP group.""" + return self.tp_size // self.nnodes_per_tp_group + + @property + def tp_rank_per_node(self) -> int: + """TP rank index within the local node (within one TP group).""" + return self.tp_rank % self.tp_size_per_node + + @property + def local_rank(self) -> int: + """Local GPU device index within the node (a.k.a. ``LOCAL_RANK`` in + PyTorch distributed / sglang / vllm). + + Matches the standard Megatron-style rank layout: + global_rank = dp_rank * pp_size * tp_size + pp_rank * tp_size + tp_rank + + When DP-attention is enabled, DP replicas share the same physical + GPUs, so the DP dimension is not reflected in the device index. + The formula then reduces to: + local_rank = pp_rank_per_node * tp_size_per_node + tp_rank_per_node + + When DP-attention is disabled, each DP replica has its own GPUs: + local_rank = (dp_rank_per_node * pp_size_per_node + pp_rank_per_node) + * tp_size_per_node + tp_rank_per_node + where the ``_per_node`` values are derived inline from the global ranks + and topology. + """ + pp_size_per_node = max(self.pp_size // self.nnodes, 1) + pp_rank_per_node = self.pp_rank % pp_size_per_node + if self.enable_dp_attention: + return pp_rank_per_node * self.tp_size_per_node + self.tp_rank_per_node + dp_size_per_node = self.gpus_per_node // (pp_size_per_node * self.tp_size_per_node) + dp_rank_per_node = self.dp_rank % dp_size_per_node + return (dp_rank_per_node * pp_size_per_node + pp_rank_per_node) * self.tp_size_per_node + self.tp_rank_per_node + + @property + def attn_tp_size(self) -> int: + """Attention-level TP size derived from tp / attn_dp / attn_cp.""" + attn_dp = max(1, self.dp_size) if self.enable_dp_attention else 1 + cp = max(1, self.attn_cp_size) + return max(1, max(1, self.tp_size) // (attn_dp * cp)) + + @property + def attn_tp_rank(self) -> int: + """Attention-level TP rank derived from tp_rank / attn_tp_size.""" + return self.tp_rank % max(1, self.attn_tp_size) + + @property + def num_kv_heads_per_node(self) -> int: + """Number of KV heads visible to a single node.""" + if self.use_mla: + return self.num_kv_heads + return self.num_kv_heads * self.tp_size_per_node // max(1, self.attn_tp_size) + + @property + def kv_dim(self) -> int: + """KV dimension: 1 for MLA (no head split), 2 for standard (head split).""" + return 1 if self.use_mla else 2 + + @property + def num_layers_per_pp_stage(self) -> int: + """Number of layers managed by this PP stage.""" + end = self.pp_end_layer if self.pp_end_layer >= 0 else self.num_layers + return end - self.pp_start_layer @property def token_size_in_bytes(self) -> int: - kv_dim = 1 if self.use_mla else 2 - return self.num_layers * self.num_kv_heads * self.head_size * kv_dim * self.dtype.itemsize + return self.num_layers * self.num_kv_heads * self.head_size * self.kv_dim * self.dtype.itemsize + + @property + def token_size_in_bytes_per_pp_stage(self) -> int: + """Token size in bytes for one PP stage (used by data plane).""" + return self.num_layers_per_pp_stage * self.num_kv_heads * self.head_size * self.kv_dim * self.dtype.itemsize @dataclass class CacheConfig: @@ -229,10 +378,6 @@ def parse_path_list(path_str: str) -> List[str]: return paths def load_user_config_from_file(config_file: str) -> UserConfig: - import json - import yaml - from dataclasses import fields - # read json config file or yaml config file if config_file.endswith('.json'): with open(config_file) as f: @@ -275,7 +420,7 @@ def convert_to_block_num(size_in_GB: float, block_size_in_bytes: int) -> int: def update_default_config_from_user_config(model_config: ModelConfig, cache_config: CacheConfig, user_config: UserConfig) -> None: - block_size_in_bytes = model_config.token_size_in_bytes * cache_config.tokens_per_block + block_size_in_bytes = model_config.token_size_in_bytes_per_pp_stage * cache_config.tokens_per_block assert user_config.cpu_cache_gb > 0 assert user_config.ssd_cache_gb >= 0 diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 4b054f575b..47b463a2a5 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -195,7 +195,8 @@ def trace_request(self, slot_mapping: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, **kwargs): """Record a request operation""" if not self.enabled: @@ -211,7 +212,8 @@ def trace_request(self, "slot_mapping": self._convert_tensor_to_list(slot_mapping), "token_mask": self._convert_tensor_to_list(token_mask) if token_mask is not None else None, "layer_granularity": layer_granularity, - "dp_id": dp_id, + "dp_rank": dp_rank, + "pp_rank": pp_rank, "token_ids_shape": list(token_ids.shape), "slot_mapping_shape": list(slot_mapping.shape), "token_mask_shape": list(token_mask.shape) if token_mask is not None else None, diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index ba1da3b46b..cca69bd756 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -8,6 +8,17 @@ from flexkv.common.debug import flexkv_logger +@dataclass(frozen=True) +class WorkerKey: + """Immutable, hashable key that uniquely identifies a worker by (dp_rank, pp_rank). + + Used as dict keys in TransferEngine's worker maps and TransferManager's + GPU grouping instead of raw ``Tuple[int, int]`` to avoid ambiguity. + """ + dp_rank: int + pp_rank: int + + @dataclass(frozen=True) class CompletedOp: graph_id: int @@ -91,7 +102,8 @@ class TransferOp: # this will keep the full info successors: Set[int] = field(default_factory=set) status: TransferOpStatus = TransferOpStatus.PENDING - dp_id: int = 0 + dp_rank: int = 0 + pp_rank: int = 0 # used for get block ids inner worker process src_slot_id: int = -1 dst_slot_id: int = -1 @@ -137,7 +149,8 @@ def __init__(self, dst_block_ids_disk2h: np.ndarray, layer_id: int = 0, layer_granularity: int = 1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, counter_id: int = 0, indexer_src_block_ids: Optional[np.ndarray] = None, indexer_dst_block_ids: Optional[np.ndarray] = None) -> None: @@ -158,7 +171,8 @@ def __init__(self, dst_block_ids=np.array([], dtype=np.int64), layer_id=layer_id, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, ) def __post_init__(self) -> None: @@ -289,13 +303,30 @@ def set_gpu_blocks(self, gpu_blocks: np.ndarray) -> None: assert op.src_block_ids.size == op.dst_block_ids.size, \ f"src_block_ids.size={op.src_block_ids.size}, dst_block_ids.size={op.dst_block_ids.size}" + def clear_gpu_blocks(self) -> None: + """Clear GPU block_ids from the graph. + """ + for op_id in self._gpu_transfer_op_id: + op = self._op_map[op_id] + # Replace with empty arrays; set_gpu_blocks() will fill them later + if op.src_block_ids.size > 0: + op.src_block_ids = np.array([], dtype=op.src_block_ids.dtype) + if op.dst_block_ids.size > 0: + op.dst_block_ids = np.array([], dtype=op.dst_block_ids.dtype) + @property def num_ops(self) -> int: return len(self._op_map) - def bind_to_dp_group(self, dp_id: int) -> None: + def bind_to_worker(self, dp_rank: int, pp_rank: int) -> None: + """Bind all ops in this graph to the specified DP group and PP stage. + + Both fields are always set together because they jointly determine + which worker (GPU) handles the transfer. + """ for op in self._op_map.values(): - op.dp_id = dp_id + op.dp_rank = dp_rank + op.pp_rank = pp_rank def visualize(self) -> str: """ @@ -351,7 +382,7 @@ def format_blocks(block_ids, max_show=4): dst_str = format_blocks(op.dst_block_ids) lines.append(f"║ ├─ src_blocks: {src_str}".ljust(71) + "║") lines.append(f"║ ├─ dst_blocks: {dst_str}".ljust(71) + "║") - lines.append(f"║ └─ layer_id={op.layer_id}, dp_id={op.dp_id}".ljust(71) + "║") + lines.append(f"║ └─ layer_id={op.layer_id}, dp_rank={op.dp_rank}, pp_rank={op.pp_rank}".ljust(71) + "║") else: lines.append("║ └─ (VIRTUAL - no blocks)".ljust(71) + "║") @@ -395,7 +426,8 @@ def _merge_ops(ops: List[TransferOp], transfer_type: TransferType, dst_block_ids=dst_blocks, layer_id=ops[0].layer_id, layer_granularity=ops[0].layer_granularity, - dp_id=ops[0].dp_id, + dp_rank=ops[0].dp_rank, + pp_rank=ops[0].pp_rank, ) if callbacks: if len(callbacks) == 1: @@ -481,7 +513,8 @@ def merge_to_batch_graph(batch_id: int, else np.array([], dtype=np.int64), layer_id=0, layer_granularity=1, - dp_id=ops_by_type[TransferType.H2D][0].dp_id, + dp_rank=ops_by_type[TransferType.H2D][0].dp_rank, + pp_rank=ops_by_type[TransferType.H2D][0].pp_rank, counter_id=counter_id, # Indexer maps 1:1 with main KV blocks, use same block_ids # CPU side (src) and GPU side (dst) for H2D direction diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 580aab57cf..875c7d9a4d 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -16,6 +16,28 @@ logger = flexkv_logger + +def _parse_dtype_str(dtype_str: str) -> torch.dtype: + """Convert a dtype string (e.g. 'fp8', 'bfloat16', 'fp8_e4m3') to torch.dtype. + + Shared by sglang / vllm / TRT-LLM integration adapters so that dtype + parsing logic is defined in exactly one place. + """ + dtype_map = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float8": torch.float8_e4m3fn, + "e4m3": torch.float8_e4m3fn, + "fp8_e4m3": torch.float8_e4m3fn, + } + return dtype_map.get(dtype_str.lower(), torch.bfloat16) + + @dataclass class FlexKVConfig: enable_flexkv: bool = True @@ -106,6 +128,17 @@ def post_init_from_vllm_config( self.model_config.dp_size = vllm_config.parallel_config.data_parallel_size self.model_config.pp_size = vllm_config.parallel_config.pipeline_parallel_size self.model_config.pp_rank = getattr(vllm_config.parallel_config, 'pipeline_parallel_rank', 0) + + if self.model_config.pp_size > 1: + from vllm.distributed.utils import get_pp_indices as vllm_get_pp_indices + start_layer, end_layer = vllm_get_pp_indices( + self.model_config.num_layers, self.model_config.pp_rank, self.model_config.pp_size + ) + self.model_config.pp_start_layer = start_layer + self.model_config.pp_end_layer = end_layer + else: + self.model_config.pp_start_layer = 0 + self.model_config.pp_end_layer = self.model_config.num_layers if self.model_config.use_mla: self.model_config.num_kv_heads = 1 else: @@ -117,57 +150,66 @@ def post_init_from_vllm_config( hf_config = getattr(vllm_config.model_config, 'hf_config', None) self._detect_indexer_config_from_hf(hf_config, source="vllm") + logger.info( + f"[FlexKV vllm] Primitive vars set: tp_size={self.model_config.tp_size}, " + f"dp_size={self.model_config.dp_size}, dp_rank={self.model_config.dp_rank}, " + f"pp_size={self.model_config.pp_size}, pp_rank={self.model_config.pp_rank}, " + f"enable_dp_attention={self.model_config.enable_dp_attention}, " + f"attn_cp_size={self.model_config.attn_cp_size}, " + f"attn_cp_rank={self.model_config.attn_cp_rank}" + ) + logger.info( + f"[FlexKV vllm] Derived vars: attn_tp_size={self.model_config.attn_tp_size}, " + f"attn_tp_rank={self.model_config.attn_tp_rank}, " + f"local_rank={self.model_config.local_rank}" + ) + + # Freeze model_config — no further mutations allowed + self.model_config.freeze() + def post_init_from_sglang_config( self, sglang_config, - tp_size: int, - page_size: int, - num_local_layers: int = 0, - pp_size: int = 1, - pp_rank: int = 0, - dp_size: int = 1, - dp_rank: int = 0, - nnodes: int = 1, - node_rank: int = 0, - is_nsa_cp: bool = False, - cp_size: int = 1, - cp_rank: int = 0, - kv_cache_dtype: Optional[str] = None, - master_host: Optional[str] = None, + server_args, + page_size: int = 64, + tp_rank: Optional[int] = 0, + pp_rank: Optional[int] = 0, + dp_rank: Optional[int] = 0, + attn_cp_rank: Optional[int] = 0, ): """ Initialize FlexKVConfig fields from sglang config. Args: sglang_config: sglang.srt.configs.model_config.ModelConfig-like object - tp_size: tensor parallel size used by sglang + server_args: sglang ServerArgs — source of tp_size, dp_size, + nnodes, node_rank, enable_dp_attention, attn_cp_size, + is_nsa_cp, kv_cache_dtype, dist_init_addr page_size: KV block size (tokens per block) used by sglang - num_local_layers: number of layers on this PP rank (0 means no PP, use total layers) - pp_size: pipeline parallel size (default 1, no PP) - pp_rank: pipeline parallel rank (default 0) - dp_size: data parallel size (default 1, no DP) - dp_rank: data parallel rank (default 0) - nnodes: number of nodes (aligned with server_args.nnodes, default 1) - node_rank: index of this node (aligned with server_args.node_rank, default 0) - is_nsa_cp: whether NSA context parallelism is enabled - cp_size: context parallel size (default 1, no CP) - cp_rank: context parallel rank (default 0) - kv_cache_dtype: KV cache dtype (default None, use model dtype) - master_host: master host for multi-node setup (default None, use localhost) + tp_rank: physical tensor parallel rank (runtime, from process group) + pp_rank: pipeline parallel rank (runtime, from process group) + dp_rank: data parallel rank (runtime, from process group) + attn_cp_rank: attention-level context parallel rank (runtime) """ + # Extract parallelism params from server_args + tp_size = server_args.tp_size + pp_size = server_args.pp_size + dp_size = server_args.dp_size + nnodes = server_args.nnodes + node_rank = server_args.node_rank + enable_dp_attention = server_args.enable_dp_attention + attn_cp_size = server_args.attn_cp_size + is_nsa_cp = getattr(server_args, 'enable_nsa_prefill_context_parallel', False) + kv_cache_dtype = getattr(server_args, 'kv_cache_dtype', None) + # cache config: use page_size as tokens_per_block so that FlexKV's # CPU radix tree manages blocks at page granularity, ensuring that # hash generation, matching, insertion and eviction are all page-aligned. self.cache_config.tokens_per_block = page_size - total_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) - self.model_config.num_layers = int(num_local_layers) if num_local_layers > 0 else total_layers + self.model_config.num_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) - attn_arch = getattr(sglang_config, "attention_arch", None) - use_mla = False - if hasattr(attn_arch, "name"): - use_mla = (attn_arch.name.upper() == "MLA") - elif isinstance(attn_arch, str): - use_mla = (attn_arch.upper() == "MLA") + from sglang.srt.configs.model_config import AttentionArch + use_mla = getattr(sglang_config, "attention_arch", None) == AttentionArch.MLA if use_mla: kv_lora_rank = int(getattr(sglang_config, "kv_lora_rank", 0)) @@ -196,21 +238,6 @@ def post_init_from_sglang_config( # to the sglang model dtype. sglang's ModelConfig.dtype is the *model # weight* dtype (e.g. bfloat16), which may differ from the KV cache # dtype (e.g. fp8_e4m3 when --kv-cache-dtype fp8_e4m3 is used). - def _parse_dtype_str(dtype_str: str) -> torch.dtype: - dtype_map = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "float8": torch.float8_e4m3fn, - "e4m3": torch.float8_e4m3fn, - "fp8_e4m3": torch.float8_e4m3fn, - } - return dtype_map.get(dtype_str.lower(), torch.bfloat16) - user_dtype_str = self.user_config.kv_cache_dtype if user_dtype_str is not None: self.model_config.dtype = _parse_dtype_str(user_dtype_str) @@ -251,48 +278,35 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.use_mla = use_mla self.model_config.tp_size = int(tp_size) + self.model_config.tp_rank = int(tp_rank) self.model_config.dp_size = int(dp_size if dp_size is not None else 1) self.model_config.dp_rank = int(dp_rank if dp_rank is not None else 0) self.model_config.pp_size = int(pp_size) self.model_config.pp_rank = int(pp_rank) + + if pp_size > 1: + from sglang.srt.distributed.utils import get_pp_indices as sglang_get_pp_indices + start_layer, end_layer = sglang_get_pp_indices( + self.model_config.num_layers, self.model_config.pp_rank, self.model_config.pp_size + ) + self.model_config.pp_start_layer = start_layer + self.model_config.pp_end_layer = end_layer + else: + self.model_config.pp_start_layer = 0 + self.model_config.pp_end_layer = self.model_config.num_layers + self.model_config.enable_dp_attention = bool(enable_dp_attention) + self.model_config.attn_cp_size = int(attn_cp_size) + self.model_config.attn_cp_rank = int(attn_cp_rank) self.model_config.is_nsa_cp = is_nsa_cp - self.model_config.cp_size = int(cp_size if cp_size is not None else 1) - # Topology: nnodes + node_rank (aligned with sglang server_args). - # ``gpus_per_node`` is no longer stored on model_config; KVTaskEngine - # derives it locally as (tp_size * pp_size) // nnodes. self.model_config.nnodes = max(1, int(nnodes)) self.model_config.node_rank = int(node_rank) # Multi-node bootstrap: master host (derived from sglang --dist-init-addr). # ``None`` here falls back to FLEXKV_MASTER_HOST env var downstream. + _dist_init_addr = getattr(server_args, 'dist_init_addr', None) + master_host = _dist_init_addr.split(":")[0] if _dist_init_addr and int(nnodes) > 1 else None self.model_config.master_host = master_host update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) - # Each PP rank needs its own IPC ports so that their - # KVManager / TransferManager instances do not collide on the same - # ZMQ endpoint. DP ranks share the same KVServer (only DP0 creates - # it), so they must use the same IPC port. - _dp_rank = int(dp_rank if dp_rank is not None else 0) - port_suffix = "" - if int(pp_size) > 1: - port_suffix += f"_pp{int(pp_rank)}" - if port_suffix: - self.server_recv_port = f"{self.server_recv_port}{port_suffix}" - self.gpu_register_port = f"{self.server_recv_port}_gpu_register" - - rank_parts = [] - if int(tp_size) > 1: - rank_parts.append("tp_rank=0") - if int(pp_size) > 1: - rank_parts.append(f"pp_rank={int(pp_rank)}") - if int(self.model_config.dp_size) > 1: - rank_parts.append(f"dp_rank={_dp_rank}") - rank_label = f" [{', '.join(rank_parts)}]" if rank_parts else "" - logger.info( - f"[FlexKV] IPC ports configured{rank_label}: " - f"server_recv_port={self.server_recv_port}, " - f"gpu_register_port={self.gpu_register_port}" - ) - hf_config = getattr(sglang_config, 'hf_config', None) self._detect_indexer_config_from_hf(hf_config, source="sglang") @@ -305,6 +319,29 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: f"tokens_per_block={self.cache_config.tokens_per_block}" ) + # Log primitive and derived variables for verification + logger.info( + f"[FlexKV sglang] Primitive vars set: tp_size={self.model_config.tp_size}, " + f"tp_rank={self.model_config.tp_rank}, dp_size={self.model_config.dp_size}, " + f"dp_rank={self.model_config.dp_rank}, pp_size={self.model_config.pp_size}, " + f"pp_rank={self.model_config.pp_rank}, " + f"enable_dp_attention={self.model_config.enable_dp_attention}, " + f"attn_cp_size={self.model_config.attn_cp_size}, " + f"attn_cp_rank={self.model_config.attn_cp_rank}, " + f"is_nsa_cp={self.model_config.is_nsa_cp}, " + f"nnodes={self.model_config.nnodes}, node_rank={self.model_config.node_rank}" + ) + logger.info( + f"[FlexKV sglang] Derived vars: attn_tp_size={self.model_config.attn_tp_size}, " + f"attn_tp_rank={self.model_config.attn_tp_rank}, " + f"tp_rank_per_node={self.model_config.tp_rank_per_node}, " + f"tp_size_per_node={self.model_config.tp_size_per_node}, " + f"local_rank={self.model_config.local_rank}" + ) + + # Freeze model_config — no further mutations allowed + self.model_config.freeze() + def post_init_from_trt_config( self, config, @@ -314,21 +351,6 @@ def post_init_from_trt_config( dtype_str = config.pytorch_backend_config.kv_cache_dtype flexkv_logger.info(f"[FlexKVConfig] dtype_str from TRT config: {dtype_str}") - # Helper function to convert dtype string to torch.dtype - def _parse_dtype_str(dtype_str: str) -> torch.dtype: - dtype_map = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "float8": torch.float8_e4m3fn, - "e4m3": torch.float8_e4m3fn, - } - return dtype_map.get(dtype_str.lower(), torch.bfloat16) - if dtype_str == "auto": # When dtype_str is "auto", try to get kv_cache_dtype from user_config first # This allows users to specify kv_cache_dtype in flexkv_config.json or via environment variable @@ -362,6 +384,14 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: self.model_config.pp_size = getattr(config.mapping, 'pp_size', 1) self.model_config.pp_rank = getattr(config.mapping, 'pp_rank', 0) + if self.model_config.pp_size > 1: + layers_range = config.mapping.pp_layers(self.model_config.num_layers) + self.model_config.pp_start_layer = layers_range[0] + self.model_config.pp_end_layer = layers_range[-1] + 1 + else: + self.model_config.pp_start_layer = 0 + self.model_config.pp_end_layer = self.model_config.num_layers + # self.model_config (model configs part) try: model_path = getattr(config, 'hf_model_dir', None) @@ -392,3 +422,23 @@ def _parse_dtype_str(dtype_str: str) -> torch.dtype: flexkv_logger.error(f"Failed to load config from {model_path}: {e}") # Update cache config with user config after model config is initialized update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) + + # Log primitive and derived variables for verification + flexkv_logger.info( + f"[FlexKV TRT-LLM] Primitive vars set: tp_size={self.model_config.tp_size}, " + f"tp_rank={self.model_config.tp_rank}, dp_size={self.model_config.dp_size}, " + f"dp_rank={self.model_config.dp_rank}, pp_size={self.model_config.pp_size}, " + f"pp_rank={self.model_config.pp_rank}, " + f"enable_dp_attention={self.model_config.enable_dp_attention}, " + f"attn_cp_size={self.model_config.attn_cp_size}, " + f"attn_cp_rank={self.model_config.attn_cp_rank}, " + f"nnodes={self.model_config.nnodes}, node_rank={self.model_config.node_rank}" + ) + flexkv_logger.info( + f"[FlexKV TRT-LLM] Derived vars: attn_tp_size={self.model_config.attn_tp_size}, " + f"attn_tp_rank={self.model_config.attn_tp_rank}, " + f"local_rank={self.model_config.local_rank}" + ) + + # Freeze model_config — no further mutations allowed + self.model_config.freeze() diff --git a/flexkv/integration/tensorrt_llm/trtllm_adapter.py b/flexkv/integration/tensorrt_llm/trtllm_adapter.py index 58853da657..940e682e6d 100644 --- a/flexkv/integration/tensorrt_llm/trtllm_adapter.py +++ b/flexkv/integration/tensorrt_llm/trtllm_adapter.py @@ -45,7 +45,7 @@ def __init__(self, config: ExecutorConfig): self.flexkv_manager = KVManager(model_config=self.model_config, cache_config=self.cache_config, server_recv_port=flexkv_config.server_recv_port, - dp_client_id=self.dp_rank) + dp_client_id=self.model_config.dp_rank) self.flexkv_manager.start() # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) @@ -209,6 +209,8 @@ def _get_match( task_id, matched_mask = self.flexkv_manager.get_match( token_ids=np_token_ids, token_mask=np_token_mask, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) num_new_matched_tokens = matched_mask.sum().item() @@ -366,6 +368,8 @@ def _put_match( task_id, unmatched_mask = self.flexkv_manager.put_match( token_ids=np_token_ids, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) @@ -538,26 +542,30 @@ def __init__(self, config: ExecutorConfig): self.remote_process = TransferManagerOnRemote.create_process() flexkv_logger.info(f"TransferManagerOnRemote process created, PID: {self.remote_process.pid}") - flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, dp_client_id: {dp_client_id}") - self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_client_id, current_device_id) + flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, dp_rank: {dp_rank}") + self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_rank=dp_rank, pp_rank=0, device_id=current_device_id) flexkv_logger.info("Finish init FlexKVWorkerConnector") def _need_to_create_remote_process(self) -> bool: """Check if need to create TransferManagerOnRemote process. Returns True when all of the following conditions are met: - - Multi-node TP is detected (tp_size > gpus_per_node) + - Multi-node TP is detected (nnodes_per_tp_group > 1) - Current node is not master node (node_rank > 0) - - Current worker is worker0 in TP group (tp_rank == 0) + - Current worker is worker0 in the local TP group (tp_rank_per_node == 0) Returns: bool: True if need to create TransferManagerOnRemote process, False otherwise. """ try: is_master_node = self.node_rank == 0 - is_first_worker = self.tp_rank % 8 == 0 - is_multinode_tp = self.flexkv_config.model_config.tp_size > torch.cuda.device_count() - flexkv_logger.info(f"{is_master_node=}, {is_first_worker=}, {is_multinode_tp=}") + is_first_worker = self.flexkv_config.model_config.tp_rank_per_node == 0 + is_multinode_tp = self.flexkv_config.model_config.nnodes_per_tp_group > 1 + flexkv_logger.info( + f"{is_master_node=}, {is_first_worker=}, {is_multinode_tp=}, " + f"nnodes_per_tp_group={self.flexkv_config.model_config.nnodes_per_tp_group}, " + f"tp_rank_per_node={self.flexkv_config.model_config.tp_rank_per_node}" + ) return is_multinode_tp and not is_master_node and is_first_worker except Exception as e: diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 015cbb026a..74cebf5f78 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -336,6 +336,8 @@ def _get_match( task_id, matched_mask = self.flexkv_manager.get_match( token_ids=np_token_ids, token_mask=np_token_mask, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) num_new_matched_tokens = matched_mask.sum().item() @@ -484,6 +486,8 @@ def _put_match( namespace = self._extract_namespace(request) task_id, unmatched_mask = self.flexkv_manager.put_match( token_ids=np_token_ids, + dp_rank=self.flexkv_config.model_config.dp_rank, + pp_rank=self.flexkv_config.model_config.pp_rank, namespace=namespace, ) @@ -704,7 +708,10 @@ def __init__( logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, " f"server_client_mode={server_client_mode}, dp_client_id={dp_client_id}, " f"client_id={client_id}, device_id={device_id}") - self.tp_client = KVTPClient(flexkv_config.gpu_register_port, client_id, device_id) + self.tp_client = KVTPClient(flexkv_config.gpu_register_port, + dp_rank=client_id, + pp_rank=self.flexkv_config.model_config.pp_rank, + device_id=device_id) logger.info("Finish init FlexKVWorkerConnector") def register_to_server(self, kv_caches: dict[str, torch.Tensor]): diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index d6cc8c2cd8..20dce81e01 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -151,7 +151,8 @@ def get_async(self, slot_mapping: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: if isinstance(token_ids, torch.Tensor): @@ -165,13 +166,15 @@ def get_async(self, slot_mapping, token_mask, layer_granularity, + pp_rank=pp_rank, namespace=namespace) else: task_id, _ = self.kv_task_engine.get_async(token_ids, slot_mapping, token_mask, layer_granularity, - dp_id, + dp_rank, + pp_rank=pp_rank, namespace=namespace) return task_id @@ -179,7 +182,8 @@ def get_match(self, token_ids: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, cpu_only: bool = False, namespace: Optional[List[str]] = None, ) -> Tuple[int, np.ndarray]: @@ -191,13 +195,15 @@ def get_match(self, task_id, mask = self.dp_client.get_match(token_ids, token_mask, layer_granularity, + pp_rank=pp_rank, cpu_only=cpu_only, namespace=namespace) else: task_id, mask = self.kv_task_engine.get_match(token_ids, token_mask, layer_granularity, - dp_id, + dp_rank, + pp_rank=pp_rank, cpu_only=cpu_only, namespace=namespace) return task_id, mask @@ -206,7 +212,8 @@ def put_async(self, token_ids: Union[torch.Tensor, np.ndarray], slot_mapping: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: if isinstance(token_ids, torch.Tensor): @@ -216,15 +223,16 @@ def put_async(self, if isinstance(token_mask, torch.Tensor): token_mask = token_mask.numpy() if self.server_client_mode: - task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask, namespace=namespace) + task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask, pp_rank=pp_rank, namespace=namespace) else: - task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_id, namespace=namespace) + task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_rank, pp_rank=pp_rank, namespace=namespace) return task_id def put_match(self, token_ids: Union[torch.Tensor, np.ndarray], token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> Tuple[int, np.ndarray]: if isinstance(token_ids, torch.Tensor): @@ -232,21 +240,22 @@ def put_match(self, if isinstance(token_mask, torch.Tensor): token_mask = token_mask.numpy() if self.server_client_mode: - task_id, mask = self.dp_client.put_match(token_ids, token_mask, namespace=namespace) + task_id, mask = self.dp_client.put_match(token_ids, token_mask, pp_rank=pp_rank, namespace=namespace) else: - task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id, namespace=namespace) + task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_rank, pp_rank=pp_rank, namespace=namespace) return task_id, mask def prefetch_async(self, token_ids: np.ndarray, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, namespace: Optional[List[str]] = None) -> int: if isinstance(token_ids, torch.Tensor): token_ids = token_ids.numpy() if self.server_client_mode: - task_id = self.dp_client.prefetch_async(token_ids, namespace=namespace) + task_id = self.dp_client.prefetch_async(token_ids, pp_rank=pp_rank, namespace=namespace) else: - task_id = self.kv_task_engine.prefetch_async(token_ids, dp_id=dp_id, namespace=namespace) + task_id = self.kv_task_engine.prefetch_async(token_ids, dp_rank=dp_rank, pp_rank=pp_rank, namespace=namespace) return task_id def launch(self, diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 3a336b8fdf..99e85b5b3c 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -25,6 +25,7 @@ ) from flexkv.cache.redis_meta import RedisMeta from flexkv.integration.dynamo.collector import KVEventCollector +from flexkv.transfer_manager import TransferManagerMultiNodeHandle class TaskStatus(Enum): # slot mapping is not ready @@ -60,7 +61,6 @@ class KVTask: token_ids: np.ndarray slot_mapping: np.ndarray token_mask: Optional[np.ndarray] - dp_id: int # cache engine return graph: TransferOpGraph @@ -68,6 +68,9 @@ class KVTask: callback: Optional[Union[Callable, List[Callable]]] op_callback_dict: Dict[int, Callable] + dp_rank: int = 0 + pp_rank: int = 0 + # batch: points to the batch task id if this task was merged into a batch batch_task_id: Optional[int] = None @@ -151,21 +154,10 @@ def __init__(self, self.cache_engine = GlobalCacheEngine(cache_config, model_config, redis_meta, event_collector) - model_config_for_transfer = copy.deepcopy(self.model_config) - if self.nnodes_per_tp_group > 1: - model_config_for_transfer.tp_size //= self.nnodes_per_tp_group - if not self.model_config.use_mla: - model_config_for_transfer.num_kv_heads //= self.nnodes_per_tp_group - # When NSA CP is active, cp_size mirrors tp_size and must also - # be divided so that TransferEngine's _eventfd_group_size matches - # the number of local GPUs on each node. - if model_config_for_transfer.is_nsa_cp and model_config_for_transfer.cp_size > 1: - model_config_for_transfer.cp_size //= self.nnodes_per_tp_group - combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1" if not combine_with_trtllm: self.transfer_handles = [TransferManagerHandle( - model_config_for_transfer, + self.model_config, self.cache_config, mode="process", gpu_register_port=gpu_register_port @@ -178,7 +170,7 @@ def __init__(self, self.remote_process = TransferManagerOnRemote.create_process(mode="TrtllmSubprocess") self.transfer_handles = [ TransferManagerHandle( - model_config_for_transfer, + self.model_config, self.cache_config, mode="remote", gpu_register_port=gpu_register_port, @@ -188,12 +180,12 @@ def __init__(self, ] self.transfer_handles[0]._handle.send_config_to_remotes() - if self.nnodes_per_tp_group > 1: + if self.model_config.nnodes > 1: master_host, master_ports = resolve_master_host_and_ports( master_host=self.model_config.master_host ) self.transfer_handles.append(TransferManagerHandle( - model_config_for_transfer, + self.model_config, self.cache_config, mode="remote", gpu_register_port=gpu_register_port, @@ -246,7 +238,8 @@ def create_get_task(self, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, is_fake_slot_mapping: bool = False, temp_cache_strategy=DEFAULT_CACHE_STRATEGY, namespace: Optional[List[str]] = None, @@ -258,9 +251,10 @@ def create_get_task(self, token_ids=token_ids, token_mask=token_mask, slot_mapping=slot_mapping, - layer_num=self.model_config.num_layers, + layer_num=self.model_config.num_layers_per_pp_stage, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, temp_cache_strategy=temp_cache_strategy, namespace=namespace) self.tasks[task_id] = KVTask( @@ -272,7 +266,8 @@ def create_get_task(self, token_ids=token_ids, slot_mapping=slot_mapping, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, graph=graph, return_mask=return_mask, callback=callback, @@ -285,7 +280,8 @@ def create_put_task(self, token_ids: np.ndarray, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, is_fake_slot_mapping: bool = False, namespace: Optional[List[str]] = None, ) -> None: @@ -296,8 +292,9 @@ def create_put_task(self, token_ids=token_ids, token_mask=token_mask, slot_mapping=slot_mapping, - layer_num=self.model_config.num_layers, - dp_id=dp_id, + layer_num=self.model_config.num_layers_per_pp_stage, + dp_rank=dp_rank, + pp_rank=pp_rank, namespace=namespace) self.tasks[task_id] = KVTask( task_id=task_id, @@ -308,7 +305,8 @@ def create_put_task(self, token_ids=token_ids, slot_mapping=slot_mapping, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, graph=graph, return_mask=return_mask, callback=callback, @@ -318,6 +316,7 @@ def create_put_task(self, def create_prefetch_task(self, task_id: int, token_ids: np.ndarray, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> None: if task_id in self.tasks: @@ -332,7 +331,9 @@ def create_prefetch_task(self, token_ids=token_ids, token_mask=fake_token_mask, slot_mapping=fake_slot_mapping, - layer_num=self.model_config.num_layers, + layer_num=self.model_config.num_layers_per_pp_stage, + dp_rank=0, # dp_rank irrelevant: prefetch only uploads to CPU (ignore_gpu=True) + pp_rank=pp_rank, temp_cache_strategy=temp_cache_strategy, namespace=namespace) self.tasks[task_id] = KVTask( @@ -344,7 +345,8 @@ def create_prefetch_task(self, token_ids=token_ids, slot_mapping=fake_slot_mapping, # ignore slot_mapping for prefetch token_mask=fake_token_mask, # ignore token_mask for prefetch - dp_id=0, # ignore dp_id for prefetch + dp_rank=0, # ignore dp_rank for prefetch + pp_rank=pp_rank, graph=graph, return_mask=return_mask, callback=callback, @@ -361,7 +363,21 @@ def _launch_task(self, task_id: int) -> None: nvtx.mark(f"launch task: task_id={task_id}, graph_id={transfer_graph.graph_id}") if transfer_graph.num_ops > 0: for transfer_handle in self.transfer_handles: - transfer_handle.submit(transfer_graph) + # For remote handles: deepcopy graph and clear GPU blocks when + # it's a cross-machine PP handle (different PP stages have + # different GPU block_ids). Cross-machine TP handles share + # the same slot_mapping, so no clear is needed. + if isinstance(transfer_handle._handle, TransferManagerMultiNodeHandle): + if self.model_config.nnodes > 1 and self.model_config.pp_size > 1: + # Cross-machine PP: each PP rank has different GPU blocks + graph_copy = copy.deepcopy(transfer_graph) + graph_copy.clear_gpu_blocks() + transfer_handle.submit(graph_copy, task_end_op_id=self.tasks[task_id].task_end_op_id) + else: + # Cross-machine TP: same slot_mapping across TP ranks + transfer_handle.submit(transfer_graph, task_end_op_id=self.tasks[task_id].task_end_op_id) + else: + transfer_handle.submit(transfer_graph, task_end_op_id=self.tasks[task_id].task_end_op_id) def _update_tasks(self, timeout: float = 0.001) -> None: completed_ops = self._get_completed_ops(timeout) @@ -498,7 +514,7 @@ def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> raise ValueError("remote_file_size must not None if use file_size model") if model_config.use_mla: kv_size = ( - model_config.num_layers + model_config.num_layers_per_pp_stage * cache_config.tokens_per_block * model_config.num_kv_heads * model_config.head_size @@ -506,7 +522,7 @@ def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> ) else: kv_size = ( - model_config.num_layers + model_config.num_layers_per_pp_stage * 2 * cache_config.tokens_per_block * model_config.num_kv_heads @@ -535,7 +551,8 @@ def get_async(self, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: # self._sync_prefetch(token_ids, namespace) @@ -544,7 +561,8 @@ def get_async(self, is_fake_slot_mapping=False, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, task_id=task_id, namespace=namespace) # trace get request @@ -555,7 +573,8 @@ def get_async(self, slot_mapping=slot_mapping, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) self._launch_task(task_id) return task_id, return_mask @@ -564,14 +583,16 @@ def put_async(self, token_ids: np.ndarray, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: task_id, return_mask = self._put_match_impl(token_ids, slot_mapping, is_fake_slot_mapping=False, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, task_id=task_id, namespace=namespace) # trace put request @@ -582,7 +603,8 @@ def put_async(self, slot_mapping=slot_mapping, token_mask=token_mask, layer_granularity=-1, # put has no layer_granularity parameter - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) self._launch_task(task_id) return task_id, return_mask @@ -687,7 +709,8 @@ def get_match(self, token_ids: np.ndarray, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, cpu_only: bool = False, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: @@ -701,7 +724,8 @@ def get_match(self, is_fake_slot_mapping=True, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, cpu_only=cpu_only, task_id=task_id, namespace=namespace) @@ -713,7 +737,8 @@ def get_match(self, slot_mapping=fake_slot_mapping, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) nvtx.pop_range() return result_task_id, return_mask @@ -724,14 +749,15 @@ def _get_match_impl(self, is_fake_slot_mapping: bool = False, token_mask: Optional[np.ndarray] = None, layer_granularity: int = -1, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, cpu_only: bool = False, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: if token_mask is None: token_mask = np.ones_like(token_ids) if layer_granularity == -1: - layer_granularity = self.model_config.num_layers + layer_granularity = self.model_config.num_layers_per_pp_stage if task_id == -1: task_id = self._gen_task_id() temp_cache_strategy = DEFAULT_CACHE_STRATEGY @@ -743,7 +769,8 @@ def _get_match_impl(self, slot_mapping, token_mask, layer_granularity, - dp_id, + dp_rank, + pp_rank=pp_rank, is_fake_slot_mapping=is_fake_slot_mapping, temp_cache_strategy=temp_cache_strategy, namespace=namespace) @@ -754,7 +781,8 @@ def _get_match_impl(self, def put_match(self, token_ids: np.ndarray, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: fake_slot_mapping = np.zeros_like(token_ids) @@ -762,7 +790,8 @@ def put_match(self, fake_slot_mapping, is_fake_slot_mapping=True, token_mask=token_mask, - dp_id=dp_id, + dp_rank=dp_rank, + pp_rank=pp_rank, task_id=task_id, namespace=namespace) # trace put match request @@ -773,7 +802,8 @@ def put_match(self, slot_mapping=fake_slot_mapping, token_mask=token_mask, layer_granularity=-1, # put has no layer_granularity parameter - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) return result_task_id, return_mask @@ -782,7 +812,8 @@ def _put_match_impl(self, slot_mapping: np.ndarray, is_fake_slot_mapping: bool = False, token_mask: Optional[np.ndarray] = None, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]: if token_mask is None: @@ -794,7 +825,8 @@ def _put_match_impl(self, token_ids, slot_mapping, token_mask, - dp_id, + dp_rank, + pp_rank=pp_rank, is_fake_slot_mapping=is_fake_slot_mapping, namespace=namespace) self._process_empty_graph(task_id) @@ -803,13 +835,14 @@ def _put_match_impl(self, def prefetch_async(self, token_ids: np.ndarray, - dp_id: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, task_id: int = -1, namespace: Optional[List[str]] = None) -> int: if task_id == -1: task_id = self._gen_task_id() nvtx.push_range(f"prefetch match: task_id={task_id}", color=get_nvtx_default_color()) - self.create_prefetch_task(task_id, token_ids, namespace=namespace) + self.create_prefetch_task(task_id, token_ids, pp_rank=pp_rank, namespace=namespace) self._process_empty_graph(task_id) nvtx.pop_range() # trace prefetch async request @@ -820,7 +853,8 @@ def prefetch_async(self, slot_mapping=np.zeros_like(token_ids), token_mask=np.ones_like(token_ids), layer_granularity=-1, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) self._launch_task(task_id) return task_id @@ -864,7 +898,8 @@ def merge_to_batch_kvtask(self, task_end_op_id=task_end_op_id, task_end_op_finished=False, status=TaskStatus.READY, - dp_id=self.tasks[task_ids[0]].dp_id, + dp_rank=self.tasks[task_ids[0]].dp_rank, + pp_rank=self.tasks[task_ids[0]].pp_rank, graph=batch_task_graph, return_mask=return_masks, callback=callbacks, diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 526efaf372..cc0ac1d267 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -94,6 +94,7 @@ def put_async( token_ids: np.ndarray, slot_mapping: np.ndarray, token_mask: Optional[np.ndarray], + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: req = PutRequest(self.dp_client_id, @@ -101,6 +102,7 @@ def put_async( slot_mapping, token_mask if token_mask is not None else None, self._get_task_id(), + pp_rank, namespace) self.send_to_server.send_pyobj(req) return req.task_id @@ -109,12 +111,14 @@ def put_match( self, token_ids: np.ndarray, token_mask: Optional[np.ndarray], + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> Optional[Tuple[int, np.ndarray]]: req = PutMatchRequest(self.dp_client_id, token_ids, token_mask if token_mask is not None else None, self._get_task_id(), + pp_rank, namespace) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() @@ -127,9 +131,10 @@ def put_match( def prefetch_async( self, token_ids: np.ndarray, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: - req = PrefetchRequest(self.dp_client_id, token_ids, self._get_task_id(), namespace) + req = PrefetchRequest(self.dp_client_id, token_ids, self._get_task_id(), pp_rank, namespace) self.send_to_server.send_pyobj(req) return req.task_id @@ -139,6 +144,7 @@ def get_async( slot_mapping: np.ndarray, token_mask: Optional[np.ndarray], layer_granularity: int, + pp_rank: int = 0, namespace: Optional[List[str]] = None, ) -> int: req = GetRequest(self.dp_client_id, @@ -147,6 +153,7 @@ def get_async( token_mask if token_mask is not None else None, self._get_task_id(), layer_granularity, + pp_rank, namespace) self.send_to_server.send_pyobj(req) return req.task_id @@ -156,6 +163,7 @@ def get_match( token_ids: np.ndarray, token_mask: Optional[np.ndarray], layer_granularity: int, + pp_rank: int = 0, cpu_only: bool = False, namespace: Optional[List[str]] = None, ) -> Optional[Tuple[int, np.ndarray]]: @@ -165,6 +173,7 @@ def get_match( layer_granularity, cpu_only, self._get_task_id(), + pp_rank, namespace) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() @@ -239,7 +248,8 @@ class KVTPClient: def __init__( self, gpu_register_port: str, - dp_client_id: int, + dp_rank: int, + pp_rank: int, device_id: int, ): # Init inter-process communication @@ -248,11 +258,39 @@ def __init__( context, zmq.SocketType.PUSH, gpu_register_port, False ) - self.dp_client_id = dp_client_id + self.dp_rank = dp_rank + self.pp_rank = pp_rank self.device_id = device_id - flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized! " - f"(gpu_register_port={gpu_register_port})") + flexkv_logger.info(f"KVTPClient {device_id} of DP {self.dp_rank} Initialized! " + f"(gpu_register_port={gpu_register_port}, dp_rank={dp_rank}, pp_rank={pp_rank})") + + def set_slot_mapping(self, task_id: int, slot_mapping: np.ndarray) -> None: + """Send set_slot_mapping message to TransferManagerOnRemote via existing ZMQ channel. + + Reuses the same PUSH socket (send_to_server) that connects to + TransferManagerOnRemote's command_socket — no separate IPC socket needed. + """ + message = { + 'type': 'set_slot_mapping', + 'task_id': task_id, + 'slot_mapping': slot_mapping, + } + try: + self.send_to_server.send_pyobj(message, flags=zmq.NOBLOCK) + flexkv_logger.debug( + f"KVTPClient {self.device_id}: set_slot_mapping sent for task_id={task_id}" + ) + except zmq.Again: + flexkv_logger.warning( + f"KVTPClient {self.device_id}: zmq.Again when sending set_slot_mapping, " + f"retrying with blocking send..." + ) + self.send_to_server.send_pyobj(message) + flexkv_logger.info( + f"KVTPClient {self.device_id}: set_slot_mapping sent (blocking retry) " + f"for task_id={task_id}" + ) def register_to_server( self, @@ -281,7 +319,8 @@ def register_to_server( indexer_handles.append(TensorSharedHandle(tensor, device_id)) register_req = RegisterTPClientRequest( - self.dp_client_id, + self.dp_rank, + self.pp_rank, device_id, handles, kv_layout, @@ -293,7 +332,7 @@ def register_to_server( self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) flexkv_logger.info( f"KVTPClient {device_id}: registration message sent " - f"(dp_client_id={self.dp_client_id}, num_kv_caches={len(kv_caches)})") + f"(dp_rank={self.dp_rank}, num_kv_caches={len(kv_caches)})") except zmq.Again: flexkv_logger.error( f"KVTPClient {device_id}: zmq.Again when sending registration " diff --git a/flexkv/server/request.py b/flexkv/server/request.py index ace0658ab7..f1e22fdf62 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -18,7 +18,8 @@ class RegisterDPClientRequest: @dataclass class RegisterTPClientRequest: - dp_client_id: int + dp_rank: int + pp_rank: int device_id: int handles: List[TensorSharedHandle] gpu_layout: KVCacheLayout @@ -37,6 +38,7 @@ class PutRequest: slot_mapping: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @@ -48,6 +50,7 @@ class GetRequest: token_mask: Optional[np.ndarray] task_id: int = -1 layer_granularity: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -55,6 +58,7 @@ class PrefetchRequest: dp_client_id: int token_ids: np.ndarray task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -63,6 +67,7 @@ class PutMatchRequest: token_ids: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -73,6 +78,7 @@ class GetMatchRequest: layer_granularity: int cpu_only: bool = False task_id: int = -1 + pp_rank: int = 0 namespace: Optional[List[str]] = None @dataclass @@ -131,3 +137,4 @@ class ShutdownRequest: @dataclass class CheckRunningRequest: dp_client_id: int + diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 60ca78d5af..208259c40e 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -376,7 +376,8 @@ def _handle_get_request(self, req: GetRequest) -> None: slot_mapping=req.slot_mapping, token_mask=req.token_mask, layer_granularity=req.layer_granularity, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, namespace=req.namespace, ) @@ -386,7 +387,8 @@ def _handle_put_request(self, req: PutRequest) -> None: token_ids=req.token_ids, slot_mapping=req.slot_mapping, token_mask=req.token_mask, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, task_id=req.task_id, namespace=req.namespace, ) @@ -397,7 +399,8 @@ def _handle_get_match_request(self, req: GetMatchRequest) -> None: token_ids=req.token_ids, token_mask=req.token_mask, layer_granularity=req.layer_granularity, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, cpu_only=req.cpu_only, task_id=req.task_id, namespace=req.namespace, @@ -411,7 +414,8 @@ def _handle_put_match_request(self, req: PutMatchRequest) -> None: req_id, mask = self.kv_task_engine.put_match( token_ids=req.token_ids, token_mask=req.token_mask, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, task_id=req.task_id, namespace=req.namespace, ) @@ -423,7 +427,8 @@ def _handle_prefetch_request(self, req: PrefetchRequest) -> None: """Handle Prefetch request""" task_id = self.kv_task_engine.prefetch_async( token_ids=req.token_ids, - dp_id=req.dp_client_id, + dp_rank=req.dp_client_id, + pp_rank=req.pp_rank, task_id=req.task_id, namespace=req.namespace, ) diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 195662c121..46eb8824ec 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -38,10 +38,10 @@ def __init__(self, if self._cache_config.enable_cpu: self._cpu_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_cpu_blocks, tokens_per_block=self._cache_config.tokens_per_block, - num_head=self._model_config.num_kv_heads, + num_head=self._model_config.num_kv_heads_per_node, head_size=self._model_config.head_size, is_mla=self._model_config.use_mla ) @@ -56,7 +56,7 @@ def __init__(self, # tokens_per_block is 1 (one indexer entry per page). indexer_cpu_layout = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.cpu_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_cpu_blocks, tokens_per_block=1, num_head=self._indexer_config.num_kv_heads, @@ -75,10 +75,10 @@ def __init__(self, raise ValueError(f"SSD layout type must be the same as CPU layout type: {self._cpu_layout.type}") self._ssd_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.ssd_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_ssd_blocks, tokens_per_block=self._cache_config.tokens_per_block, - num_head=self._model_config.num_kv_heads, + num_head=self._model_config.num_kv_heads_per_node, head_size=self._model_config.head_size, is_mla=self._model_config.use_mla ) @@ -92,7 +92,7 @@ def __init__(self, if self._indexer_config is not None: indexer_ssd_layout = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.ssd_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_ssd_blocks, tokens_per_block=1, num_head=self._indexer_config.num_kv_heads, @@ -113,10 +113,10 @@ def __init__(self, raise ValueError(f"Remote layout type must be the same as CPU layout type: {self._cpu_layout.type}") self._remote_layout: Optional[KVCacheLayout] = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.remote_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_remote_blocks, tokens_per_block=self._cache_config.tokens_per_block, - num_head=self._model_config.num_kv_heads, + num_head=self._model_config.num_kv_heads_per_node, head_size=self._model_config.head_size, is_mla=self._model_config.use_mla ) @@ -130,7 +130,7 @@ def __init__(self, if self._indexer_config is not None: indexer_remote_layout = KVCacheLayout( type=GLOBAL_CONFIG_FROM_ENV.remote_layout_type, - num_layer=self._model_config.num_layers, + num_layer=self._model_config.num_layers_per_pp_stage, num_block=self._cache_config.num_remote_blocks, tokens_per_block=1, num_head=self._indexer_config.num_kv_heads, diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py index df2b1dc25d..e3161d9cf1 100644 --- a/flexkv/transfer/layerwise.py +++ b/flexkv/transfer/layerwise.py @@ -19,36 +19,26 @@ from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister -def build_layerwise_eventfd_socket_path(model_config: ModelConfig) -> str: +def build_layerwise_eventfd_socket_path( + pp_rank: int, + dp_rank: int, + pp_size: int = 1, + dp_size: int = 1, +) -> str: """Construct the LayerwiseWorker's UDS socket path. Disambiguated by ``(pp_rank, dp_rank)`` so multiple PP stages and DP replicas on the same host each get their own endpoint. - - We deliberately do NOT embed ``node_rank`` in the path: Unix domain - sockets are kernel-local, so two FlexKV instances on different - physical hosts cannot collide even when ``/tmp`` happens to be on a - shared filesystem (NFS and friends propagate the inode, not the - socket endpoint). Deployments that stack multiple containers on one - host with a shared ``/tmp`` should disambiguate via the - ``FLEXKV_LAYERWISE_EVENTFD_SOCKET`` env var (e.g. embed ``$HOSTNAME`` - or the container id in the base path). - - Must stay in sync with the sglang-side consumer at - ``sglang.srt.mem_cache.storage.flexkv.flexkv_connector``, which - imports this helper directly so the two ends cannot drift. Both - sides derive the path from the same ``ModelConfig`` fields, so no - env-var plumbing between processes is required. """ base = os.environ.get( 'FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock', ) suffix = "" - if model_config.pp_size > 1: - suffix += f"_pp{model_config.pp_rank}" - if model_config.dp_size > 1: - suffix += f"_dp{model_config.dp_rank}" + if pp_size > 1: + suffix += f"_pp{pp_rank}" + if dp_size > 1: + suffix += f"_dp{dp_rank}" if not suffix: return base root, ext = os.path.splitext(base) @@ -87,11 +77,6 @@ def __init__(self, ssd_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, - dp_group_id: int, - pp_rank: int, - pp_size: int, - dp_size: int, - dp_rank: int, layerwise_eventfd_socket: str, num_blocks_per_file: int, use_ce_transfer_h2d: bool = False, @@ -99,7 +84,6 @@ def __init__(self, h2d_cta_num: int = 4, d2h_cta_num: int = 4, enable_eventfd: bool = True, - is_nsa_cp: bool = False, indexer_gpu_blocks: Optional[List[List[TensorSharedHandle]]] = None, indexer_cpu_blocks: Optional[Union[torch.Tensor, HugePageTensorHandle]] = None, indexer_gpu_kv_layouts: Optional[List[KVCacheLayout]] = None, @@ -110,8 +94,7 @@ def __init__(self, indexer_num_blocks_per_file: int = 0) -> None: flexkv_logger.debug( f"[LayerwiseWorker] __init__ started: worker_id={worker_id}, " - f"tp_group_size={tp_group_size}, dp_group_id={dp_group_id}, " - f"pp_rank={pp_rank}, pp_size={pp_size}, " + f"tp_group_size={tp_group_size}, " f"enable_eventfd={enable_eventfd}, " f"num_gpu_blocks={[len(b) for b in gpu_blocks]}") super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) @@ -129,17 +112,11 @@ def __init__(self, self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size - self.pp_rank = pp_rank - self.pp_size = pp_size if pp_size > 0 else 1 - self.dp_group_id = dp_group_id - self.dp_size = dp_size if dp_size > 0 else 1 - self.dp_rank = dp_rank # Pre-computed UDS socket path. Both ends (this worker and the # sglang connector) derive the path from the same ModelConfig # fields (pp_rank / dp_rank / node_rank / is_multinode_tp), so no # env-var plumbing between processes is required. self.layerwise_eventfd_socket = layerwise_eventfd_socket - self.is_nsa_cp = is_nsa_cp # initialize GPU storage self.num_layers = gpu_kv_layouts[0].num_layer @@ -183,17 +160,11 @@ def __init__(self, self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize # TP-divided CPU strides (for CPU->GPU, each rank reads its own portion) - if self.is_nsa_cp: - # CP: no head partitioning, every rank gets the full KV cache - cpu_kv_layout_tp = cpu_kv_layout - self.cpu_tp_stride_in_bytes = 0 + if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: + cpu_kv_layout_tp = cpu_kv_layout.div_head(self.tp_group_size) else: - # TP: partition by heads, each rank reads a different head slice - if cpu_kv_layout.type == KVCacheLayoutType.BLOCKFIRST and not self.is_mla: - cpu_kv_layout_tp = cpu_kv_layout.div_head(self.tp_group_size) - else: - cpu_kv_layout_tp = cpu_kv_layout - self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size + cpu_kv_layout_tp = cpu_kv_layout + self.cpu_tp_stride_in_bytes = self.cpu_block_stride_in_bytes // self.tp_group_size self.h2d_cpu_kv_stride_in_bytes = cpu_kv_layout_tp.get_kv_stride() * self.dtype.itemsize self.h2d_cpu_layer_stride_in_bytes = cpu_kv_layout_tp.get_layer_stride() * self.dtype.itemsize @@ -330,7 +301,7 @@ def __init__(self, self.layerwise_transfer_group = LayerwiseTransferGroup( self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, - dp_group_id, self.num_layers, + self.num_layers, gpu_kv_strides_tensor, gpu_block_strides_tensor, gpu_layer_strides_tensor, gpu_chunk_sizes_tensor, GLOBAL_CONFIG_FROM_ENV.iouring_entries, @@ -348,15 +319,6 @@ def _receive_eventfds_from_sglang(self, tp_group_size: int, """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" socket_path = self.layerwise_eventfd_socket - rank_parts = [] - if int(self.tp_group_size) > 1: - rank_parts.append("tp_rank=0") - if int(self.pp_size) > 1: - rank_parts.append(f"pp_rank={int(self.pp_rank)}") - if int(self.dp_size) > 1: - rank_parts.append(f"dp_rank={int(self.dp_rank)}") - rank_label = f" [{', '.join(rank_parts)}]" if rank_parts else "" - def cleanup_socket(): try: if os.path.exists(socket_path): @@ -374,11 +336,11 @@ def cleanup_socket(): server_sock.listen(tp_group_size * 3) os.chmod(socket_path, 0o777) flexkv_logger.info( - f"[LayerwiseWorker] Eventfd server created{rank_label}: " + f"[LayerwiseWorker] Eventfd server created: " f"socket={socket_path}, waiting for {tp_group_size} connection(s)") except Exception as e: flexkv_logger.error( - f"[LayerwiseWorker] Failed to bind/listen on {socket_path}{rank_label}: {e}") + f"[LayerwiseWorker] Failed to bind/listen on {socket_path}: {e}") server_sock.close() return torch.empty(0, dtype=torch.int32) @@ -397,7 +359,7 @@ def cleanup_socket(): while len(all_rank_eventfds) < tp_group_size: if time.time() > total_deadline: flexkv_logger.error( - f"[LayerwiseWorker] Deadline exceeded on {socket_path}{rank_label}, " + f"[LayerwiseWorker] Deadline exceeded on {socket_path}, " f"received {len(all_rank_eventfds)}/{tp_group_size} ranks") break @@ -410,41 +372,34 @@ def cleanup_socket(): flexkv_logger.info( f"[LayerwiseWorker] Accepted connection " f"{conn_idx} (registered {len(all_rank_eventfds)}/{tp_group_size}) " - f"on {socket_path}{rank_label}") + f"on {socket_path}") except socket.timeout: flexkv_logger.warning( - f"[LayerwiseWorker] Timeout waiting for connection on {socket_path}{rank_label}, " + f"[LayerwiseWorker] Timeout waiting for connection on {socket_path}, " f"registered {len(all_rank_eventfds)}/{tp_group_size}, retrying...") continue try: with conn: - # Accept both 16-byte (legacy: tp_rank, tp_size, num_layers, num_counters) - # and 24-byte (new: tp_rank, tp_size, cp_rank, cp_size, num_layers, num_counters) - metadata = conn.recv(24) + # Receive 16-byte metadata: tp_rank_per_node, tp_size_per_node, + # num_layers, num_counters + metadata = conn.recv(16) if len(metadata) < 16: flexkv_logger.error( - f"[LayerwiseWorker] Incomplete metadata on {socket_path}{rank_label}: " - f"{len(metadata)} bytes") + f"[LayerwiseWorker] Incomplete metadata on {socket_path}: " + f"expected 16 bytes, got {len(metadata)}") continue - if len(metadata) >= 24: - tp_rank, _, cp_rank, cp_size, recv_num_layers, recv_num_counters = \ - struct.unpack("iiiiii", metadata[:24]) - else: - tp_rank, _, recv_num_layers, recv_num_counters = \ - struct.unpack("iiii", metadata[:16]) - cp_rank, cp_size = 0, 1 - - # Use cp_rank as the connection key when CP is active, - # otherwise use tp_rank - rank_key = cp_rank if cp_size > 1 else tp_rank + rank_key, tp_size_per_node_recv, recv_num_layers, recv_num_counters = \ + struct.unpack("iiii", metadata[:16]) + if not all_rank_eventfds: num_layers, num_counters = recv_num_layers, recv_num_counters flexkv_logger.debug( f"[LayerwiseWorker] Connection {conn_idx}: " - f"tp_rank={tp_rank}, cp_rank={cp_rank}, cp_size={cp_size}, " + f"tp_rank_per_node={rank_key}, " + f"tp_size_per_node={tp_size_per_node_recv}, " f"num_layers={recv_num_layers}, " f"num_counters={recv_num_counters}") @@ -455,7 +410,7 @@ def cleanup_socket(): rank_eventfds[counter_id] = fds flexkv_logger.debug( f"[LayerwiseWorker] Received counter_id={counter_id}, " - f"num_fds={len(fds)} from rank_key={rank_key}") + f"num_fds={len(fds)} from tp_rank_per_node={rank_key}") all_rank_eventfds[rank_key] = rank_eventfds # Send ACK to client so it knows the fds were received @@ -464,8 +419,8 @@ def cleanup_socket(): except Exception: pass flexkv_logger.info( - f"[LayerwiseWorker] Received all eventfds from rank_key={rank_key} " - f"(tp_rank={tp_rank}, cp_rank={cp_rank}) on {socket_path}") + f"[LayerwiseWorker] Received all eventfds from tp_rank_per_node={rank_key} " + f"on {socket_path}") except Exception as e: # Send NACK so client knows to retry try: @@ -474,19 +429,19 @@ def cleanup_socket(): pass flexkv_logger.warning( f"[LayerwiseWorker] Failed to receive eventfds from connection {conn_idx} " - f"on {socket_path}{rank_label}: {e}. " + f"on {socket_path}: {e}. " f"Client will retry, continuing accept loop...") continue except Exception as e: flexkv_logger.error( - f"[LayerwiseWorker] Fatal error in accept loop on {socket_path}{rank_label}: {e}") + f"[LayerwiseWorker] Fatal error in accept loop on {socket_path}: {e}") finally: server_sock.close() cleanup_socket() if not all_rank_eventfds: flexkv_logger.warning( - f"[LayerwiseWorker] No connections received on {socket_path}{rank_label}") + f"[LayerwiseWorker] No connections received on {socket_path}") return torch.empty(0, dtype=torch.int32) # Build tensor: [num_counters, tp_size, num_layers] @@ -498,9 +453,9 @@ def cleanup_socket(): tensor = torch.tensor(eventfds_list, dtype=torch.int32) flexkv_logger.info( - f"[LayerwiseWorker] Eventfd setup complete{rank_label}: " + f"[LayerwiseWorker] Eventfd setup complete: " f"socket={socket_path}, tensor_shape={tensor.shape}, " - f"counters={num_counters}, tp_size={tp_group_size}, layers={num_layers}" + f"counters={num_counters}, tp_size_per_rank={tp_group_size}, layers={num_layers}" ) return tensor @@ -621,7 +576,7 @@ def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> bool: kv_dim = 2 if not self.is_mla else 1 transfer_size = self.cpu_chunk_size_in_bytes * self.num_layers * num_h2d_blocks * kv_dim - if self.is_nsa_cp or self.is_mla: + if self.is_mla: transfer_size *= self.tp_group_size self._log_transfer_performance( diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index e4779932f6..84fe86b786 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -18,7 +18,7 @@ import multiprocessing as mp import selectors import os -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import contextlib import nvtx @@ -27,7 +27,7 @@ from flexkv.common.debug import flexkv_logger from flexkv.common.storage import StorageHandle -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp +from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferType, CompletedOp, WorkerKey from flexkv.common.transfer import get_nvtx_range_color from flexkv.transfer.scheduler import TransferScheduler from flexkv.transfer.worker import ( @@ -87,13 +87,13 @@ def free_op_from_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: class TransferEngine: def __init__(self, - gpu_handles: Dict[int, List[StorageHandle]], + gpu_handles: Dict[WorkerKey, List[StorageHandle]], model_config: ModelConfig, cache_config: CacheConfig, cpu_handle: Optional[StorageHandle] = None, ssd_handle: Optional[StorageHandle] = None, remote_handle: Optional[StorageHandle] = None, - indexer_gpu_handles: Optional[Dict[int, List[StorageHandle]]] = None, + indexer_gpu_handles: Optional[Dict[WorkerKey, List[StorageHandle]]] = None, indexer_cpu_handle: Optional[StorageHandle] = None, indexer_ssd_handle: Optional[StorageHandle] = None, indexer_remote_handle: Optional[StorageHandle] = None): @@ -101,7 +101,7 @@ def __init__(self, Initialize transfer engine Args: - gpu_handles: Dict mapping dp_client_id -> list of GPU handles for that TP group + gpu_handles: Dict mapping WorkerKey(dp_rank, pp_rank) -> list of GPU handles for that TP group cpu_handle: CPU handle ssd_handle: Optional SSD handle remote_handle: Optional remote handle @@ -123,7 +123,7 @@ def __init__(self, # Create shutdown pipe for zero-latency selector self.shutdown_read_fd, self.shutdown_write_fd = os.pipe() - self.gpu_handle_groups = gpu_handles # dp_client_id -> list of GPU handles for that TP group + self.gpu_handle_groups = gpu_handles # WorkerKey -> list of GPU handles for that TP group self._cpu_handle = cpu_handle self._ssd_handle = ssd_handle self._remote_handle = remote_handle @@ -143,7 +143,7 @@ def __init__(self, self.op_id_to_nvtx_range: Dict[int, str] = {} # self.dp_size = model_config.dp_size - self.tp_size = model_config.tp_size + self.tp_size_per_node = model_config.tp_size_per_node self.num_gpu_groups = len(self.gpu_handle_groups) self._running = False self._has_indexer = False @@ -151,10 +151,14 @@ def __init__(self, self._indexer_op_to_parent_op: Dict[int, int] = {} self._indexer_op_map: Dict[int, TransferOp] = {} + # Same-node PP layerwise fan-out: replica op_id → parent op_id + self._pp_replica_to_parent_op: Dict[int, int] = {} + self._pp_replica_op_map: Dict[int, TransferOp] = {} + def _init_workers(self) -> None: if self._running: return - self._worker_map: Dict[TransferType, Union[WorkerHandle, List[WorkerHandle]]] = {} + self._worker_map: Dict[TransferType, Union[WorkerHandle, Dict[WorkerKey, WorkerHandle]]] = {} assert self._cpu_handle is not None _enable_layerwise = GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer @@ -163,9 +167,9 @@ def _init_workers(self) -> None: # H2D worker if not _enable_layerwise: - if self.tp_size == 1: - self.h2d_workers: List[WorkerHandle] = [ - GPUCPUTransferWorker.create_worker( + if self.tp_size_per_node == 1: + self.h2d_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -180,11 +184,11 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, gpu_handles in self.gpu_handle_groups.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } else: - self.h2d_workers = [ - tpGPUCPUTransferWorker.create_worker( + self.h2d_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -193,23 +197,20 @@ def _init_workers(self) -> None: gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - is_nsa_cp=self.model_config.is_nsa_cp, - cp_size=self.model_config.cp_size, + tp_group_size=self.tp_size_per_node, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, gpu_handles in self.gpu_handle_groups.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } self._worker_map[TransferType.H2D] = self.h2d_workers # D2H worker - if self.tp_size == 1: - self.d2h_workers: List[WorkerHandle] = [ - GPUCPUTransferWorker.create_worker( + if self.tp_size_per_node == 1: + self.d2h_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -224,11 +225,11 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, gpu_handles in self.gpu_handle_groups.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } else: - self.d2h_workers = [ - tpGPUCPUTransferWorker.create_worker( + self.d2h_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -237,17 +238,14 @@ def _init_workers(self) -> None: gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=gpu_handles[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - is_nsa_cp=self.model_config.is_nsa_cp, - cp_size=self.model_config.cp_size, + tp_group_size=self.tp_size_per_node, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, gpu_handles in self.gpu_handle_groups.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } self._worker_map[TransferType.D2H] = self.d2h_workers if self._ssd_handle is not None and self._cpu_handle is not None: @@ -308,9 +306,9 @@ def _init_workers(self) -> None: self._worker_map[TransferType.H2REMOTE] = self.remotecpu_write_worker self._worker_map[TransferType.REMOTE2H] = self.remotecpu_read_worker if self.cache_config.enable_gds: - if self.tp_size == 1: - self.gds_workers = [ - GDSTransferWorker.create_worker( + if self.tp_size_per_node == 1: + self.gds_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -322,11 +320,11 @@ def _init_workers(self) -> None: dtype=self._ssd_handle.dtype, gpu_device_id=gpu_handles[0].gpu_device_id, ) - for _, gpu_handles in self.gpu_handle_groups.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } else: - self.gds_workers = [ - tpGDSTransferWorker.create_worker( + self.gds_workers = { + worker_key: tpGDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self.finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -336,25 +334,16 @@ def _init_workers(self) -> None: gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], ssd_kv_layout=self._ssd_handle.kv_layout, dtype=self._ssd_handle.dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, + tp_group_id=self.tp_size_per_node, ) - for dp_client_id, gpu_handles in self.gpu_handle_groups.items() - ] + for worker_key, gpu_handles in self.gpu_handle_groups.items() + } self._worker_map[TransferType.DISK2D] = self.gds_workers self._worker_map[TransferType.D2DISK] = self.gds_workers if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: ssd_files = {} if self._ssd_handle is None else self._ssd_handle.get_file_list() ssd_kv_layout = None if self._ssd_handle is None else self._ssd_handle.kv_layout num_blocks_per_file = 0 if self._ssd_handle is None else self._ssd_handle.num_blocks_per_file - _is_nsa_cp = self.model_config.is_nsa_cp - _cp_size = self.model_config.cp_size - # For CP, each CP rank connects via eventfd; for TP, each TP rank connects. - _eventfd_group_size = _cp_size if _is_nsa_cp and _cp_size > 1 else self.tp_size - - _layerwise_eventfd_socket = build_layerwise_eventfd_socket_path( - self.model_config - ) # Prepare indexer handles for fused layerwise transfer has_indexer_for_layerwise = ( @@ -362,12 +351,18 @@ def _init_workers(self) -> None: self._indexer_cpu_handle is not None ) - self.layerwise_workers = [] - for dp_client_id, gpu_handles in self.gpu_handle_groups.items(): - # Resolve indexer handles for this dp_client_id + self.layerwise_workers: Dict[WorkerKey, WorkerHandle] = {} + for worker_key, gpu_handles in self.gpu_handle_groups.items(): + _layerwise_eventfd_socket = build_layerwise_eventfd_socket_path( + pp_rank=worker_key.pp_rank, + dp_rank=worker_key.dp_rank, + pp_size=self.model_config.pp_size, + dp_size=self.model_config.dp_size, + ) + # Resolve indexer handles for this WorkerKey idx_handles = None if has_indexer_for_layerwise: - idx_handles = self._indexer_gpu_handles.get(dp_client_id) + idx_handles = self._indexer_gpu_handles.get(worker_key) worker = LayerwiseTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -380,19 +375,13 @@ def _init_workers(self) -> None: cpu_kv_layout=self._cpu_handle.kv_layout, ssd_kv_layout=ssd_kv_layout, dtype=gpu_handles[0].dtype, - tp_group_size=_eventfd_group_size, - dp_group_id=dp_client_id, - pp_rank=self.model_config.pp_rank, - pp_size=self.model_config.pp_size, - dp_size=self.model_config.dp_size, - dp_rank=dp_client_id, + tp_group_size=self.tp_size_per_node, layerwise_eventfd_socket=_layerwise_eventfd_socket, num_blocks_per_file=num_blocks_per_file, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, h2d_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, d2h_cta_num=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, - is_nsa_cp=_is_nsa_cp, indexer_gpu_blocks=[h.get_tensor_handle_list() for h in idx_handles] if idx_handles else None, indexer_cpu_blocks=self._indexer_cpu_handle.get_worker_tensor() if idx_handles else None, indexer_gpu_kv_layouts=[h.kv_layout for h in idx_handles] if idx_handles else None, @@ -402,11 +391,11 @@ def _init_workers(self) -> None: indexer_ssd_kv_layout=self._indexer_ssd_handle.kv_layout if (idx_handles and self._indexer_ssd_handle) else None, indexer_num_blocks_per_file=self._indexer_ssd_handle.num_blocks_per_file if (idx_handles and self._indexer_ssd_handle) else 0, ) - self.layerwise_workers.append(worker) + self.layerwise_workers[worker_key] = worker flexkv_logger.debug( - f"[TransferEngine] Created layerwise worker for dp_client_id={dp_client_id}: " - f"tp_size={self.tp_size}, has_indexer={idx_handles is not None}, " + f"[TransferEngine] Created layerwise worker for {worker_key}: " + f"tp_size_per_node={self.tp_size_per_node}, has_indexer={idx_handles is not None}, " f"has_ssd={len(ssd_files) > 0}") self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers @@ -442,12 +431,12 @@ def _init_workers(self) -> None: if (self._indexer_gpu_handles is not None and self._indexer_cpu_handle is not None): self._indexer_finished_ops_queue = self.mp_ctx.Queue() - self._indexer_worker_map: Dict[TransferType, Union[WorkerHandle, List[WorkerHandle]]] = {} + self._indexer_worker_map: Dict[TransferType, Union[WorkerHandle, Dict[WorkerKey, WorkerHandle]]] = {} # H2D indexer worker if not _enable_layerwise: - if self.tp_size == 1: - self._indexer_h2d_workers = [ - GPUCPUTransferWorker.create_worker( + if self.tp_size_per_node == 1: + self._indexer_h2d_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -462,11 +451,11 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } else: - self._indexer_h2d_workers = [ - tpGPUCPUTransferWorker.create_worker( + self._indexer_h2d_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -475,23 +464,20 @@ def _init_workers(self) -> None: gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], cpu_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=indexer_gpu_handles_list[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - is_nsa_cp=self.model_config.is_nsa_cp, - cp_size=self.model_config.cp_size, + tp_group_size=self.tp_size_per_node, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } self._indexer_worker_map[TransferType.H2D] = self._indexer_h2d_workers # D2H indexer worker - if self.tp_size == 1: - self._indexer_d2h_workers = [ - GPUCPUTransferWorker.create_worker( + if self.tp_size_per_node == 1: + self._indexer_d2h_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -506,11 +492,11 @@ def _init_workers(self) -> None: transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } else: - self._indexer_d2h_workers = [ - tpGPUCPUTransferWorker.create_worker( + self._indexer_d2h_workers = { + worker_key: tpGPUCPUTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -519,17 +505,14 @@ def _init_workers(self) -> None: gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], cpu_kv_layout=self._indexer_cpu_handle.kv_layout, dtype=indexer_gpu_handles_list[0].dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, - is_nsa_cp=self.model_config.is_nsa_cp, - cp_size=self.model_config.cp_size, + tp_group_size=self.tp_size_per_node, use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, transfer_num_cta_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_h2d, transfer_num_cta_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_num_cta_d2h, ) - for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } self._indexer_worker_map[TransferType.D2H] = self._indexer_d2h_workers if self._indexer_ssd_handle is not None and self._indexer_cpu_handle is not None: # H2DISK indexer worker @@ -590,7 +573,7 @@ def _init_workers(self) -> None: self._indexer_worker_map[TransferType.REMOTE2H] = self._indexer_remote2h_worker flexkv_logger.info("TransferEngine: indexer Remote workers initialized") if self.cache_config.enable_gds and self._indexer_ssd_handle is not None: - if self.tp_size == 1: + if self.tp_size_per_node == 1: self._indexer_gds_workers = [ GDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, @@ -618,10 +601,9 @@ def _init_workers(self) -> None: gpu_kv_layouts=[h.kv_layout for h in indexer_gpu_handles_list], ssd_kv_layout=self._indexer_ssd_handle.kv_layout, dtype=self._indexer_ssd_handle.dtype, - tp_group_size=self.tp_size, - dp_group_id=dp_client_id, + tp_group_size=self.tp_size_per_node, ) - for dp_client_id, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() ] self._indexer_worker_map[TransferType.DISK2D] = self._indexer_gds_workers self._indexer_worker_map[TransferType.D2DISK] = self._indexer_gds_workers @@ -662,27 +644,27 @@ def _init_workers(self) -> None: raise ValueError("No workers initialized, please check the config") # Wait for all main KV workers to ready for transfer_type, worker in self._worker_map.items(): - if isinstance(worker, List): - for w in worker: - flexkv_logger.info(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") + if isinstance(worker, dict): + for w in worker.values(): + flexkv_logger.debug(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") w.ready_event.wait() - flexkv_logger.info(f"{transfer_type.name} worker {w.worker_id} is ready") + flexkv_logger.debug(f"{transfer_type.name} worker {w.worker_id} is ready") else: - flexkv_logger.info(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") + flexkv_logger.debug(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") worker.ready_event.wait() - flexkv_logger.info(f"{transfer_type.name} worker {worker.worker_id} is ready") + flexkv_logger.debug(f"{transfer_type.name} worker {worker.worker_id} is ready") # Wait for all indexer workers to ready if self._has_indexer: for transfer_type, worker in self._indexer_worker_map.items(): - if isinstance(worker, List): - for w in worker: - flexkv_logger.info(f"waiting for indexer {transfer_type.name} worker {w.worker_id} to ready") + if isinstance(worker, dict): + for w in worker.values(): + flexkv_logger.debug(f"waiting for indexer {transfer_type.name} worker {w.worker_id} to ready") w.ready_event.wait() - flexkv_logger.info(f"indexer {transfer_type.name} worker {w.worker_id} is ready") + flexkv_logger.debug(f"indexer {transfer_type.name} worker {w.worker_id} is ready") else: - flexkv_logger.info(f"waiting for indexer {transfer_type.name} worker {worker.worker_id} to ready") + flexkv_logger.debug(f"waiting for indexer {transfer_type.name} worker {worker.worker_id} to ready") worker.ready_event.wait() - flexkv_logger.info(f"indexer {transfer_type.name} worker {worker.worker_id} is ready") + flexkv_logger.debug(f"indexer {transfer_type.name} worker {worker.worker_id} is ready") # Startup assertions: verify layerwise mode worker map consistency if _enable_layerwise: assert TransferType.H2D not in self._worker_map, \ @@ -760,10 +742,28 @@ def _scheduler_loop(self) -> None: while True: try: op_id = self.finished_ops_queue.get_nowait() - op = self.op_id_to_op[op_id] - op.pending_count -= 1 - if op.pending_count == 0: - self._finalize_op(op, finished_ops) + # Check if this is a PP-replica op (same-node PP fan-out) + if op_id in self._pp_replica_to_parent_op: + # PP-replica op: decrement parent's pending_count + replica_op = self._pp_replica_op_map.pop(op_id) + parent_op_id = self._pp_replica_to_parent_op.pop(op_id) + parent_op = self.op_id_to_op[parent_op_id] + parent_op.pending_count -= 1 + # Clean up replica from op_id_to_op and NVTX + del self.op_id_to_op[op_id] + if op_id in self.op_id_to_nvtx_range: + nvtx.end_range(self.op_id_to_nvtx_range[op_id]) + self.op_id_to_nvtx_range.pop(op_id) + if parent_op.pending_count == 0: + self._finalize_op(parent_op, finished_ops) + flexkv_logger.debug( + f"[TransferEngine] PP replica op {op_id} completed, " + f"parent op {parent_op_id} pending_count={parent_op.pending_count}") + else: + op = self.op_id_to_op[op_id] + op.pending_count -= 1 + if op.pending_count == 0: + self._finalize_op(op, finished_ops) except queue.Empty: break nvtx.end_range(nvtx_r2) @@ -834,7 +834,7 @@ def _finalize_op(self, op: TransferOp, finished_ops: List[TransferOp]) -> None: free_op_from_buffer(op, self.pin_buffer) # Compute transfer metrics for this completed op num_blocks = len(op.src_block_ids) if op.src_block_ids is not None else 0 - num_bytes = num_blocks * self.cache_config.tokens_per_block * self.model_config.token_size_in_bytes + num_bytes = num_blocks * self.cache_config.tokens_per_block * self.model_config.token_size_in_bytes_per_pp_stage transfer_type_str = op.transfer_type.value if op.transfer_type != TransferType.VIRTUAL else None self.completed_queue.put(CompletedOp( graph_id=op.graph_id, @@ -846,6 +846,85 @@ def _finalize_op(self, op: TransferOp, finished_ops: List[TransferOp]) -> None: finished_ops.append(op) del self.op_id_to_op[op.op_id] + def _assign_layerwise_op_to_workers(self, op: TransferOp) -> None: + """Fan-out a LAYERWISE op to all PP-stage layerwise workers on the same dp_rank. + + In cross-node PP, the remote TransferManagerOnRemote handles this by + rebinding WorkerKey. In same-node PP there is no remote TM, so we + replicate the op here for every PP-stage worker under the same dp_rank. + + Replicas are tracked via ``_pp_replica_to_parent_op`` so that their + completion decrements the parent op's ``pending_count`` (identical to + how indexer ops are tracked). + """ + from flexkv.common.transfer import LayerwiseTransferOp + assert isinstance(op, LayerwiseTransferOp) + + worker_map = self._worker_map[TransferType.LAYERWISE] + assert isinstance(worker_map, dict), \ + "LAYERWISE worker map must be a Dict[WorkerKey, WorkerHandle]" + + # Find all layerwise workers sharing the same dp_rank + sibling_keys = [wk for wk in worker_map if wk.dp_rank == op.dp_rank] + + if not sibling_keys: + raise ValueError( + f"No layerwise worker found for dp_rank={op.dp_rank}, pp_rank={op.pp_rank}") + + # Submit to the original pp_rank's worker + primary_key = WorkerKey(dp_rank=op.dp_rank, pp_rank=op.pp_rank) + if primary_key in worker_map: + worker_map[primary_key].submit_transfer(op) + else: + # Original worker not found — this shouldn't happen, but handle gracefully + raise ValueError( + f"No layerwise worker found for primary key {primary_key}") + + # If there's only one worker for this dp_rank, no fan-out needed + if len(sibling_keys) <= 1: + return + + # Create replicas for every other pp_rank under the same dp_rank + for wk in sibling_keys: + if wk == primary_key: + continue + + # Create a replica LayerwiseTransferOp with the target pp_rank + replica = LayerwiseTransferOp( + graph_id=op.graph_id, + src_block_ids_h2d=op.src_block_ids_h2d.copy(), + dst_block_ids_h2d=op.dst_block_ids_h2d.copy(), + src_block_ids_disk2h=op.src_block_ids_disk2h.copy(), + dst_block_ids_disk2h=op.dst_block_ids_disk2h.copy(), + layer_id=op.layer_id, + layer_granularity=op.layer_granularity, + dp_rank=op.dp_rank, + pp_rank=wk.pp_rank, + counter_id=op.counter_id, + indexer_src_block_ids=op.indexer_src_block_ids.copy(), + indexer_dst_block_ids=op.indexer_dst_block_ids.copy(), + ) + + # Track replica → parent so that completion decrements parent's pending_count + self._pp_replica_to_parent_op[replica.op_id] = op.op_id + self._pp_replica_op_map[replica.op_id] = replica + op.pending_count += 1 + + # Register in op_id_to_op so scheduler can find it on completion + self.op_id_to_op[replica.op_id] = replica + self.op_id_to_nvtx_range[replica.op_id] = nvtx.start_range( + f"schedule LAYERWISE_REPLICA op_id: {replica.op_id}, " + f"graph_id: {replica.graph_id}, pp_rank={wk.pp_rank}", + color=get_nvtx_range_color(replica.graph_id)) + + worker_map[wk].submit_transfer(replica) + + flexkv_logger.debug( + f"[TransferEngine] === Layerwise PP Replica Dispatched ===" + f"\n parent_op_id={op.op_id}, replica_op_id={replica.op_id}" + f"\n dp_rank={op.dp_rank}, pp_rank={wk.pp_rank}" + f"\n pending_count={op.pending_count}") + def _assign_op_to_worker(self, op: TransferOp) -> None: self.op_id_to_nvtx_range[op.op_id] = nvtx.start_range(f"schedule {op.transfer_type.name} " f"op_id: {op.op_id}, " @@ -858,6 +937,18 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: if op.transfer_type not in self._worker_map: raise ValueError(f"Unsupported transfer type: {op.transfer_type}") + # --- Same-node PP fan-out for LAYERWISE ops --- + # In cross-node PP, TransferManagerOnRemote rebinds WorkerKey so that + # PP1+ workers receive the op. In same-node PP, only one local + # TransferManager exists, so we fan-out here: for every layerwise + # worker that shares the same dp_rank but has a different pp_rank, we + # create a replica op with the correct pp_rank. + if op.transfer_type == TransferType.LAYERWISE: + self._assign_layerwise_op_to_workers(op) + return + + worker_key = WorkerKey(dp_rank=op.dp_rank, pp_rank=op.pp_rank) + if self._has_indexer and op.transfer_type in self._indexer_worker_map: # Indexer maps 1:1 with main KV blocks, use block_ids directly. src_page_ids = op.src_block_ids @@ -874,7 +965,8 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: dst_block_ids=dst_page_ids, layer_id=op.layer_id, layer_granularity=op.layer_granularity, - dp_id=op.dp_id, + dp_rank=op.dp_rank, + pp_rank=op.pp_rank, ) register_op_to_buffer(indexer_op, self.pin_buffer) self._indexer_op_to_parent_op[indexer_op.op_id] = op.op_id @@ -884,18 +976,18 @@ def _assign_op_to_worker(self, op: TransferOp) -> None: flexkv_logger.debug( f"[TransferEngine] === Indexer Op Dispatched (non-layerwise) ===" f"\n parent_op_id={op.op_id}, indexer_op_id={indexer_op.op_id}" - f"\n type={op.transfer_type.name}, dp_id={op.dp_id}" + f"\n type={op.transfer_type.name}, dp_rank={op.dp_rank}, pp_rank={op.pp_rank}" f"\n num_pages={num_pages}, pending_count={op.pending_count}") indexer_worker = self._indexer_worker_map[op.transfer_type] - if isinstance(indexer_worker, List): - indexer_worker[op.dp_id].submit_transfer(indexer_op) + if isinstance(indexer_worker, dict): + indexer_worker[worker_key].submit_transfer(indexer_op) else: indexer_worker.submit_transfer(indexer_op) worker = self._worker_map[op.transfer_type] - if isinstance(worker, List): - worker[op.dp_id].submit_transfer(op) + if isinstance(worker, dict): + worker[worker_key].submit_transfer(op) else: worker.submit_transfer(op) @@ -964,15 +1056,15 @@ def shutdown(self) -> None: # shutdown indexer workers first if self._has_indexer: for worker in self._indexer_worker_map.values(): - if isinstance(worker, List): - for w in worker: + if isinstance(worker, dict): + for w in worker.values(): w.shutdown() else: worker.shutdown() # shutdown main KV workers for worker in self._worker_map.values(): - if isinstance(worker, List): - for w in worker: + if isinstance(worker, dict): + for w in worker.values(): w.shutdown() else: worker.shutdown() diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 534a315c7f..bae3b7bcc9 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -419,9 +419,6 @@ def __init__(self, cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, - dp_group_id: int, - is_nsa_cp: bool = False, - cp_size: int = 1, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, transfer_num_cta_h2d: int = 4, @@ -440,12 +437,9 @@ def __init__(self, self.gpu_blocks = imported_gpu_blocks self.dtype = dtype # note this should be quantized data type self.is_mla = gpu_kv_layouts[0].is_mla - self.is_nsa_cp = is_nsa_cp - self.cp_size = cp_size self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size - self.dp_group_id = dp_group_id flexkv_logger.info(f"Pinning CPU Memory: {cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") cudaHostRegister(cpu_blocks) @@ -497,7 +491,6 @@ def __init__(self, gpu_block_ptrs_flat, num_tensors_per_gpu, cpu_blocks_ptr, - dp_group_id, self.num_layers, self.gpu_kv_strides_in_bytes, self.gpu_block_strides_in_bytes, @@ -551,7 +544,6 @@ def _transfer_impl(self, layer_id, layer_granularity, self.is_mla, - self.is_nsa_cp and self.cp_size > 1, ) @@ -1145,7 +1137,6 @@ def __init__( ssd_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, - dp_group_id: int, ) -> None: """ Initialize TP GDS Transfer Worker @@ -1161,7 +1152,6 @@ def __init__( ssd_kv_layout: Layout of SSD KV cache dtype: Data type tp_group_size: Size of tensor parallel group - dp_group_id: Data parallel group ID """ # Initialize base class first super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) @@ -1182,7 +1172,6 @@ def __init__( self.is_mla = gpu_kv_layouts[0].is_mla self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size - self.dp_group_id = dp_group_id # Layout information self.num_layers = gpu_kv_layouts[0].num_layer @@ -1205,7 +1194,7 @@ def __init__( # SSD layout calculations self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize - self.ssd_tp_stride_in_bytes = self.ssd_block_stride_in_bytes // self.tp_group_size if not self.is_mla else self.ssd_block_stride_in_bytes + self.ssd_tp_stride_in_bytes = self.ssd_block_stride_in_bytes // self.tp_size_per_node if not self.is_mla else self.ssd_block_stride_in_bytes # Resolve pointers in Python (where storage is valid); pass them to C++ so we avoid # "Tensor that doesn't have storage" when C++ calls .data_ptr() on tensors passed @@ -1224,7 +1213,6 @@ def __init__( gpu_block_ptrs_flat, num_tensors_per_gpu, ssd_files, - dp_group_id, self.num_layers, self.gpu_kv_strides_in_bytes, self.gpu_block_strides_in_bytes, diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 5b1a55ee10..276873a36b 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -19,7 +19,7 @@ import pickle import sys -from flexkv.common.transfer import TransferOpGraph, CompletedOp +from flexkv.common.transfer import TransferOpGraph, CompletedOp, WorkerKey from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle @@ -43,12 +43,12 @@ def __init__(self, # Multi-instance support: get instance_num from environment self.instance_num = GLOBAL_CONFIG_FROM_ENV.instance_num - # Calculate total expected GPUs across all instances - self.expected_gpus = self.instance_num * model_config.tp_size * model_config.dp_size + # Calculate total expected GPUs on this node across all instances + self.expected_gpus = self.instance_num * model_config.gpus_per_node self.all_gpu_layouts: Dict[int, KVCacheLayout] = {} self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks - self.gpu_client_mapping: Dict[int, int] = {} # device_id -> dp_client_id + self.gpu_worker_key_mapping: Dict[int, WorkerKey] = {} # device_id -> WorkerKey(dp_rank, pp_rank) # Indexer GPU registration data self.all_indexer_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> indexer_gpu_blocks @@ -72,7 +72,7 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: try: self.all_gpu_blocks[device_id] = req.handles self.all_gpu_layouts[device_id] = req.gpu_layout - self.gpu_client_mapping[device_id] = req.dp_client_id + self.gpu_worker_key_mapping[device_id] = WorkerKey(dp_rank=req.dp_rank, pp_rank=req.pp_rank) # Store indexer GPU data if present if req.indexer_handles is not None: self.all_indexer_gpu_blocks[device_id] = req.indexer_handles @@ -87,8 +87,9 @@ def _register_gpu_blocks_via_socket(self) -> None: try: flexkv_logger.info(f"GPU tensor registration server started on port {self.gpu_register_port}, " f"expected {self.expected_gpus} GPUs to register " - f"(instance_num={self.instance_num}, tp={self.model_config.tp_size}, " - f"dp={self.model_config.dp_size})") + f"(instance_num={self.instance_num}, gpus_per_node={self.model_config.gpus_per_node}, " + f"total_gpus={self.model_config.total_gpus}, pp_rank={self.model_config.pp_rank}, " + f"node_rank={self.model_config.node_rank}, nnodes={self.model_config.nnodes})") last_log_time = time.time() while len(self.all_gpu_blocks) < self.expected_gpus: try: @@ -109,7 +110,8 @@ def _register_gpu_blocks_via_socket(self) -> None: continue if isinstance(req, RegisterTPClientRequest): - flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}") + flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}, " + f"device_id={req.device_id}, dp_rank={req.dp_rank}") self._handle_gpu_blocks_registration(req) flexkv_logger.info(f"GPU {req.device_id} registered successfully, " f"waiting for {self.expected_gpus - len(self.all_gpu_blocks)} GPUs to register") @@ -153,13 +155,13 @@ def initialize_transfer_engine(self) -> None: indexer_dtype=indexer_dtype, ) - # Group GPU handles by dp_client_id - grouped_gpu_handles: Dict[int, List] = {} + # Group GPU handles by WorkerKey + grouped_gpu_handles: Dict[WorkerKey, List] = {} for device_id in sorted(self.all_gpu_blocks.keys()): - dp_client_id = self.gpu_client_mapping[device_id] - if dp_client_id not in grouped_gpu_handles: - grouped_gpu_handles[dp_client_id] = [] - grouped_gpu_handles[dp_client_id].append( + worker_key = self.gpu_worker_key_mapping[device_id] + if worker_key not in grouped_gpu_handles: + grouped_gpu_handles[worker_key] = [] + grouped_gpu_handles[worker_key].append( self.storage_engine.get_storage_handle(DeviceType.GPU, device_id)) cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) \ @@ -172,15 +174,15 @@ def initialize_transfer_engine(self) -> None: else None ) - indexer_gpu_handles: Optional[Dict[int, List]] = None + indexer_gpu_handles: Optional[Dict[WorkerKey, List]] = None if self.storage_engine.has_storage_handle(DeviceType.CPU, is_indexer=True): indexer_gpu_handles = {} for device_id in sorted(self.all_gpu_blocks.keys()): if self.storage_engine.has_storage_handle(DeviceType.GPU, device_id, is_indexer=True): - dp_client_id = self.gpu_client_mapping[device_id] - if dp_client_id not in indexer_gpu_handles: - indexer_gpu_handles[dp_client_id] = [] - indexer_gpu_handles[dp_client_id].append( + worker_key = self.gpu_worker_key_mapping[device_id] + if worker_key not in indexer_gpu_handles: + indexer_gpu_handles[worker_key] = [] + indexer_gpu_handles[worker_key].append( self.storage_engine.get_storage_handle(DeviceType.GPU, device_id, is_indexer=True)) indexer_cpu_handle = ( self.storage_engine.get_storage_handle(DeviceType.CPU, is_indexer=True) @@ -210,7 +212,23 @@ def initialize_transfer_engine(self) -> None: indexer_ssd_handle=indexer_ssd_handle, indexer_remote_handle=indexer_remote_handle, ) - flexkv_logger.info("Initialized TransferEngine successfully") + + # Derive local (dp_rank, pp_rank) from GPU registrations rather than + # model_config. In cross-node PP, TransferManagerOnRemote receives + # model_config from the PP0 master (pp_rank=0), but local GPUs + # register with their true pp_rank (e.g. pp_rank=1). + worker_keys = set(self.gpu_worker_key_mapping.values()) + self._local_dp_rank = self.model_config.dp_rank + self._local_pp_rank = self.model_config.pp_rank + if len(worker_keys) == 1: + wk = next(iter(worker_keys)) + self._local_dp_rank = wk.dp_rank + self._local_pp_rank = wk.pp_rank + + flexkv_logger.info(f"Initialized TransferEngine successfully, " + f"grouped_gpu_handles keys={list(grouped_gpu_handles.keys())}, " + f"num_gpu_groups={len(grouped_gpu_handles)}, " + f"local_dp_rank={self._local_dp_rank}, local_pp_rank={self._local_pp_rank}") def submit(self, transfer_graph: TransferOpGraph) -> None: self.transfer_engine.submit_transfer_graph(transfer_graph) @@ -289,6 +307,11 @@ def __init__(self, mode: str = "Default", master_host: Optional[str] = None): self._active_graphs: Dict[int, int] = {} self._active_graphs_lock = threading.Lock() + # Pending matching for cross-node PP: graph arrives before or after slot_mapping + self._pending_graphs: Dict[int, TransferOpGraph] = {} + self._pending_slot_mappings: Dict[int, np.ndarray] = {} + self._pending_lock = threading.Lock() + self._worker_thread: threading.Thread | None = None self._connect_to_master_transfer_manager() @@ -352,21 +375,21 @@ def _polling_worker(self) -> None: task_end_op_id = message.get('task_end_op_id', -1) if graph is not None: - graph_id = graph.graph_id - - with self._active_graphs_lock: - self._active_graphs[graph_id] = task_end_op_id - - self.submit(graph) + self._handle_submit(graph, task_end_op_id) else: flexkv_logger.warning("Received submit message without graph") elif msg_type == 'submit_batch': graphs = message.get('graphs', []) for graph in graphs: + self._rebind_graph_to_local_worker(graph) graph_id = graph.graph_id with self._active_graphs_lock: self._active_graphs[graph_id] = -1 self.submit(graph) + elif msg_type == 'set_slot_mapping': + task_id = message.get('task_id') + slot_mapping = message.get('slot_mapping') + self._handle_set_slot_mapping(task_id, slot_mapping) else: flexkv_logger.warning(f"Unexpected command message: {message}") else: @@ -415,11 +438,90 @@ def _polling_worker(self) -> None: poller.unregister(self.command_socket) poller.unregister(self.query_socket) + def _handle_set_slot_mapping(self, task_id: int, slot_mapping: np.ndarray) -> None: + """Handle set_slot_mapping message from FlexKVConnector. + + When the graph (with cleared GPU blocks) arrived earlier, we can immediately + set_gpu_blocks and submit. Otherwise, store the slot_mapping and wait + for the graph to arrive later. + """ + with self._pending_lock: + if task_id in self._pending_graphs: + # Graph already arrived, set GPU blocks and submit + graph = self._pending_graphs.pop(task_id) + graph.set_gpu_blocks(slot_mapping) + self._rebind_graph_to_local_worker(graph) + self.submit(graph) + flexkv_logger.debug( + f"[TransferManagerOnRemote] set_slot_mapping: " + f"graph for task_id={task_id} submitted (graph arrived first)" + ) + else: + # Graph not yet arrived, store slot_mapping for later matching + self._pending_slot_mappings[task_id] = slot_mapping + flexkv_logger.debug( + f"[TransferManagerOnRemote] set_slot_mapping: " + f"slot_mapping stored for task_id={task_id}, waiting for graph" + ) + + def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> None: + """Handle submit message with pending matching support. + + If slot_mapping already arrived, set_gpu_blocks and submit immediately. + Otherwise, store graph in pending_graphs for later matching. + """ + task_id = graph.graph_id # Use graph_id as task_id for matching + with self._pending_lock: + if task_id in self._pending_slot_mappings: + # slot_mapping already arrived, set GPU blocks and submit + slot_mapping = self._pending_slot_mappings.pop(task_id) + graph.set_gpu_blocks(slot_mapping) + self._rebind_graph_to_local_worker(graph) + flexkv_logger.debug( + f"[TransferManagerOnRemote] submit: " + f"graph for task_id={task_id} submitted (slot_mapping arrived first)" + ) + else: + # slot_mapping not yet arrived, store graph for later matching + self._pending_graphs[task_id] = graph + flexkv_logger.debug( + f"[TransferManagerOnRemote] submit: " + f"graph stored for task_id={task_id}, waiting for slot_mapping" + ) + return # Don't submit yet, wait for slot_mapping + + # Submit graph to transfer engine + with self._active_graphs_lock: + self._active_graphs[graph.graph_id] = task_end_op_id + self.submit(graph) + + def _rebind_graph_to_local_worker(self, graph: TransferOpGraph) -> None: + """Rebind transfer graph ops to the local WorkerKey. + + In cross-node PP setups, the master (PP0) creates transfer graphs with + its own (dp_rank=0, pp_rank=0). When these graphs are sent to a remote + node (e.g. PP1), the ops must be rebound to the local (dp_rank, pp_rank) + so the TransferEngine can find the correct workers. + """ + model_pp_rank = self.model_config.pp_rank + model_dp_rank = self.model_config.dp_rank + + if model_pp_rank == self._local_pp_rank and model_dp_rank == self._local_dp_rank: + return # No rebinding needed + + old_key = WorkerKey(dp_rank=model_dp_rank, pp_rank=model_pp_rank) + new_key = WorkerKey(dp_rank=self._local_dp_rank, pp_rank=self._local_pp_rank) + graph.bind_to_worker(self._local_dp_rank, self._local_pp_rank) + flexkv_logger.debug( + f"[TransferManagerOnRemote] Rebound graph {graph.graph_id} " + f"from {old_key} to {new_key}" + ) + def start(self) -> None: self.initialize_transfer_engine() super().start() - self._is_ready = true + self._is_ready = True self._worker_thread = threading.Thread( target=self._polling_worker, daemon=True @@ -455,9 +557,6 @@ def __del__(self) -> None: @classmethod def create_process(cls, **kwargs: Any) -> Process: - import tempfile - import os - # Serialize the class and kwargs cls_data = pickle.dumps(cls) kwargs_data = pickle.dumps(kwargs) @@ -538,7 +637,6 @@ def cleanup_files(): except Exception: pass - import threading cleanup_thread = threading.Thread(target=cleanup_files, daemon=True) cleanup_thread.start() @@ -647,6 +745,11 @@ def _start_process(self) -> None: if self.process is not None and self.process.is_alive(): return + flexkv_logger.debug( + f"Spawning TransferManager subprocess: " + f"pp_rank={self.model_config.pp_rank}, node_rank={self.model_config.node_rank}, " + f"tp_size={self.model_config.tp_size}, dp_size={self.model_config.dp_size}, " + f"gpu_register_port={self.gpu_register_port}") self.process = self.mp_ctx.Process( target=self._process_worker, args=(self.model_config, @@ -659,6 +762,7 @@ def _start_process(self) -> None: daemon=False ) self.process.start() + flexkv_logger.debug(f"TransferManager subprocess spawned, pid={self.process.pid}") def _process_worker(self, model_config: ModelConfig, @@ -682,11 +786,15 @@ def _reap_children(signum, frame): break signal.signal(signal.SIGCHLD, _reap_children) try: + flexkv_logger.debug(f"_process_worker started, pid={os.getpid()}, " + f"gpu_register_port={gpu_register_port}, " + f"pp_rank={model_config.pp_rank}, node_rank={model_config.node_rank}") start_event.set() os.environ['MPI4PY_RC_INITIALIZE'] = 'false' transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) transfer_manager.initialize_transfer_engine() transfer_manager.start() + flexkv_logger.debug("TransferEngine started successfully, setting ready_event") ready_event.set() # Setup selector for event-driven processing (complete zero polling!) @@ -825,7 +933,7 @@ def __del__(self): self.shutdown() -class TranserManagerMultiNodeHandle(TransferManagerHandleBase): +class TransferManagerMultiNodeHandle(TransferManagerHandleBase): def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, @@ -1023,6 +1131,10 @@ def __init__(self, gpu_register_port: Optional[str] = None, mode: str = "process", **kwargs): # process or thread or remote + flexkv_logger.debug( + f"Creating TransferManagerHandle: mode={mode}, " + f"pp_rank={model_config.pp_rank}, node_rank={model_config.node_rank}, " + f"gpu_register_port={gpu_register_port}") if gpu_register_port is None: gpu_register_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" if mode == "process": @@ -1036,7 +1148,7 @@ def __init__(self, elif mode == "remote": master_host = kwargs["master_host"] master_ports = kwargs["master_ports"] - self._handle: TransferManagerHandleBase = TranserManagerMultiNodeHandle( + self._handle: TransferManagerHandleBase = TransferManagerMultiNodeHandle( model_config, cache_config, gpu_register_port, master_host, master_ports ) else: diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index 2887d0723b..ba7f557820 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -147,7 +147,6 @@ def parse_config_event(self, event: Dict[str, Any]): num_remote_blocks=cache_config_data['num_remote_blocks'], ssd_cache_dir=cache_config_data['ssd_cache_dir'], gds_cache_dir=cache_config_data['gds_cache_dir'], - remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], remote_file_size=cache_config_data['remote_file_size'], remote_file_num=cache_config_data['remote_file_num'], remote_file_prefix=cache_config_data['remote_file_prefix'], @@ -234,7 +233,8 @@ def register_gpu_blocks_to_kvmanager(self, gpu_register_port: str): # Create registration request register_req = RegisterTPClientRequest( - dp_client_id=gpu_id // self.model_config.tp_size, # DP client ID + dp_rank=gpu_id // self.model_config.tp_size, # DP client ID + pp_rank=0, # single PP stage for replay device_id=gpu_id, handles=handles, gpu_layout=self.gpu_layout @@ -305,7 +305,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: slot_mapping = np.array(data['slot_mapping'], dtype=np.int64) token_mask = np.array(data['token_mask'], dtype=bool) if data['token_mask'] else None layer_granularity = data.get('layer_granularity', -1) - dp_id = data.get('dp_id', 0) + dp_rank = data.get('dp_id', 0) + pp_rank = data.get('pp_rank', 0) self.log(f"Replaying {request_type} request with {len(token_ids)} tokens") @@ -319,7 +320,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: slot_mapping=slot_mapping, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) elif request_type == "PUT": print(f"✅✅✅PUT token_ids: {token_ids[:128]}") @@ -330,7 +332,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: token_ids=token_ids, slot_mapping=slot_mapping, token_mask=token_mask, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) elif request_type == "GET_MATCH": print(f"🔍📝GET_MATCH token_ids: {token_ids[:128]}") @@ -341,7 +344,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: token_ids=token_ids, token_mask=token_mask, layer_granularity=layer_granularity, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) elif request_type == "PUT_MATCH": print(f"✅📝PUT_MATCH token_ids: {token_ids[:128]}") @@ -351,7 +355,8 @@ def replay_request_event(self, event: Dict[str, Any]) -> int: task_id, return_mask = self.kvmanager.put_match( token_ids=token_ids, token_mask=token_mask, - dp_id=dp_id + dp_rank=dp_rank, + pp_rank=pp_rank ) else: raise ValueError(f"Unknown request type: {request_type}") diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index c446dda667..706c10575d 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -40,7 +40,7 @@ def _fp8_cuda_ops_unavailable(): except NotImplementedError: return True -def run_tp_client(dp_client_id, +def run_tp_client(dp_rank, tp_rank, server_recv_port, model_config, @@ -50,8 +50,8 @@ def run_tp_client(dp_client_id, gpu_layout_type): """Run tp_client process""" try: - device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) + device_id = tp_rank + dp_rank * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_rank, device_id) gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) @@ -242,14 +242,14 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): initial_write_num = int(num_requests * initial_write_ratio) print("writing initial data...") put_ids = [] - for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + for token_ids, block_ids, dp_rank in request_pairs[:initial_write_num]: if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) kvmanager.wait([write_request], completely=True) @@ -261,7 +261,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): token_ids=torch.randint(0, 100, size=(8,), dtype=torch.int64), slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), token_mask=None, - dp_id=0, + dp_rank=0, namespace=namespace, ) kvmanager.wait([write_request], completely=True) @@ -272,7 +272,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): # token_ids=torch.randint(0, 100, size=(16,), dtype=torch.int64), # slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), # token_mask=my_mask, - # dp_id=0, + # dp_rank=0, #) #kvmanager.wait_for_graph_finished(write_request) @@ -289,13 +289,13 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): for i in range(initial_write_num, num_requests): print(f"performing mixed read/write {i} / {num_requests} ...") read_idx = i - initial_write_num - token_ids, block_ids, dp_id = request_pairs[read_idx] + token_ids, block_ids, dp_rank = request_pairs[read_idx] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) request_id, _ = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) kvmanager.launch(request_id, slot_mapping) @@ -303,14 +303,14 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): running_get_requests.append(request_id) req_id2block_ids[request_id] = block_ids req_id2token_ids[request_id] = token_ids - token_ids, block_ids, dp_id = request_pairs[i] + token_ids, block_ids, dp_rank = request_pairs[i] if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) request_id = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) req_id2block_ids[request_id] = block_ids @@ -378,14 +378,14 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): # Create multiple get_match requests for i in range(batch_size): - token_ids, block_ids, dp_id = request_pairs[random.randint(0, num_requests - 1)] + token_ids, block_ids, dp_rank = request_pairs[random.randint(0, num_requests - 1)] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) request_id, return_mask = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, namespace=namespace, ) batched_get_task_ids.append(request_id) @@ -580,7 +580,7 @@ def verify_gpu_blocks(self, block_ids, main_kv_tokens_per_block, token_ids) -> b return verification_passed -def run_tp_client_with_indexer(dp_client_id, +def run_tp_client_with_indexer(dp_rank, tp_rank, server_recv_port, model_config, @@ -593,7 +593,7 @@ def run_tp_client_with_indexer(dp_client_id, Indexer configuration is read from cache_config.indexer (IndexerCacheConfig). """ try: - device_id = tp_rank + dp_client_id * model_config.tp_size + device_id = tp_rank + dp_rank * model_config.tp_size gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) @@ -647,7 +647,7 @@ def run_tp_client_with_indexer(dp_client_id, # Use KVTPClient directly with indexer buffers (shadow transfer mode) tp_client = KVTPClient( gpu_register_port=server_recv_port + "_gpu_register", - dp_client_id=dp_client_id, + dp_rank=dp_rank, device_id=device_id, ) tp_client.register_to_server( @@ -786,7 +786,7 @@ def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, initial_write_num = int(num_requests * initial_write_ratio) print(f"[Test] Testing put flow ({test_label})...") - for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + for token_ids, block_ids, dp_rank in request_pairs[:initial_write_num]: if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) if indexer_kv_verifier is not None: @@ -795,7 +795,7 @@ def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, ) put_results = kvmanager.wait([write_request], completely=True) assert put_results[write_request].status == KVResponseStatus.SUCCESS @@ -816,13 +816,13 @@ def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, batch_slot_mappings = [] for i in range(min(initial_write_num, num_requests)): - token_ids, block_ids, dp_id = request_pairs[i] + token_ids, block_ids, dp_rank = request_pairs[i] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) request_id, _ = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, ) batch_task_ids.append(request_id) batch_slot_mappings.append(slot_mapping) @@ -889,7 +889,7 @@ def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, print(f"[Test] Testing try_wait flow ({test_label})...") if initial_write_num < num_requests: - token_ids, block_ids, dp_id = request_pairs[initial_write_num] + token_ids, block_ids, dp_rank = request_pairs[initial_write_num] if gpu_kv_verifier is not None: gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) if indexer_kv_verifier is not None: @@ -898,7 +898,7 @@ def _run_indexer_test(model_config, cache_config, test_config, gpu_layout_type, token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), token_mask=None, - dp_id=dp_id, + dp_rank=dp_rank, ) finished = {} for _ in range(200): @@ -1018,11 +1018,8 @@ def _mock_sglang_eventfd_client(socket_path: str, f"after {max_retries} attempts") return - # Send 24-byte metadata: tp_rank, tp_size, cp_rank, cp_size, - # num_layers, num_counters - metadata = struct.pack("iiiiii", + metadata = struct.pack("iiii", tp_rank, tp_size, - 0, 1, # cp_rank=0, cp_size=1 num_layers, num_counters) sock.sendall(metadata) diff --git a/tests/test_transfer_engine_atomic_eviction.py b/tests/test_transfer_engine_atomic_eviction.py index d08e646081..d380df14ed 100644 --- a/tests/test_transfer_engine_atomic_eviction.py +++ b/tests/test_transfer_engine_atomic_eviction.py @@ -300,7 +300,7 @@ def test_layerwise_op_pending_count_incremented_for_indexer(self): engine, main_worker, indexer_worker = self._make_engine_stub_with_indexer(enable_layerwise=True) op = _make_op(TransferType.LAYERWISE) - op.dp_id = 0 + op.dp_rank = 0 engine.op_id_to_op[op.op_id] = op initial_pending_count = op.pending_count # should be 1 @@ -324,7 +324,7 @@ def test_layerwise_op_submitted_to_both_main_and_indexer_workers(self): engine, main_worker, indexer_worker = self._make_engine_stub_with_indexer(enable_layerwise=True) op = _make_op(TransferType.LAYERWISE) - op.dp_id = 0 + op.dp_rank = 0 engine.op_id_to_op[op.op_id] = op with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ @@ -358,7 +358,7 @@ def test_layerwise_op_no_indexer_pending_count_stays_one(self): engine._worker_map[TransferType.LAYERWISE] = [main_layerwise_worker] op = _make_op(TransferType.LAYERWISE) - op.dp_id = 0 + op.dp_rank = 0 engine.op_id_to_op[op.op_id] = op initial_pending_count = op.pending_count # should be 1 From 2f572d8c216bc5186170474a70c428e867e03e10 Mon Sep 17 00:00:00 2001 From: zittozhang Date: Thu, 30 Apr 2026 16:49:33 +0800 Subject: [PATCH 58/59] fix: improve transfer manager and engine robustness for PP --- .../tensorrt_llm/trtllm_adapter.py | 2 +- flexkv/transfer/transfer_engine.py | 18 +++--- flexkv/transfer_manager.py | 63 +++++++++++-------- tests/test_kvmanager.py | 3 +- tests/test_transfer_engine_atomic_eviction.py | 61 ++++++++++++------ 5 files changed, 90 insertions(+), 57 deletions(-) diff --git a/flexkv/integration/tensorrt_llm/trtllm_adapter.py b/flexkv/integration/tensorrt_llm/trtllm_adapter.py index 940e682e6d..db925fe6b7 100644 --- a/flexkv/integration/tensorrt_llm/trtllm_adapter.py +++ b/flexkv/integration/tensorrt_llm/trtllm_adapter.py @@ -543,7 +543,7 @@ def __init__(self, config: ExecutorConfig): flexkv_logger.info(f"TransferManagerOnRemote process created, PID: {self.remote_process.pid}") flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, dp_rank: {dp_rank}") - self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_rank=dp_rank, pp_rank=0, device_id=current_device_id) + self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_rank=dp_rank, pp_rank=flexkv_config.model_config.pp_rank, device_id=current_device_id) flexkv_logger.info("Finish init FlexKVWorkerConnector") def _need_to_create_remote_process(self) -> bool: diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 84fe86b786..092baa214d 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -334,7 +334,7 @@ def _init_workers(self) -> None: gpu_kv_layouts=[gpu_handle.kv_layout for gpu_handle in gpu_handles], ssd_kv_layout=self._ssd_handle.kv_layout, dtype=self._ssd_handle.dtype, - tp_group_id=self.tp_size_per_node, + tp_group_size=self.tp_size_per_node, ) for worker_key, gpu_handles in self.gpu_handle_groups.items() } @@ -574,8 +574,8 @@ def _init_workers(self) -> None: flexkv_logger.info("TransferEngine: indexer Remote workers initialized") if self.cache_config.enable_gds and self._indexer_ssd_handle is not None: if self.tp_size_per_node == 1: - self._indexer_gds_workers = [ - GDSTransferWorker.create_worker( + self._indexer_gds_workers: Dict[WorkerKey, WorkerHandle] = { + worker_key: GDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -587,11 +587,11 @@ def _init_workers(self) -> None: dtype=self._indexer_ssd_handle.dtype, gpu_device_id=indexer_gpu_handles_list[0].gpu_device_id, ) - for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } else: - self._indexer_gds_workers = [ - tpGDSTransferWorker.create_worker( + self._indexer_gds_workers = { + worker_key: tpGDSTransferWorker.create_worker( mp_ctx=self.mp_ctx, finished_ops_queue=self._indexer_finished_ops_queue, op_buffer_tensor=self.pin_buffer.get_buffer(), @@ -603,8 +603,8 @@ def _init_workers(self) -> None: dtype=self._indexer_ssd_handle.dtype, tp_group_size=self.tp_size_per_node, ) - for _, indexer_gpu_handles_list in self._indexer_gpu_handles.items() - ] + for worker_key, indexer_gpu_handles_list in self._indexer_gpu_handles.items() + } self._indexer_worker_map[TransferType.DISK2D] = self._indexer_gds_workers self._indexer_worker_map[TransferType.D2DISK] = self._indexer_gds_workers flexkv_logger.info("TransferEngine: indexer GDS workers initialized") diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 276873a36b..d8afeabd09 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -213,17 +213,19 @@ def initialize_transfer_engine(self) -> None: indexer_remote_handle=indexer_remote_handle, ) - # Derive local (dp_rank, pp_rank) from GPU registrations rather than - # model_config. In cross-node PP, TransferManagerOnRemote receives - # model_config from the PP0 master (pp_rank=0), but local GPUs - # register with their true pp_rank (e.g. pp_rank=1). + # Derive local pp_rank from GPU registrations rather than model_config. + # In cross-node PP, TransferManagerOnRemote receives model_config from + # the PP0 master (pp_rank=0), but local GPUs register with their true + # pp_rank (e.g. pp_rank=1). All local workers share the same pp_rank + # because they belong to the same PP stage on this node. worker_keys = set(self.gpu_worker_key_mapping.values()) self._local_dp_rank = self.model_config.dp_rank self._local_pp_rank = self.model_config.pp_rank - if len(worker_keys) == 1: - wk = next(iter(worker_keys)) - self._local_dp_rank = wk.dp_rank - self._local_pp_rank = wk.pp_rank + if len(worker_keys) >= 1: + pp_ranks = set(wk.pp_rank for wk in worker_keys) + assert len(pp_ranks) == 1, \ + f"Expected all local workers to share the same pp_rank, got {pp_ranks}" + self._local_pp_rank = pp_ranks.pop() flexkv_logger.info(f"Initialized TransferEngine successfully, " f"grouped_gpu_handles keys={list(grouped_gpu_handles.keys())}, " @@ -308,7 +310,8 @@ def __init__(self, mode: str = "Default", master_host: Optional[str] = None): self._active_graphs_lock = threading.Lock() # Pending matching for cross-node PP: graph arrives before or after slot_mapping - self._pending_graphs: Dict[int, TransferOpGraph] = {} + # _pending_graphs stores (graph, task_end_op_id) tuples + self._pending_graphs: Dict[int, Tuple[TransferOpGraph, int]] = {} self._pending_slot_mappings: Dict[int, np.ndarray] = {} self._pending_lock = threading.Lock() @@ -445,13 +448,14 @@ def _handle_set_slot_mapping(self, task_id: int, slot_mapping: np.ndarray) -> No set_gpu_blocks and submit. Otherwise, store the slot_mapping and wait for the graph to arrive later. """ + graph = None + task_end_op_id = -1 with self._pending_lock: if task_id in self._pending_graphs: - # Graph already arrived, set GPU blocks and submit - graph = self._pending_graphs.pop(task_id) + # Graph already arrived, set GPU blocks and prepare for submit + graph, task_end_op_id = self._pending_graphs.pop(task_id) graph.set_gpu_blocks(slot_mapping) self._rebind_graph_to_local_worker(graph) - self.submit(graph) flexkv_logger.debug( f"[TransferManagerOnRemote] set_slot_mapping: " f"graph for task_id={task_id} submitted (graph arrived first)" @@ -463,6 +467,12 @@ def _handle_set_slot_mapping(self, task_id: int, slot_mapping: np.ndarray) -> No f"[TransferManagerOnRemote] set_slot_mapping: " f"slot_mapping stored for task_id={task_id}, waiting for graph" ) + return + + # Submit graph to transfer engine + with self._active_graphs_lock: + self._active_graphs[graph.graph_id] = task_end_op_id + self.submit(graph) def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> None: """Handle submit message with pending matching support. @@ -482,8 +492,8 @@ def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> No f"graph for task_id={task_id} submitted (slot_mapping arrived first)" ) else: - # slot_mapping not yet arrived, store graph for later matching - self._pending_graphs[task_id] = graph + # slot_mapping not yet arrived, store graph and task_end_op_id for later matching + self._pending_graphs[task_id] = (graph, task_end_op_id) flexkv_logger.debug( f"[TransferManagerOnRemote] submit: " f"graph stored for task_id={task_id}, waiting for slot_mapping" @@ -496,25 +506,26 @@ def _handle_submit(self, graph: TransferOpGraph, task_end_op_id: int = -1) -> No self.submit(graph) def _rebind_graph_to_local_worker(self, graph: TransferOpGraph) -> None: - """Rebind transfer graph ops to the local WorkerKey. + """Rebind transfer graph ops to the local pp_rank. In cross-node PP setups, the master (PP0) creates transfer graphs with - its own (dp_rank=0, pp_rank=0). When these graphs are sent to a remote - node (e.g. PP1), the ops must be rebound to the local (dp_rank, pp_rank) - so the TransferEngine can find the correct workers. - """ - model_pp_rank = self.model_config.pp_rank - model_dp_rank = self.model_config.dp_rank + its own pp_rank=0. When these graphs are sent to a remote node (e.g. PP1), + the ops' pp_rank must be updated to the local pp_rank so the + TransferEngine can find the correct workers. - if model_pp_rank == self._local_pp_rank and model_dp_rank == self._local_dp_rank: + Each op's dp_rank is preserved — in multi-DP scenarios, different ops + may belong to different dp_ranks and should remain bound to their + original DP group. + """ + if self.model_config.pp_rank == self._local_pp_rank: return # No rebinding needed - old_key = WorkerKey(dp_rank=model_dp_rank, pp_rank=model_pp_rank) - new_key = WorkerKey(dp_rank=self._local_dp_rank, pp_rank=self._local_pp_rank) - graph.bind_to_worker(self._local_dp_rank, self._local_pp_rank) + for op in graph._op_map.values(): + op.pp_rank = self._local_pp_rank + flexkv_logger.debug( f"[TransferManagerOnRemote] Rebound graph {graph.graph_id} " - f"from {old_key} to {new_key}" + f"pp_rank from {self.model_config.pp_rank} to {self._local_pp_rank}" ) def start(self) -> None: diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index 706c10575d..eff1c48e2c 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -51,7 +51,7 @@ def run_tp_client(dp_rank, """Run tp_client process""" try: device_id = tp_rank + dp_rank * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_rank, device_id) + tp_client = KVTPClient(server_recv_port, dp_rank, pp_rank=0, device_id=device_id) gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks, gpu_layout_type) @@ -648,6 +648,7 @@ def run_tp_client_with_indexer(dp_rank, tp_client = KVTPClient( gpu_register_port=server_recv_port + "_gpu_register", dp_rank=dp_rank, + pp_rank=0, device_id=device_id, ) tp_client.register_to_server( diff --git a/tests/test_transfer_engine_atomic_eviction.py b/tests/test_transfer_engine_atomic_eviction.py index d380df14ed..83b8ca4831 100644 --- a/tests/test_transfer_engine_atomic_eviction.py +++ b/tests/test_transfer_engine_atomic_eviction.py @@ -16,7 +16,7 @@ import numpy as np -from flexkv.common.transfer import TransferOp, TransferType, CompletedOp +from flexkv.common.transfer import TransferOp, TransferType, CompletedOp, LayerwiseTransferOp, WorkerKey # --------------------------------------------------------------------------- @@ -25,6 +25,8 @@ def _make_op(transfer_type: TransferType = TransferType.D2H) -> TransferOp: """Create a minimal TransferOp for testing.""" + if transfer_type == TransferType.LAYERWISE: + return _make_layerwise_op() return TransferOp( graph_id=0, transfer_type=transfer_type, @@ -33,6 +35,19 @@ def _make_op(transfer_type: TransferType = TransferType.D2H) -> TransferOp: ) +def _make_layerwise_op(**kwargs) -> LayerwiseTransferOp: + """Create a minimal LayerwiseTransferOp for testing.""" + defaults = dict( + graph_id=0, + src_block_ids_h2d=np.array([0, 1], dtype=np.int64), + dst_block_ids_h2d=np.array([2, 3], dtype=np.int64), + src_block_ids_disk2h=np.array([], dtype=np.int64), + dst_block_ids_disk2h=np.array([], dtype=np.int64), + ) + defaults.update(kwargs) + return LayerwiseTransferOp(**defaults) + + # --------------------------------------------------------------------------- # Tests – TransferOp.pending_count field # --------------------------------------------------------------------------- @@ -261,7 +276,9 @@ def _make_engine_stub_with_indexer(self, enable_layerwise: bool = True): engine._worker_map[TransferType.H2D] = [MagicMock()] engine._worker_map[TransferType.D2H] = [MagicMock()] if enable_layerwise: - engine._worker_map[TransferType.LAYERWISE] = [main_layerwise_worker] + # LAYERWISE worker map must be Dict[WorkerKey, WorkerHandle] + wk0 = WorkerKey(dp_rank=0, pp_rank=0) + engine._worker_map[TransferType.LAYERWISE] = {wk0: main_layerwise_worker} # Create mock workers for indexer indexer_h2d_worker = MagicMock() @@ -271,6 +288,10 @@ def _make_engine_stub_with_indexer(self, enable_layerwise: bool = True): if enable_layerwise: engine._indexer_worker_map[TransferType.LAYERWISE] = [indexer_layerwise_worker] + # PP replica tracking (needed by _assign_layerwise_op_to_workers) + engine._pp_replica_to_parent_op = {} + engine._pp_replica_op_map = {} + return engine, main_layerwise_worker, indexer_layerwise_worker def test_indexer_worker_map_contains_layerwise_when_enabled(self): @@ -289,10 +310,11 @@ def test_indexer_worker_map_no_layerwise_when_disabled(self): engine, _, _ = self._make_engine_stub_with_indexer(enable_layerwise=False) self.assertNotIn(TransferType.LAYERWISE, engine._indexer_worker_map) - def test_layerwise_op_pending_count_incremented_for_indexer(self): + def test_layerwise_op_pending_count_not_incremented_for_single_pp_stage(self): """ - WHEN _assign_op_to_worker processes a LAYERWISE op with _has_indexer=True - THEN op.pending_count SHALL be incremented by 1 before submitting to indexer (req 2.2). + WHEN _assign_op_to_worker processes a LAYERWISE op with only one PP stage + THEN op.pending_count SHALL remain 1 (no fan-out needed). + LAYERWISE indexer is fused inside the worker, not dispatched separately. """ from flexkv.transfer.transfer_engine import register_op_to_buffer import nvtx @@ -301,6 +323,7 @@ def test_layerwise_op_pending_count_incremented_for_indexer(self): op = _make_op(TransferType.LAYERWISE) op.dp_rank = 0 + op.pp_rank = 0 engine.op_id_to_op[op.op_id] = op initial_pending_count = op.pending_count # should be 1 @@ -309,15 +332,14 @@ def test_layerwise_op_pending_count_incremented_for_indexer(self): patch('nvtx.start_range', return_value=MagicMock()): engine._assign_op_to_worker(op) - # pending_count should have been incremented by 1 (for indexer) before submission - # After _assign_op_to_worker: pending_count = initial + 1 = 2 - self.assertEqual(op.pending_count, initial_pending_count + 1) + # With a single PP stage, no fan-out → pending_count stays at 1 + self.assertEqual(op.pending_count, initial_pending_count) - def test_layerwise_op_submitted_to_both_main_and_indexer_workers(self): + def test_layerwise_op_submitted_to_main_worker(self): """ - WHEN _assign_op_to_worker processes a LAYERWISE op with _has_indexer=True - THEN op SHALL be submitted to main KV worker, and a separate indexer_op - SHALL be submitted to the indexer layerwise worker (req 2.1). + WHEN _assign_op_to_worker processes a LAYERWISE op + THEN op SHALL be submitted to the main KV layerwise worker. + LAYERWISE indexer is fused inside the worker, not dispatched separately. """ from flexkv.transfer.transfer_engine import register_op_to_buffer @@ -325,20 +347,15 @@ def test_layerwise_op_submitted_to_both_main_and_indexer_workers(self): op = _make_op(TransferType.LAYERWISE) op.dp_rank = 0 + op.pp_rank = 0 engine.op_id_to_op[op.op_id] = op with patch('flexkv.transfer.transfer_engine.register_op_to_buffer'), \ patch('nvtx.start_range', return_value=MagicMock()): engine._assign_op_to_worker(op) - # Main KV worker should have received the original op + # Main KV layerwise worker should have received the op main_worker.submit_transfer.assert_called_once_with(op) - # Indexer worker should have received a separate indexer_op (not the same op object) - indexer_worker.submit_transfer.assert_called_once() - indexer_op = indexer_worker.submit_transfer.call_args[0][0] - self.assertIsNot(indexer_op, op, "Indexer worker must receive a separate op, not the original") - self.assertEqual(indexer_op.graph_id, op.graph_id) - self.assertEqual(indexer_op.transfer_type, op.transfer_type) def test_layerwise_op_no_indexer_pending_count_stays_one(self): """ @@ -353,12 +370,16 @@ def test_layerwise_op_no_indexer_pending_count_stays_one(self): engine._indexer_worker_map = {} engine.op_id_to_op = {} engine.op_id_to_nvtx_range = {} + engine._pp_replica_to_parent_op = {} + engine._pp_replica_op_map = {} main_layerwise_worker = MagicMock() - engine._worker_map[TransferType.LAYERWISE] = [main_layerwise_worker] + wk0 = WorkerKey(dp_rank=0, pp_rank=0) + engine._worker_map[TransferType.LAYERWISE] = {wk0: main_layerwise_worker} op = _make_op(TransferType.LAYERWISE) op.dp_rank = 0 + op.pp_rank = 0 engine.op_id_to_op[op.op_id] = op initial_pending_count = op.pending_count # should be 1 From 16fa4503829d8bd434f9f3551718e64fefc85f4a Mon Sep 17 00:00:00 2001 From: phaedonsun Date: Wed, 29 Apr 2026 21:01:16 +0800 Subject: [PATCH 59/59] refactor(DistributedRadixTree): enforce single-node matching constraint - Remove CMatchResult.block_node_ids tensor; add matched_node_id (int32) - RefRadixTree::match_prefix now stops when encountering a different node_id - Simplify pybind bindings: expose matched_node_id, remove block_node_ids - Python MatchResultAccel: add matched_node_id field, broadcast to per-block arrays for backward compat in downstream worker/transfer paths - Update hie_cache_engine.py and cache_engine.py to derive per-block arrays from single matched_node_id instead of reading C++ tensor - Add unit tests for CMatchResult and MatchResultAccel matched_node_id This simplifies the distributed matching and transfer paths: - No need for shared_transfer_kv_blocks_remote_read multi-node grouping - Lease management is simpler (no cross-node cascade invalidation) - Better fault isolation (single node failure domain) - Cleaner integration with CP/PP/TP cooperative GET flow --- csrc/bindings.cpp | 11 ++++- csrc/dist/distributed_radix_tree.cpp | 30 ++++++++----- csrc/radix_tree.cpp | 3 +- csrc/radix_tree.h | 7 +-- flexkv/cache/cache_engine.py | 17 ++++--- flexkv/cache/hie_cache_engine.py | 45 +++++++++---------- flexkv/common/type.py | 6 ++- tests/test_cache_engine.py | 66 ++++++++++++++++++++++++++++ 8 files changed, 135 insertions(+), 50 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 659802cdd8..80aa938309 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -660,11 +660,18 @@ PYBIND11_MODULE(c_ext, m) { py::class_>( m, "CMatchResult") .def(py::init()) + torch::Tensor, int32_t>(), + py::arg("num_ready_matched_blocks"), + py::arg("num_matched_blocks"), + py::arg("last_node_matched_length"), + py::arg("last_ready_node"), + py::arg("last_node"), + py::arg("physical_blocks"), + py::arg("matched_node_id") = -1) .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) .def_readonly("last_node", &flexkv::CMatchResult::last_node) .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) - .def_readonly("block_node_ids", &flexkv::CMatchResult::block_node_ids) + .def_readonly("matched_node_id", &flexkv::CMatchResult::matched_node_id) .def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks) .def_readonly("num_matched_blocks", diff --git a/csrc/dist/distributed_radix_tree.cpp b/csrc/dist/distributed_radix_tree.cpp index 1831896861..04a3b38b35 100644 --- a/csrc/dist/distributed_radix_tree.cpp +++ b/csrc/dist/distributed_radix_tree.cpp @@ -354,8 +354,7 @@ std::shared_ptr DistributedRadixTree::match_prefix( if (idx == nullptr) { // Remote index not yet built - this is normal at startup auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); - auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); - return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64); } // Safely increment reference count while holding the lock @@ -563,8 +562,7 @@ std::shared_ptr RefRadixTree::match_prefix( if (root == nullptr) { auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); - auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); - return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64); } auto current_node = root; @@ -578,10 +576,10 @@ std::shared_ptr RefRadixTree::match_prefix( auto block_hashes_ptr = block_hashes.data_ptr(); HashType child_hash; - // node ids stored as int32 tensor (PyTorch lacks uint32 dtype) - auto node_ids_tensor = torch::empty({num_blocks}, torch::dtype(torch::kInt32)); - auto *ni_out = node_ids_tensor.data_ptr(); - int32_t ni_write = 0; + // Single-node matching constraint: all matched blocks must come from the + // same peer node_id. We lock the node_id on the first valid block and + // stop matching when a different node_id is encountered. + int32_t matched_node_id = -1; // -1 = not yet determined // now in ms struct timeval now_tv; gettimeofday(&now_tv, nullptr); @@ -638,9 +636,20 @@ std::shared_ptr RefRadixTree::match_prefix( if (bnis == nullptr || bnis->size() != pbs.size()) break; + // Single-node constraint: stop at the first block whose node_id + // differs from the already-locked matched_node_id. + int actually_copied = 0; for (int i = 0; i < matched; ++i) { + int32_t block_nid = static_cast((*bnis)[i]); + if (matched_node_id == -1) { + matched_node_id = block_nid; // lock the first node_id + } else if (block_nid != matched_node_id) { + // Different node_id encountered - stop matching here + matched = actually_copied; + break; + } pb_out[pb_write++] = pbs[i]; - ni_out[ni_write++] = (*bnis)[i]; + actually_copied++; } if (current_node->is_ready()) { @@ -672,10 +681,9 @@ std::shared_ptr RefRadixTree::match_prefix( } auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); - auto node_ids = node_ids_tensor.narrow(0, 0, ni_write); return std::make_shared(prefix_blocks_num, prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks, node_ids); + last_ready_node, current_node, physical_blocks, matched_node_id); } // Helper function to clean up an orphan tree (not attached to main tree) diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp index 04d3429f6b..b27ac43920 100644 --- a/csrc/radix_tree.cpp +++ b/csrc/radix_tree.cpp @@ -520,9 +520,8 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks, } auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); - auto empty_uint32 = torch::Tensor(); return std::make_shared(ready_prefix_blocks_num, prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks, empty_uint32); + last_ready_node, current_node, physical_blocks); } } // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h index 65bad5dc12..4c7c3c5d86 100644 --- a/csrc/radix_tree.h +++ b/csrc/radix_tree.h @@ -227,17 +227,18 @@ class CMatchResult { CRadixNode *last_ready_node; CRadixNode *last_node; torch::Tensor physical_blocks; - torch::Tensor block_node_ids; + int32_t matched_node_id; // single node_id for all matched blocks (-1 = no match) CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, CRadixNode *_last_ready_node, CRadixNode *_last_node, torch::Tensor blocks, - torch::Tensor block_node_ids = torch::Tensor()) + int32_t matched_node_id = -1) : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), last_node(_last_node), - physical_blocks(blocks), block_node_ids(block_node_ids) {} + physical_blocks(blocks), + matched_node_id(matched_node_id) {} ~CMatchResult() {} }; diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 39e6cba16f..1b78d39272 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -90,15 +90,13 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: sequence_meta.num_blocks, True) # physical blocks (torch.Tensor -> numpy, zero-copy on CPU) phys = match_result.physical_blocks.cpu().numpy() - # optional block_node_ids - try: - bnis = getattr(match_result, "block_node_ids", None) - if isinstance(bnis, torch.Tensor) and bnis.numel() > 0: - bnids_np = bnis.cpu().numpy() - else: - bnids_np = None - except Exception: - bnids_np = None + # Extract single matched_node_id (single-node constraint) + raw_nid = getattr(match_result, "matched_node_id", -1) + single_node_id = int(raw_nid) if raw_nid is not None and raw_nid >= 0 else None + # Broadcast matched_node_id to per-block array for downstream compat + bnids_np = None + if single_node_id is not None and len(phys) > 0: + bnids_np = np.full(len(phys), single_node_id, dtype=np.uint32) return MatchResultAccel( num_ready_matched_blocks=match_result.num_ready_matched_blocks, num_matched_blocks=match_result.num_matched_blocks, @@ -106,6 +104,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: last_node=match_result.last_node, last_node_matched_length=match_result.last_node_matched_length, physical_blocks=phys, + matched_node_id=single_node_id, block_node_ids=bnids_np, matched_pos="remote" if self.device_type == DeviceType.REMOTE else "local", ) diff --git a/flexkv/cache/hie_cache_engine.py b/flexkv/cache/hie_cache_engine.py index 41f37ee801..fabe741cd9 100644 --- a/flexkv/cache/hie_cache_engine.py +++ b/flexkv/cache/hie_cache_engine.py @@ -201,37 +201,37 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> # physical blocks bnids_np = None + single_node_id = None if chosen is mr_remote: - #try to use DistributedRadixTree's block_node_ids - #if check fails, use LocalRadixTree's match result - nids = chosen.block_node_ids - nps = chosen.physical_blocks - # Convert tensors to numpy views (CPU) if present - if isinstance(nids, torch.Tensor) and nids.numel() > 0: - # For P2P mode (CPU/SSD), no PCFS conversion is needed - # Only convert to PCFS file_nodeids if device_type is REMOTE - if self.device_type == DeviceType.REMOTE: - bnids_np = self.nodeids_to_file_nodeids(nids.cpu().numpy(), nps.cpu().numpy()) - if bnids_np is None: - chosen = mr_local - matched_pos = "local" # Update matched_pos after fallback - else: - # For P2P mode, use node_ids directly - bnids_np = nids.cpu().numpy().astype(np.uint32) - #print(f"[REMOTE_MATCH {self.device_type.name}] Using remote data: block_ids={nps.cpu().numpy()[:min(4, len(nps))]}, node_ids={bnids_np[:min(4, len(bnids_np))]}") + # Extract single matched_node_id from CMatchResult (single-node constraint) + raw_node_id = getattr(chosen, "matched_node_id", -1) + if raw_node_id is not None and raw_node_id >= 0: + single_node_id = int(raw_node_id) + nps = chosen.physical_blocks + num_blocks = nps.shape[0] if isinstance(nps, torch.Tensor) else len(nps) + if num_blocks > 0: + # Broadcast single node_id to per-block array for downstream compat + raw_nids = np.full(num_blocks, single_node_id, dtype=np.uint32) + if self.device_type == DeviceType.REMOTE: + bnids_np = self.nodeids_to_file_nodeids(raw_nids, nps.cpu().numpy()) + if bnids_np is None: + chosen = mr_local + matched_pos = "local" + single_node_id = None + else: + bnids_np = raw_nids else: - bnids_np = None + # No valid matched_node_id → fall back to local if mr_remote.num_matched_blocks > 0: - #print(f"[REMOTE_MATCH {self.device_type.name}] Warning: remote matched but block_node_ids is empty, falling back to local") chosen = mr_local - matched_pos = "local" # Update matched_pos after fallback + matched_pos = "local" + single_node_id = None phys_np = chosen.physical_blocks.cpu().numpy() #maybe we should always not insert if self.device_type == DeviceType.CPU and matched_pos == "remote" and mr_local.num_matched_blocks > 0: insert_to_local_cpu_index = False else: insert_to_local_cpu_index = True - #TODO A big question is how to get the node id for peer_cpu and peer_ssd? return MatchResultAccel( num_ready_matched_blocks=int(chosen.num_ready_matched_blocks), num_matched_blocks=int(chosen.num_matched_blocks), @@ -239,9 +239,10 @@ def match_all(self, sequence_meta: SequenceMeta, gpu_matched_blocks: int = 0) -> last_node=chosen.last_node, last_node_matched_length=int(chosen.last_node_matched_length), physical_blocks=phys_np, + matched_node_id=single_node_id, block_node_ids=bnids_np, matched_pos=matched_pos, - matched_node_ids=bnids_np, # Set matched_node_ids for P2P transfer + matched_node_ids=bnids_np, # deprecated: kept for backward compat insert_to_local_cpu_index=insert_to_local_cpu_index, ) diff --git a/flexkv/common/type.py b/flexkv/common/type.py index 8b893f2eb5..25f17d2d93 100644 --- a/flexkv/common/type.py +++ b/flexkv/common/type.py @@ -11,9 +11,13 @@ class MatchResultAccel: last_node: Optional['CRadixNode'] = None last_node_matched_length: int = 0 physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + # Single node_id for all matched blocks (single-node matching constraint). + # -1 means no remote match. Preferred over the deprecated per-block arrays. + matched_node_id: Optional[int] = None + # deprecated: kept for backward compat; prefer matched_node_id block_node_ids: Optional[np.ndarray] = None matched_pos: Optional[str] = None - matched_node_ids: Optional[np.ndarray] = None #TODO id or ids? should we allow one req match results on multiple nodes? + matched_node_ids: Optional[np.ndarray] = None # deprecated: prefer matched_node_id insert_to_local_cpu_index: bool = True def __post_init__(self) -> None: diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 6b1cc5779d..012d91a1a4 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -757,3 +757,69 @@ def test_eviction_policy_reinsert_after_eviction(engine_cls): assert engine.match(seqs[2]).num_matched_blocks == 0, ( "C should be evicted to make room for re-inserted B" ) + + +# --------------------------------------------------------------------------- +# Tests – MatchResultAccel matched_node_id field +# --------------------------------------------------------------------------- +class TestMatchResultAccelNodeId: + """Verify the single-node matching constraint data structures.""" + + def test_matched_node_id_default_none(self): + """matched_node_id defaults to None when not set.""" + from flexkv.common.type import MatchResultAccel + result = MatchResultAccel( + num_ready_matched_blocks=0, + num_matched_blocks=0, + physical_blocks=np.array([], dtype=np.int64), + ) + assert result.matched_node_id is None + + def test_matched_node_id_set(self): + """matched_node_id can be set to a single integer.""" + from flexkv.common.type import MatchResultAccel + result = MatchResultAccel( + num_ready_matched_blocks=5, + num_matched_blocks=5, + physical_blocks=np.arange(5, dtype=np.int64), + matched_node_id=42, + ) + assert result.matched_node_id == 42 + assert isinstance(result.matched_node_id, int) + + def test_backward_compat_block_node_ids(self): + """block_node_ids (deprecated) still works alongside matched_node_id.""" + from flexkv.common.type import MatchResultAccel + bnids = np.array([42, 42, 42], dtype=np.uint32) + result = MatchResultAccel( + num_ready_matched_blocks=3, + num_matched_blocks=3, + physical_blocks=np.arange(3, dtype=np.int64), + matched_node_id=42, + block_node_ids=bnids, + ) + assert result.matched_node_id == 42 + assert np.all(result.block_node_ids == 42) + + +# --------------------------------------------------------------------------- +# Tests – CMatchResult matched_node_id field (C++ binding) +# --------------------------------------------------------------------------- +class TestCMatchResultNodeId: + """Verify the C++ CMatchResult exposes matched_node_id.""" + + def test_cmatch_result_default_node_id(self): + """CMatchResult.matched_node_id defaults to -1.""" + import torch + from flexkv.c_ext import CMatchResult + result = CMatchResult(0, 0, 0, None, None, torch.empty(0, dtype=torch.int64)) + assert result.matched_node_id == -1 + + def test_cmatch_result_with_node_id(self): + """CMatchResult.matched_node_id can be set via constructor.""" + import torch + from flexkv.c_ext import CMatchResult + blocks = torch.arange(3, dtype=torch.int64) + result = CMatchResult(3, 3, 0, None, None, blocks, 7) + assert result.matched_node_id == 7 + assert result.physical_blocks.shape[0] == 3