Skip to content

Commit 278e455

Browse files
committed
Fix EP vllm keyerror
1 parent 4e8731b commit 278e455

3 files changed

Lines changed: 39 additions & 26 deletions

File tree

src/maxtext/eval/runner/common.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@
2323
if TYPE_CHECKING:
2424
from maxtext.eval.runner.server_manager import VllmServerManager
2525

26-
ENABLE_EXPERT_PARALLEL_HELP = (
27-
"Enable expert parallelism in vLLM. Required for MoE models such as "
28-
"qwen3-30b-a3b, qwen3-235b-a22b, deepseek-v3, etc. Without this flag "
29-
"tpu-inference omits the 'expert' mesh axis and MaxText's MoE sharding "
30-
"raises KeyError."
31-
)
32-
3326

3427
def resolve_token(cfg: dict, hf_token: str | None) -> str | None:
3528
"""Return HF token from explicit arg or HF_TOKEN env var."""
@@ -71,10 +64,9 @@ def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
7164
if max_num_seqs is not None:
7265
max_num_seqs = int(max_num_seqs)
7366

67+
expert_parallel_size = int(cfg.get("expert_parallel_size") or 1)
68+
7469
server_env = {"HF_TOKEN": token} if token else None
75-
additional_vllm_kwargs: dict = {}
76-
if cfg.get("enable_expert_parallel"):
77-
additional_vllm_kwargs["enable_expert_parallel"] = True
7870

7971
return VllmServerManager(
8072
model_path=hf_path,
@@ -83,11 +75,11 @@ def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
8375
host=server_host,
8476
port=server_port,
8577
tensor_parallel_size=tensor_parallel_size,
78+
expert_parallel_size=expert_parallel_size,
8679
max_model_len=max_model_len,
8780
max_num_batched_tokens=max_num_batched_tokens,
8881
max_num_seqs=max_num_seqs,
8982
env=server_env,
90-
additional_vllm_kwargs=additional_vllm_kwargs or None,
9183
)
9284

9385

