Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
absl-py
aiohttp
aqtp
array-record
chex
cloud-accelerator-diagnostics
cloud-tpu-diagnostics!=1.1.14
datasets
drjax
evaluate
flax
gcsfs
google-api-python-client
Expand All @@ -21,6 +23,7 @@ jsonlines
math-verify
ml-collections
ml-goodput-measurement
nltk
numpy
omegaconf
optax
Expand Down
169 changes: 169 additions & 0 deletions src/maxtext/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# MaxText vLLM Eval Framework

A vLLM-native evaluation framework for MaxText models supporting harness-based eval (lm-eval, evalchemy) and custom datasets.

## Quick Start

All runners share a single entry point:

```bash
python -m maxtext.eval.runner.run --runner <eval|lm_eval|evalchemy> [flags]
```

### Custom dataset (MLPerf OpenOrca, ROUGE scoring, Other)

```bash
python -m maxtext.eval.runner.run \
--runner eval \
--config src/maxtext/eval/configs/mlperf.yml \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--base_output_directory gs://<bucket>/ \
--run_name eval_run \
--max_model_len 8192 \
--hf_token $HF_TOKEN
```

HF safetensors mode (no MaxText checkpoint):

```bash
python -m maxtext.eval.runner.run \
--runner eval \
--config src/maxtext/eval/configs/mlperf.yml \
--hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--model_name tinyllama \
--base_output_directory gs://<bucket>/ \
--run_name eval_test \
--hf_mode \
--num_samples 20 \
--max_model_len 2048 \
--tensor_parallel_size 1
```

### LM Eval

Requires: `pip install "lm_eval[api]"`

```bash
python -m maxtext.eval.runner.run \
--runner lm_eval \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name qwen3-30b-a3b \
--hf_path Qwen/Qwen3-30B-A3B \
--tasks gsm8k \
--base_output_directory gs://<bucket>/ \
--run_name my_run \
--max_model_len 8192 \
--tensor_parallel_size 8 \
--expert_parallel_size 8 \
--hf_token $HF_TOKEN
```

### Evalchemy

Requires: `pip install git+https://github.com/mlfoundations/evalchemy.git`

```bash
python -m maxtext.eval.runner.run \
--runner evalchemy \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--tasks ifeval math500 gpqa_diamond \
--base_output_directory gs://<bucket>/ \
--run_name eval_run \
--max_model_len 8192 \
--tensor_parallel_size 4 \
--hf_token $HF_TOKEN
```

## Common Flags

| Flag | Description |
|---|---|
| `--checkpoint_path` | MaxText Orbax checkpoint path. Enables `MaxTextForCausalLM` mode. |
| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) |
| `--hf_path` | HF model ID or local path |
| `--max_model_len` | vLLM max context length. |
| `--tensor_parallel_size` | Chips per model replica |
| `--expert_parallel_size` | Chips for the expert mesh axis |
| `--data_parallel_size` | Number of model replicas |
| `--hbm_memory_utilization` | Fraction of HBM reserved for KV cache |
| `--hf_token` | HF token (or set `HF_TOKEN` env var) |
| `--hf_mode` | HF safetensors mode, no MaxText checkpoint loading |
| `--server_host` / `--server_port` | vLLM server address (default: localhost:8000) |
| `--max_num_batched_tokens` | vLLM tokens per scheduler step |
| `--max_num_seqs` | vLLM max concurrent sequences |
| `--gcs_results_path` | GCS path to upload results JSON |
| `--log_level` | Logging verbosity (default: INFO) |

Custom `eval` specific:

| Flag | Description |
|---|---|
| `--config` | Benchmark YAML config (required) |
| `--num_samples` | Limit eval samples |
| `--max_tokens` | Max tokens per generation |
| `--temperature` | Sampling temperature (default: 0.0) |
| `--concurrency` | HTTP request concurrency (default: 64) |

Harness `lm_eval` / `evalchemy` specific:

| Flag | Description |
|---|---|
| `--tasks` | Space-separated task names |
| `--num_fewshot` | Few-shot examples per task (default: 0) |
| `--num_samples` | Limit samples per task (default: full dataset) |

## Eval on RL Checkpoints



Example (Qwen3-30B-A3B, v6e-8):

```bash
STEP=244
MODEL=qwen3-30b-a3b
HF_PATH=Qwen/Qwen3-30B-A3B
CHECKPOINT=gs://<bucket>/run/checkpoints/actor/${STEP}/model_params
OUTPUT=gs://<bucket>/eval/

python -m maxtext.eval.runner.run \
--runner lm_eval \
--checkpoint_path ${CHECKPOINT} \
--model_name ${MODEL} \
--hf_path ${HF_PATH} \
--tasks gsm8k \
--base_output_directory ${OUTPUT} \
--run_name rl_${MODEL}_step${STEP} \
--max_model_len 4096 \
--tensor_parallel_size 8 \
--expert_parallel_size 8 \
--num_samples 20 \
--hf_token $HF_TOKEN
```


## Adding a Custom Benchmark

