Skip to content

Commit 52e36b6

Browse files
committed
Fix a few bugs
1 parent 6352537 commit 52e36b6

2 files changed

Lines changed: 13 additions & 8 deletions

File tree

src/maxtext/eval/runner/lm_eval_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"mmlu": "mmlu", # loglikelihood, 14042 questions
4848
"gpqa": "gpqa_diamond", # loglikelihood, 198 questions
4949
"math": "hendrycks_math", # generation, 12500 problems (5 subjects)
50+
"gsm8k": "gsm8k", # generation, 8500 grade-school math problems
5051
}
5152

5253

src/maxtext/eval/runner/server_manager.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import uuid
2424
from typing import Any
2525

26-
import jax
2726
import requests
2827

2928
logger = logging.getLogger(__name__)
@@ -241,6 +240,11 @@ def start(self) -> None:
241240
# pylint: disable=import-outside-toplevel
242241
from vllm import LLM
243242

243+
# Disable V1 multiprocessing so EngineCore runs in-process instead.
244+
# V1 engine architecture is otherwise preserved (tpu-inference plugin works),
245+
# and JAX/TPU is initialised exactly once inside LLM() in this process.
246+
os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
247+
244248
if self.env:
245249
os.environ.update(self.env)
246250

@@ -249,7 +253,6 @@ def start(self) -> None:
249253
"tensor_parallel_size": self.tensor_parallel_size,
250254
"max_model_len": self.max_model_len,
251255
"dtype": self.dtype,
252-
"device": "tpu",
253256
}
254257
if self.max_num_batched_tokens is not None:
255258
vllm_kwargs["max_num_batched_tokens"] = self.max_num_batched_tokens
@@ -269,15 +272,16 @@ def start(self) -> None:
269272
vllm_kwargs["load_format"] = "auto"
270273

271274
logger.info(
272-
"Rank %d: initialising in-process vLLM (tp=%d, max_len=%d)...",
273-
jax.process_index(),
275+
"Initializing in-process vLLM (tp=%d, max_len=%d)...",
274276
self.tensor_parallel_size,
275277
self.max_model_len,
276278
)
277279
self._llm = LLM(**vllm_kwargs)
278-
logger.info("Rank %d: vLLM LLM ready.", jax.process_index())
279280

280-
if jax.process_index() == 0:
281+
import jax as _jax # pylint: disable=import-outside-toplevel
282+
logger.info("Rank %d: vLLM LLM ready.", _jax.process_index())
283+
284+
if _jax.process_index() == 0:
281285
import uvicorn # pylint: disable=import-outside-toplevel
282286

283287
app = _build_app(self._llm)
@@ -317,7 +321,7 @@ def _wait_until_healthy(self) -> None:
317321

318322
def stop(self) -> None:
319323
"""Stop the HTTP server and release the LLM."""
320-
if jax.process_index() == 0 and self._uvicorn_server is not None:
324+
if self._uvicorn_server is not None:
321325
logger.info("Stopping vLLM HTTP server...")
322326
self._uvicorn_server.should_exit = True
323327
if self._server_thread is not None:
@@ -327,7 +331,7 @@ def stop(self) -> None:
327331
self._llm = None
328332
self._uvicorn_server = None
329333
self._server_thread = None
330-
logger.info("Rank %d: VllmServerManager stopped.", jax.process_index())
334+
logger.info("VllmServerManager stopped.")
331335

332336
def __enter__(self) -> "VllmServerManager":
333337
self.start()

0 commit comments

Comments
 (0)