@@ -121,7 +113,12 @@ def add_server_args(parser: argparse.ArgumentParser) -> None:
121113
parser.add_argument("--max_num_seqs", type=int, help="vLLM max concurrent sequences.")
122114
parser.add_argument("--hf_mode", action="store_true", help="HF safetensors mode.")
123115
parser.add_argument(
124-
"--enable_expert_parallel", action="store_true", help=ENABLE_EXPERT_PARALLEL_HELP
116+
"--expert_parallel_size",
117+
type=int,
118+
default=0,
119+
help=(
120+
"Chips allocated to the expert mesh axis (EP). "
121+
),
125122
)
126123
parser.add_argument("--hf_token", help="HuggingFace token for gated models.")
127124
parser.add_argument(

src/maxtext/eval/runner/server_manager.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ class VllmServerManager:
188188
maxtext_model_name: MaxText model name (e.g. "llama3.1-8b").
189189
host: Hostname the HTTP server binds to (rank-0 only).
190190
port: Port the HTTP server listens on.
191-
tensor_parallel_size: Tensor parallelism.
191+
tensor_parallel_size: Total number of chips.
192+
expert_parallel_size: Chips allocated to the expert mesh axis (EP).
192193
max_model_len: Maximum sequence length.
193194
dtype: Activation dtype string passed to vLLM (e.g. "bfloat16").
194195
max_num_batched_tokens: Tokens per scheduler step (None = vLLM default).
@@ -206,6 +207,7 @@ def __init__(
206207
host: str = "localhost",
207208
port: int = 8000,
208209
tensor_parallel_size: int = 4,
210+
expert_parallel_size: int = 1,
209211
max_model_len: int = 4096,
210212
dtype: str = "bfloat16",
211213
max_num_batched_tokens: int | None = None,
@@ -216,12 +218,18 @@ def __init__(
216218
):
217219
if checkpoint_path and not maxtext_model_name:
218220
raise ValueError("maxtext_model_name is required when checkpoint_path is set.")
221+
if tensor_parallel_size % expert_parallel_size != 0:
222+
raise ValueError(
223+
f"tensor_parallel_size ({tensor_parallel_size}) is not divisible by "
224+
f"expert_parallel_size ({expert_parallel_size})."
225+
)
219226
self.model_path = model_path
220227
self.checkpoint_path = checkpoint_path
221228
self.maxtext_model_name = maxtext_model_name
222229
self.host = host
223230
self.port = port
224231
self.tensor_parallel_size = tensor_parallel_size
232+
self.expert_parallel_size = expert_parallel_size
225233
self.max_model_len = max_model_len
226234
self.dtype = dtype
227235
self.max_num_batched_tokens = max_num_batched_tokens
@@ -251,9 +259,13 @@ def start(self) -> None:
251259
if self.env:
252260
os.environ.update(self.env)
253261

262+
# total chips = ici_tensor_parallelism x ici_expert_parallelism.
263+
ici_tp = self.tensor_parallel_size // self.expert_parallel_size
264+
ici_ep = self.expert_parallel_size
265+
254266
vllm_kwargs: dict = {
255267
"model": self.model_path,
256-
"tensor_parallel_size": self.tensor_parallel_size,
268+
"tensor_parallel_size": ici_tp,
257269
"max_model_len": self.max_model_len,
258270
"dtype": self.dtype,
259271
}
@@ -269,14 +281,15 @@ def start(self) -> None:
269281
"model_name": self.maxtext_model_name,
270282
"load_parameters_path": self.checkpoint_path,
271283
"log_config": False,
272-
}
284+
"ici_tensor_parallelism": ici_tp,
285+
"ici_expert_parallelism": ici_ep,
286+
},
287+
"sharding": {
288+
"sharding_strategy": {},
289+
},
273290
}
274-
if self.additional_vllm_kwargs.get("enable_expert_parallel"):
275-
vllm_kwargs["additional_config"]["sharding"] = {
276-
"sharding_strategy": {
277-
"expert_parallelism": self.tensor_parallel_size,
278-
}
279-
}
291+
if ici_ep > 1:
292+
vllm_kwargs["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = ici_ep
280293
else:
281294
vllm_kwargs["load_format"] = "auto"
282295

@@ -298,8 +311,9 @@ def start(self) -> None:
298311
vllm_kwargs[_k] = _v
299312

300313
logger.info(
301-
"Initializing in-process vLLM (tp=%d, max_len=%d)...",
302-
self.tensor_parallel_size,
314+
"Initializing in-process vLLM (tp=%d, ep=%d, max_len=%d)...",
315+
ici_tp,
316+
ici_ep,
303317
self.max_model_len,
304318
)
305319
self._llm = LLM(**vllm_kwargs)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from flax import nnx
2121
import flax.linen as nn
2222
from jax import numpy as jnp
23-
from jax.sharding import Mesh
23+
from jax.sharding import AxisType, Mesh
2424
from maxtext.configs import pyconfig
2525
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
2626
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE
2727
from maxtext.utils import max_logging
28+
from maxtext.utils import maxtext_utils
2829
from maxtext.utils import model_creation_utils
2930

3031

@@ -98,8 +99,9 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9899
self.cfg = vllm_config.model_config
99100
self.maxtext_config = generate_maxtext_config(vllm_config)
100101

101-
# Model configuration
102-
self.mesh = mesh
102+
devices_array = maxtext_utils.create_device_mesh(self.maxtext_config)
103+
axis_types = tuple([AxisType.Auto] * len(self.maxtext_config.mesh_axes))
104+
self.mesh = Mesh(devices_array, self.maxtext_config.mesh_axes, axis_types=axis_types)
103105
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
104106
self.is_text_generation_model = True
105107

0 commit comments

Comments
 (0)