@@ -188,7 +188,8 @@ class VllmServerManager:
188188 maxtext_model_name: MaxText model name (e.g. "llama3.1-8b").
189189 host: Hostname the HTTP server binds to (rank-0 only).
190190 port: Port the HTTP server listens on.
191- tensor_parallel_size: Tensor parallelism.
191+ tensor_parallel_size: Total number of chips.
192+ expert_parallel_size: Chips allocated to the expert mesh axis (EP).
192193 max_model_len: Maximum sequence length.
193194 dtype: Activation dtype string passed to vLLM (e.g. "bfloat16").
194195 max_num_batched_tokens: Tokens per scheduler step (None = vLLM default).
@@ -206,6 +207,7 @@ def __init__(
206207 host : str = "localhost" ,
207208 port : int = 8000 ,
208209 tensor_parallel_size : int = 4 ,
210+ expert_parallel_size : int = 1 ,
209211 max_model_len : int = 4096 ,
210212 dtype : str = "bfloat16" ,
211213 max_num_batched_tokens : int | None = None ,
@@ -216,12 +218,18 @@ def __init__(
216218 ):
217219 if checkpoint_path and not maxtext_model_name :
218220 raise ValueError ("maxtext_model_name is required when checkpoint_path is set." )
221+ if tensor_parallel_size % expert_parallel_size != 0 :
222+ raise ValueError (
223+ f"tensor_parallel_size ({ tensor_parallel_size } ) is not divisible by "
224+ f"expert_parallel_size ({ expert_parallel_size } )."
225+ )
219226 self .model_path = model_path
220227 self .checkpoint_path = checkpoint_path
221228 self .maxtext_model_name = maxtext_model_name
222229 self .host = host
223230 self .port = port
224231 self .tensor_parallel_size = tensor_parallel_size
232+ self .expert_parallel_size = expert_parallel_size
225233 self .max_model_len = max_model_len
226234 self .dtype = dtype
227235 self .max_num_batched_tokens = max_num_batched_tokens
@@ -251,9 +259,13 @@ def start(self) -> None:
251259 if self .env :
252260 os .environ .update (self .env )
253261
262+ # total chips = ici_tensor_parallelism x ici_expert_parallelism.
263+ ici_tp = self .tensor_parallel_size // self .expert_parallel_size
264+ ici_ep = self .expert_parallel_size
265+
254266 vllm_kwargs : dict = {
255267 "model" : self .model_path ,
256- "tensor_parallel_size" : self . tensor_parallel_size ,
268+ "tensor_parallel_size" : ici_tp ,
257269 "max_model_len" : self .max_model_len ,
258270 "dtype" : self .dtype ,
259271 }
@@ -269,14 +281,15 @@ def start(self) -> None:
269281 "model_name" : self .maxtext_model_name ,
270282 "load_parameters_path" : self .checkpoint_path ,
271283 "log_config" : False ,
272- }
284+ "ici_tensor_parallelism" : ici_tp ,
285+ "ici_expert_parallelism" : ici_ep ,
286+ },
287+ "sharding" : {
288+ "sharding_strategy" : {},
289+ },
273290 }
274- if self .additional_vllm_kwargs .get ("enable_expert_parallel" ):
275- vllm_kwargs ["additional_config" ]["sharding" ] = {
276- "sharding_strategy" : {
277- "expert_parallelism" : self .tensor_parallel_size ,
278- }
279- }
291+ if ici_ep > 1 :
292+ vllm_kwargs ["additional_config" ]["sharding" ]["sharding_strategy" ]["expert_parallelism" ] = ici_ep
280293 else :
281294 vllm_kwargs ["load_format" ] = "auto"
282295
@@ -298,8 +311,9 @@ def start(self) -> None:
298311 vllm_kwargs [_k ] = _v
299312
300313 logger .info (
301- "Initializing in-process vLLM (tp=%d, max_len=%d)..." ,
302- self .tensor_parallel_size ,
314+ "Initializing in-process vLLM (tp=%d, ep=%d, max_len=%d)..." ,
315+ ici_tp ,
316+ ici_ep ,
303317 self .max_model_len ,
304318 )
305319 self ._llm = LLM (** vllm_kwargs )
0 commit comments