Skip to content

Commit b6b37a2

Browse files
committed
Add DP axis, test multihost, fix tests
1 parent 6ca20c4 commit b6b37a2

9 files changed

Lines changed: 108 additions & 168 deletions

File tree

src/maxtext/eval/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ python -m maxtext.eval.runner.run \
8686
| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) |
8787
| `--hf_path` | HF model ID or local path |
8888
| `--max_model_len` | vLLM max context length. |
89-
| `--tensor_parallel_size` | Total number of chips |
90-
| `--expert_parallel_size` | Number of EP chips |
89+
| `--tensor_parallel_size` | Chips per model replica |
90+
| `--expert_parallel_size` | Chips for the expert mesh axis |
91+
| `--data_parallel_size` | Number of model replicas |
9192
| `--hbm_memory_utilization` | Fraction of HBM reserved for KV cache |
9293
| `--hf_token` | HF token (or set `HF_TOKEN` env var) |
9394
| `--hf_mode` | HF safetensors mode, no MaxText checkpoint loading |

src/maxtext/eval/runner/async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_DEFAULT_MAX_TOKENS = 1024
3232
_DEFAULT_TEMPERATURE = 0.0
3333
_COMPLETIONS_PATH = "/v1/completions"
34-
_REQUEST_TIMEOUT_S = 600 # (TODO): Check if this is reasoanable.
34+
_REQUEST_TIMEOUT_S = 600
3535

3636

3737
@dataclass

src/maxtext/eval/runner/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
6565
max_num_seqs = int(max_num_seqs)
6666

6767
expert_parallel_size = int(cfg.get("expert_parallel_size") or 1)
68+
data_parallel_size = int(cfg.get("data_parallel_size") or 1)
6869
hbm_memory_utilization = float(cfg.get("hbm_memory_utilization") or 0.3)
6970

