Skip to content

Commit 187cbda

Browse files
committed
fix(linear-att): fix latent prefix-cache ref/buffer leaks
Four latent defects in LinearAttPagedRadixCache / _linear_att_free_req, found via adversarial audit + property-based fuzzing: - Root in eviction set: _add_node added root to _evict_tree_set when the tree emptied (root is a leaf then), contradicting _evict's own 'assert node is not self.root_node'. Now excluded. - Root ref leak on miss / trim-to-empty: match_prefix takes a root ref on descent; the two 'no match' returns handed None to the caller without releasing it, so root.ref_counter drifted up on every miss. Both returns now release it. - Root ref leak in deref_to_first_big_page_node: the big-page downgrade path leaked the same root ref when it bottomed out at root. Fixed. - Big-page state-buffer leak / assert-crash: big-page state ids accumulated in req.linear_att_len_to_big_page_id during chunked prefill were neither inserted nor freed when a request was paused/aborted mid-prefill (fallback branch of _linear_att_free_req), tripping free_a_req_mem's assert (worker crash) or leaking slots with asserts off. Now released on the non-insert exit paths. New CPU-only tests (no GPU): property-based invariant fuzzers for the small-page and big-page regimes, plus regression tests for the pause/abort big-page release. The three root-ref issues are latent (root carries zero tokens and is never evictable). The big-page leak is a reachable worker crash for long-context serving with --linear_att_page_block_num set.
1 parent 3a15cb0 commit 187cbda

5 files changed

Lines changed: 821 additions & 1 deletion

File tree

lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ def _discard_node(self, node: LinearAttPagedTreeNode):
163163
return
164164

165165
def _add_node(self, node: LinearAttPagedTreeNode):
166-
if node.is_leaf():
166+
# root 永远不参与回收:当树为空时 root 自身也满足 is_leaf(),若加入 _evict_tree_set,
167+
# 会与 _evict 中 "node is not self.root_node" 的断言相矛盾(当前仅靠 root 的 ref_counter>=1
168+
# 和回收水位 guard 掩盖)。这里显式排除,使数据结构与回收逻辑的意图一致。
169+
if node.is_leaf() and node is not self.root_node:
167170
self._evict_tree_set.add(node)
168171
if node.small_page_buffer_idx is not None:
169172
self._evict_tree_set_for_linear_att.add(node)
@@ -362,12 +365,18 @@ def match_prefix(
362365
ans_node_list=ans_node_list,
363366
update_refs=update_refs,
364367
)
368+
# _match_prefix_helper 进入时一定对 root 自增了一次 ref_counter。命中链非空时,调用方最终会
369+
# 通过 dec_node_ref_counter(ans_node) 沿父链回收(含 root),增减平衡;但下面两个 "命中为空"
370+
# 的提前返回会把 None 交给调用方,调用方不会再回收,root 自增就无人抵消,导致 root.ref_counter
371+
# 在每次 miss / trim 到空时持续漂移。这里显式补偿这一次 root 自增。
365372
if len(ans_node_list) == 0:
373+
self.dec_node_ref_counter(self.root_node)
366374
return None, 0, None
367375

368376
# 判定真正可以用的匹配节点。
369377
ans_node_list = self._trim_unusable_match_tail(ans_node_list)
370378
if len(ans_node_list) == 0:
379+
self.dec_node_ref_counter(self.root_node)
371380
return None, 0, None
372381

373382
ans_node = ans_node_list[-1]
@@ -482,6 +491,9 @@ def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional
482491
iter_node = iter_node.parent
483492

484493
if iter_node is self.root_node:
494+
# 没有可承接的 big-page 节点交给调用方释放:root 在 match 阶段同样被 +1,
495+
# 这里必须补偿,否则与 match_prefix miss 路径同类的 root ref 漂移。
496+
self.dec_node_ref_counter(self.root_node)
485497
return None
486498
else:
487499
return iter_node

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"):
149149
req.shared_kv_node = None
150150
return
151151

152+
def _release_pending_linear_att_big_page_ids(self, req: "InferReq"):
153+
# 释放本请求 prefill 阶段在 big page 边界上申请、但尚未插入 radix cache 的 big page
154+
# state buffer。仅当请求未走 insert 分支(小页/大页插入)就被释放时才会有残留,典型场景:
155+
# big page 模式下请求在 prefill 跨过 big page 边界后、到达末尾前被 pause / abort。
156+
# 若不释放,会泄漏 big page state slot,并触发 free_a_req_mem 中 dict 为空的断言。
157+
if req.linear_att_len_to_big_page_id:
158+
self.radix_cache.linear_att_big_page_buffers.free_state_cache(
159+
list(req.linear_att_len_to_big_page_id.values())
160+
)
161+
req.linear_att_len_to_big_page_id.clear()
162+
return
163+
152164
def _linear_att_free_req(self, free_token_index: List, req: "InferReq"):
153165
assert g_infer_context.is_linear_att_mixed_model is True
154166
args = get_env_start_args()
@@ -164,6 +176,7 @@ def _linear_att_free_req(self, free_token_index: List, req: "InferReq"):
164176
assert req.linear_att_cache_len <= req.cur_kv_len
165177

166178
if req.cur_kv_len == 0:
179+
self._release_pending_linear_att_big_page_ids(req)
167180
return
168181

169182
if req.linear_att_cache_len <= req.cur_kv_len and req.tail_linear_att_small_page_buffer_id is not None:
@@ -232,6 +245,9 @@ def _linear_att_free_req(self, free_token_index: List, req: "InferReq"):
232245
assert req.shared_kv_node.node_prefix_total_len == req.cur_kv_len
233246
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
234247
req.shared_kv_node = None
248+
# 该分支不会把 prefill 阶段累积的 big page id 插入 radix cache(典型为 pause/abort
249+
# 在 prefill 跨过 big page 边界后、到达末尾前触发),需在此显式释放,避免泄漏。
250+
self._release_pending_linear_att_big_page_ids(req)
235251
return
236252

237253
assert False, f"error state: cur_kv_len: {req.cur_kv_len}"
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""Big-page-regime coverage + invariant fuzz for LinearAttPagedRadixCache.
2+
3+
Active in production only when --linear_att_page_block_num is set (e.g. the GSM8K
4+
launch scripts use 8). Here big_page_num is small so inserts create big-page nodes
5+
plus an optional small tail, mirroring _linear_att_free_req's two insert calls and
6+
copy_linear_att_state_to_cache_buffer's len_to_big_page_id construction.
7+
"""
8+
import uuid
9+
10+
import numpy as np
11+
import pytest
12+
import torch
13+
from sortedcontainers import SortedDict
14+
15+
from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache
16+
from lightllm.utils.kv_cache_utils import compute_token_list_hash
17+
18+
PAGE = 4
19+
BIGN = 2
20+
BIG_TOKENS = PAGE * BIGN
21+
22+
23+
class FakePool:
24+
def __init__(self, size):
25+
self.size = size
26+
self.free_set = set(range(size))
27+
self.order = list(range(size))
28+
29+
def alloc_one_state_cache(self):
30+
if not self.order:
31+
return None
32+
i = self.order.pop(0)
33+
self.free_set.discard(i)
34+
return i
35+
36+
def free_state_cache(self, free_indexes):
37+
for i in free_indexes:
38+
assert i is not None and i not in self.free_set, f"double free {i}"
39+
self.free_set.add(i)
40+
self.order.append(i)
41+
42+
def get_free_cache_num(self):
43+
return len(self.order)
44+
45+
46+
class FakeAllocator:
47+
def __init__(self, size):
48+
self.size = size
49+
self.can_use_mem_size = size
50+
51+
52+
class FakeMem:
53+
def __init__(self, size, big_pool):
54+
self.allocator = FakeAllocator(size)
55+
self.linear_att_big_page_buffers = big_pool
56+
57+
def free(self, mem_index):
58+
self.allocator.can_use_mem_size += len(mem_index)
59+
60+
61+
def build(small_size=32, big_size=64, mem=400_000):
62+
small = FakePool(small_size)
63+
big = FakePool(big_size)
64+
mm = FakeMem(mem, big)
65+
cache = LinearAttPagedRadixCache(
66+
unique_name=f"bp_{uuid.uuid4().hex[:8]}",
67+
total_token_num=mem,
68+
rank_in_node=0,
69+
hash_page_size=PAGE,
70+
big_page_num=BIGN,
71+
kv_cache_mem_manager=mm,
72+
linear_att_small_page_buffers=small,
73+
)
74+
return cache, small, big, mm
75+
76+
77+
def walk(cache):
78+
out = []
79+
st = list(cache.root_node.children.values())
80+
while st:
81+
n = st.pop()
82+
out.append(n)
83+
st.extend(n.children.values())
84+
return out
85+
86+
87+
def page_tokens(pid):
88+
return list(range(pid * PAGE, pid * PAGE + PAGE))
89+
90+
91+
def hashes_for(pids):
92+
toks = []
93+
for p in pids:
94+
toks += page_tokens(p)
95+
toks.append(-1)
96+
return compute_token_list_hash(toks, PAGE)
97+
98+
99+
def check(cache, small, big):
100+
nodes = walk(cache)
101+
# structural + accounting
102+
total = 0
103+
refed = 0
104+
for n in nodes:
105+
assert n.parent is not None
106+
assert n.node_prefix_total_len == n.parent.node_prefix_total_len + n.node_value_len
107+
assert n.ref_counter >= 0
108+
assert n.node_value_len == len(n.token_mem_index_value)
109+
if n.is_big_page_node():
110+
assert n.page_num == BIGN and n.node_value_len == BIG_TOKENS
111+
assert n.big_page_buffer_idx is not None
112+
assert n.small_page_buffer_idx is None
113+
else:
114+
assert n.page_num == 1 and n.node_value_len == PAGE
115+
assert n.big_page_buffer_idx is None
116+
total += n.node_value_len
117+
if n.ref_counter > 0:
118+
refed += n.node_value_len
119+
for k, c in n.children.items():
120+
assert c.page_hash == k and c.parent is n
121+
assert cache.get_tree_total_tokens_num() == total
122+
assert cache.get_refed_tokens_num() == refed
123+
# evict set == non-root leaves
124+
leaves = {id(n) for n in nodes if n.is_leaf()}
125+
assert {id(n) for n in cache._evict_tree_set} == leaves
126+
assert id(cache.root_node) not in {id(n) for n in cache._evict_tree_set}
127+
# buffer-evict set == small-buffer holders
128+
assert {id(n) for n in cache._evict_tree_set_for_linear_att} == {
129+
id(n) for n in nodes if n.small_page_buffer_idx is not None
130+
}
131+
# big-page id conservation
132+
big_in_tree = [n.big_page_buffer_idx for n in nodes if n.is_big_page_node()]
133+
assert len(big_in_tree) == len(set(big_in_tree)), "big-page id reused by two nodes"
134+
assert set(big_in_tree).isdisjoint(big.free_set)
135+
assert set(big_in_tree) | big.free_set == set(range(big.size)), "big-page id leaked"
136+
# small-page id conservation
137+
small_in_tree = [n.small_page_buffer_idx for n in nodes if n.small_page_buffer_idx is not None]
138+
assert len(small_in_tree) == len(set(small_in_tree))
139+
assert set(small_in_tree).isdisjoint(small.free_set)
140+
assert set(small_in_tree) | small.free_set == set(range(small.size)), "small-page id leaked"
141+
142+
143+
def make_insert(cache, small, big):
144+
"""Mirror _linear_att_free_req: big-page-aligned prefix (+ optional small tail)."""
145+
146+
def insert(pids, mem_base, with_small_tail):
147+
L = len(pids)
148+
num_big = L // BIGN
149+
# len_to_big_page_id: one fresh big id per big-page boundary along the path
150+
l2b = SortedDict()
151+
big_ids_alloced = []
152+
for j in range(1, num_big + 1):
153+
bid = big.alloc_one_state_cache()
154+
if bid is None:
155+
# big pool exhausted: the real caller would not start this insert; roll back.
156+
for got in big_ids_alloced:
157+
big.free_state_cache([got])
158+
return
159+
big_ids_alloced.append(bid)
160+
l2b[j * BIG_TOKENS] = bid
161+
hashs = hashes_for(pids)
162+
key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64)
163+
value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64)
164+
linear_idxs = [None] * L
165+
tail_buf = None
166+
if with_small_tail and (L % BIGN != 0):
167+
tail_buf = small.alloc_one_state_cache()
168+
if tail_buf is None:
169+
# contract: cannot insert a None-tailed non-aligned path; drop the tail page
170+
pids = pids[:-1]
171+
L = len(pids)
172+
if L == 0:
173+
# nothing to insert; release any big ids we grabbed (none, since num_big recomputed)
174+
for bid in big_ids_alloced:
175+
big.free_state_cache([bid])
176+
return
177+
hashs = hashes_for(pids)
178+
key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64)
179+
value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64)
180+
linear_idxs = [None] * L
181+
else:
182+
linear_idxs[-1] = tail_buf
183+
elif L % BIGN != 0:
184+
# no small tail wanted but path is not big-aligned -> trim to aligned length
185+
pids = pids[: num_big * BIGN]
186+
L = len(pids)
187+
if L == 0:
188+
for bid in big_ids_alloced:
189+
big.free_state_cache([bid])
190+
return
191+
hashs = hashes_for(pids)
192+
key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64)
193+
value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64)
194+
linear_idxs = [None] * L
195+
196+
before_small = set(small.free_set)
197+
cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs, len_to_big_page_id=l2b)
198+
# any tail buffer that was a duplicate got freed by the cache; nothing to track
199+
_ = before_small
200+
201+
return insert
202+
203+
204+
def test_pure_bigpage_insert_and_match():
205+
cache, small, big = build()[:3]
206+
ins = make_insert(cache, small, big)
207+
# 4 pages -> 2 big pages, no small tail
208+
ins([1, 2, 3, 4], 1000, with_small_tail=False)
209+
check(cache, small, big)
210+
assert cache.get_tree_total_tokens_num() == 4 * PAGE
211+
212+
hashs = hashes_for([1, 2, 3, 4])
213+
key = torch.tensor([t for p in [1, 2, 3, 4] for t in page_tokens(p)], dtype=torch.int64)
214+
node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True)
215+
assert node is not None and node.is_big_page_node()
216+
assert kv == 16 and len(mem) == 16
217+
assert torch.equal(mem, torch.arange(1000, 1016, dtype=torch.int64))
218+
cache.dec_node_ref_counter(node)
219+
check(cache, small, big)
220+
221+
222+
def test_mixed_insert_match_trims_to_bigpage_when_tail_unusable():
223+
cache, small, big = build(small_size=1)[:3]
224+
ins = make_insert(cache, small, big)
225+
# 5 pages -> 2 big pages (8 tokens *2 =16) + 1 small tail page (4) = 20 tokens
226+
ins([1, 2, 3, 4, 5], 2000, with_small_tail=True)
227+
check(cache, small, big)
228+
assert cache.get_tree_total_tokens_num() == 20
229+
230+
# exhaust small pool and steal the tail buffer -> tail page unusable
231+
while small.alloc_one_state_cache() is not None:
232+
pass
233+
cache.free_one_small_page_linear_att_buffer()
234+
check(cache, small, big)
235+
236+
hashs = hashes_for([1, 2, 3, 4, 5])
237+
key = torch.tensor([t for p in [1, 2, 3, 4, 5] for t in page_tokens(p)], dtype=torch.int64)
238+
node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True)
239+
# tail small page has no buffer -> trim back to the last big-page boundary (16)
240+
assert node is not None and node.is_big_page_node()
241+
assert kv == 16
242+
cache.dec_node_ref_counter(node)
243+
check(cache, small, big)
244+
245+
246+
@pytest.mark.parametrize("seed", list(range(10)))
247+
def test_bigpage_fuzz(seed):
248+
rng = np.random.default_rng(seed)
249+
cache, small, big, mm = build(small_size=10, big_size=48, mem=400_000)
250+
ins = make_insert(cache, small, big)
251+
live = []
252+
mem_base = [10_000]
253+
254+
def do_ins():
255+
L = int(rng.integers(1, 7))
256+
pids = [int(rng.integers(0, 25)) for _ in range(L)]
257+
ins(pids, mem_base[0], with_small_tail=bool(rng.integers(0, 2)))
258+
mem_base[0] += 100
259+
260+
def do_match():
261+
L = int(rng.integers(1, 7))
262+
pids = [int(rng.integers(0, 25)) for _ in range(L)]
263+
hashs = hashes_for(pids)
264+
key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64)
265+
node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True)
266+
if node is None:
267+
assert kv == 0 and mem is None
268+
return
269+
assert kv == node.node_prefix_total_len == len(mem)
270+
assert node.is_big_page_node() or node.small_page_buffer_idx is not None
271+
live.append(node)
272+
273+
def do_dec():
274+
if live:
275+
cache.dec_node_ref_counter(live.pop(int(rng.integers(0, len(live)))))
276+
277+
def do_steal():
278+
cache.free_one_small_page_linear_att_buffer()
279+
280+
def do_evict():
281+
unref = cache.get_tree_total_tokens_num() - cache.get_refed_tokens_num()
282+
if unref < PAGE:
283+
return
284+
need = int(rng.integers(1, unref // PAGE + 1)) * PAGE
285+
cache._evict(need, lambda m, b: small.free_state_cache([b]) if b is not None else None)
286+
287+
ops = [do_ins, do_ins, do_match, do_match, do_dec, do_steal, do_evict]
288+
for _ in range(400):
289+
ops[int(rng.integers(0, len(ops)))]()
290+
check(cache, small, big)
291+
assert cache.root_node.ref_counter == 1 + len(live), "root ref drifted (big-page regime)"
292+
293+
while live:
294+
cache.dec_node_ref_counter(live.pop())
295+
assert cache.get_refed_tokens_num() == 0
296+
t = cache.get_tree_total_tokens_num()
297+
if t:
298+
cache._evict(t, lambda m, b: small.free_state_cache([b]) if b is not None else None)
299+
assert cache.get_tree_total_tokens_num() == 0
300+
check(cache, small, big)

0 commit comments

Comments
 (0)