Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .reauthor-marker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

<!-- reauthored 2026-05-02T09:09:09Z -->
10 changes: 10 additions & 0 deletions benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# T53 bench workload — KV-bound (not slot-bound); gate: FD_HEAD_WISE_KV_CACHE=1
# max_num_seqs raised to 512 so the KV pool, not the slot count, is the bottleneck.
# kv_cache_ratio lowered to 0.50 to shrink the pool and accelerate KV pressure.
# Use with: INPUT_LEN=8192 OUTPUT_LEN=4096 REQUEST_RATE=8
#
max_model_len: 32768
max_num_seqs: 512
kv_cache_ratio: 0.50
tensor_parallel_size: 1
max_num_batched_tokens: 32768
130 changes: 129 additions & 1 deletion fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
local_data_parallel_id=0,
):
"""
initialize the PrefixCacheManager
initialize the PrefixCacheManager.
"""

self.metrics = CacheMetrics()
Expand All @@ -79,6 +79,27 @@ def __init__(
self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num
self.num_cpu_blocks = self.cache_config.num_cpu_blocks

# Head-wise KV cache (Hackathon 10th Spring No.53, mirrors PR #6702 contract).
# Default-off: behavior is bit-identical to mainline unless FD_HEAD_WISE_KV_CACHE=1.
# T53: per-rank KV head count for free-list sizing (TP-aware).
kv_num_heads_global = int(
getattr(getattr(self.cache_config, "model_cfg", None), "num_key_value_heads", 1) or 1
)
tp_size = int(self.tensor_parallel_size or 1)
self.kv_num_heads = max(1, kv_num_heads_global // tp_size) if kv_num_heads_global >= tp_size else 1
_enable_prefix_caching = bool(getattr(self.cache_config, "enable_prefix_caching", False))
if bool(envs.FD_HEAD_WISE_KV_CACHE) and _enable_prefix_caching:
raise ValueError(
"FD_HEAD_WISE_KV_CACHE is mutually exclusive with enable_prefix_caching " "(matches PR #6702 contract)"
)
self.head_wise = bool(envs.FD_HEAD_WISE_KV_CACHE) and not _enable_prefix_caching
self.total_head_wise_cache_ids = 0
# Head-wise free list lives in its OWN attribute so the legacy
# gpu_free_block_list (consumed by allocate_gpu_blocks) keeps its
# [0, num_gpu_blocks) ID space. Aliasing the two lists corrupts the
# legacy allocator with OOB cache ids → CUDA error 700 at decode.
self.gpu_free_head_wise_block_list = []

self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
Expand Down Expand Up @@ -172,6 +193,9 @@ def _get_kv_cache_shape(self, max_block_num):

@property
def available_gpu_resource(self):
if getattr(self, "head_wise", False) and self.num_gpu_blocks > 0:
head_free = len(getattr(self, "gpu_free_head_wise_block_list", []))
return (head_free // max(1, self.kv_num_heads)) / self.num_gpu_blocks
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0

def launch_cache_manager(
Expand Down Expand Up @@ -468,6 +492,29 @@ def update_cache_config(self, cache_config):
main_process_metrics.free_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)

if getattr(self, "head_wise", False):
self._init_head_wise_free_list()

def _init_head_wise_free_list(self):
"""
Build a head-wise free list over ``num_gpu_blocks * kv_num_heads`` cache ids.

Each cache id corresponds to a (block, head) pair. Allocation/recycling
is performed via :meth:`allocate_gpu_blocks_head_wise` /
:meth:`recycle_gpu_blocks_head_wise`. This path is unreachable when
``FD_HEAD_WISE_KV_CACHE=0`` (default).

The list is stored on a dedicated attribute (``gpu_free_head_wise_block_list``)
so the legacy ``gpu_free_block_list`` (consumed by ``allocate_gpu_blocks``)
keeps its [0, num_gpu_blocks) ID space untouched.
"""
total_cache_ids = self.num_gpu_blocks * max(1, self.kv_num_heads)
self.gpu_free_head_wise_block_list = list(range(total_cache_ids - 1, -1, -1))
heapq.heapify(self.gpu_free_head_wise_block_list)
self.total_head_wise_cache_ids = total_cache_ids
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_head_wise_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)

def can_allocate_gpu_blocks(self, num_blocks: int, try_free_gpu_blocks: bool = True):
"""
Check if num_blocks gpu blocks can be allocated.
Expand Down Expand Up @@ -532,6 +579,87 @@ def recycle_gpu_blocks(self, gpu_block_ids, req_id=None):
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)

def allocate_gpu_blocks_head_wise(self, num_blocks, req_id=None):
"""
Allocate ``num_blocks`` GPU blocks per KV head.

Returns a head-major nested list of cache ids with shape
``[kv_num_heads][num_blocks]``. Mirrors :meth:`allocate_gpu_blocks` but
operates on the head-wise free list built by
:meth:`_init_head_wise_free_list`.

Active only when ``FD_HEAD_WISE_KV_CACHE=1`` (default-off; mainline
behavior is unchanged).
"""
kv_num_heads = max(1, self.kv_num_heads)
needed = num_blocks * kv_num_heads
free_list = self.gpu_free_head_wise_block_list
assert needed <= len(free_list), f"head-wise gpu free block num: {len(free_list)} < needed number {needed}"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug assert 被用于运行时分配容量校验,Python -O 下此断言被静默跳过,heappop 将在空堆上抛出无上下文的 IndexError(checklist §C 表层信号)。

建议修复:

if needed > len(free_list):
    raise ValueError(
        f"head-wise gpu free block num: {len(free_list)} < needed number {needed}"
    )

logger.debug(f"{req_id} start allocate (head-wise)...")
flat = [heapq.heappop(free_list) for _ in range(needed)]
# Head-major reshape: row h contains the num_blocks cache ids assigned to KV head h.
allocated = [flat[h * num_blocks : (h + 1) * num_blocks] for h in range(kv_num_heads)]
logger.info(
f"req_id:{req_id} allocate_gpu_blocks_head_wise: {allocated}, "
f"len(gpu_free_head_wise_block_list) {len(free_list)}"
)
main_process_metrics.free_gpu_block_num.set(len(free_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
return allocated

def recycle_gpu_blocks_head_wise(self, cache_ids, req_id=None):
"""
Recycle head-wise cache ids back into the free heap.

Accepts either a flat list/tuple of ids or a nested list-of-lists
(head-major shape from :meth:`allocate_gpu_blocks_head_wise`).
Duplicates are dropped and out-of-range ids are warned and skipped
(never raised) so a single bad caller cannot poison the heap.

Mirrors the ``prefix_tree_status_signal`` early-return guarded by
:meth:`recycle_gpu_blocks`.
"""
if (
hasattr(self, "prefix_tree_status_signal")
and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL
):
logger.warning("Prefix tree is not normal, skip recycle gpu blocks (head-wise)")
return

# Auto-flatten nested input.
if cache_ids and isinstance(cache_ids[0], (list, tuple)):
flat = [cid for row in cache_ids for cid in row]
elif isinstance(cache_ids, (list, tuple)):
flat = list(cache_ids)
else:
flat = [cache_ids]

total = self.total_head_wise_cache_ids
seen = set()
valid = []
for cid in flat:
if cid in seen:
logger.warning(f"req_id:{req_id} head-wise recycle: duplicate cache id {cid} dropped")
continue
if not (0 <= int(cid) < total):
logger.warning(
f"req_id:{req_id} head-wise recycle: out-of-range cache id {cid} "
f"(valid range [0, {total})) dropped"
)
continue
seen.add(cid)
valid.append(cid)

free_list = self.gpu_free_head_wise_block_list
for cid in valid:
heapq.heappush(free_list, cid)
logger.info(
f"req_id:{req_id} recycle_gpu_blocks_head_wise: pushed {len(valid)} ids, "
f"len(gpu_free_head_wise_block_list) {len(free_list)}"
)
main_process_metrics.free_gpu_block_num.set(len(free_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)

def allocate_cpu_blocks(self, num_blocks):
"""
allocate cpu blocks.
Expand Down
21 changes: 20 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@

# Some model suffixes are based on auto classes from Transformers:
# https://huggingface.co/docs/transformers/en/model_doc/auto
# NOTE: Items higher on this list priority over lower ones
# NOTE: Items higher on this list priority over lower ones.
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForCausalLM", ("generate", "none")),
("ForConditionalGeneration", ("generate", "none")),
Expand Down Expand Up @@ -2043,6 +2043,25 @@ def __init__(
self.read_from_config()
self.postprocess()
self.init_pd_info()
# T53 PR1 — engine-main FDConfig fixture for per-head SWA block recycle.
# ResourceManagerV1._should_use_head_wise_swa (resource_manager_v1.py:298-305)
# reads model_config.head_wise_swa_ratio from the engine-main FDConfig instance.
# The worker-side mutation at paddleformers/base.py:793-804 sets the same attrs
# on a DIFFERENT FDConfig copy (worker process). This block mirrors that mutation
# in the engine-main process so the dispatcher gate is not dormant.
# Guards are identical to the worker side — idempotent if already set.
if envs.FD_T53_HEAD_WISE_SWA_FIXTURE:
cfg = self.model_config
n_kv = getattr(cfg, "num_key_value_heads", 1) or 1
ratio = envs.FD_T53_HEAD_WISE_SWA_RATIO if envs.FD_T53_HEAD_WISE_SWA_RATIO is not None else (1.0 / n_kv)
if getattr(cfg, "window_size", None) is None:
cfg.window_size = 4096
if getattr(cfg, "sink_size", None) is None:
cfg.sink_size = 0
if getattr(cfg, "window_attn_skip_freq", None) is None:
cfg.window_attn_skip_freq = 1
if getattr(cfg, "head_wise_swa_ratio", None) is None:
cfg.head_wise_swa_ratio = ratio
if test_mode:
return
self.check()
Expand Down
Loading
Loading