Skip to content

Commit fdfc908

Browse files
authored
[Others] reuse unit test (#7127)
1 parent 6cae9b1 commit fdfc908

4 files changed

Lines changed: 1283 additions & 0 deletions

File tree

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from fastdeploy.cache_manager.multimodal_cache_manager import EncoderCacheManager
16+
from fastdeploy.engine.request import ImagePosition
17+
18+
19+
def test_mm_encoder_cache():
20+
max_encoder_cache = 4096
21+
encoder_cache = EncoderCacheManager(max_encoder_cache=max_encoder_cache)
22+
23+
mm_hashes = ["mm_hash1", "mm_hash2"]
24+
mm_positions = [ImagePosition(offset=120, length=400), ImagePosition(offset=620, length=800)]
25+
26+
cache_length = mm_positions[0].length + mm_positions[1].length
27+
evict_hashes = encoder_cache.apply_cache(mm_hashes=mm_hashes, mm_items=mm_positions)
28+
assert evict_hashes == [], "The evicted hashes should be empty."
29+
assert list(encoder_cache.cache.keys()) == [
30+
"mm_hash1",
31+
"mm_hash2",
32+
], "The cache should contain mm_hash1 and mm_hash2."
33+
assert (
34+
encoder_cache.current_cache_size == cache_length
35+
), "The cache size should be the sum of the lengths of mm_hash1 and mm_hash2."
36+
assert (
37+
encoder_cache.current_cache_size <= max_encoder_cache
38+
), "The cache size should be less than or equal to the max_encoder_cache."
39+
40+
mm_hashes = ["mm_hash3", "mm_hash4"]
41+
mm_positions = [ImagePosition(offset=20, length=1204), ImagePosition(offset=1800, length=2048)]
42+
cache_length += mm_positions[0].length + mm_positions[1].length - 400
43+
evict_hashes = encoder_cache.apply_cache(mm_hashes=mm_hashes, mm_items=mm_positions)
44+
assert evict_hashes == ["mm_hash1"], "The evicted hashes should be mm_hash1."
45+
assert list(encoder_cache.cache.keys()) == [
46+
"mm_hash2",
47+
"mm_hash3",
48+
"mm_hash4",
49+
], "The cache should contain mm_hash2, mm_hash3, and mm_hash4."
50+
assert (
51+
encoder_cache.current_cache_size == cache_length
52+
), "The cache size should be the sum of the lengths of mm_hash2, mm_hash3, and mm_hash4."
53+
assert (
54+
encoder_cache.current_cache_size <= max_encoder_cache
55+
), "The cache size should be less than or equal to the max_encoder_cache."
56+
57+
evict_hashes = encoder_cache.apply_cache(mm_hashes=["mm_hash2"], mm_items=[ImagePosition(offset=620, length=800)])
58+
assert evict_hashes == [], "The evicted hashes should be empty."
59+
assert (
60+
encoder_cache.current_cache_size == cache_length
61+
), "The cache size should be the sum of the lengths of mm_hash2, mm_hash3, and mm_hash4."
62+
assert (
63+
encoder_cache.current_cache_size <= max_encoder_cache
64+
), "The cache size should be less than or equal to the max_encoder_cache."
65+
66+
cache_length -= 1204
67+
evict_hashes = encoder_cache.evict_cache(needed=800)
68+
assert evict_hashes == ["mm_hash3"], "The evicted hashes should be mm_hash3."
69+
assert list(encoder_cache.cache.keys()) == [
70+
"mm_hash4",
71+
"mm_hash2",
72+
], "The cache should contain mm_hash2 and mm_hash4."
73+
assert (
74+
encoder_cache.current_cache_size == cache_length
75+
), "The cache size should be the sum of the lengths of mm_hash2 and mm_hash4."
76+
assert (
77+
encoder_cache.current_cache_size <= max_encoder_cache
78+
), "The cache size should be less than or equal to the max_encoder_cache."
79+
80+
encoder_cache.clear_cache()
81+
assert encoder_cache.current_cache_size == 0, "The cache size should be 0."
82+
assert list(encoder_cache.cache.keys()) == [], "The cache should be empty."
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import asdict
16+
from types import SimpleNamespace
17+
18+
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
19+
from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig
20+
from fastdeploy.engine.args_utils import EngineArgs
21+
from fastdeploy.engine.request import ImagePosition, Request
22+
from fastdeploy.scheduler import SchedulerConfig
23+
24+
25+
def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_override=100, max_num_batched_tokens=3200):
26+
engine_args = EngineArgs(
27+
max_num_seqs=max_num_seqs,
28+
num_gpu_blocks_override=num_gpu_blocks_override,
29+
max_num_batched_tokens=max_num_batched_tokens,
30+
)
31+
args = asdict(engine_args)
32+
cache_cfg = CacheConfig(args)
33+
model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196)
34+
speculative_cfg = SimpleNamespace(method=None)
35+
model_cfg.print = print
36+
model_cfg.architectures = ["test_model"]
37+
model_cfg.mm_max_tokens_per_item = None
38+
model_cfg.version = None # Required for register_info
39+
cache_cfg.bytes_per_token_per_layer = 1
40+
41+
parallel_cfg = ParallelConfig(args)
42+
scheduler_cfg = SchedulerConfig(args)
43+
graph_opt_cfg = engine_args.create_graph_optimization_config()
44+
fd_config = FDConfig(
45+
model_config=model_cfg,
46+
cache_config=cache_cfg,
47+
parallel_config=parallel_cfg,
48+
graph_opt_config=graph_opt_cfg,
49+
speculative_config=speculative_cfg,
50+
scheduler_config=scheduler_cfg,
51+
)
52+
return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed")
53+
54+
55+
def test_block_num_limit():
56+
import pytest
57+
58+
with pytest.raises(AssertionError):
59+
make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=20)
60+
61+
62+
def test_normal_case():
63+
block_size = 64
64+
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=False, num_gpu_blocks_override=128)
65+
req1 = Request.from_dict({"request_id": "req1", "prompt_token_ids": [1] * 3200, "prompt_token_ids_len": 3200})
66+
req2 = Request.from_dict(
67+
{"request_id": "req2", "prompt_token_ids": [1] * 1600 + [2] * 1600, "prompt_token_ids_len": 3200}
68+
)
69+
req3 = Request.from_dict(
70+
{"request_id": "req3", "prompt_token_ids": [1] * 1600 + [3] * 1600, "prompt_token_ids_len": 3200}
71+
)
72+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
73+
assert len(common_block_ids) == 0
74+
assert matched_token_num == 0
75+
assert len(cache_manager.gpu_free_block_list) == 128
76+
req1.block_tables.extend(common_block_ids)
77+
# allocate for req1 inputs
78+
num_new_block = 50
79+
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
80+
req1.num_computed_tokens += 50 * block_size
81+
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
82+
assert len(cache_manager.gpu_free_block_list) == 78
83+
# allocate for req2 inputs
84+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
85+
assert len(common_block_ids) == 25
86+
assert matched_token_num == 25 * block_size
87+
req2.num_cached_tokens = matched_token_num
88+
req2.num_computed_tokens = 25 * block_size
89+
num_new_block = 25
90+
req2.block_tables.extend(common_block_ids)
91+
req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
92+
cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens)
93+
# allocate for req3 input
94+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size)
95+
assert len(common_block_ids) == 25
96+
assert matched_token_num == 25 * block_size
97+
req3.num_cached_tokens = matched_token_num
98+
req3.num_computed_tokens = 25 * block_size
99+
assert len(cache_manager.gpu_free_block_list) == 53
100+
req3.block_tables.extend(common_block_ids)
101+
num_new_block = 25
102+
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
103+
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
104+
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
105+
assert len(cache_manager.gpu_free_block_list) == 28
106+
107+
108+
def test_mm_extra_keys():
109+
block_size = 64
110+
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True)
111+
112+
prompt_token_ids = [1] * 100 + [2] * 100
113+
req1 = {
114+
"request_id": "req1",
115+
"prompt_token_ids": prompt_token_ids,
116+
"prompt_token_ids_len": len(prompt_token_ids),
117+
}
118+
for idx in range(0, len(prompt_token_ids), block_size):
119+
token_ids_lens = min(block_size, len(prompt_token_ids[idx:]))
120+
mm_idx, extra_keys = cache_manager.get_block_hash_extra_keys(
121+
request=Request.from_dict(req1),
122+
start_idx=idx,
123+
end_idx=idx + token_ids_lens,
124+
mm_idx=0,
125+
)
126+
assert extra_keys == [], f"extra_keys {extra_keys} != [], start_idx {idx}, end_idx {idx + token_ids_lens}"
127+
assert mm_idx == 0, f"mm_idx {mm_idx} != 0, start_idx {idx}, end_idx {idx + token_ids_lens}"
128+
129+
# block 1
130+
prompt_token_ids = [1] * 30 + [-1] * 34
131+
mm_positions = [ImagePosition(offset=30, length=80)]
132+
mm_hashes = ["image1"]
133+
extra_keys_list = [(0, ["image1"])]
134+
135+
# block 2
136+
prompt_token_ids += [-1] * 46 + [2] * 18
137+
extra_keys_list.append((1, ["image1"]))
138+
139+
# block 3
140+
prompt_token_ids += [-1] * 100
141+
mm_positions.append(ImagePosition(offset=128, length=100))
142+
mm_hashes.append("image2")
143+
extra_keys_list.append((1, ["image2"]))
144+
145+
# block 4、5
146+
prompt_token_ids += [3] * 40
147+
extra_keys_list.append((1, ["image2"]))
148+
extra_keys_list.append((1, []))
149+
150+
req2 = {
151+
"request_id": "req2",
152+
"prompt_token_ids": prompt_token_ids,
153+
"prompt_token_ids_len": len(prompt_token_ids),
154+
"multimodal_inputs": {
155+
"mm_positions": mm_positions,
156+
"mm_hashes": mm_hashes,
157+
},
158+
}
159+
160+
mm_idx, key_idx = 0, 0
161+
for idx in range(0, len(prompt_token_ids), block_size):
162+
token_ids_lens = min(block_size, len(prompt_token_ids[idx:]))
163+
mm_idx, extra_keys = cache_manager.get_block_hash_extra_keys(
164+
request=Request.from_dict(req2),
165+
start_idx=idx,
166+
end_idx=idx + token_ids_lens,
167+
mm_idx=mm_idx,
168+
)
169+
170+
target_idx, target_keys = extra_keys_list[key_idx]
171+
assert (
172+
mm_idx == target_idx
173+
), f"mm_idx {mm_idx} != target_idx {target_idx}, start_idx {idx}, end_idx {idx + token_ids_lens}"
174+
assert (
175+
extra_keys == target_keys
176+
), f"extra_keys {extra_keys} != target_keys {target_keys}, start_idx {idx}, end_idx {idx + token_ids_lens}"
177+
key_idx += 1
178+
179+
180+
def test_mm_prefix_cache():
181+
block_size = 64
182+
cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100)
183+
multimodal_inputs = {
184+
"mm_positions": [ImagePosition(offset=120, length=1200)],
185+
"mm_hashes": ["image1"],
186+
}
187+
req1_dict = {
188+
"request_id": "req1",
189+
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120,
190+
"prompt_token_ids_len": 1440,
191+
"multimodal_inputs": multimodal_inputs,
192+
}
193+
req1 = Request.from_dict(req1_dict)
194+
195+
multimodal_inputs = dict(multimodal_inputs)
196+
multimodal_inputs["mm_positions"].append(ImagePosition(offset=1836, length=587))
197+
multimodal_inputs["mm_hashes"].append("image2")
198+
req2_dict = {
199+
"request_id": "req2",
200+
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 396 + [-1] * 587,
201+
"prompt_token_ids_len": 2423,
202+
"multimodal_inputs": multimodal_inputs,
203+
}
204+
req2 = Request.from_dict(req2_dict)
205+
206+
multimodal_inputs = dict(multimodal_inputs)
207+
multimodal_inputs["mm_hashes"] = ["image3", "image4"]
208+
req3_dict = {
209+
"request_id": "req3",
210+
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 396 + [-1] * 587,
211+
"prompt_token_ids_len": 2423,
212+
"multimodal_inputs": multimodal_inputs,
213+
}
214+
req3 = Request.from_dict(req3_dict)
215+
216+
multimodal_inputs = dict(multimodal_inputs)
217+
multimodal_inputs["mm_positions"] = [ImagePosition(offset=120, length=1200)]
218+
multimodal_inputs["mm_hashes"] = ["image3"]
219+
req4_dict = {
220+
"request_id": "req4",
221+
"prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [3] * 352,
222+
"prompt_token_ids_len": 1792,
223+
"multimodal_inputs": multimodal_inputs,
224+
}
225+
req4 = Request.from_dict(req4_dict)
226+
227+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req1, block_size)
228+
assert len(common_block_ids) == 0
229+
assert matched_token_num == 0
230+
assert len(cache_manager.gpu_free_block_list) == 100
231+
req1.block_tables.extend(common_block_ids)
232+
233+
# allocate for req1 inputs
234+
num_new_block = 22
235+
req1.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
236+
req1.num_computed_tokens += 22 * block_size
237+
cache_manager.update_cache_blocks(req1, block_size, req1.num_computed_tokens)
238+
assert len(cache_manager.gpu_free_block_list) == 78
239+
240+
# allocate for req2 inputs
241+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req2, block_size)
242+
assert len(common_block_ids) == 22
243+
assert matched_token_num == 22 * block_size
244+
req2.num_cached_tokens = matched_token_num
245+
req2.num_computed_tokens = matched_token_num
246+
num_new_block = 15
247+
req2.block_tables.extend(common_block_ids)
248+
req2.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
249+
req2.num_computed_tokens += 15 * block_size
250+
cache_manager.update_cache_blocks(req2, block_size, req2.num_computed_tokens)
251+
252+
# allocate for req3 input
253+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req3, block_size)
254+
assert len(common_block_ids) == 1
255+
assert matched_token_num == 1 * block_size
256+
req3.num_cached_tokens = matched_token_num
257+
req3.num_computed_tokens = matched_token_num
258+
assert len(cache_manager.gpu_free_block_list) == 63
259+
req3.block_tables.extend(common_block_ids)
260+
num_new_block = 36
261+
assert cache_manager.can_allocate_gpu_blocks(num_new_block)
262+
req3.block_tables.extend(cache_manager.allocate_gpu_blocks(num_new_block))
263+
req3.num_computed_tokens += 36 * block_size
264+
cache_manager.update_cache_blocks(req3, block_size, req3.num_computed_tokens)
265+
assert len(cache_manager.gpu_free_block_list) == 27
266+
267+
# allocate for req4 input
268+
(common_block_ids, matched_token_num, hit_info) = cache_manager.request_match_blocks(req4, block_size)
269+
assert len(common_block_ids) == 28
270+
assert matched_token_num == 28 * block_size

0 commit comments

Comments
 (0)