diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 82911eccfa3..9fd48cec2ce 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -14,13 +14,35 @@ # limitations under the License. """ +from dataclasses import dataclass from enum import Enum +from typing import Any, Optional from fastdeploy.utils import get_logger logger = get_logger("prefix_cache_manager", "cache_manager.log") +@dataclass +class AuxBlockDataSpec: + """ + Describes a type of auxiliary data bound to KVCache blocks. + CacheTransferManager iterates registered specs during swap/storage + to perform corresponding data transfers. + """ + + name: str + num_layers: int + per_token_size: int = 0 + block_size: int = 0 + dtype: str = "uint8" + swap_buffer: Optional[Any] = None + enabled: bool = True + + def get_storage_key(self, key_prefix: str, block_hash: str, rank: int) -> str: + return f"prefix{key_prefix}_{block_hash}_{rank}_{self.name}" + + class CacheStatus(Enum): """ cache status enum class @@ -56,6 +78,7 @@ def __init__( cache_status=CacheStatus.GPU, is_persistent=False, persistent_shared_count=0, + aux_data_names=None, ): """ Args: @@ -89,6 +112,7 @@ def __init__( self.cache_status = cache_status self.is_persistent = is_persistent self.persistent_shared_count = persistent_shared_count + self.aux_data_names = aux_data_names or [] self.req_id_set = set() def __lt__(self, other): @@ -102,7 +126,7 @@ def __lt__(self, other): else: return self.depth > other.depth - def __str__(self): + def __str__(self) -> str: """ return node info """ diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 8c5499cafde..5983266693b 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -24,6 +24,7 @@ import threading import time import traceback +import weakref from typing import List import numpy as np @@ -48,7 +49,7 @@ FileStore, MooncakeStore, ) -from fastdeploy.config import CacheConfig, SpeculativeConfig +from fastdeploy.config import CacheConfig, RoutingReplayConfig, SpeculativeConfig from fastdeploy.engine.request import ControlRequest, ControlResponse from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.inter_communicator.fmq import FMQ @@ -129,7 +130,11 @@ def parse_args(): ) parser.add_argument("--model_path", type=str, help="The path of model") + # Routing replay (R3) — single JSON arg, mirrors SpeculativeConfig pattern + parser.add_argument("--routing_replay_config", type=json.loads, default="{}", help="Routing replay config JSON") + args = parser.parse_args() + args.routing_replay_config = RoutingReplayConfig(args.routing_replay_config) return args @@ -241,6 +246,25 @@ def __init__(self, args): self._init_cpu_cache() if self.storage_backend_type is not None: self._init_storage(args) + + # Initialize auxiliary data specs (e.g., routing replay) + self.aux_data_specs = {} + self.routing_host_view = None + self.routing_swap_buffer = None + self.routing_replay_config = args.routing_replay_config + self.engine_worker_queue_port = args.engine_worker_queue_port + self._init_routing_aux_data() + + # Register finalizer to release routing SharedMemory on process exit. + # Must use a static method — callback must NOT hold a reference to self, + # otherwise the object can never be GC'd and the finalizer won't fire. + self._finalizer = weakref.finalize( + self, + CacheTransferManager._cleanup_routing_resources, + self.routing_swap_buffer, + self.routing_host_view, + ) + self._init_control() cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) @@ -307,6 +331,185 @@ def __init__(self, args): ) self.cache_transfer_inited_signal.value[self.rank] = 1 + def _init_routing_aux_data(self): + """Initialize routing auxiliary data buffers for swap sync.""" + routing_replay_config = self.routing_replay_config + if not routing_replay_config.enable_routing_replay: + return + + try: + from fastdeploy.cache_manager.cache_data import AuxBlockDataSpec + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingHostBufferView, + RoutingSwapBuffer, + ) + + num_moe_layers = routing_replay_config.num_moe_layers + moe_top_k = routing_replay_config.moe_top_k + routing_dtype = routing_replay_config.routing_dtype + + if num_moe_layers == 0 or moe_top_k == 0: + return + + spec = AuxBlockDataSpec( + name="routing", + num_layers=num_moe_layers, + per_token_size=moe_top_k, + block_size=self.block_size, + dtype=routing_dtype, + ) + + # Create routing swap buffer (for CPU blocks). + # Only rank 0 needs it — _swap_routing() only runs on rank 0. + if self.num_cpu_blocks > 0 and self.rank == 0: + dp_suffix = str(self.engine_worker_queue_port) + self.routing_swap_buffer = RoutingSwapBuffer( + num_cpu_blocks=self.num_cpu_blocks, + block_size=self.block_size, + num_moe_layers=num_moe_layers, + top_k=moe_top_k, + dtype=routing_dtype, + dp_suffix=dp_suffix, + ) + spec.swap_buffer = self.routing_swap_buffer + + # Attach to routing host buffer (SharedMemory created by Engine) + dp_suffix = str(self.engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = self.num_gpu_blocks * self.block_size + shape = (max_num_kv_tokens, num_moe_layers, moe_top_k) + try: + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=routing_dtype, shm_name=shm_name) + logger.info(f"[R3] CTM attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + logger.warning(f"[R3] CTM RoutingHostBuffer {shm_name} not found") + + self.aux_data_specs["routing"] = spec + logger.info(f"[R3] CTM registered routing aux data: layers={num_moe_layers}, top_k={moe_top_k}") + + except Exception as e: + logger.warning(f"[R3] CTM failed to init routing aux data: {e}") + + @staticmethod + def _cleanup_routing_resources(routing_swap_buffer, routing_host_view): + """Release routing SharedMemory on process exit. Called by weakref.finalize.""" + if routing_swap_buffer is not None: + routing_swap_buffer.close() + if routing_host_view is not None: + routing_host_view.close() + + def _swap_routing(self, gpu_block_ids, cpu_block_ids, direction): + """ + Swap routing data between routing_host_buffer and routing_swap_buffer. + Pure CPU-to-CPU numpy memcpy, no GPU DMA. + Only rank 0 performs this (routing buffers are cross-rank SharedMemory). + """ + if self.routing_host_view is None or self.routing_swap_buffer is None: + logger.warning( + f"[R3] _swap_routing skipped: host_view={self.routing_host_view is not None}, " + f"swap_buffer={self.routing_swap_buffer is not None}" + ) + return + if self.rank > 0: + return + bs = self.block_size + for gpu_bid, cpu_bid in zip(gpu_block_ids, cpu_block_ids): + gpu_start = gpu_bid * bs + gpu_end = gpu_start + bs + cpu_start = cpu_bid * bs + cpu_end = cpu_start + bs + if direction == "to_cpu": + self.routing_swap_buffer.buffer[cpu_start:cpu_end] = self.routing_host_view.buffer[gpu_start:gpu_end] + elif direction == "to_gpu": + self.routing_host_view.buffer[gpu_start:gpu_end] = self.routing_swap_buffer.buffer[cpu_start:cpu_end] + else: + raise ValueError(f"[R3] _swap_routing: unknown direction '{direction}', expected 'to_cpu' or 'to_gpu'") + logger.info( + f"[R3] _swap_routing {direction}: {len(gpu_block_ids)} blocks, " + f"gpu_ids={gpu_block_ids[:3]}{'...' if len(gpu_block_ids) > 3 else ''}, " + f"cpu_ids={cpu_block_ids[:3]}{'...' if len(cpu_block_ids) > 3 else ''}" + ) + + def _write_routing_to_storage(self, task_keys, gpu_block_ids): + """ + Write routing data from routing_host_buffer to storage backend. + Only for mooncake/file backends; only tp_rank=0 writes routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + routing_keys = [] + routing_ptrs = [] + routing_sizes = [] + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + block_data = self.routing_host_view.buffer[start:end] + if not block_data.flags["C_CONTIGUOUS"]: + block_data = np.ascontiguousarray(block_data) + routing_keys.append(key) + routing_ptrs.append(block_data.ctypes.data) + routing_sizes.append(per_block_bytes) + + if routing_keys: + self.storage_backend.batch_set( + keys=routing_keys, target_locations=routing_ptrs, target_sizes=routing_sizes + ) + logger.debug(f"[R3] Wrote {len(routing_keys)} routing blocks to storage") + except Exception as e: + logger.warning(f"[R3] Failed to write routing to storage: {e}") + + def _read_routing_from_storage(self, task_keys, gpu_block_ids): + """ + Read routing data from storage backend into routing_host_buffer. + Only for mooncake/file backends; only tp_rank=0 reads routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + target_slice = self.routing_host_view.buffer[start:end] + if not target_slice.flags["C_CONTIGUOUS"]: + # Need contiguous target for ctypes pointer + tmp = np.ascontiguousarray(target_slice) + result = self.storage_backend.get( + key=key, target_location=tmp.ctypes.data, target_size=per_block_bytes + ) + if result is not None and result >= 0: + self.routing_host_view.buffer[start:end] = tmp + else: + self.storage_backend.get( + key=key, target_location=target_slice.ctypes.data, target_size=per_block_bytes + ) + + logger.debug(f"[R3] Read {len(task_keys)} routing blocks from storage") + except Exception as e: + logger.warning(f"[R3] Failed to read routing from storage: {e}") + def _init_control(self): dp_rank = self.local_data_parallel_id tp_rank = self.rank @@ -809,6 +1012,9 @@ def read_storage_task(self, task: ReadStorageTask): logger.info( f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" ) + # Read routing data from storage for matched blocks + matched_keys = task.keys[: len(valid_gpu_block_ids)] + self._read_routing_from_storage(matched_keys, valid_gpu_block_ids) except Exception as e: logger.error( f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}" @@ -1000,6 +1206,9 @@ def write_back_storage_task(self, task: WriteStorageTask): logger.info( f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) + # Write routing data to storage (shares dedup with KVCache) + remaining_keys = task.keys[match_block_num:] + self._write_routing_to_storage(remaining_keys, gpu_block_ids) except Exception as e: logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") gpu_block_ids = [] @@ -1375,6 +1584,10 @@ def _transfer_data( 0, ) + # Routing: routing_host_buffer → routing_swap_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_cpu") + elif event_type.value == CacheStatus.SWAP2GPU.value: swap_cache_all_layers( self.gpu_cache_k_tensors, @@ -1413,6 +1626,11 @@ def _transfer_data( self.device, 1, ) + + # Routing: routing_swap_buffer → routing_host_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_gpu") + else: logger.warning( f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported" diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 3d88e199d27..e03a07baba9 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -293,6 +293,13 @@ def launch_cache_manager( else: storage_arg_str = " " + # Compute routing replay args for CTM — single JSON arg + routing_replay_config = getattr(self.config, "routing_replay_config", None) + if routing_replay_config is not None and routing_replay_config.enable_routing_replay: + routing_arg_str = f" --routing_replay_config '{routing_replay_config.to_json_string()}'" + else: + routing_arg_str = "" + if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend: for i in range(tensor_parallel_size): launch_cmd = ( @@ -324,6 +331,7 @@ def launch_cache_manager( + f" --write_policy {cache_config.write_policy}" + f" --max_model_len {self.config.model_config.max_model_len}" + f" --model_path {self.config.model_config.model}" + + routing_arg_str + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" ) logger.info(f"Launch cache transfer manager, command:{launch_cmd}") diff --git a/fastdeploy/cache_manager/routing_cache_manager.py b/fastdeploy/cache_manager/routing_cache_manager.py new file mode 100644 index 00000000000..51242b7a1c5 --- /dev/null +++ b/fastdeploy/cache_manager/routing_cache_manager.py @@ -0,0 +1,283 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +import multiprocessing +import multiprocessing.shared_memory +from typing import Optional + +import numpy as np +from paddleformers.utils.log import logger + + +class RoutingHostBuffer: + """ + Manages routing_host_buffer (corresponds to KVCache GPU cache). + Indexed by gpu_block_id * block_size + offset. + Shared across processes via POSIX SharedMemory. + Each DP rank creates its own instance; name includes dp_suffix. + """ + + def __init__( + self, num_gpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_gpu_tokens = num_gpu_blocks * block_size + self.shape = (max_num_gpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_host_buffer.{dp_suffix}" + # Clean up stale SharedMemory from previous crashed process + try: + stale = multiprocessing.shared_memory.SharedMemory(name=self.shm_name, create=False) + stale.close() + stale.unlink() + logger.warning(f"[R3] Cleaned up stale SharedMemory: {self.shm_name}") + except FileNotFoundError: + pass + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = -1 # unsigned wrap: uint8→255, uint16→65535, uint32→4294967295 + + self._owner = True + logger.info( + f"[R3] Created RoutingHostBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + """Close and unlink SharedMemory. Only the owner (creator) unlinks.""" + self.shm.close() + if self._owner: + self.shm.unlink() + self._owner = False + + +class RoutingHostBufferView: + """Read/write view of routing_host_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def scatter(self, slot_mapping: np.ndarray, data: np.ndarray): + """Scatter GPU buffer data to corresponding slots (Worker calls this).""" + self.buffer[slot_mapping] = data + + def gather(self, slot_mapping: np.ndarray) -> np.ndarray: + """Gather data from specified slots (TokenProcessor calls this).""" + return self.buffer[slot_mapping].copy() + + def close(self): + self.shm.close() + + +class RoutingSwapBuffer: + """ + Manages routing_swap_buffer (corresponds to KVCache CPU cache). + Indexed by cpu_block_id * block_size + offset. + CacheTransferManager creates this; shared via SharedMemory. + """ + + def __init__( + self, num_cpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_cpu_tokens = num_cpu_blocks * block_size + self.shape = (max_num_cpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_swap_buffer.{dp_suffix}" + # Clean up stale SharedMemory from previous crashed process + try: + stale = multiprocessing.shared_memory.SharedMemory(name=self.shm_name, create=False) + stale.close() + stale.unlink() + logger.warning(f"[R3] Cleaned up stale SharedMemory: {self.shm_name}") + except FileNotFoundError: + pass + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = -1 # unsigned wrap: uint8→255, uint16→65535, uint32→4294967295 + + self._owner = True + logger.info( + f"[R3] Created RoutingSwapBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + """Close and unlink SharedMemory. Only the owner (creator) unlinks.""" + self.shm.close() + if self._owner: + self.shm.unlink() + self._owner = False + + +class RoutingSwapBufferView: + """Read/write view of routing_swap_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def close(self): + self.shm.close() + + +def split_request_id(request_id: str) -> str: + """ + Split the request id to get rollout id. + + request_id: "chatcmpl-request.user-uuid" + rollout_id: "request.user" + example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" + -> "xxx_xxx_epoch_15:2:2:1" + """ + chat_type, tmp_str = request_id.split("-", 1) + assert ( + chat_type == "chatcmpl" + ), "Rollout Routing Replay only supports chatcmpl. Please check request type and userid settings." + reversed_tmp_str = tmp_str[::-1].split("-", 5) + rollout_id = reversed_tmp_str[-1][::-1] + return rollout_id + + +class RoutingCacheManager: + """ + Engine-side stateless routing data manager. + Does NOT maintain request mapping — request state is fully managed by Scheduler. + Responsible for: SharedMemory creation/destruction, routing data gather, return mode dispatch. + """ + + def __init__(self, fd_config, num_gpu_blocks: int): + routing_replay_config = fd_config.routing_replay_config + self.num_moe_layers = routing_replay_config.num_moe_layers + self.moe_top_k = routing_replay_config.moe_top_k + self.routing_dtype = routing_replay_config.routing_dtype + self.only_last_turn = routing_replay_config.only_last_turn + self.use_fused_put = routing_replay_config.use_fused_put + self.block_size = fd_config.cache_config.block_size + self.return_mode = ( + routing_replay_config.routing_store_type + ) # "local" / "rdma" → p2pstore; "response" → attach to RequestOutput + + dp_suffix = str(fd_config.parallel_config.local_engine_worker_queue_port) + + # Create SharedMemory routing_host_buffer + self.host_buffer = RoutingHostBuffer( + num_gpu_blocks=num_gpu_blocks, + block_size=self.block_size, + num_moe_layers=self.num_moe_layers, + top_k=self.moe_top_k, + dtype=self.routing_dtype, + dp_suffix=dp_suffix, + ) + + # Host view for gather operations + self.host_view = RoutingHostBufferView( + shape=self.host_buffer.shape, + dtype=self.routing_dtype, + shm_name=self.host_buffer.shm_name, + ) + + # Initialize store wrapper for p2pstore mode + self._store_wrapper = None + if self.return_mode in ("local", "rdma"): + from fastdeploy.cache_manager.routing_store import StoreWrapper + + self._store_wrapper = StoreWrapper(fd_config=fd_config) + self._store_wrapper.start_store_warpper() + + logger.info( + f"[R3] RoutingCacheManager initialized: return_mode={self.return_mode}, " + f"host_buffer shape={self.host_buffer.shape}" + ) + + def gather_routing_for_request(self, block_table, seq_len: int) -> np.ndarray: + """ + Gather complete routing data for a request from routing_host_buffer. + + Args: + block_table: List of block IDs for the request + seq_len: Total sequence length + + Returns: + routing_data: [seq_len, num_moe_layers, top_k] numpy array + """ + num_blocks = math.ceil(seq_len / self.block_size) + block_ids = block_table[:num_blocks] + positions = np.arange(seq_len) + block_indices = positions // self.block_size + offsets = positions % self.block_size + slot_mapping = np.array(block_ids)[block_indices] * self.block_size + offsets + return self.host_view.gather(slot_mapping) + + def on_request_finished(self, request_id: str, block_table, seq_len: int) -> Optional[np.ndarray]: + """ + Unified entry point when a request finishes. Called by TokenProcessor on EOS detection. + Scheduler/TokenProcessor passes request_id, block_table, seq_len. + + Returns: + - "response" mode: routing_data numpy array (caller attaches to RequestOutput) + - "local"/"rdma" mode: None (submitted to StoreWrapper internally) + """ + routing_data = self.gather_routing_for_request(block_table, seq_len) + + if self._store_wrapper is not None: + # P2PStore mode: submit to store + rollout_id = split_request_id(request_id) + # Transpose to [num_moe_layers, seq_len, top_k] for store compatibility + # TODO(gongshaotian): Delete redundant transpose + routing_data = np.ascontiguousarray(routing_data.transpose(1, 0, 2)) + + if self.use_fused_put: + self._store_wrapper.submit_put_task(routing_indices=routing_data, rollout_id=rollout_id) + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + else: + for layer_id in range(self.num_moe_layers): + layer_buffer = routing_data[layer_id] + self._store_wrapper.submit_put_task( + routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id + ) + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) + return None + else: + # Response mode: return data for caller to attach to RequestOutput + return routing_data + + def reset(self): + """Reset SharedMemory buffer. Used during RL round cleanup.""" + self.host_buffer.buffer[:] = -1 + + def close(self): + """Clean up SharedMemory resources.""" + if self.host_view is not None: + self.host_view.close() + self.host_view = None + if self.host_buffer is not None: + self.host_buffer.close() + self.host_buffer = None diff --git a/fastdeploy/cache_manager/routing_store.py b/fastdeploy/cache_manager/routing_store.py new file mode 100644 index 00000000000..4c16fd10ad0 --- /dev/null +++ b/fastdeploy/cache_manager/routing_store.py @@ -0,0 +1,512 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import atexit +import functools +import multiprocessing +import os +import shutil +import threading +import time +import traceback +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process, Queue +from typing import Optional, TypedDict + +import numpy as np +import paddle +from paddleformers.utils.log import logger + +from fastdeploy.config import RoutingReplayConfig + + +class StoreTask(TypedDict): + task_type: str + key: str + data: np.ndarray + + +class StoreWrapper(object): + def __init__(self, fd_config) -> None: + super().__init__() + self.fd_config = fd_config + + # Initialize task queue + moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index + max_num_seqs = fd_config.scheduler_config.max_num_seqs + self.queue_max_size = moe_layer_num * max_num_seqs * 1000 + + self.manager = multiprocessing.Manager() + self._task_queue = self.manager.Queue(maxsize=self.queue_max_size) + + self._monitor_thread: threading.Thread = None + self._stop_monitor = threading.Event() + + # Initialize consumer process + self._routing_store_process = StoreProcess( + task_queue=self._task_queue, + routing_replay_config=self.fd_config.routing_replay_config, + max_model_len=self.fd_config.model_config.max_model_len, + ) + self._store_process_running = False + + # Register atexit handler + atexit.register(self.shutdown) + + def shutdown(self): + """ """ + if not self._store_process_running: + return + self._store_process_running = False + + # Stop the monitor thread + self._stop_monitor.set() + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=3.0) + + # Put a sentinel value to signal the consumer to stop + if self._routing_store_process and self._routing_store_process.is_alive(): + try: + self._task_queue.put_nowait(None) + except Exception as e: + logger.info(f"Could not put sentinel into queue: {e}") + + if self._routing_store_process and self._routing_store_process.is_alive(): + # Wait for all tasks to be processed + self._routing_store_process.join(timeout=10.0) + if self._routing_store_process.is_alive(): + self._routing_store_process.close() + self._routing_store_process.join() + + self._task_queue.join() + self.manager.shutdown() + self._store_process_running = False + + def start_store_warpper(self): + """ """ + if self._store_process_running: + return + self._store_process_running = True + + # Start monitor thread + self._stop_monitor.clear() + self._monitor_thread = threading.Thread(target=self._monitor_queue_load, daemon=True) + self._monitor_thread.start() + + # Start Routing Store Wrapper in sub process + self._routing_store_process.start() + + def _monitor_queue_load(self): + """ """ + while not self._stop_monitor.is_set(): + time.sleep(2.0) + if not self._store_process_running: + break + qsize = self._task_queue.qsize() + + # Alarm when the task exceeds 80% of the queue capacity + if qsize > self.queue_max_size * 0.8: + logger.warning( + f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " + "Consider increasing max_workers or queue_max_size." + ) + logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") + + def submit_put_task(self, routing_indices: np.ndarray, rollout_id: str, layer_idx: int = None) -> None: + """Submit a put task to the task queue""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + + start_time = time.perf_counter() + if layer_idx is not None: + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + else: + rdma_rollout_key = rollout_id + + task: StoreTask = {"task_type": "put", "key": rdma_rollout_key, "data": routing_indices} + + try: + self._task_queue.put_nowait(task) + except Exception: + raise RuntimeError(f"Queue is FULL. Dropping put task for key: {rdma_rollout_key}. ") + logger.info(f"[R3] Submit put task for key: {rdma_rollout_key}, cost time: {time.perf_counter()-start_time} s") + + def submit_clear_store_task(self) -> None: + """Submit clear store task""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + + start_time = time.perf_counter() + task: StoreTask = {"task_type": "clear_store", "key": None, "data": None} + + try: + self._task_queue.put_nowait(task) + # Wait for the task to be processed + self._task_queue.join() + except Exception: + raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") + + def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: + """Submit clear prefix batch task""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + prefix_batch_id = self.get_needed_clear_ids(rollout_id) + if prefix_batch_id is None: + return + start_time = time.perf_counter() + if layer_idx is not None: + rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" + else: + rdma_rollout_key = prefix_batch_id + + task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} + try: + self._task_queue.put_nowait(task) + except Exception: + raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + logger.info( + f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" + ) + + def get_needed_clear_ids(self, rollout_id: str) -> Optional[str]: + """ + Generate the prefix IDs for all closed multi-round tasks. + rollout_id: "xxx_xxx_epoch_15:2:2:1" + example: xxx_xxx_data_id:gen_id:turn_id:segment_id + """ + reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = rollout_id[::-1].split(":", 2) + prefix_gen_id = reversed_prefix_gen_id[::-1] + turn_id = eval(reversed_turn_id[::-1]) + segment_id = eval(reversed_segment_id[::-1]) + + assert turn_id >= 0 and segment_id >= 0 + prefix_batch = None + if turn_id > 0: + prefix_batch = f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}" + return prefix_batch + + +class StoreProcess(Process): + def __init__(self, task_queue: Queue, routing_replay_config: RoutingReplayConfig, max_model_len: int) -> None: + super().__init__() + self.max_model_len = max_model_len + self._task_queue = task_queue + self.routing_replay_config = routing_replay_config + self.max_workers = 5 + self._closed = False + + # Note: _routing_store and _event_loop_thread must be initialized in run() + # because they cannot be properly inherited after fork() + self._routing_store = None + self._event_loop_thread = None + + def run(self): + logger.info(f"[R3] Start Running Store Wrapper in sub process {os.getpid()}") + + # Initialize routing store in subprocess + self._routing_store = get_routing_store(routing_replay_config=self.routing_replay_config) + + # Initialize event loop thread in subprocess + self._event_loop_thread = AsyncEventLoopThread() + self._event_loop_thread.start() + if not self._event_loop_thread._started_event.wait(timeout=5.0): + raise RuntimeError("Failed to start async event loop thread in subprocess") + + clear_store_task = StoreTask({"task_type": "clear_store", "key": None, "data": None}) + self._task_queue.put_nowait(clear_store_task) + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + while not self._closed: + try: + task = self._task_queue.get() + if task is None: # Sentinel + self._task_queue.task_done() + break + + if task["task_type"] == "put": + future = executor.submit(self.process_put_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + elif task["task_type"] == "clear_store": + future = executor.submit(self.process_clear_store_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + elif task["task_type"] == "clear_prefix_batch": + future = executor.submit(self.process_clear_prefix_batch_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + except Exception as e: + self._task_queue.task_done() + raise RuntimeError(f"Error during processing task. {e}") + + logger.info("RoutingReplay Consumer Process Shutdown.") + + def process_put_task(self, store_task: StoreTask) -> None: + try: + # TODO(gongshaotian): delete this after trainer support dynamic len + store_task["data"] = self.pad_routing_indices(store_task["data"]) + coro_obj = self._routing_store.put(routing_key=store_task["key"], routing_indices=store_task["data"]) + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error submitting put task: {e}") + traceback.print_exc() + raise + + def process_clear_store_task(self, store_task: StoreTask) -> None: + try: + coro_obj = self._routing_store.clear_store() + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error during processing clear store task. {e}") + traceback.print_exc() + raise + + def process_clear_prefix_batch_task(self, store_task: StoreTask) -> None: + try: + coro_obj = self._routing_store.clear_prefix_batch(routing_prefix_key=store_task["key"]) + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error submitting clear_prefix_batch task: {e}") + traceback.print_exc() + raise + + def _on_async_task_completed(self, task, future): + """ """ + try: + # result = future.result() + logger.info(f"[R3] Async task completed: {task['task_type']}, key: {task['key']}") + except Exception as e: + logger.error(f"[R3] Async task failed: {task['task_type']}, key: {task['key']}, error: {e}") + traceback.print_exc() + raise + + def close(self): + """Close the store process""" + self._closed = True + if hasattr(self, "_event_loop_thread"): + self._event_loop_thread.stop() + + def pad_routing_indices(self, routing_indices: np.ndarray) -> np.ndarray: + """Pad routing indices of the request levevl to max model len""" + routing_shape = routing_indices.shape + if len(routing_shape) == 2: # [token, topk] + pad_array = np.full( + shape=[(self.max_model_len - routing_indices.shape[0]), routing_indices.shape[1]], + fill_value=-1, + dtype=routing_indices.dtype, + ) + return np.concatenate([routing_indices, pad_array], axis=0) + + elif len(routing_shape) == 3: # [layer, token, topk] + pad_array = np.full( + shape=[ + routing_indices.shape[0], + (self.max_model_len - routing_indices.shape[1]), + routing_indices.shape[2], + ], + fill_value=-1, + dtype=routing_indices.dtype, + ) + return np.concatenate([routing_indices, pad_array], axis=1) + else: + raise ValueError(f"Invalid routing indices shape: {routing_shape}") + + +class AsyncEventLoopThread(threading.Thread): + def __init__(self): + super().__init__(daemon=True) + self._loop = None + self._started_event = threading.Event() + self._closed = False + + def run(self): + """Run the async event loop""" + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Set the event loop to be started + self._started_event.set() + logger.info("[EventLoopThread] Event loop started, running forever...") + + try: + self._loop.run_forever() + logger.info("[EventLoopThread] Event loop stopped") + except Exception as e: + logger.error(f"[EventLoopThread] Event loop exception: {e}") + traceback.print_exc() + finally: + logger.info("[EventLoopThread] Closing event loop") + self._loop.close() + + def submit_coroutine(self, coro, callback=None): + """Thread safely submit coroutine to event loop""" + if self._closed: + raise RuntimeError("Event loop thread is closed") + if not self._started_event.wait(timeout=5.0): + raise RuntimeError("Event loop failed to start within 5 seconds") + + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + if callback: + + def wrapped_callback(f): + try: + callback(f) + except Exception as e: + logger.error(f"Error in callback: {e}") + traceback.print_exc() + + future.add_done_callback(wrapped_callback) + return future + + def stop(self): + """Stop the event loop""" + if not self._closed: + self._closed = True + if self._loop: + self._loop.call_soon_threadsafe(self._loop.stop) + + +class RoutingStoreBase(ABC): + """Base class for routing store""" + + def __init__(self, routing_replay_config: RoutingReplayConfig) -> None: + self.routing_replay_config = routing_replay_config + + @abstractmethod + async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: + """Put the routing indices into store""" + raise NotImplementedError + + @abstractmethod + async def clear_store( + self, + ): + """Clear the routing indices store""" + raise NotImplementedError + + @abstractmethod + async def clear_prefix_batch(self, routing_prefix_key: str): + """Clear the routing indices""" + raise NotImplementedError + + +class RoutingStoreLocal(RoutingStoreBase): + """Routing Store using local memory""" + + def __init__(self, routing_replay_config) -> None: + super().__init__(routing_replay_config=routing_replay_config) + self.local_store_dir = routing_replay_config.local_store_dir + os.makedirs(self.local_store_dir, exist_ok=True) + + async def put( + self, + routing_key: str, + routing_indices: np.ndarray, + ) -> None: + """Put the routing indices into store""" + # TODO(gongshaotian) covert ./store_dir/routing_key/layer_id.pdtensor to ./store_dir/routing_key.pdtensor + time_before_put = time.perf_counter() + + if len(routing_indices.shape) == 2: + re_layer_id, re_rollout_id = routing_key[::-1].split("_", 1) + rollout_id = re_rollout_id[::-1] + layer_id = re_layer_id[::-1] + request_path = os.path.join(self.local_store_dir, rollout_id) + file_path = os.path.join(request_path, f"layer_{layer_id}.pdtensor") + elif len(routing_indices.shape) == 3: + request_path = os.path.join(self.local_store_dir, routing_key) + file_path = os.path.join(request_path, f"{routing_key}.pdtensor") + else: + raise ValueError(f"Invalid routing indices shape: {routing_indices.shape}") + + paddle.save(routing_indices, file_path) + logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s") + + async def clear_store(self): + """Clear the routing indices store""" + if os.path.isdir(self.local_store_dir): + shutil.rmtree(self.local_store_dir) + + logger.info("[R3] Clear routing store.") + + async def clear_prefix_batch(self, routing_prefix_key: str): + """Clear the routing indices""" + raise NotImplementedError + + +class RoutingStoreRDMA(RoutingStoreBase): + """Routing Store using RDMA""" + + def __init__(self, routing_replay_config) -> None: + super().__init__(routing_replay_config=routing_replay_config) + try: + # Only used in RLHF + from p2pstore import P2PClient, P2PConfig + except ModuleNotFoundError: + raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") + + rdma_store_server = routing_replay_config.rdma_store_server + p2pConfig = P2PConfig(metadata_server=rdma_store_server) + self.p2p_client = P2PClient(p2pConfig) + + async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: + """Put the routing indices into store""" + time_before_put = time.perf_counter() + if len(routing_indices.shape) == 3: + # NOTE(gongshaotian) Fused put with bytes data + routing_bytes = routing_indices.tobytes() + result = await self.p2p_client.put(routing_key, routing_bytes) + else: + result = await self.p2p_client.put(routing_key, routing_indices) + logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") + return result + + async def clear_prefix_batch(self, routing_prefix_key: str): + time_before_clear = time.perf_counter() + result = await self.p2p_client.delete_batch([routing_prefix_key]) + logger.info( + f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" + ) + return result + + async def clear_store(self): + """Clear the routing indices store""" + time_before_clear = time.perf_counter() + result = await self.p2p_client.clear() + logger.info(f"[R3] Clear routing store cost is {time.perf_counter()-time_before_clear}s.") + return result + + +def get_routing_store(routing_replay_config: RoutingReplayConfig) -> RoutingStoreBase: + if routing_replay_config.routing_store_type == "local": + return RoutingStoreLocal(routing_replay_config=routing_replay_config) + elif routing_replay_config.routing_store_type == "rdma": + return RoutingStoreRDMA(routing_replay_config=routing_replay_config) + else: + raise ValueError( + f"Invalid routing store type: '{routing_replay_config.routing_store_type}'. " + "Valid types are: 'local', 'rdma'" + ) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index d59d5694cbf..f6214e952d8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1665,7 +1665,7 @@ def __init__(self, args) -> None: self.enable_routing_replay: bool = False - # Routing store type: local/rdma + # Routing return mode: "local" (file store) / "rdma" (P2PStore) / "response" (attach to RequestOutput) self.routing_store_type: str = "local" # Local routing store @@ -1680,11 +1680,37 @@ def __init__(self, args) -> None: # Fused routing of all layers self.use_fused_put: bool = False + # Auto-filled by FDConfig from ModelConfig (do not set manually) + self.routing_dtype: str = "" # "uint8" / "uint16" / "uint32" + self.num_moe_layers: int = 0 + self.moe_top_k: int = 0 + if args is not None: for key, value in args.items(): if hasattr(self, key) and value != "None": setattr(self, key, value) + def postprocess(self, model_config: "ModelConfig") -> None: + """Fill computed fields from ModelConfig. Must be called after model-specific + field unification (e.g. GLM's first_k_dense_replace → moe_layer_start_index).""" + if not self.enable_routing_replay: + return + self.num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index + if model_config.architectures[0] == "Glm4MoeForCausalLM": + self.moe_top_k = model_config.num_experts_per_tok + else: + self.moe_top_k = model_config.moe_k + num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts + total_number = num_experts + 1 # +1 for reserved fill value + if total_number <= 255: + self.routing_dtype = "uint8" + elif total_number <= 65535: + self.routing_dtype = "uint16" + elif total_number <= 4294967295: + self.routing_dtype = "uint32" + else: + raise ValueError(f"num_experts {num_experts} exceeds uint32 range") + def to_json_string(self): """ Convert routing replay config to json string. @@ -1747,6 +1773,8 @@ def __init__( self.router_config: RouterConfig = router_config self.routing_replay_config = routing_replay_config self.deploy_modality: DeployModality = deploy_modality if deploy_modality is not None else DeployModality.MIXED + + # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs if self.speculative_config is not None and self.speculative_config.method in ["mtp", "suffix"]: max_capture_shape = self.scheduler_config.max_num_seqs * ( @@ -1905,6 +1933,9 @@ def postprocess(self): # The first moe layer id of GLM4.5 model self.model_config.moe_layer_start_index = self.model_config.first_k_dense_replace + if self.routing_replay_config is not None: + self.routing_replay_config.postprocess(self.model_config) + if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node or self.node_rank == 0: self.is_master = True self.master_ip = "0.0.0.0" diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 22f0a278874..edbea927dce 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2398,10 +2398,47 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingCacheManager (SharedMemory) after num_gpu_blocks is known + self.routing_cache_manager = None + if self.cfg.routing_replay_config.enable_routing_replay: + self._init_routing_cache_manager(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + def _init_routing_cache_manager(self, num_gpu_blocks: int): + """Create RoutingCacheManager (includes SharedMemory host buffer) after profiling.""" + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingCacheManager, + RoutingHostBufferView, + ) + + self.routing_cache_manager = RoutingCacheManager( + fd_config=self.cfg, + num_gpu_blocks=num_gpu_blocks, + ) + + # Pass routing_cache_manager to TokenProcessor for local/rdma store dispatch + self.token_processor.routing_cache_manager = self.routing_cache_manager + + # Set routing_host_view on resource_manager for PD disaggregation (D side) + if hasattr(self, "resource_manager") and hasattr(self.resource_manager, "routing_host_view"): + rrc = self.cfg.routing_replay_config + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = num_gpu_blocks * self.cfg.cache_config.block_size + shape = (max_num_kv_tokens, rrc.num_moe_layers, rrc.moe_top_k) + try: + self.resource_manager.routing_host_view = RoutingHostBufferView( + shape=shape, dtype=rrc.routing_dtype, shm_name=shm_name + ) + except FileNotFoundError: + self.llm_logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {shm_name} not found for resource_manager" + ) + def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e0f27cb0509..25199aa5af9 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -140,6 +140,12 @@ def start(self, api_server_pid=None): self.engine.create_data_processor() self.data_processor = self.engine.data_processor + # Create RoutingCacheManager when skipping profiling (num_gpu_blocks_override is set) + if not self.do_profile and self.cfg.routing_replay_config.enable_routing_replay: + num_gpu_blocks = self.cfg.cache_config.num_gpu_blocks_override + if num_gpu_blocks is not None: + self.engine._init_routing_cache_manager(num_gpu_blocks) + # If block numer is specified and model is deployed in mixed mode, start cache manager first if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): @@ -449,6 +455,11 @@ def _exit_sub_services(self): if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() + if hasattr(self, "engine") and hasattr(self.engine, "routing_cache_manager"): + if self.engine.routing_cache_manager is not None: + self.engine.routing_cache_manager.close() + self.engine.routing_cache_manager = None + if hasattr(self, "dp_processed"): for p in self.dp_processed: console_logger.info(f"Waiting for worker {p.pid} to exit") @@ -726,6 +737,11 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingCacheManager (SharedMemory) before starting cache service + if self.cfg.routing_replay_config.enable_routing_replay: + self.engine._init_routing_cache_manager(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): device_ids = self.cfg.parallel_config.device_ids.split(",") diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 6a3cf152d66..f916a9f5f4d 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -1017,6 +1017,7 @@ def __init__( self.ic_req_data = ic_req_data self.prompt_token_ids_len = prompt_token_ids_len self.trace_carrier = trace_carrier + self.routing_data = None # Optional[np.ndarray], [seq_len, num_moe_layers, top_k] if prompt_token_ids is None: self.prompt_token_ids = [] @@ -1112,12 +1113,15 @@ def from_dict(cls, d: dict): d.pop("metrics", None) metrics = None trace_carrier = d.pop("trace_carrier", {}) - return RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + routing_data = d.pop("routing_data", None) + obj = RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + obj.routing_data = routing_data + return obj def to_dict(self): """convert RequestOutput into a serializable dict""" - return { + d = { "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, @@ -1135,6 +1139,9 @@ def to_dict(self): "prompt_token_ids_len": self.prompt_token_ids_len, "trace_carrier": self.trace_carrier, } + if self.routing_data is not None: + d["routing_data"] = self.routing_data + return d def get(self, key: str, default_value=None): if hasattr(self, key): diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 47abce202a0..40d7f8f492a 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -208,6 +208,7 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.processor_cache = None + self.routing_host_view = None # Set by Engine after RoutingHostBuffer creation if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0: max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) @@ -1362,8 +1363,38 @@ def add_prefilled_request(self, request_output: RequestOutput): request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time request.metrics = request_output.metrics + + # [R3] Write P's prefill routing data into D's routing_host_buffer + if ( + self.routing_host_view is not None + and hasattr(request_output, "routing_data") + and request_output.routing_data is not None + ): + try: + self._write_prefill_routing_to_host_buffer(request, request_output.routing_data) + except Exception as e: + llm_logger.warning(f"[R3] Failed to write prefill routing for {request_output.request_id}: {e}") + self.running.append(request) + def _write_prefill_routing_to_host_buffer(self, request, routing_data): + """ + Write P's prefill routing data into D's routing_host_buffer. + Uses D's block_tables to compute slot_mapping. + """ + import math + + seq_len = routing_data.shape[0] + block_size = self.config.cache_config.block_size + num_blocks_needed = math.ceil(seq_len / block_size) + block_ids = request.block_tables[:num_blocks_needed] + + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + self.routing_host_view.scatter(slot_mapping, routing_data) + def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: self.cache_manager.release_block_ids(request) diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index e02ce5343be..e61148a1cf4 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -285,6 +285,7 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo + routed_experts: Optional[str] = None class LogProbEntry(BaseModel): @@ -390,6 +391,7 @@ class CompletionResponse(BaseModel): model: str choices: List[CompletionResponseChoice] usage: UsageInfo + routed_experts: Optional[str] = None class CompletionLogprobs(BaseModel): diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 59e85e4d541..68f57b75bf4 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -585,6 +585,7 @@ async def chat_completion_full_generator( sampling_mask_list = [[] for _ in range(num_choices)] speculate_metrics = [None for _ in range(num_choices)] choices = [] + routing_data_result = None while num_choices > 0: if self.engine_client.check_model_weight_status(): return ErrorResponse( @@ -705,6 +706,15 @@ async def chat_completion_full_generator( speculate_metrics=speculate_metrics[idx], ) choices.append(choice) + if data.get("routing_data") is not None: + import base64 + + import numpy as np + + rd = data["routing_data"] + if not isinstance(rd, np.ndarray): + rd = np.array(rd) + routing_data_result = base64.b64encode(rd.tobytes()).decode("utf-8") finally: tracing.trace_req_finish(request_id) await self.engine_client.connection_manager.cleanup_request(request_id) @@ -735,6 +745,7 @@ async def chat_completion_full_generator( model=model_name, choices=choices, usage=usage, + routed_experts=routing_data_result, ) trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", "")) api_server_logger.info(f"Chat response: {res.model_dump_json()}") diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 7bd04f4ecab..73f9dbf2af2 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -773,12 +773,24 @@ def request_output_to_completion_response( ) del request + routed_experts = None + if final_res_batch and final_res_batch[-1].get("routing_data") is not None: + import base64 + + import numpy as np + + rd = final_res_batch[-1]["routing_data"] + if not isinstance(rd, np.ndarray): + rd = np.array(rd) + routed_experts = base64.b64encode(rd.tobytes()).decode("utf-8") + return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, + routed_experts=routed_experts, ) async def _call_process_response_dict(self, res, request, stream): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index a5b4bf225d4..3993eb2439c 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -208,8 +208,6 @@ "FD_PD_REORDER": lambda: int(os.getenv("FD_PD_REORDER", "0")), # Whether to probe MoE routing probabilities and use Fleet's fused SwiGLU kernel. "FD_MOE_PROB_IN_ADVANCE": lambda: bool(int(os.getenv("FD_MOE_PROB_IN_ADVANCE", "0"))), - # Suspend rollouting routing replay - "FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))), # Whether to enable v1 weight updating, which utilizes ZMQ/EngineWorkerQueue/EngineCacheQueue/FMQs # to pass control requests and responses. # When v1 is enabled, the legacy /clear_load_weight and /update_model_weight diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index d625326c06f..f72eb7cbf7c 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -146,8 +146,8 @@ class ForwardMeta: caches: Optional[list[paddle.Tensor]] = None # Flag of profile run is_dummy_or_profile_run: bool = False - # Routing Replay table buffer - routing_replay_table: Optional[paddle.Tensor] = None + # GPU transient routing buffer [max_num_batched_tokens, num_moe_layers, top_k] + gpu_routing_buffer: Optional[paddle.Tensor] = None # chunked MoE related moe_num_chunk: int = 1 diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 95024f3c7a5..61e20f77b74 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -28,7 +28,7 @@ ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( - save_routing_to_buffer, + save_routing_to_buffer_v2, ) from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import h2d_copy, slice_fn @@ -719,14 +719,11 @@ def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = topk_ids_hookfunc = None if self.enable_routing_replay: # When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None. - if forward_meta is not None and forward_meta.routing_replay_table is not None: + if forward_meta is not None and forward_meta.gpu_routing_buffer is not None: moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index topk_ids_hookfunc = partial( - save_routing_to_buffer, - routing_replay_table=forward_meta.routing_replay_table, - batch_id_per_token=forward_meta.batch_id_per_token, - seq_lens_decoder=forward_meta.seq_lens_decoder, - cu_seqlens_q=forward_meta.cu_seqlens_q, + save_routing_to_buffer_v2, + gpu_routing_buffer=forward_meta.gpu_routing_buffer, layer_idx=moe_layer_idx, tp_size=self.fd_config.parallel_config.tensor_parallel_size, ep_size=self.fd_config.parallel_config.expert_parallel_size, diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index efd29477f2b..089e94d2c58 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -14,20 +14,6 @@ # limitations under the License. """ -import asyncio -import atexit -import functools -import multiprocessing -import os -import shutil -import threading -import time -import traceback -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from multiprocessing import Process, Queue -from typing import Dict, Optional, TypedDict - import numpy as np import paddle import paddle.distributed as dist @@ -35,7 +21,8 @@ import triton.language as tl from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig, RoutingReplayConfig +from fastdeploy.cache_manager.routing_cache_manager import RoutingHostBufferView +from fastdeploy.config import FDConfig from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( enable_compat_on_triton_kernel, ) @@ -43,71 +30,39 @@ @enable_compat_on_triton_kernel @triton.jit -def _save_routing_kernel( - ROUTING_REPLAY_TABLE_PTR, +def _save_routing_kernel_v2( + GPU_ROUTING_BUFFER_PTR, TOPK_IDS_PTR, - BATCH_ID_PER_TOKEN_PTR, - CU_SEQLENS_Q_PTR, - SEQ_LENS_DECODER_PTR, LAYER_IDX, TOKEN_NUM, TOP_K, - NUM_HIDDEN_LAYERS, - MAX_MODEL_LEN, - MAX_NUM_SEQS, + NUM_MOE_LAYERS, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(axis=0) - token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) token_mask = token_offsets < TOKEN_NUM - k_offsets = tl.arange(0, BLOCK_SIZE_K) k_mask = k_offsets < TOP_K - topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] load_mask = token_mask[:, None] & k_mask[None, :] - topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1) - - batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1) - - batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS) - pad_mask = token_mask & (batch_ids != -1) & batch_mask - - start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0) - token_relative_index = token_offsets - start_offsets - - len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0) - token_seq_pos = len_decoder + token_relative_index - - STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64) - STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64) - STRIDE_BUF_LAYER = TOP_K + topk_vals = tl.load( + TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :], + mask=load_mask, + ) + STRIDE_TOKEN = NUM_MOE_LAYERS * TOP_K + STRIDE_LAYER = TOP_K output_ptrs = ( - ROUTING_REPLAY_TABLE_PTR - + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ - + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN - + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER - + k_offsets[None, :] + GPU_ROUTING_BUFFER_PTR + token_offsets[:, None] * STRIDE_TOKEN + LAYER_IDX * STRIDE_LAYER + k_offsets[None, :] ) + tl.store(output_ptrs, topk_vals, mask=load_mask) - pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN) - pos_mask = pos_mask & pad_mask - pos_mask = pos_mask[:, None] & k_mask[None, :] - - final_mask = load_mask & pos_mask - tl.store(output_ptrs, topk_vals, mask=final_mask) - - -def save_routing_to_buffer( - routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k] - topk_ids: paddle.Tensor, # [token_num, top_k] - batch_id_per_token: paddle.Tensor, # [token_num, 1] - seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1] - cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1] +def save_routing_to_buffer_v2( + gpu_routing_buffer: paddle.Tensor, + topk_ids: paddle.Tensor, layer_idx: int, tp_size: int, ep_size: int, @@ -119,124 +74,137 @@ def save_routing_to_buffer( if tp_size > 1 and ep_size > 1: topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) - topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] + topk_ids = topk_ids_all[:token_num_per_rank, :] token_num, top_k = topk_ids.shape - max_num_seqs, max_model_len, num_hidden_layers, _ = routing_replay_table.shape - assert token_num > 0 - assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3]) - assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num) - assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs) + buf_max_tokens, num_moe_layers, buf_top_k = gpu_routing_buffer.shape - BLOCK_SIZE_M = 128 - BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k + assert ( + token_num <= buf_max_tokens + ), f"[R3] token_num={token_num} exceeds gpu_routing_buffer capacity={buf_max_tokens}" + assert top_k == buf_top_k, f"[R3] top_k mismatch: topk_ids.top_k={top_k} vs gpu_routing_buffer.top_k={buf_top_k}" + assert 0 <= layer_idx < num_moe_layers, f"[R3] layer_idx={layer_idx} out of range [0, {num_moe_layers})" + BLOCK_SIZE_M = 128 + BLOCK_SIZE_K = triton.next_power_of_2(top_k) grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) - _save_routing_kernel[grid]( - routing_replay_table, + _save_routing_kernel_v2[grid]( + gpu_routing_buffer, topk_ids, - batch_id_per_token, - cu_seqlens_q, - seq_lens_decoder, LAYER_IDX=layer_idx, TOKEN_NUM=token_num, TOP_K=top_k, - NUM_HIDDEN_LAYERS=num_hidden_layers, - MAX_MODEL_LEN=max_model_len, - MAX_NUM_SEQS=max_num_seqs, + NUM_MOE_LAYERS=num_moe_layers, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_K=BLOCK_SIZE_K, ) -class RoutingReplayManager: - """Request level routing replay table manager""" +class RoutedExpertsCapturer: + """ + Worker-side routing capture: manages GPU transient buffer and GPU→CPU scatter. + Does NOT manage request lifecycle — that is handled by RoutingCacheManager on the Engine side. + """ def __init__(self, fd_config: FDConfig, block_table, total_block_num): self.fd_config = fd_config self.block_table = block_table self.max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.max_model_len = fd_config.model_config.max_model_len - self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index - self.only_last_turn = fd_config.routing_replay_config.only_last_turn - self.use_fused_put = fd_config.routing_replay_config.use_fused_put - logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}") - if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": - self.moe_top_k = fd_config.model_config.num_experts_per_tok - else: - self.moe_top_k = fd_config.model_config.moe_k + + # Read routing params from centralized config + rrc = fd_config.routing_replay_config + self.num_moe_layers = rrc.num_moe_layers + self.moe_top_k = rrc.moe_top_k + self.routing_dtype = rrc.routing_dtype self.tp_rank = fd_config.parallel_config.tensor_parallel_rank - # Initialize the routing replay table and routing cache - self.routing_batch_to_request: Dict[int, str] = {} - num_experts = fd_config.model_config.moe_num_experts + fd_config.model_config.moe_num_shared_experts - self.routing_dtype = self.get_routing_dtype(num_experts=num_experts) + logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}") + self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num) self.pending_update_positions = None - # Initialize routing store wrapper - if self.tp_rank == 0: - self._store_wrapper = StoreWrapper( - fd_config=fd_config, - ) - self._store_wrapper.start_store_warpper() - - # Suspend Routing Replay - self.suspend_routing_replay = False - self.update_suspend_routing_replay() - - def update_suspend_routing_replay(self): - """Allow RL to use R3 in different training rounds""" - # TODO(gongshaotian): Delete this func - suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0") - self.suspend_routing_replay = bool(int(suspend_routing_replay)) - logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}") - def _init_routing_cache(self, dtype: str, total_block_num: int): - """Initialize the device buffer and host buffer.""" - + """Initialize GPU transient buffer and prepare lazy SharedMemory attach.""" max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size - self._host_cache = paddle.full( - shape=[max_num_kv_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, device="cpu" - ) - - self.routing_replay_table = paddle.full( - shape=[self.max_num_seqs, self.max_model_len, self.num_moe_layers, self.moe_top_k], + # Small GPU transient buffer: only current step's token routing + max_num_batched_tokens = self.fd_config.scheduler_config.max_num_batched_tokens + self.gpu_routing_buffer = paddle.full( + shape=[max_num_batched_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, ) + + # Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling) + self.routing_host_view = None + self._routing_host_view_attach_attempted = False + self._routing_host_view_shm_name = ( + f"routing_host_buffer.{str(self.fd_config.parallel_config.local_engine_worker_queue_port)}" + ) + self._routing_host_view_shape = (max_num_kv_tokens, self.num_moe_layers, self.moe_top_k) + self._routing_host_view_dtype = dtype + + gpu_buffer_bytes = int(np.prod(self.gpu_routing_buffer.shape)) * np.dtype(dtype).itemsize logger.info( - f"[R3] The host cache size is:{self._host_cache.shape}, device cache size is: {self.routing_replay_table.shape}" + f"[R3] GPU transient routing buffer: {self.gpu_routing_buffer.shape} " + f"({gpu_buffer_bytes / 1024:.1f} KB)" ) - def get_routing_dtype(self, num_experts: int, reserved_fill_value: int = 1) -> str: - """Calculate the minimum number of bits required for storage routing.""" - if num_experts <= 0: - raise ValueError(f"num_experts must be greater than 0 but got {num_experts}, please check model config.") - dtype = "uint8" - total_number = num_experts + reserved_fill_value - if total_number <= 255: # uint8: 0~255 - dtype = "uint8" - elif total_number <= 65535: # uint16: 0~65,535 - dtype = "uint16" - elif total_number <= 4294967295: # uint32: 0~4,294,967,295 - dtype = "uint32" - else: - raise ValueError( - f"The number of experts {num_experts} exceeds the representation range of uint32, please check model config." + def _try_attach_routing_host_view(self): + """Lazily attach to SharedMemory routing_host_buffer on first use.""" + if self._routing_host_view_attach_attempted: + return + self._routing_host_view_attach_attempted = True + try: + self.routing_host_view = RoutingHostBufferView( + shape=self._routing_host_view_shape, + dtype=self._routing_host_view_dtype, + shm_name=self._routing_host_view_shm_name, + ) + logger.info(f"[R3] Attached to RoutingHostBuffer SharedMemory: {self._routing_host_view_shm_name}") + except FileNotFoundError: + logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {self._routing_host_view_shm_name} not found. " + "Routing capture will be skipped." ) - logger.info(f"[R3] Routing replay table dtype: {dtype}") - return dtype - def update_host_cache(self, positions: paddle.Tensor, slot_mapping: paddle.Tensor): - """Update the host cache with new tokens""" - for batch_id, position in enumerate(positions): - if len(position) > 0 and len(slot_mapping[batch_id]) > 0: - routing_ids = self.routing_replay_table[batch_id, position, :, :].contiguous() - routing_ids = routing_ids.cpu() + def save_captured_routing(self, num_tokens: int, slot_mapping: np.ndarray): + """ + After forward, scatter GPU buffer routing data to routing_host_buffer. + Called in step gap (post_process), not during forward. CUDAGraph compatible. + """ + if num_tokens == 0: + return - self._host_cache[slot_mapping[batch_id], :, :] = routing_ids + # Lazy attach to SharedMemory (Engine creates it after profiling completes) + if self.routing_host_view is None and not self._routing_host_view_attach_attempted: + self._try_attach_routing_host_view() + + if self.routing_host_view is None: + return + + # D2H copy: GPU → CPU numpy, then scatter to SharedMemory + data = self.gpu_routing_buffer[:num_tokens].cpu().numpy() + self.routing_host_view.scatter(slot_mapping, data) + + def compute_slot_mapping_flat(self, positions) -> np.ndarray: + """ + Compute flat slot_mapping for all tokens in the step. + Returns a 1D numpy array of slot indices. + """ + all_slots = [] + block_size = self.fd_config.cache_config.block_size + for batch_id, position in enumerate(positions): + if len(position) == 0: + continue + block_table_indices = position // block_size + token_block_ids = self.block_table[batch_id, block_table_indices] + block_offset = position % block_size + token_cache_ids = np.array(token_block_ids) * block_size + block_offset + all_slots.append(token_cache_ids) + if all_slots: + return np.concatenate(all_slots) + return np.array([], dtype=np.int64) def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): """Get token position of each sequence in a batch.""" @@ -253,640 +221,14 @@ def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): return positions - def compute_slot_mapping(self, positions: np.ndarray): - """Compute the mapping between token ids and kvcache slots""" - slot_mapping = [] - for batch_id, position in enumerate(positions): - if len(position) == 0: - slot_mapping.append([]) - continue - block_table_indices = position // self.fd_config.cache_config.block_size - token_block_ids = self.block_table[batch_id, block_table_indices] - block_offset = position % self.fd_config.cache_config.block_size - - token_cache_ids = np.array(token_block_ids) * self.fd_config.cache_config.block_size + block_offset - slot_mapping.append(token_cache_ids) - - return slot_mapping - - def _get_routing_from_cache(self, finished_batch_ids, seq_lens_decoder): - """ - When request is finished or cleared the length of the request is recorded at seq_lens_decoder - 1. finish the step: after update input, lens = seq_lens_decoder_buffer - 2. clear parameter: after update input, lens = seq_lens_decoder_buffer - """ - # Get the slot mapping of the request cache. - current_token_nums = seq_lens_decoder.numpy()[:, 0] - positions = [] - for batch_id in range(self.max_num_seqs): - position = [] - if batch_id in finished_batch_ids: - position = np.arange(0, current_token_nums[batch_id]) - positions.append(position) - - # Collection the cached routing information - token_cache_ids = self.compute_slot_mapping(positions=positions) - for slot_map in token_cache_ids: - if len(slot_map) > 0: - token_cached_routing = self._host_cache[slot_map, :, :] - return paddle.transpose(token_cached_routing, [1, 0, 2]) - raise ValueError("No cached routing found") - - def put_finished_batch( - self, - finished_batch_ids, - seq_lens_decoder, - ): - finished_batch_ids_list = finished_batch_ids.cpu().tolist() - for batch_id, finished in enumerate(finished_batch_ids_list): - if finished: - assert batch_id in self.routing_batch_to_request.keys() - # Deregister the request - request_id = self._deregister_request(batch_id) - # Put the routing of finished request to store - self._put_request_to_store( - batch_id=batch_id, - request_id=request_id, - seq_lens_decoder=seq_lens_decoder, - ) - # Clear the slot of the finished batch - self._clear_table_slot(batch_id) - - def register_request(self, batch_id: int, request_id: str): - """ - Register a new request to routing replay table - Args: - batch_id: The batch ID of this request - request_id: The global ID of the request is usually executed by the training process in RL - """ - # The chunked prefill tasks will be registered repeatedly - if batch_id in self.routing_batch_to_request: - if self.routing_batch_to_request[batch_id] == request_id: - logger.warning(f"[R3] Request {request_id} has been registered at {batch_id}.") - return - else: - raise RuntimeError( - f"[R3] The Batch {batch_id} has been registered by request {self.routing_batch_to_request[batch_id]}, now robed by {request_id}," - ) - - # Register the new request - self.routing_batch_to_request[batch_id] = request_id - logger.info(f"[R3] Register request {request_id} with batch id {batch_id}") - - def _deregister_request(self, batch_id: int) -> str: - """ - Deregister a request from routing replay table - """ - assert batch_id in self.routing_batch_to_request - return self.routing_batch_to_request.pop(batch_id) - - def _put_request_to_store( - self, - batch_id: int, - request_id: str, - seq_lens_decoder, - ): - if self.tp_rank == 0: - # TODO(gongshaotian): Delete the suspend func - if self.suspend_routing_replay: - logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store") - return - - before_put_request_time = time.perf_counter() - - # Collect the routing of finished request - batch_buffer = self._get_routing_from_cache( - finished_batch_ids=[batch_id], seq_lens_decoder=seq_lens_decoder - ) - rollout_id = self.split_request_id(request_id) - - if self.use_fused_put: - self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) - - else: - for layer_id in range(self.num_moe_layers): - layer_buffer = batch_buffer[layer_id] - self._store_wrapper.submit_put_task( - routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id - ) - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) - - logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") - - def clear_request(self, batch_id: int): - """Clear the routing indices of the request""" - self._clear_table_slot(batch_id) - self.routing_batch_to_request.pop(batch_id, None) - - def _clear_table_slot(self, batch_id: int): - assert 0 <= batch_id < self.max_num_seqs - self.routing_replay_table[batch_id].fill_(-1) - - def get_routing_table(self) -> paddle.Tensor: - return self.routing_replay_table - - def split_request_id(self, request_id: str): - """ - Split the request id to get rollout id. - - request_id: "chatcmpl-request.user-uuid" - rollout_id: "request.user" - example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" -> "xxx_xxx_epoch_15:2:2:1" - """ - chat_type, tmp_str = request_id.split("-", 1) - # NOTE(gongshaotian): only support chatcmpl now - assert ( - chat_type == "chatcmpl" - ), "Rollout Routing Replay only supports chatcmpl. Please check whether the request type and userid settings are correct." - reversed_tmp_str = tmp_str[::-1].split("-", 5) - rollout_id = reversed_tmp_str[-1][::-1] - return rollout_id - - def clear_all_request(self): - """Clear all requests""" - self.routing_replay_table.fill_(-1) - self.routing_batch_to_request = {} - - -class StoreWrapper(object): - def __init__(self, fd_config: False) -> None: - super().__init__() - self.fd_config = fd_config - - # Initialize task queue - moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index - max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.queue_max_size = moe_layer_num * max_num_seqs * 1000 - - self.manager = multiprocessing.Manager() - self._task_queue = self.manager.Queue(maxsize=self.queue_max_size) - - self._monitor_thread: threading.Thread = None - self._stop_monitor = threading.Event() - - # Initialize consumer process - self._routing_store_process = StoreProcess( - task_queue=self._task_queue, - routing_replay_config=self.fd_config.routing_replay_config, - max_model_len=self.fd_config.model_config.max_model_len, - ) - self._sotre_process_running = False - - # Register atexit handler - atexit.register(self.shutdown) - - def shutdown(self): - """ """ - if not self._sotre_process_running: - return - self._sotre_process_running = False - - # Stop the monitor thread - self._stop_monitor.set() - if self._monitor_thread and self._monitor_thread.is_alive(): - self._monitor_thread.join(timeout=3.0) - - # Put a sentinel value to signal the consumer to stop - if self._routing_store_process and self._routing_store_process.is_alive(): - try: - self._task_queue.put_nowait(None) - except Exception as e: - logger.info(f"Could not put sentinel into queue: {e}") - - if self._routing_store_process and self._routing_store_process.is_alive(): - # Wait for all tasks to be processed - self._routing_store_process.join(timeout=10.0) - if self._routing_store_process.is_alive(): - self._routing_store_process.close() - self._routing_store_process.join() - - self._task_queue.join() - self.manager.shutdown() - self._sotre_process_running = False - - def start_store_warpper(self): - """ """ - if self._sotre_process_running: - return - self._sotre_process_running = True - - # Start monitor thread - self._stop_monitor.clear() - self._monitor_thread = threading.Thread(target=self._monitor_queue_load, daemon=True) - self._monitor_thread.start() - - # Start Routing Store Wrapper in sub process - self._routing_store_process.start() - - def _monitor_queue_load(self): - """ """ - while not self._stop_monitor.is_set(): - time.sleep(2.0) - if not self._sotre_process_running: - break - qsize = self._task_queue.qsize() - - # Alarm when the task exceeds 80% of the queue capacity - if qsize > self.queue_max_size * 0.8: - logger.warning( - f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " - "Consider increasing max_workers or queue_max_size." - ) - logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") - - def submit_put_task(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int = None) -> None: - """Submit a put task to the task queue""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - - start_time = time.perf_counter() - if layer_idx is not None: - rdma_rollout_key = f"{rollout_id}_{layer_idx}" - else: - rdma_rollout_key = rollout_id - - routing_indices_np = routing_indices.numpy() - - task: StoreTask = {"task_type": "put", "key": rdma_rollout_key, "data": routing_indices_np} - - try: - self._task_queue.put_nowait(task) - except Exception: - raise RuntimeError(f"Queue is FULL. Dropping put task for key: {rdma_rollout_key}. ") - logger.info(f"[R3] Submit put task for key: {rdma_rollout_key}, cost time: {time.perf_counter()-start_time} s") - - def submit_clear_store_task(self) -> None: - """Submit clear store task""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - - start_time = time.perf_counter() - task: StoreTask = {"task_type": "clear_store", "key": None, "data": None} - - try: - self._task_queue.put_nowait(task) - # Wait for the task to be processed - self._task_queue.join() - except Exception: - raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") - logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") - - def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: - """Submit clear prefix batch task""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - prefix_batch_id = self.get_needed_clear_ids(rollout_id) - if prefix_batch_id is None: - return - start_time = time.perf_counter() - if layer_idx is not None: - rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" - else: - rdma_rollout_key = prefix_batch_id - - task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} - try: - self._task_queue.put_nowait(task) - except Exception: - raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") - logger.info( - f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" - ) - - def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: - """ - Generate the prefix IDs for all closed multi-round tasks. - rollout_id: "xxx_xxx_epoch_15:2:2:1" - example: xxx_xxx_data_id:gen_id:turn_id:segment_id - """ - reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = roullout_id[::-1].split(":", 2) - prefix_gen_id = reversed_prefix_gen_id[::-1] - turn_id = eval(reversed_turn_id[::-1]) - segment_id = eval(reversed_segment_id[::-1]) - - assert turn_id >= 0 and segment_id >= 0 - prefix_batch = None - if turn_id > 0: - prefix_batch = f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}" - return prefix_batch - - -class StoreTask(TypedDict): - task_type: str - key: str - data: np.ndarray - - -class StoreProcess(Process): - def __init__(self, task_queue: Queue, routing_replay_config: RoutingReplayConfig, max_model_len: int) -> None: - super().__init__() - self.max_model_len = max_model_len - self._task_queue = task_queue - self.routing_replay_config = routing_replay_config - self.max_workers = 5 - self._closed = False - - # Note: _routing_store and _event_loop_thread must be initialized in run() - # because they cannot be properly inherited after fork() - self._routing_store = None - self._event_loop_thread = None - - def run(self): - logger.info(f"[R3] Start Running Store Wrapper in sub process {os.getpid()}") - - # Initialize routing store in subprocess - self._routing_store = get_routing_store(routing_replay_config=self.routing_replay_config) - - # Initialize event loop thread in subprocess - self._event_loop_thread = AsyncEventLoopThread() - self._event_loop_thread.start() - if not self._event_loop_thread._started_event.wait(timeout=5.0): - raise RuntimeError("Failed to start async event loop thread in subprocess") - - clear_store_task = StoreTask({"task_type": "clear_store", "key": None, "data": None}) - self._task_queue.put_nowait(clear_store_task) - - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - while not self._closed: - try: - task = self._task_queue.get() - if task is None: # Sentinel - self._task_queue.task_done() - break - - if task["task_type"] == "put": - future = executor.submit(self.process_put_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - elif task["task_type"] == "clear_store": - future = executor.submit(self.process_clear_store_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - elif task["task_type"] == "clear_prefix_batch": - future = executor.submit(self.process_clear_prefix_batch_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - except Exception as e: - self._task_queue.task_done() - raise RuntimeError(f"Error during processing task. {e}") - - logger.info("RoutingReplay Consumer Process Shutdown.") - - def process_put_task(self, store_task: StoreTask) -> None: - try: - # TODO(gongshaotian): delete this after trainer support dynamic len - store_task["data"] = self.pad_routing_indices(store_task["data"]) - coro_obj = self._routing_store.put(routing_key=store_task["key"], routing_indices=store_task["data"]) - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error submitting put task: {e}") - traceback.print_exc() - raise - - def process_clear_store_task(self, store_task: StoreTask) -> None: - try: - coro_obj = self._routing_store.clear_store() - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error during processing clear store task. {e}") - traceback.print_exc() - raise - - def process_clear_prefix_batch_task(self, store_task: StoreTask) -> None: - try: - coro_obj = self._routing_store.clear_prefix_batch(routing_prefix_key=store_task["key"]) - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error submitting clear_prefix_batch task: {e}") - traceback.print_exc() - raise - - def _on_async_task_completed(self, task, future): - """ """ - try: - # result = future.result() - logger.info(f"[R3] Async task completed: {task['task_type']}, key: {task['key']}") - except Exception as e: - logger.error(f"[R3] Async task failed: {task['task_type']}, key: {task['key']}, error: {e}") - traceback.print_exc() - raise - - def close(self): - """Close the store process""" - self._closed = True - if hasattr(self, "_event_loop_thread"): - self._event_loop_thread.stop() - - def pad_routing_indices(self, routing_indices: np.ndarray) -> np.ndarray: - """Pad routing indices of the request levevl to max model len""" - routing_shape = routing_indices.shape - if len(routing_shape) == 2: # [token, topk] - pad_array = np.full( - shape=[(self.max_model_len - routing_indices.shape[0]), routing_indices.shape[1]], - fill_value=-1, - dtype=routing_indices.dtype, - ) - return np.concatenate([routing_indices, pad_array], axis=0) - - elif len(routing_shape) == 3: # [layer, token, topk] - pad_array = np.full( - shape=[ - routing_indices.shape[0], - (self.max_model_len - routing_indices.shape[1]), - routing_indices.shape[2], - ], - fill_value=-1, - dtype=routing_indices.dtype, - ) - return np.concatenate([routing_indices, pad_array], axis=1) - else: - raise ValueError(f"Invalid routing indices shape: {routing_shape}") - - -class AsyncEventLoopThread(threading.Thread): - def __init__(self): - super().__init__(daemon=True) - self._loop = None - self._started_event = threading.Event() - self._closed = False + def get_gpu_routing_buffer(self) -> paddle.Tensor: + return self.gpu_routing_buffer - def run(self): - """Run the async event loop""" - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) + def clear(self): + """Clear GPU buffer and pending positions. Used during RL round cleanup.""" + self.gpu_routing_buffer.fill_(-1) + self.pending_update_positions = None - # Set the event loop to be started - self._started_event.set() - logger.info("[EventLoopThread] Event loop started, running forever...") - try: - self._loop.run_forever() - logger.info("[EventLoopThread] Event loop stopped") - except Exception as e: - logger.error(f"[EventLoopThread] Event loop exception: {e}") - traceback.print_exc() - finally: - logger.info("[EventLoopThread] Closing event loop") - self._loop.close() - - def submit_coroutine(self, coro, callback=None): - """Thread safely submit coroutine to event loop""" - if self._closed: - raise RuntimeError("Event loop thread is closed") - if not self._started_event.wait(timeout=5.0): - raise RuntimeError("Event loop failed to start within 5 seconds") - - future = asyncio.run_coroutine_threadsafe(coro, self._loop) - - if callback: - - def wrapped_callback(f): - try: - callback(f) - except Exception as e: - logger.error(f"Error in callback: {e}") - traceback.print_exc() - - future.add_done_callback(wrapped_callback) - return future - - def stop(self): - """Stop the event loop""" - if not self._closed: - self._closed = True - if self._loop: - self._loop.call_soon_threadsafe(self._loop.stop) - - -class RoutingStoreBase(ABC): - """Base class for routing store""" - - def __init__(self, routing_replay_config: RoutingReplayConfig) -> None: - self.routing_replay_config = routing_replay_config - - @abstractmethod - async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: - """Put the routing indices into store""" - raise NotImplementedError - - @abstractmethod - async def clear_store( - self, - ): - """Clear the routing indices store""" - raise NotImplementedError - - @abstractmethod - async def clear_prefix_batch(self, routing_prefix_key: str): - """Clear the routing indices""" - raise NotImplementedError - - -class RoutingStoreLocal(RoutingStoreBase): - """Routing Store using local memory""" - - def __init__(self, routing_replay_config) -> None: - super().__init__(routing_replay_config=routing_replay_config) - self.local_store_dir = routing_replay_config.local_store_dir - os.makedirs(self.local_store_dir, exist_ok=True) - - async def put( - self, - routing_key: str, - routing_indices: np.ndarray, - ) -> None: - """Put the routing indices into store""" - # TODO(gongshaotian) covert ./store_dir/routing_key/layer_id.pdtensor to ./store_dir/routing_key.pdtensor - time_before_put = time.perf_counter() - - if len(routing_indices.shape) == 2: - re_layer_id, re_rollout_id = routing_key[::-1].split("_", 1) - rollout_id = re_rollout_id[::-1] - layer_id = re_layer_id[::-1] - request_path = os.path.join(self.local_store_dir, rollout_id) - file_path = os.path.join(request_path, f"layer_{layer_id}.pdtensor") - elif len(routing_indices.shape) == 3: - request_path = os.path.join(self.local_store_dir, routing_key) - file_path = os.path.join(request_path, f"{routing_key}.pdtensor") - else: - raise ValueError(f"Invalid routing indices shape: {routing_indices.shape}") - - paddle.save(routing_indices, file_path) - logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s") - - async def clear_store(self): - """Clear the routing indices store""" - if os.path.isdir(self.local_store_dir): - shutil.rmtree(self.local_store_dir) - - logger.info("[R3] Clear routing store.") - - async def clear_prefix_batch(self, routing_prefix_key: str): - """Clear the routing indices""" - raise NotImplementedError - - -class RoutingStoreRDMA(RoutingStoreBase): - """Routing Store using RDMA""" - - def __init__(self, routing_replay_config) -> None: - super().__init__(routing_replay_config=routing_replay_config) - try: - # Only used in RLHF - from p2pstore import P2PClient, P2PConfig - except ModuleNotFoundError: - raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") - - rdma_store_server = routing_replay_config.rdma_store_server - p2pConfig = P2PConfig(metadata_server=rdma_store_server) - self.p2p_client = P2PClient(p2pConfig) - - async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: - """Put the routing indices into store""" - time_before_put = time.perf_counter() - if len(routing_indices.shape) == 3: - # NOTE(gongshaotian) Fused put with bytes data - routing_bytes = routing_indices.tobytes() - result = await self.p2p_client.put(routing_key, routing_bytes) - else: - result = await self.p2p_client.put(routing_key, routing_indices) - logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") - return result - - async def clear_prefix_batch(self, routing_prefix_key: str): - time_before_clear = time.perf_counter() - result = await self.p2p_client.delete_batch([routing_prefix_key]) - logger.info( - f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" - ) - return result - - async def clear_store(self): - """Clear the routing indices store""" - time_before_clear = time.perf_counter() - result = await self.p2p_client.clear() - logger.info(f"[R3] Clear routing store cost is {time.perf_counter()-time_before_clear}s.") - return result - - -def get_routing_store(routing_replay_config: RoutingReplayConfig) -> RoutingStoreBase: - if routing_replay_config.routing_store_type == "local": - return RoutingStoreLocal(routing_replay_config=routing_replay_config) - elif routing_replay_config.routing_store_type == "rdma": - return RoutingStoreRDMA(routing_replay_config=routing_replay_config) - else: - raise ValueError( - f"Invalid routing store type: '{routing_replay_config.routing_store_type}'. " - "Valid types are: 'local', 'rdma'" - ) +# Backward compatibility alias +RoutingReplayManager = RoutedExpertsCapturer diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index e37b52a41c8..001f291b09f 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -313,18 +313,18 @@ def post_process_normal( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( + # Trigger lazy SharedMemory attach if not yet attempted + routing_replay_manager._try_attach_routing_host_view() + # GPU transient buffer → SharedMemory routing_host_buffer + slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( positions=routing_replay_manager.pending_update_positions ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) - - # Put routing of finished requests to store - finished_batch_ids = paddle.flatten(paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id)) - context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder - routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens) + num_tokens = len(slot_mapping_flat) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.save_captured_routing( + num_tokens=num_tokens, + slot_mapping=slot_mapping_flat, + ) # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): @@ -488,27 +488,18 @@ def post_process_specualate( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( + # Trigger lazy SharedMemory attach if not yet attempted + routing_replay_manager._try_attach_routing_host_view() + # GPU transient buffer → SharedMemory routing_host_buffer + slot_mapping_flat = routing_replay_manager.compute_slot_mapping_flat( positions=routing_replay_manager.pending_update_positions ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) - - # Put routing of finished requests to store - last_accept_token = paddle.full_like(model_output.accept_tokens, -1) - col_indices = paddle.arange(model_output.accept_tokens.shape[1], dtype=model_output.accept_num.dtype) - mask = col_indices < paddle.unsqueeze(model_output.accept_num, 1) - last_accept_token[mask] = model_output.accept_tokens[mask] - eos_tokens_flat = model_output.eos_token_id.flatten() - isin_mask = paddle.isin(last_accept_token, eos_tokens_flat) - finished_batch_ids = isin_mask.any(axis=-1) - context_lens = model_output.seq_lens_encoder + model_output.seq_lens_decoder - routing_replay_manager.put_finished_batch( - finished_batch_ids=finished_batch_ids, - seq_lens_decoder=context_lens, - ) + num_tokens = len(slot_mapping_flat) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.save_captured_routing( + num_tokens=num_tokens, + slot_mapping=slot_mapping_flat, + ) speculate_update( model_output.seq_lens_encoder, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 2d0b9f7a743..2449e982544 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -139,6 +139,65 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.health_lock = threading.Lock() self.engine_output_token_hang = False + # Routing replay: attach to SharedMemory routing_host_buffer (lazy init after profiling) + self.routing_host_view = None + self._routing_host_view_init_attempted = False + self.routing_cache_manager = None # Set by Engine after profiling for local/rdma store dispatch + + def _init_routing_host_view(self): + """Attach to SharedMemory routing_host_buffer created by Engine. Called lazily.""" + self._routing_host_view_init_attempted = True + if not self.cfg.routing_replay_config.enable_routing_replay: + return + try: + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingHostBufferView, + ) + + rrc = self.cfg.routing_replay_config + cache_config = self.cfg.cache_config + + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + num_gpu_blocks = cache_config.total_block_num + max_num_kv_tokens = num_gpu_blocks * cache_config.block_size + shape = (max_num_kv_tokens, rrc.num_moe_layers, rrc.moe_top_k) + + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=rrc.routing_dtype, shm_name=shm_name) + self._routing_block_size = cache_config.block_size + llm_logger.info(f"[R3] TokenProcessor attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + llm_logger.warning("[R3] RoutingHostBuffer SharedMemory not found, routing gather disabled.") + except Exception as e: + llm_logger.warning(f"[R3] Failed to attach to RoutingHostBuffer: {e}") + + def _gather_routing_for_finished_request(self, task, seq_len: int): + """ + Gather complete routing data for a finished request from routing_host_buffer. + + Args: + task: Request task with block_tables + seq_len: Total sequence length + + Returns: + numpy array [seq_len, num_moe_layers, top_k] or None + """ + if self.routing_host_view is None and not self._routing_host_view_init_attempted: + self._init_routing_host_view() + if self.routing_host_view is None: + return None + + import math + + block_size = self._routing_block_size + block_ids = task.block_tables[: math.ceil(seq_len / block_size)] + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + + return self.routing_host_view.gather(slot_mapping) + def healthy(self): """ whether token processor is healthy @@ -273,6 +332,7 @@ def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: self._compute_speculative_status() if not is_prefill: self._record_completion_metrics(task, current_time) + self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, batch_id, task, result, is_prefill) break return result @@ -336,6 +396,7 @@ def _process_batch_output_use_zmq(self, receive_datas): prompt_token_ids=task.prompt_token_ids, outputs=PoolingOutput(data=pooler_output), ) + self._finalize_routing(task_id, task, result, False) self._recycle_resources(task_id, i, task, result, False) batch_result.append(result) else: @@ -522,6 +583,47 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3): except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") + def _finalize_routing(self, task_id, task, result, is_prefill=False): + """ + Gather routing data before blocks are freed. + Must be called before _recycle_resources so that block_tables are still valid. + + - PD P node (is_prefill=True): gather prefill-only routing, attach to result for sending to D. + - Non-PD / D node (result.finished): gather full routing (prompt + output), + either attach to result ("response" mode) or dispatch to store ("local"/"rdma" mode). + """ + if not self.cfg.routing_replay_config.enable_routing_replay: + return + if result is None: + return + + try: + if is_prefill: + if result.error_code == 200: + seq_len = task.prompt_token_ids_len + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + elif result.finished: + store_type = self.cfg.routing_replay_config.routing_store_type + seq_len = ( + task.prompt_token_ids_len + len(task.output_token_ids) + if hasattr(task, "output_token_ids") + else task.prompt_token_ids_len + ) + if store_type == "response": + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + elif self.routing_cache_manager is not None: + self.routing_cache_manager.on_request_finished( + request_id=task_id, + block_table=task.block_tables, + seq_len=seq_len, + ) + except Exception as e: + llm_logger.warning(f"[R3] Failed to finalize routing for {task_id}: {e}") + def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False): """ recycle resources @@ -977,6 +1079,7 @@ def _process_batch_output(self): self.resource_manager.cache_output_tokens( task ) # when enable prefix caching, cache kv cache for output tokens + self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill) llm_logger.info(f"eos token {task_id} Recycle end.") break @@ -1098,6 +1201,7 @@ def clear_data(self): ), ) is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" + self._finalize_routing(task.request_id, task, result, is_prefill) self._recycle_resources(task.request_id, i, task, result, is_prefill) llm_logger.warning(f"clear data for task {task.request_id}") diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index db698391d91..acd1cbcce12 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -831,11 +831,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = prompt_token_ids = request.prompt_token_ids self.proposer.start_request(idx, request.request_id, prompt_token_ids) - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - # 1.prefix task(need regist) 2. chunkend task(not need regist) - self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) - if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token @@ -876,10 +871,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_request(batch_id=idx) - continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1467,9 +1458,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - routing_replay_table = None + gpu_routing_buffer = None if self.routing_replay_manager is not None: - routing_replay_table = self.routing_replay_manager.get_routing_table() + gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1496,7 +1487,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - routing_replay_table=routing_replay_table, + gpu_routing_buffer=gpu_routing_buffer, ) dist_status = self.collect_distributed_status() @@ -2855,7 +2846,7 @@ def clear_requests(self): # Routing Replay if self.routing_replay_manager: - self.routing_replay_manager.clear_all_request() + self.routing_replay_manager.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" @@ -2873,10 +2864,6 @@ def update_parameters(self, pid): # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() - # Rollout Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - # TODO(gongshaotian): Delete suspend func - self.routing_replay_manager.update_suspend_routing_replay() # Send single self.dynamic_weight_manager.finalize_update(pid) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 77944b3a2cf..88d1a6a74db 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -51,9 +51,6 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) -from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( - RoutingReplayManager, -) from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata @@ -204,8 +201,6 @@ def __init__( # Rollout routing replay config self.routing_replay_manager = None - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config) self.zmq_client = None self.async_output_queue = None @@ -769,11 +764,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.forward_batch_reqs_list[idx] = request has_prefill_task = True - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - if prefill_start_index == 0: - self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) - if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token @@ -805,10 +795,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_request(batch_id=idx) - continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1367,9 +1353,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - routing_replay_table = None + gpu_routing_buffer = None if self.routing_replay_manager is not None: - routing_replay_table = self.routing_replay_manager.get_routing_table() + gpu_routing_buffer = self.routing_replay_manager.get_gpu_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1396,7 +1382,8 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - routing_replay_table=routing_replay_table, + routing_replay_table=None, + gpu_routing_buffer=gpu_routing_buffer, ) dist_status = self.collect_distributed_status() @@ -1919,8 +1906,8 @@ def _dummy_run( # only need to capture prefill break - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_routing_table() + if self.fd_config.routing_replay_config.enable_routing_replay and self.routing_replay_manager is not None: + self.routing_replay_manager.clear() def _update_chunked_prefill(self, tasks): """ @@ -2521,7 +2508,7 @@ def _postprocess( and self.share_inputs["is_block_step"].sum() == 0 and self.share_inputs["is_chunk_step"].sum() == 0 ): - self.routing_replay_manager.put_table_to_store() + pass # Routing store submission now handled by RoutingCacheManager on Engine side return model_output_data, sampler_output, post_process_done def _save_model_output( @@ -2749,8 +2736,8 @@ def clear_requests(self): self.prompt_logprobs_reqs.clear() self.in_progress_prompt_logprobs.clear() self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)] - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.put_table_to_store() + if self.fd_config.routing_replay_config.enable_routing_replay and self.routing_replay_manager is not None: + self.routing_replay_manager.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 05621af03bb..7be8487d001 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -45,6 +45,7 @@ class Args: kvcache_storage_backend = None write_policy = "write_through" model_path = "test_model" + routing_replay_config = MagicMock(enable_routing_replay=False) # ========================== diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 9398e07d9f5..b47344470de 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -65,6 +65,7 @@ class CacheConfig: model_config = ModelConfig() scheduler_config = SchedulerConfig() cache_config = CacheConfig() + routing_replay_config = MagicMock(enable_routing_replay=False) class MockTask: diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index c0609094a2b..ced98b07002 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -64,6 +64,7 @@ def __init__( ) self.max_num_seqs = max_num_seqs self.splitwise_version = "v1" + self.routing_replay_config = types.SimpleNamespace(enable_routing_replay=False) class _DummyResourceManager: