Skip to content

Commit 3fe08db

Browse files
feat: Add prefix cache benchmark
This commit introduces a new benchmark to test the performance of prefix caching in JetStream. The benchmark (`benchmark_prefix_cache.sh`) allows testing with various prompt lengths and common prefix lengths. It utilizes a new mock dataset generated by `load_mock_prefix_cache_test_input_requests` in `benchmark_serving.py`, which creates prompts sharing common prefixes of varying lengths based on a normal distribution. Key changes include: - New script `benchmarks/benchmark_prefix_cache.sh` to orchestrate prefix cache benchmark runs. - Added `PrefixCacheTestTokenizer` for simple character-to-ordinal tokenization, suitable for controlled prefix testing. - Implemented `load_mock_prefix_cache_test_input_requests` in `benchmark_serving.py` to generate test data with shared prefixes. - Added `prefix_cache_test` as a dataset option and `--prefix-cache-test-common-len` argument to `benchmark_serving.py`. - Updated `benchmarks/README.md` with instructions on how to run the new prefix cache benchmark.
1 parent 219e5a1 commit 3fe08db

3 files changed

Lines changed: 234 additions & 1 deletion

File tree

benchmarks/README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ python benchmark_serving.py \
7575
```
7676

7777
## Benchmark with openorca dataset (openorca is used by MLPerf inference for LLaMA2 models)
78+
7879
```
7980
python JetStream/benchmarks/benchmark_serving.py \
8081
--tokenizer ~/maxtext/assets/tokenizer.llama2 \
@@ -93,6 +94,7 @@ python JetStream/benchmarks/benchmark_serving.py \
9394
The benchmark has better performance if it first conducts a warmup of the JetStream server. We currently support `sampled` and `full` warmup modes. `full` mode would warmup up the JetStream server with all the input requests. `sampled` mode would warmup up the JetStream server with a sampling of the input requests across different bucket sizes of input lengths.
9495

9596
Example to run benchmark with `full` warmup mode:
97+
9698
```
9799
python JetStream/benchmarks/benchmark_serving.py \
98100
--tokenizer ~/maxtext/assets/tokenizer.llama2 \
@@ -115,7 +117,25 @@ python eval_accuracy.py outputs.json
115117
```
116118

