diff --git a/src/dependencies/requirements/base_requirements/requirements.txt b/src/dependencies/requirements/base_requirements/requirements.txt index 52b68dd289..bca4242971 100644 --- a/src/dependencies/requirements/base_requirements/requirements.txt +++ b/src/dependencies/requirements/base_requirements/requirements.txt @@ -1,4 +1,5 @@ absl-py +aiohttp aqtp array-record chex @@ -6,6 +7,7 @@ cloud-accelerator-diagnostics cloud-tpu-diagnostics!=1.1.14 datasets drjax +evaluate flax gcsfs google-api-python-client @@ -21,6 +23,7 @@ jsonlines math-verify ml-collections ml-goodput-measurement +nltk numpy omegaconf optax diff --git a/src/maxtext/eval/README.md b/src/maxtext/eval/README.md new file mode 100644 index 0000000000..0388fb0569 --- /dev/null +++ b/src/maxtext/eval/README.md @@ -0,0 +1,169 @@ +# MaxText vLLM Eval Framework + +A vLLM-native evaluation framework for MaxText models supporting harness-based eval (lm-eval, evalchemy) and custom datasets. + +## Quick Start + +All runners share a single entry point: + +```bash +python -m maxtext.eval.runner.run --runner [flags] +``` + +### Custom dataset (MLPerf OpenOrca, ROUGE scoring, Other) + +```bash +python -m maxtext.eval.runner.run \ + --runner eval \ + --config src/maxtext/eval/configs/mlperf.yml \ + --checkpoint_path gs:///checkpoints/0/items \ + --model_name llama3.1-8b \ + --hf_path meta-llama/Llama-3.1-8B-Instruct \ + --base_output_directory gs:/// \ + --run_name eval_run \ + --max_model_len 8192 \ + --hf_token $HF_TOKEN +``` + +HF safetensors mode (no MaxText checkpoint): + +```bash +python -m maxtext.eval.runner.run \ + --runner eval \ + --config src/maxtext/eval/configs/mlperf.yml \ + --hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --model_name tinyllama \ + --base_output_directory gs:/// \ + --run_name eval_test \ + --hf_mode \ + --num_samples 20 \ + --max_model_len 2048 \ + --tensor_parallel_size 1 +``` + +### LM Eval + +Requires: `pip install "lm_eval[api]"` + +```bash +python -m maxtext.eval.runner.run \ + --runner lm_eval \ + --checkpoint_path gs:///checkpoints/0/items \ + --model_name qwen3-30b-a3b \ + --hf_path Qwen/Qwen3-30B-A3B \ + --tasks gsm8k \ + --base_output_directory gs:/// \ + --run_name my_run \ + --max_model_len 8192 \ + --tensor_parallel_size 8 \ + --expert_parallel_size 8 \ + --hf_token $HF_TOKEN +``` + +### Evalchemy + +Requires: `pip install git+https://github.com/mlfoundations/evalchemy.git` + +```bash +python -m maxtext.eval.runner.run \ + --runner evalchemy \ + --checkpoint_path gs:///checkpoints/0/items \ + --model_name llama3.1-8b \ + --hf_path meta-llama/Llama-3.1-8B-Instruct \ + --tasks ifeval math500 gpqa_diamond \ + --base_output_directory gs:/// \ + --run_name eval_run \ + --max_model_len 8192 \ + --tensor_parallel_size 4 \ + --hf_token $HF_TOKEN +``` + +## Common Flags + +| Flag | Description | +|---|---| +| `--checkpoint_path` | MaxText Orbax checkpoint path. Enables `MaxTextForCausalLM` mode. | +| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) | +| `--hf_path` | HF model ID or local path | +| `--max_model_len` | vLLM max context length. | +| `--tensor_parallel_size` | Chips per model replica | +| `--expert_parallel_size` | Chips for the expert mesh axis | +| `--data_parallel_size` | Number of model replicas | +| `--hbm_memory_utilization` | Fraction of HBM reserved for KV cache | +| `--hf_token` | HF token (or set `HF_TOKEN` env var) | +| `--hf_mode` | HF safetensors mode, no MaxText checkpoint loading | +| `--server_host` / `--server_port` | vLLM server address (default: localhost:8000) | +| `--max_num_batched_tokens` | vLLM tokens per scheduler step | +| `--max_num_seqs` | vLLM max concurrent sequences | +| `--gcs_results_path` | GCS path to upload results JSON | +| `--log_level` | Logging verbosity (default: INFO) | + + Custom `eval` specific: + +| Flag | Description | +|---|---| +| `--config` | Benchmark YAML config (required) | +| `--num_samples` | Limit eval samples | +| `--max_tokens` | Max tokens per generation | +| `--temperature` | Sampling temperature (default: 0.0) | +| `--concurrency` | HTTP request concurrency (default: 64) | + +Harness `lm_eval` / `evalchemy` specific: + +| Flag | Description | +|---|---| +| `--tasks` | Space-separated task names | +| `--num_fewshot` | Few-shot examples per task (default: 0) | +| `--num_samples` | Limit samples per task (default: full dataset) | + +## Eval on RL Checkpoints + + + +Example (Qwen3-30B-A3B, v6e-8): + +```bash +STEP=244 +MODEL=qwen3-30b-a3b +HF_PATH=Qwen/Qwen3-30B-A3B +CHECKPOINT=gs:///run/checkpoints/actor/${STEP}/model_params +OUTPUT=gs:///eval/ + +python -m maxtext.eval.runner.run \ + --runner lm_eval \ + --checkpoint_path ${CHECKPOINT} \ + --model_name ${MODEL} \ + --hf_path ${HF_PATH} \ + --tasks gsm8k \ + --base_output_directory ${OUTPUT} \ + --run_name rl_${MODEL}_step${STEP} \ + --max_model_len 4096 \ + --tensor_parallel_size 8 \ + --expert_parallel_size 8 \ + --num_samples 20 \ + --hf_token $HF_TOKEN +``` + + +## Adding a Custom Benchmark + +1. Implement `BenchmarkDataset` in `src/maxtext/eval/datasets/`: + +```python +from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest + +class MyDataset(BenchmarkDataset): + name = "my_benchmark" + + def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]: + # load dataset, build prompts, return SampleRequest list +``` + +2. Register in `src/maxtext/eval/datasets/registry.py`: + +```python +from maxtext.eval.datasets.my_dataset import MyDataset +DATASET_REGISTRY["my_benchmark"] = MyDataset +``` + +3. Add a scorer in `src/maxtext/eval/scoring/` and register it in `src/maxtext/eval/scoring/registry.py`. diff --git a/src/maxtext/eval/configs/base_eval.yml b/src/maxtext/eval/configs/base_eval.yml new file mode 100644 index 0000000000..10e26e4128 --- /dev/null +++ b/src/maxtext/eval/configs/base_eval.yml @@ -0,0 +1,8 @@ +# Base evaluation configuration. + +temperature: 0.0 +concurrency: 64 +server_host: "localhost" +server_port: 8000 +tensor_parallel_size: 4 +num_samples: null diff --git a/src/maxtext/eval/configs/mlperf.yml b/src/maxtext/eval/configs/mlperf.yml new file mode 100644 index 0000000000..863f21c6c2 --- /dev/null +++ b/src/maxtext/eval/configs/mlperf.yml @@ -0,0 +1,5 @@ +# MLPerf OpenOrca evaluation config. + +benchmark: "mlperf_openorca" +max_tokens: 1024 +num_samples: 5000 diff --git a/src/maxtext/eval/datasets/base.py b/src/maxtext/eval/datasets/base.py new file mode 100644 index 0000000000..28e0d51b61 --- /dev/null +++ b/src/maxtext/eval/datasets/base.py @@ -0,0 +1,57 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract base classes for benchmark datasets.""" + +from __future__ import annotations + +import abc +from typing import NamedTuple + + +class SampleRequest(NamedTuple): + """A single inference request with its ground-truth reference. + + Attributes: + prompt: The full text prompt to send to the model (after chat templating). + reference: Ground-truth answer/label used by the scorer. + metadata: Optional dict of extra fields forwarded to the scorer + (e.g. {"subject": "college_math"} for per-subject MMLU stats). + """ + + prompt: str + reference: str + metadata: dict | None = None + + +class BenchmarkDataset(abc.ABC): + """Abstract base class for benchmark datasets.""" + name: str + + @abc.abstractmethod + def sample_requests( + self, + num_samples: int | None, + tokenizer, + ) -> list[SampleRequest]: + """Load the dataset and return a list of SampleRequests. + + Args: + num_samples: If not None, truncate to this number of samples. + tokenizer: A HuggingFace tokenizer used for chat templating. Implementations + that do not require tokenization may ignore this parameter. + + Returns: + List of SampleRequest objects ready for inference. + """ diff --git a/src/maxtext/eval/datasets/mlperf.py b/src/maxtext/eval/datasets/mlperf.py new file mode 100644 index 0000000000..88e49c15c8 --- /dev/null +++ b/src/maxtext/eval/datasets/mlperf.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLPerf OpenOrca summarisation dataset.""" + +from __future__ import annotations + +from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest + +_SYSTEM_PROMPT = ( + "You are a helpful assistant. Summarize the following conversation." +) + + +class MlperfOpenOrcaDataset(BenchmarkDataset): + """MLPerf OpenOrca — summarisation benchmark used in MLPerf Inference. + + Uses Open-Orca/OpenOrca HuggingFace dataset. + """ + + name = "mlperf_openorca" + + def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]: + # pylint: disable=import-outside-toplevel + import datasets as hf_datasets + + ds = hf_datasets.load_dataset("Open-Orca/OpenOrca", split="train", streaming=True) + + requests = [] + for row in ds: + if not row.get("response", "").strip(): + continue + + system_prompt = row.get("system_prompt", _SYSTEM_PROMPT) or _SYSTEM_PROMPT + question = row["question"] + reference = row["response"] + + if tokenizer is not None: + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + prompt = f"{system_prompt}\n\nUser: {question}\nAssistant:" + + requests.append(SampleRequest(prompt=prompt, reference=reference)) + + if num_samples is not None and len(requests) >= num_samples: + break + + return requests diff --git a/src/maxtext/eval/datasets/registry.py b/src/maxtext/eval/datasets/registry.py new file mode 100644 index 0000000000..c526957475 --- /dev/null +++ b/src/maxtext/eval/datasets/registry.py @@ -0,0 +1,60 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry mapping benchmark names to BenchmarkDataset classes. + +This can be used to define custom dataset loaders for benchmarks not covered by lm_eval and evalchemy. +""" + +from __future__ import annotations + +from maxtext.eval.datasets.base import BenchmarkDataset +from maxtext.eval.datasets.mlperf import MlperfOpenOrcaDataset + +DATASET_REGISTRY: dict[str, type[BenchmarkDataset]] = { + "mlperf_openorca": MlperfOpenOrcaDataset, + "openorca": MlperfOpenOrcaDataset, +} + + +def get_dataset(benchmark_name: str) -> BenchmarkDataset: + """Instantiate and return the mapping for benchmark_name. + + Args: + benchmark_name: Benchmark identifier (e.g. "mlperf_openorca"). + + Returns: + An instance of the corresponding BenchmarkDataset subclass. + + Raises: + KeyError: If no dataset is registered for the given name. + """ + key = benchmark_name.lower() + if key not in DATASET_REGISTRY: + raise KeyError( + f"No dataset registered for benchmark '{benchmark_name}'. " + f"Available: {sorted(DATASET_REGISTRY)}. " + f"For MMLU/GPQA/MATH use lm_eval_runner or evalchemy_runner instead." + ) + return DATASET_REGISTRY[key]() + + +def register_dataset(benchmark_name: str, dataset_cls: type[BenchmarkDataset]) -> None: + """Register a custom dataset class for benchmark_name. + + Args: + benchmark_name: Lowercase benchmark identifier. + dataset_cls: A BenchmarkDataset subclass. + """ + DATASET_REGISTRY[benchmark_name.lower()] = dataset_cls diff --git a/src/maxtext/eval/reporting/gcs_reporter.py b/src/maxtext/eval/reporting/gcs_reporter.py new file mode 100644 index 0000000000..8f73196aae --- /dev/null +++ b/src/maxtext/eval/reporting/gcs_reporter.py @@ -0,0 +1,39 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Upload eval result files to Google Cloud Storage.""" + +from __future__ import annotations + +import logging +import os + +from maxtext.utils.gcs_utils import upload_blob + +logger = logging.getLogger(__name__) + +def upload_results(local_path: str, gcs_path: str) -> None: + """Upload local_path to gcs_path. + + Args: + local_path: Absolute local path to the file to upload. + gcs_path: Destination GCS path (e.g. gs:///eval/). + """ + if gcs_path.endswith("/"): + gcs_dest = gcs_path + os.path.basename(local_path) + else: + gcs_dest = gcs_path + + upload_blob(gcs_dest, local_path) + logger.info("Uploaded %s to %s", local_path, gcs_dest) diff --git a/src/maxtext/eval/reporting/json_reporter.py b/src/maxtext/eval/reporting/json_reporter.py new file mode 100644 index 0000000000..6cee16f7dd --- /dev/null +++ b/src/maxtext/eval/reporting/json_reporter.py @@ -0,0 +1,83 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Write eval results to a structured JSON file.""" + +from __future__ import annotations + +import datetime +import json +import logging +import os +import tempfile + +logger = logging.getLogger(__name__) + + +def write_results( + benchmark: str, + model_name: str, + scores: dict, + generation_stats: dict, + config: dict, + results_path: str = "./eval_results", +) -> dict: + """Write eval results to a JSON file under results_path. + + The output filename is: + {results_path}/{benchmark}_{model_name}_{timestamp}.json + + Args: + benchmark: Benchmark name (e.g. "mmlu"). + model_name: Model name (e.g. "llama3.1-8b"). + scores: Dict of metric name to value from the scorer. + generation_stats: Dict of generation statistics (timing, token counts, etc.). + config: The full merged configuration dict used for this run. + results_path: Directory to write the JSON file into. + + Returns: + Dict with keys: + - results: The full results dict written to disk. + - local_path: Absolute path of the written file. + """ + timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ") + safe_model = model_name.replace("/", "_").replace(":", "_") + filename = f"{benchmark}_{safe_model}_{timestamp}.json" + + results = { + "benchmark": benchmark, + "model_name": model_name, + "timestamp_utc": timestamp, + "scores": scores, + "generation_stats": generation_stats, + "config": config, + } + + if results_path.startswith("gs://"): + from maxtext.utils.gcs_utils import upload_blob # pylint: disable=import-outside-toplevel + tmp_dir = tempfile.mkdtemp(prefix="eval_results_") + local_path = os.path.join(tmp_dir, filename) + with open(local_path, "w") as f: + json.dump(results, f, indent=2) + gcs_dest = f"{results_path.rstrip('/')}/{filename}" + upload_blob(gcs_dest, local_path) + logger.info("Results written to %s", gcs_dest) + else: + os.makedirs(results_path, exist_ok=True) + local_path = os.path.join(results_path, filename) + with open(local_path, "w") as f: + json.dump(results, f, indent=2) + logger.info("Results written to %s", local_path) + + return {"results": results, "local_path": os.path.abspath(local_path)} diff --git a/src/maxtext/eval/runner/async_client.py b/src/maxtext/eval/runner/async_client.py new file mode 100644 index 0000000000..7a50ff4c72 --- /dev/null +++ b/src/maxtext/eval/runner/async_client.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async HTTP client for the /v1/completions endpoint. + +Fans out requests concurrently with a semaphore-bounded asyncio pool and +returns results in prompt order. Uses aiohttp for non-blocking I/O. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + +_DEFAULT_CONCURRENCY = 64 +_DEFAULT_MAX_TOKENS = 1024 +_DEFAULT_TEMPERATURE = 0.0 +_COMPLETIONS_PATH = "/v1/completions" +_REQUEST_TIMEOUT_S = 600 + + +@dataclass +class GenerationResult: + """Result of a single /v1/completions request. + + Attributes: + text: Generated text (empty string on error). + prompt_tokens: Tokens consumed by the prompt. + completion_tokens: Tokens in the generated completion. + error: Non-empty error message if the request failed. + latency_s: End-to-end wall-clock latency in seconds. + """ + + text: str = "" + prompt_tokens: int = 0 + completion_tokens: int = 0 + error: str = "" + latency_s: float = field(default=0.0) + + +async def generate_batch_async( + prompts: list[str], + base_url: str, + model: str, + max_tokens: int = _DEFAULT_MAX_TOKENS, + temperature: float = _DEFAULT_TEMPERATURE, + concurrency: int = _DEFAULT_CONCURRENCY, + request_timeout: int = _REQUEST_TIMEOUT_S, +) -> list[GenerationResult]: + """Send all prompts concurrently and return results in prompt order. + + Args: + prompts: Formatted prompt strings. + base_url: Base URL of the server. + model: Model name to send in each request. + max_tokens: Maximum tokens to generate per response. + temperature: Sampling temperature. + concurrency: Maximum number of in-flight requests at once. + request_timeout: Per-request wall-clock timeout in seconds. + + Returns: + List of GenerationResult in the same order as prompts. + """ + import aiohttp # pylint: disable=import-outside-toplevel + + api_url = f"{base_url}{_COMPLETIONS_PATH}" + semaphore = asyncio.Semaphore(concurrency) + timeout = aiohttp.ClientTimeout(total=request_timeout) + + async def _generate_one(session: aiohttp.ClientSession, prompt: str) -> GenerationResult: + payload = { + "model": model, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + } + async with semaphore: + t0 = time.monotonic() + try: + async with session.post(api_url, json=payload) as resp: + if resp.status != 200: + body = await resp.text() + return GenerationResult(error=f"HTTP {resp.status}: {body[:200]}") + data = await resp.json() + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + return GenerationResult(error=str(exc)) + latency = time.monotonic() - t0 + + choice = data["choices"][0] + usage = data.get("usage", {}) + return GenerationResult( + text=choice.get("text", ""), + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + latency_s=latency, + ) + + async with aiohttp.ClientSession(timeout=timeout) as session: + return list(await asyncio.gather(*[_generate_one(session, p) for p in prompts])) + + +def generate_batch( + prompts: list[str], + base_url: str, + model: str, + max_tokens: int = _DEFAULT_MAX_TOKENS, + temperature: float = _DEFAULT_TEMPERATURE, + concurrency: int = _DEFAULT_CONCURRENCY, + request_timeout: int = _REQUEST_TIMEOUT_S, +) -> list[GenerationResult]: + """Synchronous wrapper around generate_batch_async.""" + return asyncio.run( + generate_batch_async( + prompts=prompts, + base_url=base_url, + model=model, + max_tokens=max_tokens, + temperature=temperature, + concurrency=concurrency, + request_timeout=request_timeout, + ) + ) diff --git a/src/maxtext/eval/runner/common.py b/src/maxtext/eval/runner/common.py new file mode 100644 index 0000000000..41c8dd0897 --- /dev/null +++ b/src/maxtext/eval/runner/common.py @@ -0,0 +1,150 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for MaxText eval runners.""" + +from __future__ import annotations + +import argparse +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from maxtext.eval.runner.server_manager import VllmServerManager + + +def resolve_token(cfg: dict, hf_token: str | None) -> str | None: + """Return HF token from explicit arg or HF_TOKEN env var.""" + return hf_token or os.environ.get("HF_TOKEN") or None + + +def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager": + """Build a VllmServerManager from a merged config dict. + + Handles token forwarding, MaxText adapter vs HF mode selection, and the + enable_expert_parallel to additional_config sharding injection. + + Args: + cfg: Merged configuration dict. Required key: max_model_len. Common + optional keys: tensor_parallel_size, server_host, server_port, + max_num_batched_tokens, max_num_seqs, hf_mode, enable_expert_parallel. + token: HuggingFace token (or None). + + Returns: + A VllmServerManager instance ready for use as a context manager (unstarted). + """ + from maxtext.eval.runner.server_manager import VllmServerManager # pylint: disable=import-outside-toplevel + + hf_path = cfg["hf_path"] + model_name = cfg["model_name"] + checkpoint_path = cfg.get("checkpoint_path") + hf_mode = cfg.get("hf_mode", False) + use_maxtext_adapter = bool(checkpoint_path) and not hf_mode + + tensor_parallel_size = int(cfg.get("tensor_parallel_size", 4)) + max_model_len = int(cfg["max_model_len"]) + server_host = cfg.get("server_host", "localhost") + server_port = int(cfg.get("server_port", 8000)) + + max_num_batched_tokens = cfg.get("max_num_batched_tokens") + if max_num_batched_tokens is not None: + max_num_batched_tokens = int(max_num_batched_tokens) + max_num_seqs = cfg.get("max_num_seqs") + if max_num_seqs is not None: + max_num_seqs = int(max_num_seqs) + + expert_parallel_size = int(cfg.get("expert_parallel_size") or 1) + data_parallel_size = int(cfg.get("data_parallel_size") or 1) + hbm_memory_utilization = float(cfg.get("hbm_memory_utilization") or 0.3) + + server_env = {"HF_TOKEN": token} if token else None + + return VllmServerManager( + model_path=hf_path, + checkpoint_path=checkpoint_path if use_maxtext_adapter else None, + maxtext_model_name=model_name if use_maxtext_adapter else None, + host=server_host, + port=server_port, + tensor_parallel_size=tensor_parallel_size, + expert_parallel_size=expert_parallel_size, + data_parallel_size=data_parallel_size, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + hbm_memory_utilization=hbm_memory_utilization, + env=server_env, + ) + + +def maybe_upload_to_gcs(output: dict, gcs_results_path: str | None) -> None: + """Upload the results JSON to GCS if gcs_results_path is provided.""" + if gcs_results_path: + from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel + upload_results(output["local_path"], gcs_results_path) + + +def add_server_args(parser: argparse.ArgumentParser) -> None: + """Add the server/model CLI args shared by all eval runner parsers.""" + parser.add_argument("--checkpoint_path", help="MaxText orbax checkpoint path (/0/items).") + parser.add_argument("--model_name", required=True, help="MaxText model name (e.g. llama3.1-8b).") + parser.add_argument("--hf_path", required=True, help="HF model ID or local tokenizer dir.") + parser.add_argument( + "--base_output_directory", + required=True, + help="Base output directory (local path or gs:///).", + ) + parser.add_argument("--run_name", required=True, help="Run name/identifier.") + parser.add_argument("--max_model_len", type=int, required=True, help="vLLM max context length.") + parser.add_argument( + "--tensor_parallel_size", type=int, default=4, help="vLLM tensor parallelism." + ) + parser.add_argument("--server_host", default="localhost", help="vLLM server bind host.") + parser.add_argument("--server_port", type=int, default=8000, help="vLLM server port.") + parser.add_argument( + "--max_num_batched_tokens", type=int, help="vLLM tokens per scheduler step." + ) + parser.add_argument("--max_num_seqs", type=int, help="vLLM max concurrent sequences.") + parser.add_argument("--hf_mode", action="store_true", help="HF safetensors mode.") + parser.add_argument( + "--expert_parallel_size", + type=int, + default=0, + help=( + "Chips allocated to the expert mesh axis (EP). " + ), + ) + parser.add_argument( + "--data_parallel_size", + type=int, + default=1, + help="Number of model replicas (DP).", + ) + parser.add_argument( + "--hbm_memory_utilization", + type=float, + default=0.3, + help=( + "Fraction of HBM reserved for KV cache." + ), + ) + parser.add_argument("--hf_token", help="HuggingFace token for gated models.") + parser.add_argument( + "--gcs_results_path", help="Optional secondary GCS path to upload the results JSON." + ) + parser.add_argument( + "--log_level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level.", + ) diff --git a/src/maxtext/eval/runner/eval_runner.py b/src/maxtext/eval/runner/eval_runner.py new file mode 100644 index 0000000000..ffd3f18ff9 --- /dev/null +++ b/src/maxtext/eval/runner/eval_runner.py @@ -0,0 +1,277 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom dataset eval runner (MLPerf OpenOrca, ROUGE scoring). + +Unified entry point: + + python -m maxtext.eval.runner.run --runner eval ... +""" + +from __future__ import annotations + +import argparse +import logging +import time + +import yaml + +logger = logging.getLogger(__name__) + + +def _load_config(config_path: str) -> dict: + with open(config_path) as f: + return yaml.safe_load(f) or {} + + +def _merge_config(base: dict, overrides: dict) -> dict: + merged = dict(base) + for k, v in overrides.items(): + if v is not None: + merged[k] = v + return merged + + +def _derive_from_maxtext_config(maxtext_config_path: str) -> dict: + raw = _load_config(maxtext_config_path) + prefill_len = raw.get("max_prefill_predict_length") + target_len = raw.get("max_target_length") + + derived: dict = {} + if target_len is not None: + derived["max_model_len"] = int(target_len) + logger.info( + "Derived max_model_len=%d from MaxText config max_target_length.", + derived["max_model_len"], + ) + if prefill_len is not None and target_len is not None: + derived["max_tokens_default"] = int(target_len) - int(prefill_len) + logger.info( + "Derived max_tokens_default=%d from max_target_length - max_prefill_predict_length.", + derived["max_tokens_default"], + ) + for key in ("base_output_directory", "run_name"): + if raw.get(key): + derived[key] = raw[key] + return derived + + +def _build_results_path(cfg: dict) -> str: + base_output_directory = cfg.get("base_output_directory", "").rstrip("/") + run_name = cfg.get("run_name", "") + if not base_output_directory or not run_name: + raise ValueError( + "Cannot build eval results_path." + ) + return f"{base_output_directory}/{run_name}/eval_results" + + + +def run_eval(cfg: dict, hf_token: str | None = None) -> dict: + """Execute all the evaluation steps. + + Args: + cfg: Configuration dict. + + Returns: + Results dict as written to the JSON output file. + """ + # pylint: disable=import-outside-toplevel + from transformers import AutoTokenizer + + from maxtext.eval.datasets.registry import get_dataset + from maxtext.eval.reporting.json_reporter import write_results + from maxtext.eval.runner.async_client import generate_batch + from maxtext.eval.runner.common import build_server_manager, maybe_upload_to_gcs, resolve_token + from maxtext.eval.runner.warmup import warmup_server + from maxtext.eval.scoring.registry import get_scorer + + benchmark = cfg["benchmark"] + model_name = cfg["model_name"] + hf_path = cfg["hf_path"] + results_path = cfg["results_path"] + num_samples = cfg.get("num_samples") + max_tokens = int(cfg.get("max_tokens", 1024)) + temperature = float(cfg.get("temperature", 0.0)) + concurrency = int(cfg.get("concurrency", 64)) + if "max_model_len" not in cfg: + raise ValueError("Error: max_model_len is required.") + gcs_results_path = cfg.get("gcs_results_path") + token = resolve_token(cfg, hf_token) + + # Load tokenizer for prompt formatting. + logger.info("Loading tokenizer from %s.", hf_path) + tokenizer = AutoTokenizer.from_pretrained(hf_path, token=token) + + # Prepare dataset. + logger.info("Loading benchmark dataset: %s", benchmark) + dataset = get_dataset(benchmark) + requests = dataset.sample_requests(num_samples=num_samples, tokenizer=tokenizer) + logger.info("Loaded %d samples.", len(requests)) + + prompts = [r.prompt for r in requests] + references = [r.reference for r in requests] + + # Start vLLM server. + with build_server_manager(cfg, token) as server: + import jax as _jax # pylint: disable=import-outside-toplevel + from jax.experimental import multihost_utils as _multihost_utils # pylint: disable=import-outside-toplevel + is_rank0 = _jax.process_index() == 0 + + if is_rank0: + base_url = server.base_url + + # Warmup server. + warmup_server(base_url=base_url, model=model_name, sample_requests=requests) + + # Generate responses. + logger.info("Generating responses for %d prompts.", len(prompts)) + t0 = time.time() + results = generate_batch( + prompts=prompts, + base_url=base_url, + model=model_name, + max_tokens=max_tokens, + temperature=temperature, + concurrency=concurrency, + ) + elapsed = time.time() - t0 + logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed) + + # All ranks block here until rank-0 finishes generation. Non-rank-0 hosts + # keep their in-process LLM alive so rank-0's llm.generate() calls can + # complete their tensor-parallel collectives across all hosts. + _multihost_utils.sync_global_devices("eval_runner_complete") + + # All ranks exit the context manager together above (LLM stopped on all). + # Only rank-0 has results/elapsed defined, non-rank-0 return early. + if not is_rank0: + return {} + + # Score. + responses = [r.text for r in results] + errors = [r for r in results if r.error] + if errors: + logger.warning("%d generation errors (out of %d).", len(errors), len(results)) + + scorer = get_scorer(benchmark) + scores = scorer(responses, references) + logger.info("Scores: %s", scores) + + # Write results + generation_stats = { + "total_samples": len(prompts), + "num_errors": len(errors), + "elapsed_s": round(elapsed, 2), + "samples_per_second": round(len(prompts) / elapsed, 2), + "total_prompt_tokens": sum(r.prompt_tokens for r in results), + "total_completion_tokens": sum(r.completion_tokens for r in results), + } + output = write_results( + benchmark=benchmark, + model_name=model_name, + scores=scores, + generation_stats=generation_stats, + config=cfg, + results_path=results_path, + ) + + maybe_upload_to_gcs(output, gcs_results_path) + return output + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="MaxText model evaluation runner.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--config", required=True, help="Path to eval config file.") + parser.add_argument("--base_config", help="Path to maxtext config.") + parser.add_argument("--benchmark", help="Benchmark name.") + parser.add_argument("--checkpoint_path", help="MaxText checkpoint path.") + parser.add_argument("--model_name", help="MaxText model name.") + parser.add_argument("--hf_path", help="HF model ID or tokenizer dir.") + parser.add_argument("--base_output_directory", help="Base output directory.") + parser.add_argument("--run_name", help="Run name/identifier.") + parser.add_argument("--gcs_results_path", help="Optional GCS path to upload results.") + parser.add_argument("--num_samples", type=int, help="Number of eval samples.") + parser.add_argument("--max_tokens", type=int, help="Max tokens per generation.") + parser.add_argument("--temperature", type=float, help="Sampling temperature.") + parser.add_argument("--concurrency", type=int, help="HTTP request concurrency.") + parser.add_argument("--tensor_parallel_size", type=int, help="vLLM tensor parallelism.") + parser.add_argument("--max_model_len", type=int, help="vLLM max context length.") + parser.add_argument("--server_host", help="vLLM server host.") + parser.add_argument("--server_port", type=int, help="vLLM server port.") + parser.add_argument("--hf_mode", action="store_true", help="Use HF safetensors mode.") + parser.add_argument( + "--enable_expert_parallel", + action="store_true", + help=( + "Enable expert parallelism in vLLM. Required for MoE models such as " + "qwen3-30b-a3b, qwen3-235b-a22b, deepseek-v3, etc. Without this flag " + "tpu-inference omits the 'expert' mesh axis and MaxText's MoE sharding " + "raises KeyError." + ), + ) + parser.add_argument("--hf_token", help="HuggingFace token for gated models.") + parser.add_argument( + "--log_level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level.", + ) + return parser + + +def main() -> None: + parser = _build_arg_parser() + args = parser.parse_args() + + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + base_cfg = _load_config(args.config) + + if args.base_config: + maxtext_derived = _derive_from_maxtext_config(args.base_config) + for k, v in maxtext_derived.items(): + if k not in base_cfg: + base_cfg[k] = v + + cli_overrides = { + k: v for k, v in vars(args).items() + if k not in ("config", "base_config", "log_level", "hf_token") + } + cfg = _merge_config(base_cfg, cli_overrides) + + if "max_tokens" not in cfg and "max_tokens_default" in cfg: + cfg["max_tokens"] = cfg["max_tokens_default"] + logger.info("Using max_tokens=%d derived from MaxText config.", cfg["max_tokens"]) + + if "results_path" not in cfg: + cfg["results_path"] = _build_results_path(cfg) + logger.info("Results will be written to %s", cfg["results_path"]) + + required = ["benchmark", "model_name", "hf_path"] + missing = [f for f in required if not cfg.get(f)] + if missing: + parser.error(f"Missing required config field(s): {missing}") + + run_eval(cfg, hf_token=args.hf_token) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/eval/runner/harness_runner.py b/src/maxtext/eval/runner/harness_runner.py new file mode 100644 index 0000000000..ec1440f3b8 --- /dev/null +++ b/src/maxtext/eval/runner/harness_runner.py @@ -0,0 +1,249 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""lm-evaluation-harness and evalchemy runner for MaxText eval. + +Supports two backends selected via --backend: + + lm_eval (default) + Uses the /v1/completions endpoint (local-completions lm-eval backend). + + evalchemy + Uses the /v1/chat/completions endpoint (local-chat-completions backend) + and imports evalchemy to register its extended task registry. + +Unified entry point: + python -m maxtext.eval.runner.run --runner lm_eval ... + python -m maxtext.eval.runner.run --runner evalchemy ... +""" + +from __future__ import annotations + +import argparse +import logging + +from maxtext.eval.runner.common import ( + add_server_args, + build_server_manager, + maybe_upload_to_gcs, + resolve_token, +) + +logger = logging.getLogger(__name__) + + +def _map_results(raw_results: dict, tasks: list[str]) -> dict: + """Extract per-task accuracy metrics from lm-eval / evalchemy output.""" + scores: dict[str, float] = {} + results_section = raw_results.get("results", {}) + for task in tasks: + task_r = results_section.get(task, {}) + + acc = None + for key in ( + "acc,none", + "exact_match,strict-match", + "exact_match,flexible-extract", + "exact_match,none", + "acc", + "score", + ): + if task_r.get(key) is not None: + acc = task_r[key] + break + + acc_norm = None + for key in ("acc_norm,none", "acc_norm"): + if task_r.get(key) is not None: + acc_norm = task_r[key] + break + + if acc is not None: + scores[f"{task}_accuracy"] = round(float(acc) * 100, 2) + if acc_norm is not None: + scores[f"{task}_accuracy_norm"] = round(float(acc_norm) * 100, 2) + + if acc is None and task_r: + logger.warning( + "No known accuracy keys found for task '%s'. Available: %s", + task, + list(task_r.keys()), + ) + + return scores + + +def run_harness(cfg: dict, hf_token: str | None = None) -> dict: + """Run lm-eval or evalchemy benchmarks against a MaxText vLLM server. + + Args: + cfg: Configuration dict. Required keys: model_name, hf_path, tasks, + max_model_len, results_path. Optional: backend (default "lm_eval"), + num_fewshot, num_samples, gcs_results_path, and all server keys handled + by build_server_manager. + hf_token: HuggingFace token for gated tokenizers. + + Returns: + Dict with keys: results, scores, JSON file path. + + Raises: + ImportError: If lm_eval (or evalchemy for that backend) is not installed. + """ + # pylint: disable=import-outside-toplevel + try: + import lm_eval as lm_eval_lib + except ImportError as exc: + raise ImportError("Install lm-eval.") from exc + + from maxtext.eval.reporting.json_reporter import write_results + from maxtext.eval.runner.warmup import warmup_server + + backend = cfg.get("backend", "lm_eval") + if backend == "evalchemy": + try: + import evalchemy as _evalchemy # noqa: F401 registers custom tasks with lm_eval + except ImportError as exc: + raise ImportError( + "Install evalchemy." + ) from exc + + model_name = cfg["model_name"] + hf_path = cfg["hf_path"] + tasks = cfg["tasks"] + results_path = cfg["results_path"] + num_fewshot = cfg.get("num_fewshot", 0) + num_samples = cfg.get("num_samples") + gcs_results_path = cfg.get("gcs_results_path") + token = resolve_token(cfg, hf_token) + + lm_model_type = "local-chat-completions" if backend == "evalchemy" else "local-completions" + + with build_server_manager(cfg, token) as server: + import jax as _jax + from jax.experimental import multihost_utils as _multihost_utils + is_rank0 = _jax.process_index() == 0 + + if is_rank0: + warmup_server(base_url=server.base_url, model=model_name) + + completions_path = ( + "/v1/chat/completions" if backend == "evalchemy" else "/v1/completions" + ) + model_args_parts = [ + f"model={model_name}", + f"base_url={server.base_url}{completions_path}", + "tokenizer_backend=huggingface", + f"tokenizer={hf_path}", + ] + if token: + model_args_parts.append(f"token={token}") + model_args = ",".join(model_args_parts) + + logger.info( + "Running %s tasks %s via %s at %s", + backend, + tasks, + lm_model_type, + server.base_url, + ) + raw_results = lm_eval_lib.simple_evaluate( + model=lm_model_type, + model_args=model_args, + tasks=tasks, + num_fewshot=num_fewshot, + limit=num_samples, + log_samples=False, + ) + + # All ranks block here until rank-0 finishes evaluation. Non-rank-0 hosts + # keep their in-process LLM alive so rank-0's llm.generate() calls can + # complete their tensor-parallel collectives across all hosts. + _multihost_utils.sync_global_devices(f"harness_{backend}_complete") + + # All ranks exit the context manager together (LLM stopped on all). + # Only rank-0 has raw_results defined; non-rank-0 return early. + if not is_rank0: + return {} + + scores = _map_results(raw_results, tasks) + logger.info("%s scores: %s", backend, scores) + + output = write_results( + benchmark="+".join(tasks), + model_name=model_name, + scores=scores, + generation_stats={ + f"{backend}_config": raw_results.get("config", {}), + f"{backend}_results": raw_results.get("results", {}), + }, + config=cfg, + results_path=results_path, + ) + maybe_upload_to_gcs(output, gcs_results_path) + return output + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="MaxText lm-eval / evalchemy runner.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + add_server_args(parser) + parser.add_argument( + "--backend", + choices=["lm_eval", "evalchemy"], + default="lm_eval", + help=( + "Evaluation backend. 'lm_eval' uses /v1/completions (local-completions); " + "'evalchemy' uses /v1/chat/completions (local-chat-completions) and " + "registers evalchemy's extended task library." + ), + ) + parser.add_argument( + "--tasks", + nargs="+", + default=["mmlu"], + help=( + "lm-eval task names passed directly to simple_evaluate. " + "Any task registered in lm-eval or evalchemy is accepted (e.g. gsm8k, mmlu, gpqa_diamond, ifeval, math_500)." + ), + ) + parser.add_argument( + "--num_fewshot", type=int, default=0, help="Few-shot examples per task." + ) + parser.add_argument( + "--num_samples", type=int, help="Limit samples per task (None = full dataset)." + ) + return parser + + +def main() -> None: + import logging as _logging # pylint: disable=import-outside-toplevel + parser = _build_arg_parser() + args = parser.parse_args() + + _logging.basicConfig( + level=getattr(_logging, args.log_level), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + results_path = f"{args.base_output_directory.rstrip('/')}/{args.run_name}/eval_results" + cfg = {k: v for k, v in vars(args).items() if k not in ("log_level", "hf_token")} + cfg["results_path"] = results_path + + run_harness(cfg, hf_token=args.hf_token) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/eval/runner/run.py b/src/maxtext/eval/runner/run.py new file mode 100644 index 0000000000..a23e1abfa3 --- /dev/null +++ b/src/maxtext/eval/runner/run.py @@ -0,0 +1,104 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified CLI entry point for MaxText model evaluation. + +Dispatches to the appropriate runner based on --runner: + + eval Custom dataset runner. Requires --config. + lm_eval lm-evaluation-harness runner. + evalchemy evalchemy runner. + +Both lm_eval and evalchemy dispatch to harness_runner.py. + +Usage:: + + # lm-eval + python -m maxtext.eval.runner.run \ + --runner lm_eval \ + --checkpoint_path gs:///checkpoints/0/items \ + --model_name llama3.1-8b \ + --hf_path meta-llama/Llama-3.1-8B-Instruct \ + --tasks mmlu gpqa \ + --base_output_directory gs:/// \ + --run_name my_run \ + --max_model_len 8192 \ + --tensor_parallel_size 4 \ + --hf_token $HF_TOKEN + + # evalchemy, MoE + python -m maxtext.eval.runner.run \ + --runner evalchemy \ + --checkpoint_path gs:///checkpoints/0/items \ + --model_name qwen3-30b-a3b \ + --hf_path Qwen/Qwen3-30B-A3B \ + --tasks ifeval math500 \ + --base_output_directory gs:/// \ + --run_name my_run \ + --max_model_len 8192 \ + --tensor_parallel_size 8 \ + --enable_expert_parallel \ + --hf_token $HF_TOKEN + + # custom eval_runner + python -m maxtext.eval.runner.run \ + --runner eval \ + --config src/maxtext/eval/configs/mlperf.yml \ + --checkpoint_path gs:///checkpoints/0/items \ + --model_name llama3.1-8b \ + --hf_path meta-llama/Llama-3.1-8B-Instruct \ + --base_output_directory gs:/// \ + --run_name my_run \ + --hf_token $HF_TOKEN + +The individual runner modules (eval_runner, harness_runner) are directly +invocable as well. +""" + +from __future__ import annotations + +import argparse +import sys + + +def main() -> None: + pre_parser = argparse.ArgumentParser( + description="MaxText eval — unified entry point.", + add_help=False, + ) + pre_parser.add_argument( + "--runner", + required=True, + choices=["eval", "lm_eval", "evalchemy"], + help="Which evaluation runner to use.", + ) + pre_args, remaining = pre_parser.parse_known_args() + + sys.argv = [sys.argv[0]] + remaining + + if pre_args.runner == "eval": + from maxtext.eval.runner.eval_runner import main as _main # pylint: disable=import-outside-toplevel + _main() + elif pre_args.runner == "lm_eval": + from maxtext.eval.runner.harness_runner import main as _main # pylint: disable=import-outside-toplevel + _main() + else: # evalchemy + if "--backend" not in remaining: + sys.argv += ["--backend", "evalchemy"] + from maxtext.eval.runner.harness_runner import main as _main # pylint: disable=import-outside-toplevel + _main() + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/eval/runner/server_manager.py b/src/maxtext/eval/runner/server_manager.py new file mode 100644 index 0000000000..3287688211 --- /dev/null +++ b/src/maxtext/eval/runner/server_manager.py @@ -0,0 +1,392 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""vLLM-TPU server lifecycle (in-process LLM + thin HTTP wrapper).""" + +from __future__ import annotations + +import logging +import os +import threading +import time +import uuid +from typing import Any + +import requests + +logger = logging.getLogger(__name__) + +_HEALTH_ENDPOINT = "/health" + +def _build_app(llm: Any) -> Any: + """Return a FastAPI app that wraps an in-process vLLM LLM instance.""" + import fastapi # pylint: disable=import-outside-toplevel + from vllm.sampling_params import SamplingParams # pylint: disable=import-outside-toplevel + globals()["fastapi"] = fastapi + + app = fastapi.FastAPI() + + @app.get("/health") + def health(): + return {"status": "ok"} + + @app.post("/v1/completions") + async def completions(request: fastapi.Request): + body = await request.json() + + raw_prompt = body.get("prompt", "") + prompts = raw_prompt if isinstance(raw_prompt, list) else [raw_prompt] + model_name = body.get("model", "") + max_tokens = int(body.get("max_tokens") or 256) + temperature = float(body.get("temperature") or 0.0) + logprobs_n = body.get("logprobs") # int | None + echo = bool(body.get("echo", False)) + stop = body.get("stop") + + sp_kwargs: dict = {"max_tokens": max_tokens, "temperature": temperature} + if logprobs_n is not None: + sp_kwargs["logprobs"] = int(logprobs_n) + if echo and logprobs_n is not None: + sp_kwargs["prompt_logprobs"] = int(logprobs_n) + if stop: + sp_kwargs["stop"] = [stop] if isinstance(stop, str) else list(stop) + + outputs = llm.generate(prompts, SamplingParams(**sp_kwargs)) + tokenizer = llm.get_tokenizer() + + choices = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for idx, output in enumerate(outputs): + gen = output.outputs[0] + total_prompt_tokens += len(output.prompt_token_ids) + total_completion_tokens += len(gen.token_ids) + + logprobs_payload = None + if logprobs_n is not None: + tok_strings: list[str] = [] + tok_lps: list[float | None] = [] + tok_offsets: list[int] = [] + running_offset = 0 + + if echo: + prompt_lps = output.prompt_logprobs or [] + for pos, tok_id in enumerate(output.prompt_token_ids): + tok_str = tokenizer.decode([tok_id]) + tok_strings.append(tok_str) + tok_offsets.append(running_offset) + running_offset += len(tok_str) + lp_dict = prompt_lps[pos] if pos < len(prompt_lps) else None + lp_val = lp_dict[tok_id].logprob if (lp_dict and tok_id in lp_dict) else None + tok_lps.append(lp_val) + + gen_lps = gen.logprobs or [] + for pos, tok_id in enumerate(gen.token_ids): + tok_str = tokenizer.decode([tok_id]) + tok_strings.append(tok_str) + tok_offsets.append(running_offset) + running_offset += len(tok_str) + lp_dict = gen_lps[pos] if pos < len(gen_lps) else None + lp_val = lp_dict[tok_id].logprob if (lp_dict and tok_id in lp_dict) else None + tok_lps.append(lp_val) + + logprobs_payload = { + "tokens": tok_strings, + "token_logprobs": tok_lps, + "top_logprobs": None, + "text_offset": tok_offsets, + } + + text_out = (prompts[idx] + gen.text) if echo else gen.text + choices.append({ + "text": text_out, + "index": idx, + "logprobs": logprobs_payload, + "finish_reason": gen.finish_reason or "stop", + }) + + return { + "id": f"cmpl-{uuid.uuid4().hex}", + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": choices, + "usage": { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_prompt_tokens + total_completion_tokens, + }, + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: fastapi.Request): # pylint: disable=unused-variable + """OpenAI-compatible chat completions endpoint. + + Used by evalchemy and lm-eval chat tasks. + """ + body = await request.json() + messages = body.get("messages", []) + model_name = body.get("model", "") + max_tokens = int(body.get("max_tokens") or 256) + temperature = float(body.get("temperature") or 0.0) + stop = body.get("stop") + + tokenizer = llm.get_tokenizer() + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + sp_kwargs: dict = {"max_tokens": max_tokens, "temperature": temperature} + if stop: + sp_kwargs["stop"] = [stop] if isinstance(stop, str) else list(stop) + + outputs = llm.generate([prompt], SamplingParams(**sp_kwargs)) + gen = outputs[0].outputs[0] + prompt_tokens = len(outputs[0].prompt_token_ids) + completion_tokens = len(gen.token_ids) + + return { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": gen.text}, + "finish_reason": gen.finish_reason or "stop", + "logprobs": None, + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + + return app + + +class VllmServerManager: + """Manages an in-process vLLM-TPU LLM with an OpenAI-compatible HTTP layer. + + Args: + model_path: HF model ID or local path. + checkpoint_path: MaxText orbax checkpoint path. + maxtext_model_name: MaxText model name (e.g. "llama3.1-8b"). + host: Hostname the HTTP server binds to (rank-0 only). + port: Port the HTTP server listens on. + tensor_parallel_size: Total number of chips. + expert_parallel_size: Chips allocated to the expert mesh axis (EP). + max_model_len: Maximum sequence length. + dtype: Activation dtype string passed to vLLM (e.g. "bfloat16"). + max_num_batched_tokens: Tokens per scheduler step (None = vLLM default). + max_num_seqs: Max concurrent sequences (None = vLLM default). + startup_timeout: Seconds to wait for /health to return healthy. + hbm_memory_utilization: Fraction of HBM reserved for KV cache. + env: Optional environment-variable overrides. + additional_vllm_kwargs: Extra kwargs merged into the vLLM LLM() constructor. + """ + + def __init__( + self, + model_path: str, + checkpoint_path: str | None = None, + maxtext_model_name: str | None = None, + host: str = "localhost", + port: int = 8000, + tensor_parallel_size: int = 4, + expert_parallel_size: int = 1, + data_parallel_size: int = 1, + max_model_len: int = 4096, + dtype: str = "bfloat16", + max_num_batched_tokens: int | None = None, + max_num_seqs: int | None = None, + startup_timeout: int = 600, + hbm_memory_utilization: float = 0.3, + env: dict[str, str] | None = None, + additional_vllm_kwargs: dict | None = None, + ): + if checkpoint_path and not maxtext_model_name: + raise ValueError("maxtext_model_name is required when checkpoint_path is set.") + if tensor_parallel_size % expert_parallel_size != 0: + raise ValueError( + f"tensor_parallel_size ({tensor_parallel_size}) is not divisible by " + f"expert_parallel_size ({expert_parallel_size})." + ) + self.model_path = model_path + self.checkpoint_path = checkpoint_path + self.maxtext_model_name = maxtext_model_name + self.host = host + self.port = port + self.tensor_parallel_size = tensor_parallel_size + self.expert_parallel_size = expert_parallel_size + self.data_parallel_size = data_parallel_size + self.max_model_len = max_model_len + self.dtype = dtype + self.max_num_batched_tokens = max_num_batched_tokens + self.max_num_seqs = max_num_seqs + self.startup_timeout = startup_timeout + self.hbm_memory_utilization = hbm_memory_utilization + self.env = env + self.additional_vllm_kwargs = additional_vllm_kwargs or {} + + self._llm: Any | None = None + self._uvicorn_server: Any | None = None + self._server_thread: threading.Thread | None = None + + @property + def base_url(self) -> str: + return f"http://{self.host}:{self.port}" + + def start(self) -> None: + """Initialise the in-process vLLM LLM and start the HTTP server.""" + # pylint: disable=import-outside-toplevel + from vllm import LLM + + # Disable V1 multiprocessing so EngineCore runs in-process instead. + # V1 engine architecture is otherwise preserved (tpu-inference plugin works), + # and JAX/TPU is initialised exactly once inside LLM() in this process. + os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + os.environ.setdefault("NEW_MODEL_DESIGN", "1") + os.environ.setdefault("SKIP_JAX_PRECOMPILE", "1") + + if self.env: + os.environ.update(self.env) + + # total chips = ici_tensor_parallelism x ici_expert_parallelism. + ici_tp = self.tensor_parallel_size // self.expert_parallel_size + ici_ep = self.expert_parallel_size + + vllm_kwargs: dict = { + "model": self.model_path, + "tensor_parallel_size": ici_tp, + "data_parallel_size": self.data_parallel_size, + "max_model_len": self.max_model_len, + "dtype": self.dtype, + "gpu_memory_utilization": self.hbm_memory_utilization, + } + if self.max_num_batched_tokens is not None: + vllm_kwargs["max_num_batched_tokens"] = self.max_num_batched_tokens + if self.max_num_seqs is not None: + vllm_kwargs["max_num_seqs"] = self.max_num_seqs + + if self.checkpoint_path: + vllm_kwargs["hf_overrides"] = {"architectures": ["MaxTextForCausalLM"]} + vllm_kwargs["additional_config"] = { + "maxtext_config": { + "model_name": self.maxtext_model_name, + "load_parameters_path": self.checkpoint_path, + "log_config": False, + "ici_tensor_parallelism": ici_tp, + "ici_expert_parallelism": ici_ep, + }, + "sharding": { + "sharding_strategy": {}, + }, + } + if ici_ep > 1: + vllm_kwargs["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = ici_ep + else: + vllm_kwargs["load_format"] = "auto" + + if self.additional_vllm_kwargs: + for _k, _v in self.additional_vllm_kwargs.items(): + if ( + _k == "additional_config" + and isinstance(_v, dict) + and isinstance(vllm_kwargs.get("additional_config"), dict) + ): + for _sub_k, _sub_v in _v.items(): + if isinstance(_sub_v, dict) and isinstance( + vllm_kwargs["additional_config"].get(_sub_k), dict + ): + vllm_kwargs["additional_config"][_sub_k].update(_sub_v) + else: + vllm_kwargs["additional_config"][_sub_k] = _sub_v + else: + vllm_kwargs[_k] = _v + + logger.info( + "Initializing in-process vLLM (tp=%d, ep=%d, dp=%d, max_len=%d)...", + ici_tp, + ici_ep, + self.data_parallel_size, + self.max_model_len, + ) + self._llm = LLM(**vllm_kwargs) + + import jax as _jax # pylint: disable=import-outside-toplevel + logger.info("Rank %d: vLLM LLM ready.", _jax.process_index()) + + if _jax.process_index() == 0: + import uvicorn # pylint: disable=import-outside-toplevel + + app = _build_app(self._llm) + config = uvicorn.Config( + app, + host=self.host, + port=self.port, + log_level="warning", + workers=1, + ) + self._uvicorn_server = uvicorn.Server(config) + self._server_thread = threading.Thread( + target=self._uvicorn_server.run, + daemon=True, + name="vllm-http-server", + ) + self._server_thread.start() + self._wait_until_healthy() + + def _wait_until_healthy(self) -> None: + deadline = time.time() + self.startup_timeout + health_url = f"{self.base_url}{_HEALTH_ENDPOINT}" + while time.time() < deadline: + try: + resp = requests.get(health_url, timeout=5) + if resp.status_code == 200: + logger.info("vLLM HTTP server is healthy at %s", self.base_url) + return + except requests.exceptions.ConnectionError: + pass + if self._server_thread is not None and not self._server_thread.is_alive(): + raise RuntimeError("vLLM HTTP server thread died before becoming healthy.") + time.sleep(2) + raise TimeoutError( + f"vLLM HTTP server did not become healthy within {self.startup_timeout}s." + ) + + def stop(self) -> None: + """Stop the HTTP server and release the LLM.""" + if self._uvicorn_server is not None: + logger.info("Stopping vLLM HTTP server...") + self._uvicorn_server.should_exit = True + if self._server_thread is not None: + self._server_thread.join(timeout=30) + if self._server_thread.is_alive(): + logger.warning("vLLM HTTP server thread did not exit within 30 s.") + self._llm = None + self._uvicorn_server = None + self._server_thread = None + logger.info("VllmServerManager stopped.") + + def __enter__(self) -> "VllmServerManager": + self.start() + return self + + def __exit__(self, *_) -> None: + self.stop() diff --git a/src/maxtext/eval/runner/warmup.py b/src/maxtext/eval/runner/warmup.py new file mode 100644 index 0000000000..8a06d480b0 --- /dev/null +++ b/src/maxtext/eval/runner/warmup.py @@ -0,0 +1,100 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server warmup — triggers XLA compilation before evaluation begins. + +Sends one request per prompt-length bucket (16, 32, 64, 128, 256, 512, 1024 +tokens) so that vLLM compiles all the kernel shapes that will be seen during +eval. Falls back to a single short prompt when no dataset requests are +provided. +""" + +from __future__ import annotations + +import asyncio +import logging + +from maxtext.eval.datasets.base import SampleRequest + +logger = logging.getLogger(__name__) + +_COMPLETIONS_PATH = "/v1/completions" +_SIMPLE_WARMUP_PROMPT = "What is 1 + 1?" +_WARMUP_BUCKETS = [0, 16, 32, 64, 128, 256, 512, 1024] + + +def _sample_by_buckets(requests: list[SampleRequest]) -> list[SampleRequest]: + """Return one request per prompt-length bucket.""" + sampled = [] + for start, end in zip(_WARMUP_BUCKETS[:-1], _WARMUP_BUCKETS[1:]): + for req in requests: + approx_len = len(req.prompt.split()) + if start < approx_len <= end: + sampled.append(req) + break + return sampled + + +async def _send_warmup_requests( + api_url: str, + model: str, + pairs: list[tuple[str, int]], +) -> None: + """Send warmup requests concurrently and log any failures.""" + import aiohttp # pylint: disable=import-outside-toplevel + + async def _post(session: aiohttp.ClientSession, prompt: str, max_tokens: int) -> bool: + payload = {"model": model, "prompt": prompt, "max_tokens": max_tokens} + try: + async with session.post(api_url, json=payload) as resp: + return resp.status == 200 + except aiohttp.ClientError: + return False + + async with aiohttp.ClientSession() as session: + results = await asyncio.gather(*[_post(session, p, n) for p, n in pairs]) + + failures = sum(1 for ok in results if not ok) + if failures: + logger.warning("Warmup: %d/%d requests failed.", failures, len(results)) + else: + logger.info("Warmup complete (%d requests).", len(results)) + + +def warmup_server( + base_url: str, + model: str, + sample_requests: list[SampleRequest] | None = None, + max_tokens: int = 16, +) -> None: + """Send warmup requests to trigger XLA compilation before eval. + + Args: + base_url: Base URL of the vLLM server. + model: Model name. + sample_requests: Optional dataset sample requests used to derive + prompt-length buckets. When omitted, a single short prompt is sent. + max_tokens: Maximum tokens per warmup response. + """ + api_url = f"{base_url}{_COMPLETIONS_PATH}" + + if sample_requests: + bucketed = _sample_by_buckets(sample_requests) + logger.info("Running bucketed warmup: %d prompt-length buckets.", len(bucketed)) + pairs = [(r.prompt, max_tokens) for r in bucketed] + else: + logger.info("Running simple warmup (no dataset samples provided).") + pairs = [(_SIMPLE_WARMUP_PROMPT, max_tokens)] + + asyncio.run(_send_warmup_requests(api_url, model, pairs)) diff --git a/src/maxtext/eval/scoring/registry.py b/src/maxtext/eval/scoring/registry.py new file mode 100644 index 0000000000..e49a98eb3c --- /dev/null +++ b/src/maxtext/eval/scoring/registry.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry mapping dataset/benchmark names to their scorer functions.""" + +from __future__ import annotations + +from typing import Callable + +from maxtext.eval.scoring import rouge_scorer + +# Maps benchmark name to score_batch callable. +SCORER_REGISTRY: dict[str, Callable[..., dict]] = { + "mlperf_openorca": rouge_scorer.score_batch, + "openorca": rouge_scorer.score_batch, +} + + +def get_scorer(benchmark_name: str) -> Callable[..., dict]: + """Return the scorer for benchmark_name. + + Args: + benchmark_name: Benchmark identifier (e.g. "mlperf_openorca"). + + Returns: + The scorer callable. + + Raises: + KeyError: If no scorer is registered for the given name. + """ + key = benchmark_name.lower() + if key not in SCORER_REGISTRY: + raise KeyError( + f"No scorer registered for benchmark '{benchmark_name}'. " + f"Available: {sorted(SCORER_REGISTRY)}. " + ) + return SCORER_REGISTRY[key] + + +def register_scorer(benchmark_name: str, scorer_fn: Callable[..., dict]) -> None: + """Register a custom scorer for benchmark_name. + + Args: + benchmark_name: Benchmark identifier. + scorer_fn: Scorer callable. + """ + SCORER_REGISTRY[benchmark_name.lower()] = scorer_fn diff --git a/src/maxtext/eval/scoring/rouge_scorer.py b/src/maxtext/eval/scoring/rouge_scorer.py new file mode 100644 index 0000000000..e6ca5c5332 --- /dev/null +++ b/src/maxtext/eval/scoring/rouge_scorer.py @@ -0,0 +1,65 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROUGE scorer for MLPerf OpenOrca benchmark.""" + +from __future__ import annotations + +import numpy as np + +import nltk # pylint: disable=import-outside-toplevel + +nltk.download("punkt", quiet=True) +nltk.download("punkt_tab", quiet=True) + + +def score_batch( + responses: list[str], + references: list[str], + use_stemmer: bool = True, # noqa: ARG001 (API consistency) +) -> dict: + """Compute ROUGE scores for a batch of generated responses. + + Args: + responses: List of model-generated summaries. + references: List of reference summaries. + use_stemmer: Accepted for API consistency (handled by the evaluate library). + + Returns: + Dict with keys: rouge1, rouge2, rougeL, rougeLsum, gen_num. + + Raises: + ValueError: If responses and references have different lengths. + """ + if len(responses) != len(references): + raise ValueError( + f"Length mismatch: {len(responses)} responses vs {len(references)} references." + ) + + import evaluate # pylint: disable=import-outside-toplevel + + metric = evaluate.load("rouge") + + preds = [] + targets = [] + for resp, ref in zip(responses, references): + pred = "\n".join(nltk.sent_tokenize(resp.strip())) + target = "\n".join(nltk.sent_tokenize(ref.strip())) + preds.append(pred) + targets.append(target) + + result = metric.compute(predictions=preds, references=targets) + result = {k: float(round(np.mean(v) * 100, 4)) for k, v in result.items()} + result["gen_num"] = len(preds) + return result diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index f06e7ed927..bf6fb9eb73 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -257,4 +257,4 @@ def load_weights(self, rng_key: jax.Array) -> None: model, _ = model_creation_utils.create_nnx_model( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) - self.model = nnx.data(model) + self.model = nnx.data(model) \ No newline at end of file diff --git a/tests/unit/eval/test_build_app.py b/tests/unit/eval/test_build_app.py new file mode 100644 index 0000000000..24fc711d34 --- /dev/null +++ b/tests/unit/eval/test_build_app.py @@ -0,0 +1,235 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for maxtext.eval.runner.server_manager._build_app.""" + +from __future__ import annotations + +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + + +def _make_mock_output(generated_text="hello", prompt_token_ids=(1, 2, 3), generated_token_ids=(4, 5)): + """Build a SimpleNamespace mimicking a vLLM RequestOutput object.""" + return SimpleNamespace( + prompt_token_ids=list(prompt_token_ids), + prompt_logprobs=None, + outputs=[ + SimpleNamespace( + text=generated_text, + token_ids=list(generated_token_ids), + logprobs=None, + finish_reason="stop", + ) + ], + ) + + +def _make_mock_llm(generated_text="hello", prompt_token_ids=(1, 2, 3), generated_token_ids=(4, 5)): + """Return a mock vLLM LLM object whose generate() returns a single RequestOutput. + + The tokenizer returned by ``get_tokenizer()`` decodes each token ID to the + string ``f"tok{tok_id}"``. + """ + mock_output = _make_mock_output( + generated_text=generated_text, + prompt_token_ids=prompt_token_ids, + generated_token_ids=generated_token_ids, + ) + + mock_tokenizer = MagicMock() + mock_tokenizer.decode.side_effect = lambda ids: "".join(f"tok{i}" for i in ids) + mock_tokenizer.apply_chat_template.return_value = "rendered_prompt" + + mock_llm = MagicMock() + mock_llm.generate.return_value = [mock_output] + mock_llm.get_tokenizer.return_value = mock_tokenizer + return mock_llm + + +class TestBuildApp(unittest.TestCase): + """Tests for the FastAPI app returned by _build_app(llm).""" + + def setUp(self): + """Patch SamplingParams at the module level used by server_manager.""" + self.mock_llm = _make_mock_llm() + self.mock_sampling_params_cls = MagicMock(return_value=MagicMock()) + + # Patch at the import location used inside _build_app. + self._sp_patcher = patch( + "vllm.sampling_params.SamplingParams", + self.mock_sampling_params_cls, + ) + self._vllm_patcher = patch.dict( + "sys.modules", + { + "vllm": MagicMock(), + "vllm.sampling_params": MagicMock(SamplingParams=self.mock_sampling_params_cls), + }, + ) + self._vllm_patcher.start() + self._sp_patcher.start() + + from maxtext.eval.runner.server_manager import _build_app + from starlette.testclient import TestClient + + self.app = _build_app(self.mock_llm) + self.client = TestClient(self.app) + + def tearDown(self): + self._sp_patcher.stop() + self._vllm_patcher.stop() + + def test_health_endpoint(self): + resp = self.client.get("/health") + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json(), {"status": "ok"}) + + def test_completions_basic(self): + resp = self.client.post( + "/v1/completions", + json={"model": "m", "prompt": "hi", "max_tokens": 10}, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertEqual(len(data["choices"]), 1) + self.assertEqual(data["choices"][0]["text"], "hello") + + def test_completions_list_prompt(self): + mock_llm = _make_mock_llm(generated_text="world") + mock_llm.generate.return_value = [ + _make_mock_output(generated_text="alpha"), + _make_mock_output(generated_text="beta"), + ] + mock_llm.get_tokenizer.return_value = self.mock_llm.get_tokenizer() + + from maxtext.eval.runner.server_manager import _build_app + from starlette.testclient import TestClient + + app = _build_app(mock_llm) + client = TestClient(app) + + resp = client.post( + "/v1/completions", + json={"model": "m", "prompt": ["first", "second"], "max_tokens": 5}, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(len(data["choices"]), 2) + self.assertEqual(data["choices"][0]["text"], "alpha") + self.assertEqual(data["choices"][1]["text"], "beta") + + def test_completions_no_logprobs(self): + resp = self.client.post( + "/v1/completions", + json={"model": "m", "prompt": "test", "max_tokens": 5}, + ) + data = resp.json() + self.assertIsNone(data["choices"][0]["logprobs"]) + + def test_completions_with_logprobs_echo_false(self): + mock_output = _make_mock_output( + generated_text="hi", + prompt_token_ids=[1, 2], + generated_token_ids=[4, 5], + ) + mock_output.outputs[0].logprobs = [ + {4: SimpleNamespace(logprob=-0.5)}, + {5: SimpleNamespace(logprob=-1.2)}, + ] + self.mock_llm.generate.return_value = [mock_output] + + resp = self.client.post( + "/v1/completions", + json={"model": "m", "prompt": "ab", "max_tokens": 5, "logprobs": 1}, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + lp = data["choices"][0]["logprobs"] + self.assertIsNotNone(lp) + self.assertEqual(len(lp["tokens"]), 2) + self.assertAlmostEqual(lp["token_logprobs"][0], -0.5, places=4) + self.assertAlmostEqual(lp["token_logprobs"][1], -1.2, places=4) + + def test_completions_with_logprobs_echo_true(self): + mock_output = _make_mock_output( + generated_text=" world", + prompt_token_ids=[1, 2, 3], + generated_token_ids=[4, 5], + ) + mock_output.prompt_logprobs = [ + None, + {2: SimpleNamespace(logprob=-0.3)}, + {3: SimpleNamespace(logprob=-0.7)}, + ] + mock_output.outputs[0].logprobs = [ + {4: SimpleNamespace(logprob=-0.9)}, + {5: SimpleNamespace(logprob=-1.1)}, + ] + self.mock_llm.generate.return_value = [mock_output] + + resp = self.client.post( + "/v1/completions", + json={ + "model": "m", + "prompt": "tok1tok2tok3", + "max_tokens": 5, + "logprobs": 1, + "echo": True, + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + lp = data["choices"][0]["logprobs"] + self.assertIsNotNone(lp) + # echo=True → prompt tokens (3) + generated tokens (2) = 5 total. + self.assertEqual(len(lp["tokens"]), 5) + + def test_chat_completions_basic(self): + resp = self.client.post( + "/v1/chat/completions", + json={ + "model": "m", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 20, + }, + ) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertIn("choices", data) + self.assertEqual(data["choices"][0]["message"]["role"], "assistant") + self.assertEqual(data["choices"][0]["message"]["content"], "hello") + + def test_chat_completions_applies_template(self): + resp = self.client.post( + "/v1/chat/completions", + json={ + "model": "m", + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 10, + }, + ) + self.assertEqual(resp.status_code, 200) + tokenizer = self.mock_llm.get_tokenizer() + tokenizer.apply_chat_template.assert_called() + call_args = tokenizer.apply_chat_template.call_args + # The messages list should have been forwarded to apply_chat_template. + passed_messages = call_args[0][0] if call_args[0] else call_args[1].get("conversation") + self.assertIsNotNone(passed_messages) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/eval/test_eval_runner.py b/tests/unit/eval/test_eval_runner.py new file mode 100644 index 0000000000..a26c280512 --- /dev/null +++ b/tests/unit/eval/test_eval_runner.py @@ -0,0 +1,136 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for maxtext.eval.runner.eval_runner.""" + +import os +import tempfile +import textwrap +import unittest + +from maxtext.eval.runner.eval_runner import ( + _build_results_path, + _derive_from_maxtext_config, + _merge_config, +) + + +class TestMergeConfig(unittest.TestCase): + + def test_override_takes_precedence(self): + base = {"concurrency": 64, "temperature": 0.0} + overrides = {"concurrency": 128} + merged = _merge_config(base, overrides) + self.assertEqual(merged["concurrency"], 128) + self.assertEqual(merged["temperature"], 0.0) + + def test_none_override_does_not_overwrite(self): + base = {"concurrency": 64} + overrides = {"concurrency": None} + merged = _merge_config(base, overrides) + self.assertEqual(merged["concurrency"], 64) + + def test_new_key_from_override(self): + merged = _merge_config({}, {"benchmark": "mmlu"}) + self.assertEqual(merged["benchmark"], "mmlu") + + def test_base_preserved_when_no_override(self): + base = {"a": 1, "b": 2} + merged = _merge_config(base, {}) + self.assertEqual(merged, base) + + +class TestBuildResultsPath(unittest.TestCase): + + def test_standard_path(self): + cfg = {"base_output_directory": "gs://bucket/", "run_name": "run1"} + self.assertEqual(_build_results_path(cfg), "gs://bucket/run1/eval_results") + + def test_trailing_slash_stripped(self): + cfg = {"base_output_directory": "gs://bucket///", "run_name": "run1"} + self.assertEqual(_build_results_path(cfg), "gs://bucket/run1/eval_results") + + def test_local_path(self): + cfg = {"base_output_directory": "/tmp/out", "run_name": "test"} + self.assertEqual(_build_results_path(cfg), "/tmp/out/test/eval_results") + + def test_missing_run_name_raises(self): + cfg = {"base_output_directory": "gs://bucket/"} + with self.assertRaises(ValueError): + _build_results_path(cfg) + + def test_missing_base_output_directory_raises(self): + cfg = {"run_name": "run1"} + with self.assertRaises(ValueError): + _build_results_path(cfg) + + def test_empty_strings_raise(self): + cfg = {"base_output_directory": "", "run_name": ""} + with self.assertRaises(ValueError): + _build_results_path(cfg) + + +class TestDeriveFromMaxtextConfig(unittest.TestCase): + + def _write_yaml(self, content: str) -> str: + f = tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) + f.write(textwrap.dedent(content)) + f.close() + return f.name + + def test_derives_max_model_len(self): + path = self._write_yaml("max_target_length: 1024\n") + derived = _derive_from_maxtext_config(path) + self.assertEqual(derived["max_model_len"], 1024) + os.unlink(path) + + def test_derives_max_tokens_default(self): + path = self._write_yaml( + "max_target_length: 1024\nmax_prefill_predict_length: 256\n" + ) + derived = _derive_from_maxtext_config(path) + self.assertEqual(derived["max_tokens_default"], 768) + os.unlink(path) + + def test_no_max_tokens_default_without_both_fields(self): + path = self._write_yaml("max_target_length: 1024\n") + derived = _derive_from_maxtext_config(path) + self.assertNotIn("max_tokens_default", derived) + os.unlink(path) + + def test_derives_run_name_and_base_output_directory(self): + path = self._write_yaml( + "base_output_directory: gs://b/\nrun_name: myrun\n" + ) + derived = _derive_from_maxtext_config(path) + self.assertEqual(derived["base_output_directory"], "gs://b/") + self.assertEqual(derived["run_name"], "myrun") + os.unlink(path) + + def test_empty_run_name_not_derived(self): + # base.yml has run_name: '' by default; empty string should not be derived. + path = self._write_yaml("run_name: ''\n") + derived = _derive_from_maxtext_config(path) + self.assertNotIn("run_name", derived) + os.unlink(path) + + def test_empty_yaml_returns_empty_dict(self): + path = self._write_yaml("") + derived = _derive_from_maxtext_config(path) + self.assertEqual(derived, {}) + os.unlink(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/eval/test_lm_eval_runner.py b/tests/unit/eval/test_lm_eval_runner.py new file mode 100644 index 0000000000..d7a53b9fcb --- /dev/null +++ b/tests/unit/eval/test_lm_eval_runner.py @@ -0,0 +1,100 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for harness_runner._map_results (no server required).""" + +from __future__ import annotations + +import unittest + +from maxtext.eval.runner.harness_runner import _map_results + + +class TestMapResults(unittest.TestCase): + """Tests for _map_results.""" + + def _make_raw(self, task_key: str, **metric_values) -> dict: + """Build a minimal lm-eval results dict for a single task.""" + return {"results": {task_key: metric_values}} + + def test_mmlu_accuracy(self): + raw = self._make_raw("mmlu", **{"acc,none": 0.725}) + scores = _map_results(raw, ["mmlu"]) + self.assertIn("mmlu_accuracy", scores) + self.assertAlmostEqual(scores["mmlu_accuracy"], 72.5, places=1) + + def test_gpqa_diamond(self): + raw = self._make_raw("gpqa_diamond", **{"acc,none": 0.4040}) + scores = _map_results(raw, ["gpqa_diamond"]) + self.assertIn("gpqa_diamond_accuracy", scores) + self.assertAlmostEqual(scores["gpqa_diamond_accuracy"], 40.4, places=1) + + def test_gsm8k_strict_match_key(self): + # GSM8K uses exact_match,strict-match rather than acc,none. + raw = self._make_raw("gsm8k", **{"exact_match,strict-match": 0.80}) + scores = _map_results(raw, ["gsm8k"]) + self.assertIn("gsm8k_accuracy", scores) + self.assertAlmostEqual(scores["gsm8k_accuracy"], 80.0, places=1) + + def test_gsm8k_flexible_extract_key(self): + raw = self._make_raw("gsm8k", **{"exact_match,flexible-extract": 0.75}) + scores = _map_results(raw, ["gsm8k"]) + self.assertIn("gsm8k_accuracy", scores) + self.assertAlmostEqual(scores["gsm8k_accuracy"], 75.0, places=1) + + def test_zero_score_not_dropped(self): + raw = self._make_raw("gsm8k", **{"exact_match,strict-match": 0.0}) + scores = _map_results(raw, ["gsm8k"]) + self.assertIn("gsm8k_accuracy", scores) + self.assertAlmostEqual(scores["gsm8k_accuracy"], 0.0, places=1) + + def test_acc_norm_extracted(self): + raw = self._make_raw("mmlu", **{"acc,none": 0.5, "acc_norm,none": 0.6}) + scores = _map_results(raw, ["mmlu"]) + self.assertIn("mmlu_accuracy", scores) + self.assertIn("mmlu_accuracy_norm", scores) + self.assertAlmostEqual(scores["mmlu_accuracy_norm"], 60.0, places=1) + + def test_multiple_tasks(self): + raw = { + "results": { + "mmlu": {"acc,none": 0.80}, + "gpqa_diamond": {"acc,none": 0.35}, + } + } + scores = _map_results(raw, ["mmlu", "gpqa_diamond"]) + self.assertIn("mmlu_accuracy", scores) + self.assertIn("gpqa_diamond_accuracy", scores) + self.assertAlmostEqual(scores["mmlu_accuracy"], 80.0, places=1) + self.assertAlmostEqual(scores["gpqa_diamond_accuracy"], 35.0, places=1) + + + def test_missing_task_returns_nothing(self): + raw = {"results": {}} + scores = _map_results(raw, ["mmlu"]) + self.assertNotIn("mmlu_accuracy", scores) + + def test_custom_task_name_used_directly(self): + raw = self._make_raw("my_custom_task", **{"acc,none": 0.9}) + scores = _map_results(raw, ["my_custom_task"]) + self.assertIn("my_custom_task_accuracy", scores) + + def test_empty_tasks_list(self): + raw = self._make_raw("mmlu", **{"acc,none": 0.9}) + scores = _map_results(raw, []) + self.assertEqual(scores, {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/eval/test_scorers.py b/tests/unit/eval/test_scorers.py new file mode 100644 index 0000000000..2681c551b3 --- /dev/null +++ b/tests/unit/eval/test_scorers.py @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for maxtext.eval.scoring scorer modules.""" + +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + + +class TestRougeScorer(unittest.TestCase): + """Tests for maxtext.eval.scoring.rouge_scorer.score_batch.""" + + def _make_mock_rouge_metric(self, rouge1=0.9, rouge2=0.85, rougeL=0.88, rougeLsum=0.88): + """Return a mock evaluate metric whose compute() returns fixed ROUGE scores.""" + mock_metric = MagicMock() + mock_metric.compute.return_value = { + "rouge1": rouge1, + "rouge2": rouge2, + "rougeL": rougeL, + "rougeLsum": rougeLsum, + } + return mock_metric + + def _run_score_batch(self, responses, references, mock_metric): + """Import score_batch with evaluate and nltk.sent_tokenize patched, return result dict.""" + with patch("evaluate.load", return_value=mock_metric), \ + patch("nltk.sent_tokenize", side_effect=lambda s: [s]): + from maxtext.eval.scoring.rouge_scorer import score_batch + return score_batch(responses, references) + + def test_perfect_match(self): + texts = ["The quick brown fox", "Hello world"] + mock_metric = self._make_mock_rouge_metric(rouge1=1.0, rouge2=1.0, rougeL=1.0, rougeLsum=1.0) + result = self._run_score_batch(texts, texts, mock_metric) + # High ROUGE scores expected for identical strings. + self.assertGreater(result["rouge1"], 50.0) + self.assertGreater(result["rouge2"], 50.0) + self.assertGreater(result["rougeL"], 50.0) + # gen_num should match the batch size. + self.assertEqual(result["gen_num"], 2) + + def test_empty_inputs(self): + mock_metric = self._make_mock_rouge_metric(rouge1=0.0, rouge2=0.0, rougeL=0.0, rougeLsum=0.0) + result = self._run_score_batch([], [], mock_metric) + self.assertEqual(result["gen_num"], 0) + + def test_length_mismatch_raises(self): + mock_metric = self._make_mock_rouge_metric() + with self.assertRaises(ValueError): + self._run_score_batch(["a", "b"], ["only_one"], mock_metric) + + def test_returns_rouge_keys(self): + mock_metric = self._make_mock_rouge_metric() + result = self._run_score_batch(["hello"], ["hello"], mock_metric) + # Must contain at least the four standard ROUGE keys plus gen_num. + self.assertIn("gen_num", result) + # At least one of rougeL / rougeLsum must be present. + has_rouge_keys = ( + "rouge1" in result + and "rouge2" in result + and ("rougeL" in result or "rougeLsum" in result) + ) + self.assertTrue(has_rouge_keys) + + def test_partial_overlap(self): + mock_metric = self._make_mock_rouge_metric(rouge1=0.5, rouge2=0.2, rougeL=0.45, rougeLsum=0.45) + # Multiply ×100 in eval_accuracy_mlperf, so 0.5 → 50.0. + result = self._run_score_batch(["fox brown quick"], ["quick brown fox"], mock_metric) + self.assertGreater(result["rouge1"], 0.0) + self.assertLess(result["rouge1"], 100.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/eval/test_server_manager.py b/tests/unit/eval/test_server_manager.py new file mode 100644 index 0000000000..5e50322f0f --- /dev/null +++ b/tests/unit/eval/test_server_manager.py @@ -0,0 +1,244 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for maxtext.eval.runner.server_manager.VllmServerManager.""" + +import os +import unittest +from unittest import mock + +from maxtext.eval.runner.server_manager import VllmServerManager + + +def _make_manager(**kwargs) -> VllmServerManager: + defaults = dict( + model_path="/fake/model", + host="localhost", + port=8000, + tensor_parallel_size=4, + max_model_len=4096, + ) + defaults.update(kwargs) + return VllmServerManager(**defaults) + + +def _start_capturing_llm_kwargs(mgr: VllmServerManager, rank: int = 0) -> dict: + """Call mgr.start() with vLLM/uvicorn/JAX mocked; return kwargs passed to LLM(...).""" + mock_llm_cls = mock.MagicMock() + mock_vllm = mock.MagicMock() + mock_vllm.LLM = mock_llm_cls + mock_uvicorn = mock.MagicMock() + + with mock.patch.dict("sys.modules", {"vllm": mock_vllm, "uvicorn": mock_uvicorn}), \ + mock.patch("jax.process_index", return_value=rank), \ + mock.patch("threading.Thread", return_value=mock.MagicMock()), \ + mock.patch("maxtext.eval.runner.server_manager._build_app", return_value=mock.MagicMock()), \ + mock.patch.object(mgr, "_wait_until_healthy"): + mgr.start() + + return mock_llm_cls.call_args.kwargs + + +class TestVllmServerManagerConfig(unittest.TestCase): + """Tests for vLLM LLM constructor kwargs built by start().""" + + def test_required_vllm_kwargs(self): + mgr = _make_manager(tensor_parallel_size=4, max_model_len=8192) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs["model"], "/fake/model") + self.assertEqual(kwargs["tensor_parallel_size"], 4) + self.assertEqual(kwargs["max_model_len"], 8192) + + def test_data_parallel_size_forwarded(self): + mgr = _make_manager(data_parallel_size=2) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs["data_parallel_size"], 2) + + def test_data_parallel_size_default_is_one(self): + mgr = _make_manager() + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs["data_parallel_size"], 1) + + def test_hbm_memory_utilization_forwarded_as_gpu_memory_utilization(self): + mgr = _make_manager(hbm_memory_utilization=0.5) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertAlmostEqual(kwargs["gpu_memory_utilization"], 0.5) + + def test_expert_parallel_size_divides_tensor_parallel(self): + mgr = _make_manager(tensor_parallel_size=8, expert_parallel_size=4) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs["tensor_parallel_size"], 2) # ici_tp = 8 // 4 + + def test_ep_not_divisible_raises(self): + with self.assertRaises(ValueError): + _make_manager(tensor_parallel_size=8, expert_parallel_size=3) + + def test_maxtext_adapter_mode_sets_hf_overrides(self): + mgr = _make_manager( + checkpoint_path="gs://bucket/run/0/items", + maxtext_model_name="llama3.1-8b", + ) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertIn("hf_overrides", kwargs) + self.assertEqual(kwargs["hf_overrides"]["architectures"], ["MaxTextForCausalLM"]) + + def test_maxtext_adapter_mode_sets_additional_config(self): + mgr = _make_manager( + checkpoint_path="gs://bucket/run/0/items", + maxtext_model_name="llama3.1-8b", + ) + kwargs = _start_capturing_llm_kwargs(mgr) + add_cfg = kwargs["additional_config"]["maxtext_config"] + self.assertEqual(add_cfg["load_parameters_path"], "gs://bucket/run/0/items") + self.assertEqual(add_cfg["model_name"], "llama3.1-8b") + + def test_hf_mode_sets_load_format_auto(self): + mgr = _make_manager() # no checkpoint_path → HF mode + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs.get("load_format"), "auto") + self.assertNotIn("hf_overrides", kwargs) + self.assertNotIn("additional_config", kwargs) + + def test_max_num_batched_tokens_forwarded(self): + mgr = _make_manager(max_num_batched_tokens=2048) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs["max_num_batched_tokens"], 2048) + + def test_max_num_batched_tokens_omitted_when_none(self): + mgr = _make_manager(max_num_batched_tokens=None) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertNotIn("max_num_batched_tokens", kwargs) + + def test_max_num_seqs_forwarded(self): + mgr = _make_manager(max_num_seqs=256) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertEqual(kwargs["max_num_seqs"], 256) + + def test_max_num_seqs_omitted_when_none(self): + mgr = _make_manager(max_num_seqs=None) + kwargs = _start_capturing_llm_kwargs(mgr) + self.assertNotIn("max_num_seqs", kwargs) + + def test_env_applied_to_os_environ_before_llm_init(self): + mgr = _make_manager(env={"_TEST_EVAL_TOKEN": "abc123"}) + env_at_init = {} + + def capture_env(**kwargs): # pylint: disable=unused-argument + env_at_init.update(os.environ) + return mock.MagicMock() + + mock_llm_cls = mock.MagicMock(side_effect=capture_env) + mock_vllm = mock.MagicMock() + mock_vllm.LLM = mock_llm_cls + + with mock.patch.dict("sys.modules", {"vllm": mock_vllm, "uvicorn": mock.MagicMock()}), \ + mock.patch("jax.process_index", return_value=0), \ + mock.patch("threading.Thread", return_value=mock.MagicMock()), \ + mock.patch("maxtext.eval.runner.server_manager._build_app", return_value=mock.MagicMock()), \ + mock.patch.object(mgr, "_wait_until_healthy"), \ + mock.patch.dict("os.environ", {}, clear=False): + mgr.start() + + self.assertEqual(env_at_init.get("_TEST_EVAL_TOKEN"), "abc123") + + def test_missing_maxtext_model_name_raises(self): + with self.assertRaises(ValueError): + VllmServerManager(model_path="/fake/model", checkpoint_path="gs://bucket/0/items") + + +class TestVllmServerManagerHttp(unittest.TestCase): + """Tests that the HTTP server is started only on rank-0.""" + + def _start_capturing_thread_calls(self, mgr, rank): + mock_llm_cls = mock.MagicMock() + mock_vllm = mock.MagicMock() + mock_vllm.LLM = mock_llm_cls + mock_thread_cls = mock.MagicMock(return_value=mock.MagicMock()) + + with mock.patch.dict("sys.modules", {"vllm": mock_vllm, "uvicorn": mock.MagicMock()}), \ + mock.patch("jax.process_index", return_value=rank), \ + mock.patch("threading.Thread", mock_thread_cls), \ + mock.patch("maxtext.eval.runner.server_manager._build_app", return_value=mock.MagicMock()), \ + mock.patch.object(mgr, "_wait_until_healthy"): + mgr.start() + + return mock_thread_cls + + def test_rank0_starts_http_server_thread(self): + mgr = _make_manager() + mock_thread_cls = self._start_capturing_thread_calls(mgr, rank=0) + mock_thread_cls.assert_called_once() + _, kwargs = mock_thread_cls.call_args + self.assertTrue(kwargs.get("daemon")) + + def test_non_rank0_does_not_start_http_server(self): + mgr = _make_manager() + mock_thread_cls = self._start_capturing_thread_calls(mgr, rank=1) + mock_thread_cls.assert_not_called() + + +class TestVllmServerManagerLifecycle(unittest.TestCase): + + def test_stop_signals_uvicorn_should_exit(self): + mgr = _make_manager() + mock_server = mock.MagicMock() + mock_thread = mock.MagicMock() + mock_thread.is_alive.return_value = False + mgr._uvicorn_server = mock_server + mgr._server_thread = mock_thread + with mock.patch("jax.process_index", return_value=0): + mgr.stop() + self.assertTrue(mock_server.should_exit) + + def test_stop_clears_references(self): + mgr = _make_manager() + mgr._llm = mock.MagicMock() + mgr._uvicorn_server = mock.MagicMock() + mgr._server_thread = mock.MagicMock() + mgr._server_thread.is_alive.return_value = False + with mock.patch("jax.process_index", return_value=0): + mgr.stop() + self.assertIsNone(mgr._llm) + self.assertIsNone(mgr._uvicorn_server) + self.assertIsNone(mgr._server_thread) + + def test_stop_is_noop_when_not_started(self): + mgr = _make_manager() + with mock.patch("jax.process_index", return_value=0): + mgr.stop() # should not raise + + def test_stop_called_on_context_exit(self): + mgr = _make_manager() + with mock.patch.object(mgr, "start"), mock.patch.object(mgr, "stop") as mock_stop: + with mgr: + pass + mock_stop.assert_called_once() + + def test_stop_called_on_exception_in_context(self): + mgr = _make_manager() + with mock.patch.object(mgr, "start"), mock.patch.object(mgr, "stop") as mock_stop: + try: + with mgr: + raise RuntimeError("boom") + except RuntimeError: + pass + mock_stop.assert_called_once() + + def test_base_url(self): + mgr = _make_manager(host="0.0.0.0", port=9000) + self.assertEqual(mgr.base_url, "http://0.0.0.0:9000") + + +if __name__ == "__main__": + unittest.main()