Skip to content

Commit f5a72c4

Browse files
support ep
1 parent f161fea commit f5a72c4

2 files changed

Lines changed: 15 additions & 34 deletions

File tree

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _post_init(self):
319319
if self.runner_type == "generate" and not is_generative_model:
320320
if is_multimodal_model:
321321
pass
322-
elif self.model_impl in ("auto", "paddleformers"):
322+
elif self.model_impl in ("auto", "paddleformers", "paddlefleet"):
323323
# Skip check for auto/paddleformers - may fallback to paddleformers which supports any model
324324
pass
325325
else:

fastdeploy/model_executor/models/paddleformers/base_fleet.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,6 @@ def __init__(self, fd_config: "FDConfig", **kwargs):
290290

291291
# Assign parallel config from fd_config.parallel_config to paddleformers_config
292292
parallel_config = fd_config.parallel_config
293-
# parallel_config.tensor_parallel_size = 1
294-
# parallel_config.expert_parallel_size = 2
295293
self.paddleformers_config.data_parallel_size = parallel_config.data_parallel_size
296294
self.paddleformers_config.tensor_model_parallel_size = parallel_config.tensor_parallel_size
297295
self.paddleformers_config.sequence_parallel = parallel_config.sequence_parallel
@@ -305,6 +303,8 @@ def __init__(self, fd_config: "FDConfig", **kwargs):
305303
# self.paddleformers_config.moe_grouped_gemm = True
306304
self.paddleformers_config.moe_token_dispatcher_type = "deepep"
307305
# self.paddleformers_config.use_cpu_initialization = True
306+
self.paddleformers_config.use_cpu_initialization = True
307+
self.paddleformers_config.perform_initialization = False
308308
self.paddleformers_config.gated_attention = getattr(self.paddleformers_config, "use_gated_attn", False)
309309
if getattr(self.paddleformers_config, "multi_latent_attention", False):
310310
self.paddleformers_config.qk_head_dim = (
@@ -396,6 +396,16 @@ def _init_paddlefleet_parallel_state(self, fd_config) -> None:
396396
"mp",
397397
],
398398
}
399+
# Reset parallel state so that PaddleFleet's fleet.init can reinitialize
400+
# with the correct EP topology instead of reusing FastDeploy's.
401+
import paddle.distributed.fleet.base.topology as tp_mod
402+
import paddle.distributed.parallel_helper as ph
403+
404+
# 1) Reset hybrid parallel group so _init_hybrid_parallel_env runs again
405+
tp_mod._HYBRID_PARALLEL_GROUP = None
406+
# 2) Reset parallel context so init_parallel_env runs again
407+
ph.__parallel_ctx__clz__ = None
408+
399409
fleet.init(is_collective=True, strategy=strategy)
400410
logger.info(
401411
f"Initialized PaddleFleet parallel_state via initialize_fleet "
@@ -405,40 +415,11 @@ def _init_paddlefleet_parallel_state(self, fd_config) -> None:
405415
f"sp={parallel_config.sequence_parallel})"
406416
)
407417

408-
import paddle.distributed as dist
409418
from paddlefleet import parallel_state
410419

411-
hcg = fleet.get_hybrid_communicate_group()
412-
expected_tp_size = parallel_config.tensor_parallel_size
413-
414-
# Check if we need to initialize or reinitialize TP group
415-
need_init = False
416420
if parallel_state._TENSOR_MODEL_PARALLEL_GROUP is None:
417-
need_init = True
418-
reason = "TP group not initialized"
419-
else:
420-
# Check if current TP group size matches expected
421-
current_tp_group = parallel_state._TENSOR_MODEL_PARALLEL_GROUP
422-
current_tp_size = getattr(current_tp_group, "nranks", None)
423-
if current_tp_size is None:
424-
current_tp_size = getattr(current_tp_group, "world_size", None)
425-
if current_tp_size != expected_tp_size:
426-
need_init = True
427-
reason = f"TP group size mismatch: current={current_tp_size}, expected={expected_tp_size}"
428-
429-
if need_init:
430-
logger.warning(f"{reason}, reinitializing TP group with size={expected_tp_size}")
431-
if expected_tp_size == 1:
432-
# Single process TP group - create manually
433-
current_rank = dist.get_rank()
434-
tp_ranks = [current_rank]
435-
default_pg = dist.new_group(ranks=tp_ranks)
436-
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = default_pg
437-
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = tp_ranks
438-
logger.info(f"Reinitialized TP group with size=1, rank={current_rank}, ranks={tp_ranks}")
439-
else:
440-
# Multiple processes - use hcg's mp group
441-
parallel_state.initialize_model_parallel(hcg)
421+
hcg = fleet.get_hybrid_communicate_group()
422+
parallel_state.initialize_model_parallel(hcg)
442423

443424
from paddlefleet.tensor_parallel.random import (
444425
model_parallel_cuda_manual_seed,

0 commit comments

Comments
 (0)