1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """CLI entry point for model evaluation.
16-
17- MaxTextForCausalLM mode (preferred):
18- Load weights directly from the MaxText checkpoint, no HuggingFace weight
19- conversion required. Flag --hf_path supplies the tokenizer (HF model ID
20- or local tokenizer dir).
21-
22- python -m maxtext.eval.runner.eval_runner \
23- --config src/maxtext/eval/configs/mlperf.yml \
24- --base_config src/maxtext/configs/base.yml \
25- --base_output_directory gs://<gcs_bucket>/ \
26- --run_name my_run \
27- --checkpoint_path gs://<gcs_bucket>/checkpoint/0/items \
28- --model_name llama3.1-8b \
29- --hf_path meta-llama/Llama-3.1-8B-Instruct
30-
31- HuggingFace safetensors mode:
32- Use --hf_mode and point --hf_path to an existing HF model directory.
33-
34- python -m maxtext.eval.runner.eval_runner \
35- --config src/maxtext/eval/configs/mlperf.yml \
36- --hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
37- --model_name tinyllama \
38- --hf_mode \
39- --base_output_directory /tmp/eval/ \
40- --run_name smoke_test \
41- --tensor_parallel_size 1
15+ """Custom dataset eval runner (MLPerf OpenOrca, ROUGE scoring).
16+
17+ Unified entry point:
18+
19+ python -m maxtext.eval.runner.run --runner eval ...
4220"""
4321
4422from __future__ import annotations
4523
4624import argparse
4725import logging
48- import os
4926import time
5027
5128import yaml
@@ -116,7 +93,7 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
11693 from maxtext .eval .datasets .registry import get_dataset
11794 from maxtext .eval .reporting .json_reporter import write_results
11895 from maxtext .eval .runner .async_client import generate_batch
119- from maxtext .eval .runner .server_manager import VllmServerManager
96+ from maxtext .eval .runner .common import build_server_manager , maybe_upload_to_gcs , resolve_token
12097 from maxtext .eval .runner .warmup import warmup_server
12198 from maxtext .eval .scoring .registry import get_scorer
12299
@@ -128,27 +105,10 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
128105 max_tokens = int (cfg .get ("max_tokens" , 1024 ))
129106 temperature = float (cfg .get ("temperature" , 0.0 ))
130107 concurrency = int (cfg .get ("concurrency" , 64 ))
131- tensor_parallel_size = int (cfg .get ("tensor_parallel_size" , 4 ))
132108 if "max_model_len" not in cfg :
133- raise ValueError (
134- "Error: max_model_len is required."
135- )
136- max_model_len = int (cfg ["max_model_len" ])
137- server_host = cfg .get ("server_host" , "localhost" )
138- server_port = int (cfg .get ("server_port" , 8000 ))
139- max_num_batched_tokens = cfg .get ("max_num_batched_tokens" )
140- if max_num_batched_tokens is not None :
141- max_num_batched_tokens = int (max_num_batched_tokens )
142- max_num_seqs = cfg .get ("max_num_seqs" )
143- if max_num_seqs is not None :
144- max_num_seqs = int (max_num_seqs )
109+ raise ValueError ("Error: max_model_len is required." )
145110 gcs_results_path = cfg .get ("gcs_results_path" )
146- token = hf_token or os .environ .get ("HF_TOKEN" ) or None
147- checkpoint_path = cfg .get ("checkpoint_path" )
148- hf_mode = cfg .get ("hf_mode" , False )
149-
150- # Determine loading mode.
151- use_maxtext_adapter = bool (checkpoint_path ) and not hf_mode
111+ token = resolve_token (cfg , hf_token )
152112
153113 # Load tokenizer for prompt formatting.
154114 logger .info ("Loading tokenizer from %s." , hf_path )
@@ -164,42 +124,40 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
164124 references = [r .reference for r in requests ]
165125
166126 # Start vLLM server.
167- server_env = {"HF_TOKEN" : token } if token else None
168- additional_vllm_kwargs = {}
169- if cfg .get ("enable_expert_parallel" ):
170- additional_vllm_kwargs ["enable_expert_parallel" ] = True
171-
172- with VllmServerManager (
173- model_path = hf_path ,
174- checkpoint_path = checkpoint_path if use_maxtext_adapter else None ,
175- maxtext_model_name = model_name if use_maxtext_adapter else None ,
176- host = server_host ,
177- port = server_port ,
178- tensor_parallel_size = tensor_parallel_size ,
179- max_model_len = max_model_len ,
180- max_num_batched_tokens = max_num_batched_tokens ,
181- max_num_seqs = max_num_seqs ,
182- env = server_env ,
183- additional_vllm_kwargs = additional_vllm_kwargs or None ,
184- ) as server :
185- base_url = server .base_url
186-
187- # Warmup server.
188- warmup_server (base_url = base_url , model = model_name , sample_requests = requests )
189-
190- # Generate responses.
191- logger .info ("Generating responses for %d prompts." , len (prompts ))
192- t0 = time .time ()
193- results = generate_batch (
194- prompts = prompts ,
195- base_url = base_url ,
196- model = model_name ,
197- max_tokens = max_tokens ,
198- temperature = temperature ,
199- concurrency = concurrency ,
200- )
201- elapsed = time .time () - t0
202- logger .info ("Generation completed in %.1fs (%.1f samples/s)." , elapsed , len (prompts ) / elapsed )
127+ with build_server_manager (cfg , token ) as server :
128+ import jax as _jax # pylint: disable=import-outside-toplevel
129+ from jax .experimental import multihost_utils as _multihost_utils # pylint: disable=import-outside-toplevel
130+ is_rank0 = _jax .process_index () == 0
131+
132+ if is_rank0 :
133+ base_url = server .base_url
134+
135+ # Warmup server.
136+ warmup_server (base_url = base_url , model = model_name , sample_requests = requests )
137+
138+ # Generate responses.
139+ logger .info ("Generating responses for %d prompts." , len (prompts ))
140+ t0 = time .time ()
141+ results = generate_batch (
142+ prompts = prompts ,
143+ base_url = base_url ,
144+ model = model_name ,
145+ max_tokens = max_tokens ,
146+ temperature = temperature ,
147+ concurrency = concurrency ,
148+ )
149+ elapsed = time .time () - t0
150+ logger .info ("Generation completed in %.1fs (%.1f samples/s)." , elapsed , len (prompts ) / elapsed )
151+
152+ # All ranks block here until rank-0 finishes generation. Non-rank-0 hosts
153+ # keep their in-process LLM alive so rank-0's llm.generate() calls can
154+ # complete their tensor-parallel collectives across all hosts.
155+ _multihost_utils .sync_global_devices ("eval_runner_complete" )
156+
157+ # All ranks exit the context manager together above (LLM stopped on all).
158+ # Only rank-0 has results/elapsed defined, non-rank-0 return early.
159+ if not is_rank0 :
160+ return {}
203161
204162 # Score.
205163 responses = [r .text for r in results ]
@@ -229,11 +187,7 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
229187 results_path = results_path ,
230188 )
231189
232- # Optional GCS Upload.
233- if gcs_results_path :
234- from maxtext .eval .reporting .gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
235- upload_results (output ["local_path" ], gcs_results_path )
236-
190+ maybe_upload_to_gcs (output , gcs_results_path )
237191 return output
238192
239193
0 commit comments