Skip to content

Commit 11a2e61

Browse files
committed
Add vllm arg to enable expert parallelism for MoE models
1 parent 52e36b6 commit 11a2e61

4 files changed

Lines changed: 51 additions & 2 deletions

File tree

src/maxtext/eval/runner/eval_runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
165165

166166
# Start vLLM server.
167167
server_env = {"HF_TOKEN": token} if token else None
168+
additional_vllm_kwargs = {}
169+
if cfg.get("enable_expert_parallel"):
170+
additional_vllm_kwargs["enable_expert_parallel"] = True
171+
168172
with VllmServerManager(
169173
model_path=hf_path,
170174
checkpoint_path=checkpoint_path if use_maxtext_adapter else None,
@@ -176,6 +180,7 @@ def run_eval(cfg: dict, hf_token: str | None = None) -> dict:
176180
max_num_batched_tokens=max_num_batched_tokens,
177181
max_num_seqs=max_num_seqs,
178182
env=server_env,
183+
additional_vllm_kwargs=additional_vllm_kwargs or None,
179184
) as server:
180185
base_url = server.base_url
181186

@@ -255,6 +260,16 @@ def _build_arg_parser() -> argparse.ArgumentParser:
255260
parser.add_argument("--server_host", help="vLLM server host.")
256261
parser.add_argument("--server_port", type=int, help="vLLM server port.")
257262
parser.add_argument("--hf_mode", action="store_true", help="Use HF safetensors mode.")
263+
parser.add_argument(
264+
"--enable_expert_parallel",
265+
action="store_true",
266+
help=(
267+
"Enable expert parallelism in vLLM. Required for MoE models such as "
268+
"qwen3-30b-a3b, qwen3-235b-a22b, deepseek-v3, etc. Without this flag "
269+
"tpu-inference omits the 'expert' mesh axis and MaxText's MoE sharding "
270+
"raises KeyError."
271+
),
272+
)
258273
parser.add_argument("--hf_token", help="HuggingFace token for gated models.")
259274
parser.add_argument(
260275
"--log_level",

src/maxtext/eval/runner/evalchemy_runner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
--tensor_parallel_size 4 \\
3434
--hf_token $HF_TOKEN
3535
36-
Requires: pip install evalchemy
36+
Requires: pip install git+https://github.com/mlfoundations/evalchemy.git
3737
"""
3838

3939
from __future__ import annotations
@@ -200,6 +200,9 @@ def run_evalchemy(cfg: dict, hf_token: str | None = None) -> dict:
200200
lm_eval_tasks.append(lm_eval_task)
201201

202202
server_env = {"HF_TOKEN": token} if token else None
203+
additional_vllm_kwargs = {}
204+
if cfg.get("enable_expert_parallel"):
205+
additional_vllm_kwargs["enable_expert_parallel"] = True
203206

204207
with VllmServerManager(
205208
model_path=hf_path,
@@ -212,6 +215,7 @@ def run_evalchemy(cfg: dict, hf_token: str | None = None) -> dict:
212215
max_num_batched_tokens=max_num_batched_tokens,
213216
max_num_seqs=max_num_seqs,
214217
env=server_env,
218+
additional_vllm_kwargs=additional_vllm_kwargs or None,
215219
) as server:
216220
warmup_server(base_url=server.base_url, model=model_name)
217221

@@ -329,6 +333,16 @@ def _build_arg_parser() -> argparse.ArgumentParser:
329333
action="store_true",
330334
help="HF safetensors mode.",
331335
)
336+
parser.add_argument(
337+
"--enable_expert_parallel",
338+
action="store_true",
339+
help=(
340+
"Enable expert parallelism in vLLM. Required for MoE models such as "
341+
"qwen3-30b-a3b, qwen3-235b-a22b, deepseek-v3, etc. Without this flag "
342+
"tpu-inference omits the 'expert' mesh axis and MaxText's MoE sharding "
343+
"raises KeyError."
344+
),
345+
)
332346
parser.add_argument(
333347
"--hf_token",
334348
help="HuggingFace token for gated tokenizers.",

src/maxtext/eval/runner/lm_eval_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def run_lm_eval(cfg: dict, hf_token: str | None = None) -> dict:
142142
lm_tasks.append(lm_task)
143143

144144
server_env = {"HF_TOKEN": token} if token else None
145+
additional_vllm_kwargs = {}
146+
if cfg.get("enable_expert_parallel"):
147+
additional_vllm_kwargs["enable_expert_parallel"] = True
145148

146149
with VllmServerManager(
147150
model_path=hf_path,
@@ -154,6 +157,7 @@ def run_lm_eval(cfg: dict, hf_token: str | None = None) -> dict:
154157
max_num_batched_tokens=max_num_batched_tokens,
155158
max_num_seqs=max_num_seqs,
156159
env=server_env,
160+
additional_vllm_kwargs=additional_vllm_kwargs or None,
157161
) as server:
158162
warmup_server(base_url=server.base_url, model=model_name)
159163

@@ -213,6 +217,16 @@ def _build_arg_parser() -> argparse.ArgumentParser:
213217
parser.add_argument("--num_samples", type=int, help="Limit samples per task (None = full dataset).")
214218
parser.add_argument("--hf_token", help="HuggingFace token for gated tokenizers.")
215219
parser.add_argument("--hf_mode", action="store_true", help="HF safetensors mode.")
220+
parser.add_argument(
221+
"--enable_expert_parallel",
222+
action="store_true",
223+
help=(
224+
"Enable expert parallelism in vLLM. Required for MoE models such as "
225+
"qwen3-30b-a3b, qwen3-235b-a22b, deepseek-v3, etc. Without this flag "
226+
"tpu-inference omits the 'expert' mesh axis and MaxText's MoE sharding "
227+
"raises KeyError."
228+
),
229+
)
216230
parser.add_argument("--gcs_results_path", help="Optional GCS path to upload results.")
217231
parser.add_argument("--log_level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
218232
return parser

src/maxtext/eval/runner/server_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class VllmServerManager:
183183
"""Manages an in-process vLLM-TPU LLM with an OpenAI-compatible HTTP layer.
184184
185185
Args:
186-
model_path: HF model ID or local path.
186+
model_path: HF model ID or local path.
187187
checkpoint_path: MaxText orbax checkpoint path.
188188
maxtext_model_name: MaxText model name (e.g. "llama3.1-8b").
189189
host: Hostname the HTTP server binds to (rank-0 only).
@@ -195,6 +195,7 @@ class VllmServerManager:
195195
max_num_seqs: Max concurrent sequences (None = vLLM default).
196196
startup_timeout: Seconds to wait for /health to return healthy.
197197
env: Optional environment-variable overrides.
198+
additional_vllm_kwargs: Extra kwargs merged into the vLLM LLM() constructor.
198199
"""
199200

200201
def __init__(
@@ -211,6 +212,7 @@ def __init__(
211212
max_num_seqs: int | None = None,
212213
startup_timeout: int = 600,
213214
env: dict[str, str] | None = None,
215+
additional_vllm_kwargs: dict | None = None,
214216
):
215217
if checkpoint_path and not maxtext_model_name:
216218
raise ValueError("maxtext_model_name is required when checkpoint_path is set.")
@@ -226,6 +228,7 @@ def __init__(
226228
self.max_num_seqs = max_num_seqs
227229
self.startup_timeout = startup_timeout
228230
self.env = env
231+
self.additional_vllm_kwargs = additional_vllm_kwargs or {}
229232

230233
self._llm: Any | None = None
231234
self._uvicorn_server: Any | None = None
@@ -271,6 +274,9 @@ def start(self) -> None:
271274
else:
272275
vllm_kwargs["load_format"] = "auto"
273276

277+
if self.additional_vllm_kwargs:
278+
vllm_kwargs.update(self.additional_vllm_kwargs)
279+
274280
logger.info(
275281
"Initializing in-process vLLM (tp=%d, max_len=%d)...",
276282
self.tensor_parallel_size,

0 commit comments

Comments
 (0)