117119
With openorca dataset and llama2-chat models (used by MLPerf), here are the reference accuracy numbers:
120+
118121
```
119122
llama2-7b-chat {'rouge1': 42.0706, 'rouge2': 19.8021, 'rougeL': 26.8474, 'rougeLsum': 39.5952, 'gen_len': 1146679, 'gen_num': 998}
120123
llama2-70b-chat {'rouge1': 44.4312, 'rouge2': 22.0352, 'rougeL': 28.6162}
121-
```
124+
```
125+
126+
## Benchmark prefix cache
127+
128+
Benchmark with mock input requests that share common prefix. Use to test prefix caching.
129+
130+
All prompts length is `max-input-length`, and share common prefix mean at length `--prefix-cache-test-common-len` with normal distribution.
131+
132+
```
133+
python JetStream/benchmarks/benchmark_serving.py \
134+
--tokenizer prefix_cache_test \
135+
--dataset prefix_cache_test
136+
--warmup-mode full \
137+
--num-prompts 100 \
138+
--max-input-length 16000 \
139+
--prefix-cache-test-common-len 9000\
140+
--max-output-length 50 \
141+
```
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#!/bin/bash
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
set -e
17+
18+
NUM_PROMPTS=${NUM_PROMPTS:-100}
19+
MAX_OUTPUT_LENGTH=${MAX_OUTPUT_LENGTH:-50}
20+
21+
# Test combination from lengths and common prefix lengths.
22+
# The length should be shorter than max_input_length minus 1 for bos.
23+
BENCHMARK_PROMPT_LENGTHS=${BENCHMARK_PROMPT_LENGTHS:-8000,16000}
24+
BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS=${BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS:-4000,6000,8000,10000,12000,14000,16000}
25+
26+
benchmark_serving_with_prefix_cache() {
27+
echo "Starting prefix cache benchmark..."
28+
echo "Benchmark serving script: ${BENCHMARK_SERVING_SCRIPT_PATH}"
29+
echo "Prompt lengths to test: ${BENCHMARK_PROMPT_LENGTHS}"
30+
echo "Common prefix lengths to test: ${BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS}"
31+
echo "Number of prompts per run: ${NUM_PROMPTS}"
32+
echo "Max output length per prompt: ${MAX_OUTPUT_LENGTH}"
33+
echo "Base output directory for results: ${OUTPUTS_DIR_BASE}"
34+
echo "Warmup mode: ${WARMUP_MODE}"
35+
36+
# Convert comma-separated strings to arrays for iteration
37+
IFS=',' read -r -a prompt_lengths_arr <<< "$BENCHMARK_PROMPT_LENGTHS"
38+
IFS=',' read -r -a common_prefix_lengths_arr <<< "$BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS"
39+
40+
for prompt_len in "${prompt_lengths_arr[@]}"; do
41+
for common_len in "${common_prefix_lengths_arr[@]}"; do
42+
if [ "${common_len}" -gt "${prompt_len}" ]; then
43+
echo "Skipping: Common prefix length ${common_len} is greater than prompt length ${prompt_len}."
44+
continue
45+
fi
46+
47+
echo "----------------------------------------------------------------------"
48+
echo "Running benchmark: Prompt Length=${prompt_len}, Common Prefix Length=${common_len}"
49+
echo "----------------------------------------------------------------------"
50+
echo "Warm up twice"
51+
echo "----------------------------------------------------------------------"
52+
53+
# With warmup-mode full, it will run twice
54+
python3 ./benchmark_serving.py \
55+
--tokenizer "prefix_cache_test" \
56+
--dataset "prefix_cache_test" \
57+
--num-prompts 10 \
58+
--max-output-length "${MAX_OUTPUT_LENGTH}" \
59+
--warmup-mode "full" \
60+
--max-input-length "${prompt_len}" \
61+
--prefix-cache-test-common-len "${common_len}"
62+
63+
echo "Warm up done"
64+
echo "----------------------------------------------------------------------"
65+
66+
python3 ./benchmark_serving.py \
67+
--tokenizer "prefix_cache_test" \
68+
--dataset "prefix_cache_test" \
69+
--num-prompts "${NUM_PROMPTS}" \
70+
--max-output-length "${MAX_OUTPUT_LENGTH}" \
71+
--warmup-mode "none" \
72+
--max-input-length "${prompt_len}" \
73+
--prefix-cache-test-common-len "${common_len}"
74+
75+
echo "Benchmark finished for Prompt Length=${prompt_len}, Common Prefix Length=${common_len}"
76+
echo "----------------------------------------------------------------------"
77+
echo
78+
done
79+
done
80+
echo "All benchmark runs completed."
81+
}
82+
83+
main() {
84+
benchmark_serving_with_prefix_cache
85+
echo "Script finished."
86+
exit 0
87+
}
88+
89+
main "$@"

benchmarks/benchmark_serving.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,22 @@ def to_dict(self):
209209
}
210210

211211

212+
class PrefixCacheTestTokenizer:
213+
"""A simple tokenizer for testing prefix caching.
214+
215+
This tokenizer converts each character in a string to its integer ordinal
216+
value during encoding, and converts a list of integer ordinals back to
217+
a string during decoding. It's designed for testing scenarios, particularly
218+
those involving prefix caching, where a basic, predictable tokenizer is needed.
219+
"""
220+
221+
def encode(self, s: str, **kwargs) -> list[int]:
222+
return [ord(c) for c in s]
223+
224+
def decode(self, token_ids: list[int], **kwargs) -> str:
225+
return "".join([chr(token_id) for token_id in token_ids])
226+
227+
212228
def get_tokenizer(
213229
model_id: str,
214230
tokenizer_name: str,
@@ -219,6 +235,9 @@ def get_tokenizer(
219235
if tokenizer_name == "test":
220236
print("Using test tokenizer")
221237
return "test"
238+
elif tokenizer_name == "prefix_cache_test":
239+
print("Using prefix_cache_test tokenizer")
240+
return PrefixCacheTestTokenizer()
222241
elif use_hf_tokenizer:
223242
# Please accept agreement to access private/gated models in HF, and
224243
# follow up instructions below to set up access token
@@ -329,6 +348,97 @@ def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
329348
return combined_dataset, prompts_per_subject
330349

331350

351+
def load_mock_prefix_cache_test_input_requests(
352+
prompt_len: int,
353+
output_len: int,
354+
common_prefix_len: int,
355+
num_samples: int,
356+
) -> list[InputRequest]:
357+
"""Generates a mock dataset for testing prefix cache.
358+
359+
The prefix part of each prompt is a sub-string of a single master string.
360+
The length of this prefix part for each sample is drawn from a normal
361+
distribution with its mean set to `common_prefix_len`, and values are
362+
clipped to the range [0, `prompt_len`].
363+
The tokenizer is assumed to treat each character as a token.
364+
365+
Args:
366+
prompt_len: The total length of each generated prompt string.
367+
output_len: The length of each generated output string.
368+
common_prefix_len: The target mean for the length of the prefix part
369+
of each prompt. These prefixes are derived from a
370+
shared master string.
371+
num_samples: The number of (prompt, output) pairs to generate.
372+
373+
Returns:
374+
A list of InputRequest objects.
375+
"""
376+
if not (0 <= common_prefix_len <= prompt_len):
377+
raise ValueError(
378+
"Target mean common_prefix_len must be between 0 and prompt_len,"
379+
f" inclusive. Got common_prefix_len={common_prefix_len}, prompt_len={prompt_len}"
380+
)
381+
if any(arg <= 0 for arg in [prompt_len, output_len, num_samples]):
382+
raise ValueError(
383+
"prompt_len, output_len, and num_samples cannot be 0 or negative."
384+
)
385+
386+
input_requests: list[InputRequest] = []
387+
388+
# Generate a master string from which all prefixes will be derived.
389+
# This ensures that prefixes of the same length are identical,
390+
# and shorter prefixes are actual prefixes of longer ones.
391+
master_potential_prefix = "".join(
392+
random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=prompt_len)
393+
)
394+
395+
# Generate prefix lengths for each sample from a normal distribution
396+
scale = prompt_len / 3.0 # Standard deviation for the normal distribution
397+
398+
generated_prefix_lengths = np.random.normal(
399+
loc=common_prefix_len, scale=scale, size=num_samples
400+
)
401+
generated_prefix_lengths = (
402+
np.clip(generated_prefix_lengths, 0, prompt_len).round().astype(int)
403+
)
404+
405+
for idx in range(num_samples):
406+
current_actual_prefix_len = generated_prefix_lengths[idx]
407+
408+
actual_prefix_for_sample = master_potential_prefix[
409+
:current_actual_prefix_len
410+
]
411+
412+
current_unique_len = prompt_len - current_actual_prefix_len
413+
# This should not happen if generated_prefix_lengths is clipped correctly
414+
if current_unique_len < 0:
415+
current_unique_len = 0 # Safeguard
416+
current_actual_prefix_len = prompt_len
417+
actual_prefix_for_sample = master_potential_prefix[
418+
:current_actual_prefix_len
419+
]
420+
421+
unique_suffix_str = "".join(
422+
random.choices(
423+
"abcdefghijklmnopqrstuvwxyz0123456789", k=current_unique_len
424+
)
425+
)
426+
427+
prompt_str = actual_prefix_for_sample + unique_suffix_str
428+
429+
output_str = "".join(random.choices("!@#$%^&*()_+", k=output_len))
430+
431+
request = InputRequest(
432+
prompt=prompt_str,
433+
prompt_len=len(prompt_str),
434+
output=output_str,
435+
output_len=len(output_str),
436+
sample_idx=idx,
437+
)
438+
input_requests.append(request)
439+
return input_requests
440+
441+
332442
def gen_mmlu_qa(data: Any, mmlu_method: str = "") -> str:
333443

334444
output = ""
@@ -893,6 +1003,7 @@ def parse_args() -> argparse.Namespace:
8931003
"mmlu",
8941004
"math500",
8951005
"longcontext",
1006+
"prefix_cache_test",
8961007
],
8971008
help="The dataset name.",
8981009
)
@@ -1086,6 +1197,12 @@ def parse_args() -> argparse.Namespace:
10861197
choices=["HELM", "Harness", ""],
10871198
help="mmlu method/format to generate shots",
10881199
)
1200+
parser.add_argument(
1201+
"--prefix-cache-test-common-len",
1202+
type=int,
1203+
default=64,
1204+
help="Common prefix length for the prefix cache test dataset.",
1205+
)
10891206
return parser.parse_args()
10901207

10911208

@@ -1112,6 +1229,13 @@ def main(args: argparse.Namespace):
11121229
input_requests = mock_requests(
11131230
args.total_mock_requests
11141231
) # e.g. [("AB", 2, "AB", 3)]
1232+
elif args.dataset == "prefix_cache_test":
1233+
input_requests = load_mock_prefix_cache_test_input_requests(
1234+
prompt_len=args.max_input_length,
1235+
output_len=args.max_output_length,
1236+
common_prefix_len=args.prefix_cache_test_common_len,
1237+
num_samples=args.num_prompts,
1238+
)
11151239
else:
11161240
dataset = []
11171241
if args.dataset == "openorca":

0 commit comments

Comments
 (0)