2323import uuid
2424from typing import Any
2525
26- import jax
2726import requests
2827
2928logger = logging .getLogger (__name__ )
@@ -241,6 +240,11 @@ def start(self) -> None:
241240 # pylint: disable=import-outside-toplevel
242241 from vllm import LLM
243242
243+ # Disable V1 multiprocessing so EngineCore runs in-process instead.
244+ # V1 engine architecture is otherwise preserved (tpu-inference plugin works),
245+ # and JAX/TPU is initialised exactly once inside LLM() in this process.
246+ os .environ .setdefault ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
247+
244248 if self .env :
245249 os .environ .update (self .env )
246250
@@ -249,7 +253,6 @@ def start(self) -> None:
249253 "tensor_parallel_size" : self .tensor_parallel_size ,
250254 "max_model_len" : self .max_model_len ,
251255 "dtype" : self .dtype ,
252- "device" : "tpu" ,
253256 }
254257 if self .max_num_batched_tokens is not None :
255258 vllm_kwargs ["max_num_batched_tokens" ] = self .max_num_batched_tokens
@@ -269,15 +272,16 @@ def start(self) -> None:
269272 vllm_kwargs ["load_format" ] = "auto"
270273
271274 logger .info (
272- "Rank %d: initialising in-process vLLM (tp=%d, max_len=%d)..." ,
273- jax .process_index (),
275+ "Initializing in-process vLLM (tp=%d, max_len=%d)..." ,
274276 self .tensor_parallel_size ,
275277 self .max_model_len ,
276278 )
277279 self ._llm = LLM (** vllm_kwargs )
278- logger .info ("Rank %d: vLLM LLM ready." , jax .process_index ())
279280
280- if jax .process_index () == 0 :
281+ import jax as _jax # pylint: disable=import-outside-toplevel
282+ logger .info ("Rank %d: vLLM LLM ready." , _jax .process_index ())
283+
284+ if _jax .process_index () == 0 :
281285 import uvicorn # pylint: disable=import-outside-toplevel
282286
283287 app = _build_app (self ._llm )
@@ -317,7 +321,7 @@ def _wait_until_healthy(self) -> None:
317321
318322 def stop (self ) -> None :
319323 """Stop the HTTP server and release the LLM."""
320- if jax . process_index () == 0 and self ._uvicorn_server is not None :
324+ if self ._uvicorn_server is not None :
321325 logger .info ("Stopping vLLM HTTP server..." )
322326 self ._uvicorn_server .should_exit = True
323327 if self ._server_thread is not None :
@@ -327,7 +331,7 @@ def stop(self) -> None:
327331 self ._llm = None
328332 self ._uvicorn_server = None
329333 self ._server_thread = None
330- logger .info ("Rank %d: VllmServerManager stopped." , jax . process_index () )
334+ logger .info ("VllmServerManager stopped." )
331335
332336 def __enter__ (self ) -> "VllmServerManager" :
333337 self .start ()
0 commit comments