Skip to content

Commit cb0295b

Browse files
committed
Refactor eval framework POC
1 parent 093ab89 commit cb0295b

24 files changed

Lines changed: 2823 additions & 1 deletion

File tree

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
absl-py
2+
aiohttp
23
aqtp
34
array-record
45
cloud-accelerator-diagnostics
56
cloud-tpu-diagnostics
67
datasets
78
drjax
9+
evaluate
810
flax
911
gcsfs
1012
google-api-python-client
@@ -20,10 +22,12 @@ jsonlines
2022
math-verify
2123
ml-collections
2224
ml-goodput-measurement
25+
nltk
2326
numpy
2427
omegaconf
2528
optax
2629
orbax-checkpoint
30+
pandas
2731
pathwaysutils
2832
pillow
2933
pre-commit

src/maxtext/eval/README.md

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# MaxText vLLM Eval Framework
2+
3+
A vLLM-native evaluation framework for MaxText models supporting harness-based eval (lm-eval, evalchemy) and custom datasets.
4+
5+
## Quick Start
6+
7+
All runners share a single entry point:
8+
9+
```bash
10+
python -m maxtext.eval.runner.run --runner <eval|lm_eval|evalchemy> [flags]
11+
```
12+
13+
### Custom dataset (MLPerf OpenOrca, ROUGE scoring, Other)
14+
15+
```bash
16+
python -m maxtext.eval.runner.run \
17+
--runner eval \
18+
--config src/maxtext/eval/configs/mlperf.yml \
19+
--checkpoint_path gs://<bucket>/checkpoints/0/items \
20+
--model_name llama3.1-8b \
21+
--hf_path meta-llama/Llama-3.1-8B-Instruct \
22+
--base_output_directory gs://<bucket>/ \
23+
--run_name eval_run \
24+
--max_model_len 8192 \
25+
--hf_token $HF_TOKEN
26+
```
27+
28+
HF safetensors mode (no MaxText checkpoint):
29+
30+
```bash
31+
python -m maxtext.eval.runner.run \
32+
--runner eval \
33+
--config src/maxtext/eval/configs/mlperf.yml \
34+
--hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
35+
--model_name tinyllama \
36+
--base_output_directory gs://<bucket>/ \
37+
--run_name eval_test \
38+
--hf_mode \
39+
--num_samples 20 \
40+
--max_model_len 2048 \
41+
--tensor_parallel_size 1
42+
```
43+
44+
### LM Eval
45+
46+
Requires: `pip install "lm_eval[api]"`
47+
48+
```bash
49+
python -m maxtext.eval.runner.run \
50+
--runner lm_eval \
51+
--checkpoint_path gs://<bucket>/checkpoints/0/items \
52+
--model_name qwen3-30b-a3b \
53+
--hf_path Qwen/Qwen3-30B-A3B \
54+
--tasks gsm8k \
55+
--base_output_directory gs://<bucket>/ \
56+
--run_name my_run \
57+
--max_model_len 8192 \
58+
--tensor_parallel_size 8 \
59+
--expert_parallel_size 8 \
60+
--hf_token $HF_TOKEN
61+
```
62+
63+
### Evalchemy
64+
65+
Requires: `pip install git+https://github.com/mlfoundations/evalchemy.git`
66+
67+
```bash
68+
python -m maxtext.eval.runner.run \
69+
--runner evalchemy \
70+
--checkpoint_path gs://<bucket>/checkpoints/0/items \
71+
--model_name llama3.1-8b \
72+
--hf_path meta-llama/Llama-3.1-8B-Instruct \
73+
--tasks ifeval math500 gpqa_diamond \
74+
--base_output_directory gs://<bucket>/ \
75+
--run_name eval_run \
76+
--max_model_len 8192 \
77+
--tensor_parallel_size 4 \
78+
--hf_token $HF_TOKEN
79+
```
80+
81+
## Common Flags
82+
83+
| Flag | Description |
84+
|---|---|
85+
| `--checkpoint_path` | MaxText Orbax checkpoint path. Enables `MaxTextForCausalLM` mode. |
86+
| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) |
87+
| `--hf_path` | HF model ID or local path |
88+
| `--max_model_len` | vLLM max context length. |
89+
| `--tensor_parallel_size` | Chips per model replica |
90+
| `--expert_parallel_size` | Chips for the expert mesh axis |
91+
| `--data_parallel_size` | Number of model replicas |
92+
| `--hbm_memory_utilization` | Fraction of HBM reserved for KV cache |
93+
| `--hf_token` | HF token (or set `HF_TOKEN` env var) |
94+
| `--hf_mode` | HF safetensors mode, no MaxText checkpoint loading |
95+
| `--server_host` / `--server_port` | vLLM server address (default: localhost:8000) |
96+
| `--max_num_batched_tokens` | vLLM tokens per scheduler step |
97+
| `--max_num_seqs` | vLLM max concurrent sequences |
98+
| `--gcs_results_path` | GCS path to upload results JSON |
99+
| `--log_level` | Logging verbosity (default: INFO) |
100+
101+
Custom `eval` specific:
102+
103+
| Flag | Description |
104+
|---|---|
105+
| `--config` | Benchmark YAML config (required) |
106+
| `--num_samples` | Limit eval samples |
107+
| `--max_tokens` | Max tokens per generation |
108+
| `--temperature` | Sampling temperature (default: 0.0) |
109+
| `--concurrency` | HTTP request concurrency (default: 64) |
110+
111+
Harness `lm_eval` / `evalchemy` specific:
112+
113+
| Flag | Description |
114+
|---|---|
115+
| `--tasks` | Space-separated task names |
116+
| `--num_fewshot` | Few-shot examples per task (default: 0) |
117+
| `--num_samples` | Limit samples per task (default: full dataset) |
118+
119+
## Eval on RL Checkpoints
120+
121+
122+
123+
Example (Qwen3-30B-A3B, v6e-8):
124+
125+
```bash
126+
STEP=244
127+
MODEL=qwen3-30b-a3b
128+
HF_PATH=Qwen/Qwen3-30B-A3B
129+
CHECKPOINT=gs://<bucket>/run/checkpoints/actor/${STEP}/model_params
130+
OUTPUT=gs://<bucket>/eval/
131+
132+
python -m maxtext.eval.runner.run \
133+
--runner lm_eval \
134+
--checkpoint_path ${CHECKPOINT} \
135+
--model_name ${MODEL} \
136+
--hf_path ${HF_PATH} \
137+
--tasks gsm8k \
138+
--base_output_directory ${OUTPUT} \
139+
--run_name rl_${MODEL}_step${STEP} \
140+
--max_model_len 4096 \
141+
--tensor_parallel_size 8 \
142+
--expert_parallel_size 8 \
143+
--num_samples 20 \
144+
--hf_token $HF_TOKEN
145+
```
146+
147+
148+
## Adding a Custom Benchmark
149+
150+
1. Implement `BenchmarkDataset` in `src/maxtext/eval/datasets/`:
151+
152+
```python
153+
from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest
154+
155+
class MyDataset(BenchmarkDataset):
156+
name = "my_benchmark"
157+
158+
def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
159+
# load dataset, build prompts, return SampleRequest list
160+
```
161+
162+
2. Register in `src/maxtext/eval/datasets/registry.py`:
163+
164+
```python
165+
from maxtext.eval.datasets.my_dataset import MyDataset
166+
DATASET_REGISTRY["my_benchmark"] = MyDataset
167+
```
168+
169+
3. Add a scorer in `src/maxtext/eval/scoring/` and register it in `src/maxtext/eval/scoring/registry.py`.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Base evaluation configuration.
2+
3+
temperature: 0.0
4+
concurrency: 64
5+
server_host: "localhost"
6+
server_port: 8000
7+
tensor_parallel_size: 4
8+
num_samples: null
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# MLPerf OpenOrca evaluation config.
2+
3+
benchmark: "mlperf_openorca"
4+
max_tokens: 1024
5+
num_samples: 5000