1. Implement `BenchmarkDataset` in `src/maxtext/eval/datasets/`:

```python
from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest

class MyDataset(BenchmarkDataset):
name = "my_benchmark"

def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
# load dataset, build prompts, return SampleRequest list
```

2. Register in `src/maxtext/eval/datasets/registry.py`:

```python
from maxtext.eval.datasets.my_dataset import MyDataset
DATASET_REGISTRY["my_benchmark"] = MyDataset
```

3. Add a scorer in `src/maxtext/eval/scoring/` and register it in `src/maxtext/eval/scoring/registry.py`.
8 changes: 8 additions & 0 deletions src/maxtext/eval/configs/base_eval.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Base evaluation configuration.

temperature: 0.0
concurrency: 64
server_host: "localhost"
server_port: 8000
tensor_parallel_size: 4
num_samples: null
5 changes: 5 additions & 0 deletions src/maxtext/eval/configs/mlperf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# MLPerf OpenOrca evaluation config.

benchmark: "mlperf_openorca"
max_tokens: 1024
num_samples: 5000
57 changes: 57 additions & 0 deletions src/maxtext/eval/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract base classes for benchmark datasets."""

from __future__ import annotations

import abc
from typing import NamedTuple


class SampleRequest(NamedTuple):
"""A single inference request with its ground-truth reference.

Attributes:
prompt: The full text prompt to send to the model (after chat templating).
reference: Ground-truth answer/label used by the scorer.
metadata: Optional dict of extra fields forwarded to the scorer
(e.g. {"subject": "college_math"} for per-subject MMLU stats).
"""

prompt: str
reference: str
metadata: dict | None = None


class BenchmarkDataset(abc.ABC):
"""Abstract base class for benchmark datasets."""
name: str

@abc.abstractmethod
def sample_requests(
self,
num_samples: int | None,
tokenizer,
) -> list[SampleRequest]:
"""Load the dataset and return a list of SampleRequests.

Args:
num_samples: If not None, truncate to this number of samples.
tokenizer: A HuggingFace tokenizer used for chat templating. Implementations
that do not require tokenization may ignore this parameter.

Returns:
List of SampleRequest objects ready for inference.
"""
63 changes: 63 additions & 0 deletions src/maxtext/eval/datasets/mlperf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MLPerf OpenOrca summarisation dataset."""

from __future__ import annotations

from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest

_SYSTEM_PROMPT = (
"You are a helpful assistant. Summarize the following conversation."
)


class MlperfOpenOrcaDataset(BenchmarkDataset):
"""MLPerf OpenOrca — summarisation benchmark used in MLPerf Inference.

Uses Open-Orca/OpenOrca HuggingFace dataset.
"""

name = "mlperf_openorca"

def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
# pylint: disable=import-outside-toplevel
import datasets as hf_datasets

ds = hf_datasets.load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)

requests = []
for row in ds:
if not row.get("response", "").strip():
continue

system_prompt = row.get("system_prompt", _SYSTEM_PROMPT) or _SYSTEM_PROMPT
question = row["question"]
reference = row["response"]

if tokenizer is not None:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
prompt = f"{system_prompt}\n\nUser: {question}\nAssistant:"

requests.append(SampleRequest(prompt=prompt, reference=reference))

if num_samples is not None and len(requests) >= num_samples:
break

return requests
60 changes: 60 additions & 0 deletions src/maxtext/eval/datasets/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Registry mapping benchmark names to BenchmarkDataset classes.

This can be used to define custom dataset loaders for benchmarks not covered by lm_eval and evalchemy.
"""

from __future__ import annotations

from maxtext.eval.datasets.base import BenchmarkDataset
from maxtext.eval.datasets.mlperf import MlperfOpenOrcaDataset

DATASET_REGISTRY: dict[str, type[BenchmarkDataset]] = {
"mlperf_openorca": MlperfOpenOrcaDataset,
"openorca": MlperfOpenOrcaDataset,
}


def get_dataset(benchmark_name: str) -> BenchmarkDataset:
"""Instantiate and return the mapping for benchmark_name.

Args:
benchmark_name: Benchmark identifier (e.g. "mlperf_openorca").

Returns:
An instance of the corresponding BenchmarkDataset subclass.

Raises:
KeyError: If no dataset is registered for the given name.
"""
key = benchmark_name.lower()
if key not in DATASET_REGISTRY:
raise KeyError(
f"No dataset registered for benchmark '{benchmark_name}'. "
f"Available: {sorted(DATASET_REGISTRY)}. "
f"For MMLU/GPQA/MATH use lm_eval_runner or evalchemy_runner instead."
)
return DATASET_REGISTRY[key]()


def register_dataset(benchmark_name: str, dataset_cls: type[BenchmarkDataset]) -> None:
"""Register a custom dataset class for benchmark_name.

Args:
benchmark_name: Lowercase benchmark identifier.
dataset_cls: A BenchmarkDataset subclass.
"""
DATASET_REGISTRY[benchmark_name.lower()] = dataset_cls
Loading
Loading