Skip to content

Commit fbf3f4e

Browse files
[Models] fix fleet model fallback ep init (#8039)
* support ep * fix ci test
1 parent 6076add commit fbf3f4e

2 files changed

Lines changed: 22 additions & 29 deletions

File tree

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _post_init(self):
321321
if self.runner_type == "generate" and not is_generative_model:
322322
if is_multimodal_model:
323323
pass
324-
elif self.model_impl in ("auto", "paddleformers"):
324+
elif self.model_impl in ("auto", "paddleformers", "paddlefleet"):
325325
# Skip check for auto/paddleformers - may fallback to paddleformers which supports any model
326326
pass
327327
else:

fastdeploy/model_executor/models/paddleformers/base_fleet.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,6 @@ def __init__(self, fd_config: "FDConfig", **kwargs):
290290

291291
# Assign parallel config from fd_config.parallel_config to paddleformers_config
292292
parallel_config = fd_config.parallel_config
293-
# parallel_config.tensor_parallel_size = 1
294-
# parallel_config.expert_parallel_size = 2
295293
self.paddleformers_config.data_parallel_size = parallel_config.data_parallel_size
296294
self.paddleformers_config.tensor_model_parallel_size = parallel_config.tensor_parallel_size
297295
self.paddleformers_config.sequence_parallel = parallel_config.sequence_parallel
@@ -305,6 +303,8 @@ def __init__(self, fd_config: "FDConfig", **kwargs):
305303
# self.paddleformers_config.moe_grouped_gemm = True
306304
self.paddleformers_config.moe_token_dispatcher_type = "deepep"
307305
# self.paddleformers_config.use_cpu_initialization = True
306+
self.paddleformers_config.use_cpu_initialization = True
307+
self.paddleformers_config.perform_initialization = False
308308
self.paddleformers_config.gated_attention = getattr(self.paddleformers_config, "use_gated_attn", False)
309309
if getattr(self.paddleformers_config, "multi_latent_attention", False):
310310
self.paddleformers_config.qk_head_dim = (
@@ -396,6 +396,16 @@ def _init_paddlefleet_parallel_state(self, fd_config) -> None:
396396
"mp",
397397
],
398398
}
399+
# Reset parallel state so that PaddleFleet's fleet.init can reinitialize
400+
# with the correct EP topology instead of reusing FastDeploy's.
401+
import paddle.distributed.fleet.base.topology as tp_mod
402+
import paddle.distributed.parallel_helper as ph
403+
404+
# 1) Reset hybrid parallel group so _init_hybrid_parallel_env runs again
405+
tp_mod._HYBRID_PARALLEL_GROUP = None
406+
# 2) Reset parallel context so init_parallel_env runs again
407+
ph.__parallel_ctx__clz__ = None
408+
399409
fleet.init(is_collective=True, strategy=strategy)
400410
logger.info(
401411
f"Initialized PaddleFleet parallel_state via initialize_fleet "
@@ -404,40 +414,23 @@ def _init_paddlefleet_parallel_state(self, fd_config) -> None:
404414
f"ep={parallel_config.expert_parallel_size}, "
405415
f"sp={parallel_config.sequence_parallel})"
406416
)
407-
408417
import paddle.distributed as dist
409418
from paddlefleet import parallel_state
410419

411-
hcg = fleet.get_hybrid_communicate_group()
412-
expected_tp_size = parallel_config.tensor_parallel_size
413-
414-
# Check if we need to initialize or reinitialize TP group
415-
need_init = False
416-
if parallel_state._TENSOR_MODEL_PARALLEL_GROUP is None:
417-
need_init = True
418-
reason = "TP group not initialized"
419-
else:
420-
# Check if current TP group size matches expected
421-
current_tp_group = parallel_state._TENSOR_MODEL_PARALLEL_GROUP
422-
current_tp_size = getattr(current_tp_group, "nranks", None)
420+
tp_group = parallel_state._TENSOR_MODEL_PARALLEL_GROUP
421+
current_tp_size = None
422+
if tp_group is not None:
423+
current_tp_size = getattr(tp_group, "nranks", None)
423424
if current_tp_size is None:
424-
current_tp_size = getattr(current_tp_group, "world_size", None)
425-
if current_tp_size != expected_tp_size:
426-
need_init = True
427-
reason = f"TP group size mismatch: current={current_tp_size}, expected={expected_tp_size}"
425+
current_tp_size = getattr(tp_group, "world_size", None)
428426

427+
expected_tp_size = parallel_config.tensor_parallel_size
428+
need_init = tp_group is None or current_tp_size != expected_tp_size
429429
if need_init:
430-
logger.warning(f"{reason}, reinitializing TP group with size={expected_tp_size}")
431430
if expected_tp_size == 1:
432-
# Single process TP group - create manually
433-
current_rank = dist.get_rank()
434-
tp_ranks = [current_rank]
435-
default_pg = dist.new_group(ranks=tp_ranks)
436-
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = default_pg
437-
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = tp_ranks
438-
logger.info(f"Reinitialized TP group with size=1, rank={current_rank}, ranks={tp_ranks}")
431+
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = dist.new_group(ranks=[dist.get_rank()])
439432
else:
440-
# Multiple processes - use hcg's mp group
433+
hcg = fleet.get_hybrid_communicate_group()
441434
parallel_state.initialize_model_parallel(hcg)
442435

443436
from paddlefleet.tensor_parallel.random import (

0 commit comments

Comments
 (0)