Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ python benchmark_serving.py \
```

## Benchmark with openorca dataset (openorca is used by MLPerf inference for LLaMA2 models)

```
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer ~/maxtext/assets/tokenizer.llama2 \
Expand All @@ -93,6 +94,7 @@ python JetStream/benchmarks/benchmark_serving.py \
The benchmark has better performance if it first conducts a warmup of the JetStream server. We currently support `sampled` and `full` warmup modes. `full` mode would warmup up the JetStream server with all the input requests. `sampled` mode would warmup up the JetStream server with a sampling of the input requests across different bucket sizes of input lengths.

Example to run benchmark with `full` warmup mode:

```
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer ~/maxtext/assets/tokenizer.llama2 \
Expand All @@ -115,7 +117,25 @@ python eval_accuracy.py outputs.json
```

With openorca dataset and llama2-chat models (used by MLPerf), here are the reference accuracy numbers:

```
llama2-7b-chat {'rouge1': 42.0706, 'rouge2': 19.8021, 'rougeL': 26.8474, 'rougeLsum': 39.5952, 'gen_len': 1146679, 'gen_num': 998}
llama2-70b-chat {'rouge1': 44.4312, 'rouge2': 22.0352, 'rougeL': 28.6162}
```
```

## Benchmark prefix cache

Benchmark with mock input requests that share common prefix. Use to test prefix caching.

All prompts length is `max-input-length`, and share common prefix mean at length `--prefix-cache-test-common-len` with normal distribution.

```
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer prefix_cache_test \
--dataset prefix_cache_test
--warmup-mode full \
--num-prompts 100 \
--max-input-length 16000 \
--prefix-cache-test-common-len 9000\
--max-output-length 50 \
```
89 changes: 89 additions & 0 deletions benchmarks/benchmark_prefix_cache.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/bin/bash
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -e

NUM_PROMPTS=${NUM_PROMPTS:-100}
MAX_OUTPUT_LENGTH=${MAX_OUTPUT_LENGTH:-50}

# Test combination from lengths and common prefix lengths.
# The length should be shorter than max_input_length minus 1 for bos.
BENCHMARK_PROMPT_LENGTHS=${BENCHMARK_PROMPT_LENGTHS:-8000,16000}
BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS=${BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS:-4000,6000,8000,10000,12000,14000,16000}

benchmark_serving_with_prefix_cache() {
echo "Starting prefix cache benchmark..."
echo "Benchmark serving script: ${BENCHMARK_SERVING_SCRIPT_PATH}"
echo "Prompt lengths to test: ${BENCHMARK_PROMPT_LENGTHS}"
echo "Common prefix lengths to test: ${BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS}"
echo "Number of prompts per run: ${NUM_PROMPTS}"
echo "Max output length per prompt: ${MAX_OUTPUT_LENGTH}"
echo "Base output directory for results: ${OUTPUTS_DIR_BASE}"
echo "Warmup mode: ${WARMUP_MODE}"

# Convert comma-separated strings to arrays for iteration
IFS=',' read -r -a prompt_lengths_arr <<< "$BENCHMARK_PROMPT_LENGTHS"
IFS=',' read -r -a common_prefix_lengths_arr <<< "$BENCHMARK_PROMPT_COMMON_PREFIX_LENGTHS"

for prompt_len in "${prompt_lengths_arr[@]}"; do
for common_len in "${common_prefix_lengths_arr[@]}"; do
if [ "${common_len}" -gt "${prompt_len}" ]; then
echo "Skipping: Common prefix length ${common_len} is greater than prompt length ${prompt_len}."
continue
fi

echo "----------------------------------------------------------------------"
echo "Running benchmark: Prompt Length=${prompt_len}, Common Prefix Length=${common_len}"
echo "----------------------------------------------------------------------"
echo "Warm up twice"
echo "----------------------------------------------------------------------"

# With warmup-mode full, it will run twice
python3 ./benchmark_serving.py \
--tokenizer "prefix_cache_test" \
--dataset "prefix_cache_test" \
--num-prompts 10 \
--max-output-length "${MAX_OUTPUT_LENGTH}" \
--warmup-mode "full" \
--max-input-length "${prompt_len}" \
--prefix-cache-test-common-len "${common_len}"

echo "Warm up done"
echo "----------------------------------------------------------------------"

python3 ./benchmark_serving.py \
--tokenizer "prefix_cache_test" \
--dataset "prefix_cache_test" \
--num-prompts "${NUM_PROMPTS}" \
--max-output-length "${MAX_OUTPUT_LENGTH}" \
--warmup-mode "none" \
--max-input-length "${prompt_len}" \
--prefix-cache-test-common-len "${common_len}"

echo "Benchmark finished for Prompt Length=${prompt_len}, Common Prefix Length=${common_len}"
echo "----------------------------------------------------------------------"
echo
done
done
echo "All benchmark runs completed."
}

main() {
benchmark_serving_with_prefix_cache
echo "Script finished."
exit 0
}

main "$@"
128 changes: 128 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,25 @@ def to_dict(self):
}


class PrefixCacheTestTokenizer:
"""A simple tokenizer for testing prefix caching.

This tokenizer converts each character in a string to its integer ordinal
value during encoding, and converts a list of integer ordinals back to
a string during decoding. It's designed for testing scenarios, particularly
those involving prefix caching, where a basic, predictable tokenizer is
needed.
"""

def encode(self, s: str, **kwargs) -> list[int]:
del kwargs
return [ord(c) for c in s]

def decode(self, token_ids: list[int], **kwargs) -> str:
del kwargs
return "".join([chr(token_id) for token_id in token_ids])


def get_tokenizer(
model_id: str,
tokenizer_name: str,
Expand All @@ -219,6 +238,9 @@ def get_tokenizer(
if tokenizer_name == "test":
print("Using test tokenizer")
return "test"
elif tokenizer_name == "prefix_cache_test":
print("Using prefix_cache_test tokenizer")
return PrefixCacheTestTokenizer()
elif use_hf_tokenizer:
# Please accept agreement to access private/gated models in HF, and
# follow up instructions below to set up access token
Expand Down Expand Up @@ -329,6 +351,98 @@ def load_mmlu_dataset_csv(dataset_path: str) -> tuple[Any, dict[str, str]]:
return combined_dataset, prompts_per_subject


def load_mock_prefix_cache_test_input_requests(
prompt_len: int,
output_len: int,
common_prefix_len: int,
num_samples: int,
) -> list[InputRequest]:
"""Generates a mock dataset for testing prefix cache.

The prefix part of each prompt is a sub-string of a single master string.
The length of this prefix part for each sample is drawn from a normal
distribution with its mean set to `common_prefix_len`, and values are
clipped to the range [0, `prompt_len`].
The tokenizer is assumed to treat each character as a token.

Args:
prompt_len: The total length of each generated prompt string.
output_len: The length of each generated output string.
common_prefix_len: The target mean for the length of the prefix part
of each prompt. These prefixes are derived from a
shared master string.
num_samples: The number of (prompt, output) pairs to generate.

Returns:
A list of InputRequest objects.
"""
if not 0 <= common_prefix_len <= prompt_len:
raise ValueError(
"Target mean common_prefix_len must be between 0 and prompt_len,"
f" inclusive. Got common_prefix_len={common_prefix_len}, "
f"prompt_len={prompt_len}"
)
if any(arg <= 0 for arg in [prompt_len, output_len, num_samples]):
raise ValueError(
"prompt_len, output_len, and num_samples cannot be 0 or negative."
)

input_requests: list[InputRequest] = []

# Generate a master string from which all prefixes will be derived.
# This ensures that prefixes of the same length are identical,
# and shorter prefixes are actual prefixes of longer ones.
master_potential_prefix = "".join(
random.choices("ABCDEFGHIJKLMNOPQRSTUVWXYZ", k=prompt_len)
)

# Generate prefix lengths for each sample from a normal distribution
scale = prompt_len / 3.0 # Standard deviation for the normal distribution

generated_prefix_lengths = np.random.normal(
loc=common_prefix_len, scale=scale, size=num_samples
)
generated_prefix_lengths = (
np.clip(generated_prefix_lengths, 0, prompt_len).round().astype(int)
)

for idx in range(num_samples):
current_actual_prefix_len = generated_prefix_lengths[idx]

actual_prefix_for_sample = master_potential_prefix[
:current_actual_prefix_len
]

current_unique_len = prompt_len - current_actual_prefix_len
# This should not happen if generated_prefix_lengths is clipped correctly
if current_unique_len < 0:
current_unique_len = 0 # Safeguard
current_actual_prefix_len = prompt_len
actual_prefix_for_sample = master_potential_prefix[
:current_actual_prefix_len
]

unique_suffix_str = "".join(
random.choices(
"abcdefghijklmnopqrstuvwxyz0123456789", k=current_unique_len
)
)

prompt_str = actual_prefix_for_sample + unique_suffix_str

output_str = "".join(random.choices("!@#$%^&*()_+", k=output_len))

request = InputRequest(
prompt=prompt_str,
prompt_len=len(prompt_str),
output=output_str,
output_len=len(output_str),
sample_idx=idx,
)
input_requests.append(request)
return input_requests


def gen_mmlu_qa(data: Any, mmlu_method: str = "") -> str:

output = ""
Expand Down Expand Up @@ -893,6 +1007,7 @@ def parse_args() -> argparse.Namespace:
"mmlu",
"math500",
"longcontext",
"prefix_cache_test",
],
help="The dataset name.",
)
Expand Down Expand Up @@ -1086,6 +1201,12 @@ def parse_args() -> argparse.Namespace:
choices=["HELM", "Harness", ""],
help="mmlu method/format to generate shots",
)
parser.add_argument(
"--prefix-cache-test-common-len",
type=int,
default=64,
help="Common prefix length for the prefix cache test dataset.",
)
return parser.parse_args()


Expand All @@ -1112,6 +1233,13 @@ def main(args: argparse.Namespace):
input_requests = mock_requests(
args.total_mock_requests
) # e.g. [("AB", 2, "AB", 3)]
elif args.dataset == "prefix_cache_test":
input_requests = load_mock_prefix_cache_test_input_requests(
prompt_len=args.max_input_length,
output_len=args.max_output_length,
common_prefix_len=args.prefix_cache_test_common_len,
num_samples=args.num_prompts,
)
else:
dataset = []
if args.dataset == "openorca":
Expand Down