Skip to content

Commit 4e8731b

Browse files
committed
Fix multihost bugs and unify CLI entry point
1 parent 11a2e61 commit 4e8731b

9 files changed

Lines changed: 593 additions & 733 deletions

File tree

src/maxtext/eval/reporting/json_reporter.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import logging
2222
import os
23+
import tempfile
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -50,13 +51,9 @@ def write_results(
5051
- results: The full results dict written to disk.
5152
- local_path: Absolute path of the written file.
5253
"""
53-
os.makedirs(results_path, exist_ok=True)
54-
5554
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
56-
# Create filename.
5755
safe_model = model_name.replace("/", "_").replace(":", "_")
5856
filename = f"{benchmark}_{safe_model}_{timestamp}.json"
59-
local_path = os.path.join(results_path, filename)
6057

6158
results = {
6259
"benchmark": benchmark,
@@ -67,7 +64,20 @@ def write_results(
6764
"config": config,
6865
}
6966

70-
with open(local_path, "w") as f:
71-
json.dump(results, f, indent=2)
72-
logger.info("Results written to %s", local_path)
67+
if results_path.startswith("gs://"):
68+
from maxtext.utils.gcs_utils import upload_blob # pylint: disable=import-outside-toplevel
69+
tmp_dir = tempfile.mkdtemp(prefix="eval_results_")
70+
local_path = os.path.join(tmp_dir, filename)
71+
with open(local_path, "w") as f:
72+
json.dump(results, f, indent=2)
73+
gcs_dest = f"{results_path.rstrip('/')}/{filename}"
74+
upload_blob(gcs_dest, local_path)
75+
logger.info("Results written to %s", gcs_dest)
76+
else:
77+
os.makedirs(results_path, exist_ok=True)
78+
local_path = os.path.join(results_path, filename)
79+
with open(local_path, "w") as f:
80+
json.dump(results, f, indent=2)
81+
logger.info("Results written to %s", local_path)
82+
7383
return {"results": results, "local_path": os.path.abspath(local_path)}

src/maxtext/eval/runner/async_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,17 @@ async def _generate_one(session: aiohttp.ClientSession, prompt: str) -> Generati
8989
"max_tokens": max_tokens,
9090
"temperature": temperature,
9191
}
92-
t0 = time.monotonic()
9392
async with semaphore:
93+
t0 = time.monotonic()
9494
try:
9595
async with session.post(api_url, json=payload) as resp:
9696
if resp.status != 200:
9797
body = await resp.text()
9898
return GenerationResult(error=f"HTTP {resp.status}: {body[:200]}")
9999
data = await resp.json()
100-
except aiohttp.ClientError as exc:
100+
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
101101
return GenerationResult(error=str(exc))
102-
latency = time.monotonic() - t0
102+
latency = time.monotonic() - t0
103103

104104
choice = data["choices"][0]
105105
usage = data.get("usage", {})

src/maxtext/eval/runner/common.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shared helpers for MaxText eval runners."""
16+
17+
from __future__ import annotations
18+
19+
import argparse
20+
import os
21+
from typing import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from maxtext.eval.runner.server_manager import VllmServerManager
25+
26+
ENABLE_EXPERT_PARALLEL_HELP = (
27+
"Enable expert parallelism in vLLM. Required for MoE models such as "
28+
"qwen3-30b-a3b, qwen3-235b-a22b, deepseek-v3, etc. Without this flag "
29+
"tpu-inference omits the 'expert' mesh axis and MaxText's MoE sharding "
30+
"raises KeyError."
31+
)
32+
33+
34+
def resolve_token(cfg: dict, hf_token: str | None) -> str | None:
35+
"""Return HF token from explicit arg or HF_TOKEN env var."""
36+
return hf_token or os.environ.get("HF_TOKEN") or None
37+
38+
39+
def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
40+
"""Build a VllmServerManager from a merged config dict.
41+
42+
Handles token forwarding, MaxText adapter vs HF mode selection, and the
43+
enable_expert_parallel to additional_config sharding injection.
44+
45+
Args:
46+
cfg: Merged configuration dict. Required key: max_model_len. Common
47+
optional keys: tensor_parallel_size, server_host, server_port,
48+
max_num_batched_tokens, max_num_seqs, hf_mode, enable_expert_parallel.
49+
token: HuggingFace token (or None).
50+
51+
Returns:
52+
A VllmServerManager instance ready for use as a context manager (unstarted).
53+
"""
54+
from maxtext.eval.runner.server_manager import VllmServerManager # pylint: disable=import-outside-toplevel
55+
56+
hf_path = cfg["hf_path"]
57+
model_name = cfg["model_name"]
58+
checkpoint_path = cfg.get("checkpoint_path")
59+
hf_mode = cfg.get("hf_mode", False)
60+
use_maxtext_adapter = bool(checkpoint_path) and not hf_mode
61+
62+
tensor_parallel_size = int(cfg.get("tensor_parallel_size", 4))
63+
max_model_len = int(cfg["max_model_len"])
64+
server_host = cfg.get("server_host", "localhost")
65+
server_port = int(cfg.get("server_port", 8000))
66+
67+
max_num_batched_tokens = cfg.get("max_num_batched_tokens")
68+
if max_num_batched_tokens is not None:
69+
max_num_batched_tokens = int(max_num_batched_tokens)
70+
max_num_seqs = cfg.get("max_num_seqs")
71+
if max_num_seqs is not None:
72+
max_num_seqs = int(max_num_seqs)
73+
74+
server_env = {"HF_TOKEN": token} if token else None
75+
additional_vllm_kwargs: dict = {}
76+
if cfg.get("enable_expert_parallel"):
77+
additional_vllm_kwargs["enable_expert_parallel"] = True
78+
79+
return VllmServerManager(
80+
model_path=hf_path,
81+
checkpoint_path=checkpoint_path if use_maxtext_adapter else None,
82+
maxtext_model_name=model_name if use_maxtext_adapter else None,
83+
host=server_host,
84+
port=server_port,
85+
tensor_parallel_size=tensor_parallel_size,
86+
max_model_len=max_model_len,
87+
max_num_batched_tokens=max_num_batched_tokens,
88+
max_num_seqs=max_num_seqs,
89+
env=server_env,
90+
additional_vllm_kwargs=additional_vllm_kwargs or None,
91+
)
92+
93+
94+
def maybe_upload_to_gcs(output: dict, gcs_results_path: str | None) -> None:
95+
"""Upload the results JSON to GCS if gcs_results_path is provided."""
96+
if gcs_results_path:
97+
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
98+
upload_results(output["local_path"], gcs_results_path)
99+
100+
101+
def add_server_args(parser: argparse.ArgumentParser) -> None:
102+
"""Add the server/model CLI args shared by all eval runner parsers."""
103+
parser.add_argument("--checkpoint_path", help="MaxText orbax checkpoint path (/0/items).")
104+
parser.add_argument("--model_name", required=True, help="MaxText model name (e.g. llama3.1-8b).")
105+
parser.add_argument("--hf_path", required=True, help="HF model ID or local tokenizer dir.")
106+
parser.add_argument(
107+
"--base_output_directory",
108+
required=True,
109+
help="Base output directory (local path or gs://<bucket>/).",
110+
)
111+
parser.add_argument("--run_name", required=True, help="Run name/identifier.")
112+
parser.add_argument("--max_model_len", type=int, required=True, help="vLLM max context length.")
113+
parser.add_argument(
114+
"--tensor_parallel_size", type=int, default=4, help="vLLM tensor parallelism."
115+
)
116+
parser.add_argument("--server_host", default="localhost", help="vLLM server bind host.")
117+
parser.add_argument("--server_port", type=int, default=8000, help="vLLM server port.")
118+
parser.add_argument(
119+
"--max_num_batched_tokens", type=int, help="vLLM tokens per scheduler step."
120+
)
121+
parser.add_argument("--max_num_seqs", type=int, help="vLLM max concurrent sequences.")
122+
parser.add_argument("--hf_mode", action="store_true", help="HF safetensors mode.")
123+
parser.add_argument(
124+
"--enable_expert_parallel", action="store_true", help=ENABLE_EXPERT_PARALLEL_HELP
125+
)
126+
parser.add_argument("--hf_token", help="HuggingFace token for gated models.")
127+
parser.add_argument(
128+
"--gcs_results_path", help="Optional secondary GCS path to upload the results JSON."
129+
)
130+
parser.add_argument(
131+
"--log_level",
132+
default="INFO",
133+
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
134+
help="Logging level.",
135+
)

src/maxtext/eval/runner/eval_runner.py

Lines changed: 43 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""CLI entry point for model evaluation.
16-
17-
MaxTextForCausalLM mode (preferred):
18-
Load weights directly from the MaxText checkpoint, no HuggingFace weight
19-
conversion required. Flag --hf_path supplies the tokenizer (HF model ID
20-
or local tokenizer dir).
21-
22-
python -m maxtext.eval.runner.eval_runner \
23-
--config src/maxtext/eval/configs/mlperf.yml \
24-
--base_config src/maxtext/configs/base.yml \
25-
--base_output_directory gs://<gcs_bucket>/ \
26-
--run_name my_run \
27-
--checkpoint_path gs://<gcs_bucket>/checkpoint/0/items \
28-
--model_name llama3.1-8b \
29-
--hf_path meta-llama/Llama-3.1-8B-Instruct
30-
31-
HuggingFace safetensors mode:
32-
Use --hf_mode and point --hf_path to an existing HF model directory.
33-
34-
python -m maxtext.eval.runner.eval_runner \
35-
--config src/maxtext/eval/configs/mlperf.yml \
36-
--hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
37-
--model_name tinyllama \
38-
--hf_mode \
39-
--base_output_directory /tmp/eval/ \
40-
--run_name smoke_test \
41-
--tensor_parallel_size 1
15+
"""Custom dataset eval runner (MLPerf OpenOrca, ROUGE scoring).
16+
17+
Unified entry point:
18+
19+
python -m maxtext.eval.runner.run --runner eval ...
4220
"""
4321

4422
from __future__ import annotations
4523

4624
import argparse
4725
import logging
48-
import os
4926
import time
5027

5128
import yaml
@@ -116,7 +93,7 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
11693
from maxtext.eval.datasets.registry import get_dataset
11794
from maxtext.eval.reporting.json_reporter import write_results
11895
from maxtext.eval.runner.async_client import generate_batch
119-
from maxtext.eval.runner.server_manager import VllmServerManager
96+
from maxtext.eval.runner.common import build_server_manager, maybe_upload_to_gcs, resolve_token
12097
from maxtext.eval.runner.warmup import warmup_server
12198
from maxtext.eval.scoring.registry import get_scorer
12299

@@ -128,27 +105,10 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
128105
max_tokens = int(cfg.get("max_tokens", 1024))
129106
temperature = float(cfg.get("temperature", 0.0))
130107
concurrency = int(cfg.get("concurrency", 64))
131-
tensor_parallel_size = int(cfg.get("tensor_parallel_size", 4))
132108
if "max_model_len" not in cfg:
133-
raise ValueError(
134-
"Error: max_model_len is required."
135-
)
136-
max_model_len = int(cfg["max_model_len"])
137-
server_host = cfg.get("server_host", "localhost")
138-
server_port = int(cfg.get("server_port", 8000))
139-
max_num_batched_tokens = cfg.get("max_num_batched_tokens")
140-
if max_num_batched_tokens is not None:
141-
max_num_batched_tokens = int(max_num_batched_tokens)
142-
max_num_seqs = cfg.get("max_num_seqs")
143-
if max_num_seqs is not None:
144-
max_num_seqs = int(max_num_seqs)
109+
raise ValueError("Error: max_model_len is required.")
145110
gcs_results_path = cfg.get("gcs_results_path")
146-
token = hf_token or os.environ.get("HF_TOKEN") or None
147-
checkpoint_path = cfg.get("checkpoint_path")
148-
hf_mode = cfg.get("hf_mode", False)
149-
150-
# Determine loading mode.
151-
use_maxtext_adapter = bool(checkpoint_path) and not hf_mode
111+
token = resolve_token(cfg, hf_token)
152112

153113
# Load tokenizer for prompt formatting.
154114
logger.info("Loading tokenizer from %s.", hf_path)
@@ -164,42 +124,40 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
164124
references = [r.reference for r in requests]
165125

166126
# Start vLLM server.
167-
server_env = {"HF_TOKEN": token} if token else None
168-
additional_vllm_kwargs = {}
169-
if cfg.get("enable_expert_parallel"):
170-
additional_vllm_kwargs["enable_expert_parallel"] = True
171-
172-
with VllmServerManager(
173-
model_path=hf_path,
174-
checkpoint_path=checkpoint_path if use_maxtext_adapter else None,
175-
maxtext_model_name=model_name if use_maxtext_adapter else None,
176-
host=server_host,
177-
port=server_port,
178-
tensor_parallel_size=tensor_parallel_size,
179-
max_model_len=max_model_len,
180-
max_num_batched_tokens=max_num_batched_tokens,
181-
max_num_seqs=max_num_seqs,
182-
env=server_env,
183-
additional_vllm_kwargs=additional_vllm_kwargs or None,
184-
) as server:
185-
base_url = server.base_url
186-
187-
# Warmup server.
188-
warmup_server(base_url=base_url, model=model_name, sample_requests=requests)
189-
190-
# Generate responses.
191-
logger.info("Generating responses for %d prompts.", len(prompts))
192-
t0 = time.time()
193-
results = generate_batch(
194-
prompts=prompts,
195-
base_url=base_url,
196-
model=model_name,
197-
max_tokens=max_tokens,
198-
temperature=temperature,
199-
concurrency=concurrency,
200-
)
201-
elapsed = time.time() - t0
202-
logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed)
127+
with build_server_manager(cfg, token) as server:
128+
import jax as _jax # pylint: disable=import-outside-toplevel
129+
from jax.experimental import multihost_utils as _multihost_utils # pylint: disable=import-outside-toplevel
130+
is_rank0 = _jax.process_index() == 0
131+
132+
if is_rank0:
133+
base_url = server.base_url
134+
135+
# Warmup server.
136+
warmup_server(base_url=base_url, model=model_name, sample_requests=requests)
137+
138+
# Generate responses.
139+
logger.info("Generating responses for %d prompts.", len(prompts))
140+
t0 = time.time()
141+
results = generate_batch(
142+
prompts=prompts,
143+
base_url=base_url,
144+
model=model_name,
145+
max_tokens=max_tokens,
146+
temperature=temperature,
147+
concurrency=concurrency,
148+
)
149+
elapsed = time.time() - t0
150+
logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed)
151+
152+
# All ranks block here until rank-0 finishes generation. Non-rank-0 hosts
153+
# keep their in-process LLM alive so rank-0's llm.generate() calls can
154+
# complete their tensor-parallel collectives across all hosts.
155+
_multihost_utils.sync_global_devices("eval_runner_complete")
156+
157+
# All ranks exit the context manager together above (LLM stopped on all).
158+
# Only rank-0 has results/elapsed defined, non-rank-0 return early.
159+
if not is_rank0:
160+
return {}
203161

204162
# Score.
205163
responses = [r.text for r in results]
@@ -229,11 +187,7 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
229187
results_path=results_path,
230188
)
231189

232-
# Optional GCS Upload.
233-
if gcs_results_path:
234-
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
235-
upload_results(output["local_path"], gcs_results_path)
236-
190+
maybe_upload_to_gcs(output, gcs_results_path)
237191
return output
238192

239193

0 commit comments

Comments
 (0)