@@ -290,8 +290,6 @@ def __init__(self, fd_config: "FDConfig", **kwargs):
290290
291291 # Assign parallel config from fd_config.parallel_config to paddleformers_config
292292 parallel_config = fd_config .parallel_config
293- # parallel_config.tensor_parallel_size = 1
294- # parallel_config.expert_parallel_size = 2
295293 self .paddleformers_config .data_parallel_size = parallel_config .data_parallel_size
296294 self .paddleformers_config .tensor_model_parallel_size = parallel_config .tensor_parallel_size
297295 self .paddleformers_config .sequence_parallel = parallel_config .sequence_parallel
@@ -305,6 +303,8 @@ def __init__(self, fd_config: "FDConfig", **kwargs):
305303 # self.paddleformers_config.moe_grouped_gemm = True
306304 self .paddleformers_config .moe_token_dispatcher_type = "deepep"
307305 # self.paddleformers_config.use_cpu_initialization = True
306+ self .paddleformers_config .use_cpu_initialization = True
307+ self .paddleformers_config .perform_initialization = False
308308 self .paddleformers_config .gated_attention = getattr (self .paddleformers_config , "use_gated_attn" , False )
309309 if getattr (self .paddleformers_config , "multi_latent_attention" , False ):
310310 self .paddleformers_config .qk_head_dim = (
@@ -396,6 +396,16 @@ def _init_paddlefleet_parallel_state(self, fd_config) -> None:
396396 "mp" ,
397397 ],
398398 }
399+ # Reset parallel state so that PaddleFleet's fleet.init can reinitialize
400+ # with the correct EP topology instead of reusing FastDeploy's.
401+ import paddle .distributed .fleet .base .topology as tp_mod
402+ import paddle .distributed .parallel_helper as ph
403+
404+ # 1) Reset hybrid parallel group so _init_hybrid_parallel_env runs again
405+ tp_mod ._HYBRID_PARALLEL_GROUP = None
406+ # 2) Reset parallel context so init_parallel_env runs again
407+ ph .__parallel_ctx__clz__ = None
408+
399409 fleet .init (is_collective = True , strategy = strategy )
400410 logger .info (
401411 f"Initialized PaddleFleet parallel_state via initialize_fleet "
@@ -404,40 +414,23 @@ def _init_paddlefleet_parallel_state(self, fd_config) -> None:
404414 f"ep={ parallel_config .expert_parallel_size } , "
405415 f"sp={ parallel_config .sequence_parallel } )"
406416 )
407-
408417 import paddle .distributed as dist
409418 from paddlefleet import parallel_state
410419
411- hcg = fleet .get_hybrid_communicate_group ()
412- expected_tp_size = parallel_config .tensor_parallel_size
413-
414- # Check if we need to initialize or reinitialize TP group
415- need_init = False
416- if parallel_state ._TENSOR_MODEL_PARALLEL_GROUP is None :
417- need_init = True
418- reason = "TP group not initialized"
419- else :
420- # Check if current TP group size matches expected
421- current_tp_group = parallel_state ._TENSOR_MODEL_PARALLEL_GROUP
422- current_tp_size = getattr (current_tp_group , "nranks" , None )
420+ tp_group = parallel_state ._TENSOR_MODEL_PARALLEL_GROUP
421+ current_tp_size = None
422+ if tp_group is not None :
423+ current_tp_size = getattr (tp_group , "nranks" , None )
423424 if current_tp_size is None :
424- current_tp_size = getattr (current_tp_group , "world_size" , None )
425- if current_tp_size != expected_tp_size :
426- need_init = True
427- reason = f"TP group size mismatch: current={ current_tp_size } , expected={ expected_tp_size } "
425+ current_tp_size = getattr (tp_group , "world_size" , None )
428426
427+ expected_tp_size = parallel_config .tensor_parallel_size
428+ need_init = tp_group is None or current_tp_size != expected_tp_size
429429 if need_init :
430- logger .warning (f"{ reason } , reinitializing TP group with size={ expected_tp_size } " )
431430 if expected_tp_size == 1 :
432- # Single process TP group - create manually
433- current_rank = dist .get_rank ()
434- tp_ranks = [current_rank ]
435- default_pg = dist .new_group (ranks = tp_ranks )
436- parallel_state ._TENSOR_MODEL_PARALLEL_GROUP = default_pg
437- parallel_state ._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = tp_ranks
438- logger .info (f"Reinitialized TP group with size=1, rank={ current_rank } , ranks={ tp_ranks } " )
431+ parallel_state ._TENSOR_MODEL_PARALLEL_GROUP = dist .new_group (ranks = [dist .get_rank ()])
439432 else :
440- # Multiple processes - use hcg's mp group
433+ hcg = fleet . get_hybrid_communicate_group ()
441434 parallel_state .initialize_model_parallel (hcg )
442435
443436 from paddlefleet .tensor_parallel .random import (
0 commit comments