src/maxtext/eval/datasets/base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Abstract base classes for benchmark datasets."""
16+
17+
from __future__ import annotations
18+
19+
import abc
20+
from typing import NamedTuple
21+
22+
23+
class SampleRequest(NamedTuple):
24+
"""A single inference request with its ground-truth reference.
25+
26+
Attributes:
27+
prompt: The full text prompt to send to the model (after chat templating).
28+
reference: Ground-truth answer/label used by the scorer.
29+
metadata: Optional dict of extra fields forwarded to the scorer
30+
(e.g. {"subject": "college_math"} for per-subject MMLU stats).
31+
"""
32+
33+
prompt: str
34+
reference: str
35+
metadata: dict | None = None
36+
37+
38+
class BenchmarkDataset(abc.ABC):
39+
"""Abstract base class for benchmark datasets."""
40+
name: str
41+
42+
@abc.abstractmethod
43+
def sample_requests(
44+
self,
45+
num_samples: int | None,
46+
tokenizer,
47+
) -> list[SampleRequest]:
48+
"""Load the dataset and return a list of SampleRequests.
49+
50+
Args:
51+
num_samples: If not None, truncate to this number of samples.
52+
tokenizer: A HuggingFace tokenizer used for chat templating. Implementations
53+
that do not require tokenization may ignore this parameter.
54+
55+
Returns:
56+
List of SampleRequest objects ready for inference.
57+
"""
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""MLPerf OpenOrca summarisation dataset."""
16+
17+
from __future__ import annotations
18+
19+
from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest
20+
21+
_SYSTEM_PROMPT = (
22+
"You are a helpful assistant. Summarize the following conversation."
23+
)
24+
25+
26+
class MlperfOpenOrcaDataset(BenchmarkDataset):
27+
"""MLPerf OpenOrca — summarisation benchmark used in MLPerf Inference.
28+
29+
Uses Open-Orca/OpenOrca HuggingFace dataset.
30+
"""
31+
32+
name = "mlperf_openorca"
33+
34+
def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
35+
# pylint: disable=import-outside-toplevel
36+
import datasets as hf_datasets
37+
38+
ds = hf_datasets.load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)
39+
40+
requests = []
41+
for row in ds:
42+
if not row.get("response", "").strip():
43+
continue
44+
45+
system_prompt = row.get("system_prompt", _SYSTEM_PROMPT) or _SYSTEM_PROMPT
46+
question = row["question"]
47+
reference = row["response"]
48+
49+
if tokenizer is not None:
50+
messages = [
51+
{"role": "system", "content": system_prompt},
52+
{"role": "user", "content": question},
53+
]
54+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
55+
else:
56+
prompt = f"{system_prompt}\n\nUser: {question}\nAssistant:"
57+
58+
requests.append(SampleRequest(prompt=prompt, reference=reference))
59+
60+
if num_samples is not None and len(requests) >= num_samples:
61+
break
62+
63+
return requests
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Registry mapping benchmark names to BenchmarkDataset classes.
16+
17+
This can be used to define custom dataset loaders for benchmarks not covered by lm_eval and evalchemy.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
from maxtext.eval.datasets.base import BenchmarkDataset
23+
from maxtext.eval.datasets.mlperf import MlperfOpenOrcaDataset
24+
25+
DATASET_REGISTRY: dict[str, type[BenchmarkDataset]] = {
26+
"mlperf_openorca": MlperfOpenOrcaDataset,
27+
"openorca": MlperfOpenOrcaDataset,
28+
}
29+
30+
31+
def get_dataset(benchmark_name: str) -> BenchmarkDataset:
32+
"""Instantiate and return the mapping for benchmark_name.
33+
34+
Args:
35+
benchmark_name: Benchmark identifier (e.g. "mlperf_openorca").
36+
37+
Returns:
38+
An instance of the corresponding BenchmarkDataset subclass.
39+
40+
Raises:
41+
KeyError: If no dataset is registered for the given name.
42+
"""
43+
key = benchmark_name.lower()
44+
if key not in DATASET_REGISTRY:
45+
raise KeyError(
46+
f"No dataset registered for benchmark '{benchmark_name}'. "
47+
f"Available: {sorted(DATASET_REGISTRY)}. "
48+
f"For MMLU/GPQA/MATH use lm_eval_runner or evalchemy_runner instead."
49+
)
50+
return DATASET_REGISTRY[key]()
51+
52+
53+
def register_dataset(benchmark_name: str, dataset_cls: type[BenchmarkDataset]) -> None:
54+
"""Register a custom dataset class for benchmark_name.
55+
56+
Args:
57+
benchmark_name: Lowercase benchmark identifier.
58+
dataset_cls: A BenchmarkDataset subclass.
59+
"""
60+
DATASET_REGISTRY[benchmark_name.lower()] = dataset_cls

0 commit comments

Comments
 (0)