7071
server_env = {"HF_TOKEN": token} if token else None
@@ -77,6 +78,7 @@ def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
7778
port=server_port,
7879
tensor_parallel_size=tensor_parallel_size,
7980
expert_parallel_size=expert_parallel_size,
81+
data_parallel_size=data_parallel_size,
8082
max_model_len=max_model_len,
8183
max_num_batched_tokens=max_num_batched_tokens,
8284
max_num_seqs=max_num_seqs,
@@ -122,6 +124,12 @@ def add_server_args(parser: argparse.ArgumentParser) -> None:
122124
"Chips allocated to the expert mesh axis (EP). "
123125
),
124126
)
127+
parser.add_argument(
128+
"--data_parallel_size",
129+
type=int,
130+
default=1,
131+
help="Number of model replicas (DP).",
132+
)
125133
parser.add_argument(
126134
"--hbm_memory_utilization",
127135
type=float,

src/maxtext/eval/runner/harness_runner.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,13 @@
4242

4343
logger = logging.getLogger(__name__)
4444

45-
_TASK_MAP_LM_EVAL: dict[str, str] = {
46-
"mmlu": "mmlu",
47-
"gpqa": "gpqa_diamond",
48-
"math": "hendrycks_math",
49-
"gsm8k": "gsm8k",
50-
}
51-
52-
_TASK_MAP_EVALCHEMY: dict[str, str] = {
53-
"ifeval": "ifeval",
54-
"alpacaeval": "alpaca_eval_v2",
55-
"arena_hard": "arena_hard",
56-
"mtbench": "mt_bench",
57-
"wildbench": "wildbench",
58-
"mixeval": "mixeval",
59-
"zeroeval": "zeroeval",
60-
"math500": "math_500",
61-
"aime24": "aime2024",
62-
"aime25": "aime2025",
63-
"amc23": "amc2023",
64-
"gpqa_diamond": "gpqa_diamond",
65-
"humaneval": "humaneval",
66-
"livecodebench": "livecodebench",
67-
"gsm8k": "gsm8k",
68-
}
69-
70-
71-
def _map_results(raw_results: dict, tasks: list[str], task_map: dict[str, str]) -> dict:
45+
46+
def _map_results(raw_results: dict, tasks: list[str]) -> dict:
7247
"""Extract per-task accuracy metrics from lm-eval / evalchemy output."""
7348
scores: dict[str, float] = {}
7449
results_section = raw_results.get("results", {})
7550
for task in tasks:
76-
lm_task = task_map.get(task, task)
77-
task_r = results_section.get(lm_task, {})
51+
task_r = results_section.get(task, {})
7852

7953
acc = None
8054
for key in (
@@ -125,7 +99,6 @@ def run_harness(cfg: dict, hf_token: str | None = None) -> dict:
12599
126100
Raises:
127101
ImportError: If lm_eval (or evalchemy for that backend) is not installed.
128-
ValueError: If a requested task name is not in the backend's task map.
129102
"""
130103
# pylint: disable=import-outside-toplevel
131104
try:
@@ -154,19 +127,8 @@ def run_harness(cfg: dict, hf_token: str | None = None) -> dict:
154127
gcs_results_path = cfg.get("gcs_results_path")
155128
token = resolve_token(cfg, hf_token)
156129

157-
task_map = _TASK_MAP_EVALCHEMY if backend == "evalchemy" else _TASK_MAP_LM_EVAL
158130
lm_model_type = "local-chat-completions" if backend == "evalchemy" else "local-completions"
159131

160-
lm_tasks: list[str] = []
161-
for t in tasks:
162-
lm_task = task_map.get(t)
163-
if lm_task is None:
164-
raise ValueError(
165-
f"No {backend} task mapping for '{t}'. "
166-
f"Known tasks: {list(task_map.keys())}"
167-
)
168-
lm_tasks.append(lm_task)
169-
170132
with build_server_manager(cfg, token) as server:
171133
import jax as _jax
172134
from jax.experimental import multihost_utils as _multihost_utils
@@ -191,14 +153,14 @@ def run_harness(cfg: dict, hf_token: str | None = None) -> dict:
191153
logger.info(
192154
"Running %s tasks %s via %s at %s",
193155
backend,
194-
lm_tasks,
156+
tasks,
195157
lm_model_type,
196158
server.base_url,
197159
)
198160
raw_results = lm_eval_lib.simple_evaluate(
199161
model=lm_model_type,
200162
model_args=model_args,
201-
tasks=lm_tasks,
163+
tasks=tasks,
202164
num_fewshot=num_fewshot,
203165
limit=num_samples,
204166
log_samples=False,
@@ -214,7 +176,7 @@ def run_harness(cfg: dict, hf_token: str | None = None) -> dict:
214176
if not is_rank0:
215177
return {}
216178

217-
scores = _map_results(raw_results, tasks, task_map)
179+
scores = _map_results(raw_results, tasks)
218180
logger.info("%s scores: %s", backend, scores)
219181

220182
output = write_results(
@@ -253,9 +215,8 @@ def _build_arg_parser() -> argparse.ArgumentParser:
253215
nargs="+",
254216
default=["mmlu"],
255217
help=(
256-
"Benchmark task names. "
257-
"lm_eval choices: " + ", ".join(_TASK_MAP_LM_EVAL) + ". "
258-
"evalchemy choices: " + ", ".join(_TASK_MAP_EVALCHEMY) + "."
218+
"lm-eval task names passed directly to simple_evaluate. "
219+
"Any task registered in lm-eval or evalchemy is accepted (e.g. gsm8k, mmlu, gpqa_diamond, ifeval, math_500)."
259220
),
260221
)
261222
parser.add_argument(

src/maxtext/eval/runner/server_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
port: int = 8000,
211211
tensor_parallel_size: int = 4,
212212
expert_parallel_size: int = 1,
213+
data_parallel_size: int = 1,
213214
max_model_len: int = 4096,
214215
dtype: str = "bfloat16",
215216
max_num_batched_tokens: int | None = None,
@@ -233,6 +234,7 @@ def __init__(
233234
self.port = port
234235
self.tensor_parallel_size = tensor_parallel_size
235236
self.expert_parallel_size = expert_parallel_size
237+
self.data_parallel_size = data_parallel_size
236238
self.max_model_len = max_model_len
237239
self.dtype = dtype
238240
self.max_num_batched_tokens = max_num_batched_tokens
@@ -272,6 +274,7 @@ def start(self) -> None:
272274
vllm_kwargs: dict = {
273275
"model": self.model_path,
274276
"tensor_parallel_size": ici_tp,
277+
"data_parallel_size": self.data_parallel_size,
275278
"max_model_len": self.max_model_len,
276279
"dtype": self.dtype,
277280
"gpu_memory_utilization": self.hbm_memory_utilization,
@@ -318,9 +321,10 @@ def start(self) -> None:
318321
vllm_kwargs[_k] = _v
319322

320323
logger.info(
321-
"Initializing in-process vLLM (tp=%d, ep=%d, max_len=%d)...",
324+
"Initializing in-process vLLM (tp=%d, ep=%d, dp=%d, max_len=%d)...",
322325
ici_tp,
323326
ici_ep,
327+
self.data_parallel_size,
324328
self.max_model_len,
325329
)
326330
self._llm = LLM(**vllm_kwargs)

src/maxtext/eval/scoring/rouge_scorer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818

1919
import numpy as np
2020

21+
import nltk # pylint: disable=import-outside-toplevel
22+
23+
nltk.download("punkt", quiet=True)
24+
nltk.download("punkt_tab", quiet=True)
25+
2126

2227
def score_batch(
2328
responses: list[str],
@@ -43,10 +48,7 @@ def score_batch(
4348
)
4449

4550
import evaluate # pylint: disable=import-outside-toplevel
46-
import nltk # pylint: disable=import-outside-toplevel
4751

48-
nltk.download("punkt", quiet=True)
49-
nltk.download("punkt_tab", quiet=True)
5052
metric = evaluate.load("rouge")
5153

5254
preds = []

0 commit comments

Comments
 (0)