@@ -1330,12 +1330,18 @@ def post_init(self):
13301330 self .init_weights ()
13311331 self ._backward_compatibility_gradient_checkpointing ()
13321332
1333+ @property
1334+ def has_ep (self ) -> bool :
1335+ """Whether expert parallelism is enabled for this model."""
1336+ distributed_config = getattr (getattr (self , "config" , None ), "distributed_config" , None )
1337+ return distributed_config is not None and getattr (distributed_config , "enable_expert_parallel" , False )
1338+
13331339 @property
13341340 def tp_plan (self ) -> dict [str , str ]:
13351341 """
13361342 The full tp plan for the model's modules
13371343 """
1338- if hasattr ( self .config , "distributed_config" ) and self . config . distributed_config . enable_expert_parallel :
1344+ if self .has_ep :
13391345 return self ._ep_plan
13401346 return self ._tp_plan
13411347
@@ -3599,14 +3605,27 @@ def float(self, *args):
35993605
36003606 @classmethod
36013607 def get_init_context (
3602- cls , dtype : torch .dtype , is_quantized : bool , _is_ds_init_called : bool , allow_all_kernels : bool | None
3608+ cls ,
3609+ dtype : torch .dtype ,
3610+ is_quantized : bool ,
3611+ _is_ds_init_called : bool ,
3612+ allow_all_kernels : bool | None ,
3613+ distributed_config = None ,
36033614 ):
36043615 # Need to instantiate with correct dtype
36053616 init_contexts = [local_torch_dtype (dtype , cls .__name__ ), init .no_tie_weights (), apply_patches ()]
36063617 # Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__
36073618 if allow_all_kernels :
36083619 init_contexts .append (allow_all_hub_kernels ())
3609- if is_deepspeed_zero3_enabled ():
3620+ _has_ep = distributed_config is not None and getattr (distributed_config , "enable_expert_parallel" , False )
3621+ if _has_ep and is_deepspeed_zero3_enabled ():
3622+ # EP + DeepSpeed: use meta device (same as the normal non-DS path).
3623+ # zero.Init is skipped because EP needs to shard experts via distribute_model()
3624+ # hooks, which are incompatible with ZeRO-3 lazy parameters.
3625+ # The standard weight loading path (not zero3) handles EP sharding via
3626+ # shard_and_distribute_module. deepspeed.initialize() wraps the result later.
3627+ init_contexts .extend ([torch .device ("meta" ), init .meta_device_safe_creation_ops ()])
3628+ elif is_deepspeed_zero3_enabled ():
36103629 import deepspeed
36113630
36123631 # We cannot initialize the model on meta device with deepspeed when not quantized
@@ -4007,6 +4026,12 @@ def from_pretrained(
40074026 download_kwargs_with_commit ,
40084027 ** adapter_kwargs ,
40094028 )
4029+ # EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model
4030+ # loads on CPU first. distribute_model() handles GPU placement during EP sharding.
4031+ # Without this, device_map triggers accelerate's dispatch path which breaks shard loading.
4032+ _has_ep = distributed_config is not None and getattr (distributed_config , "enable_expert_parallel" , False )
4033+ if _has_ep and is_deepspeed_zero3_enabled ():
4034+ device_map = None
40104035 device_map = check_and_set_device_map (device_map ) # warn, error and fix the device map
40114036
40124037 user_agent = {"file_type" : "model" , "framework" : "pytorch" , "from_auto_class" : from_auto_class }
@@ -4110,7 +4135,9 @@ def from_pretrained(
41104135
41114136 register_fusion_patches (cls , config , fusion_config )
41124137
4113- model_init_context = cls .get_init_context (dtype , is_quantized , _is_ds_init_called , allow_all_kernels )
4138+ model_init_context = cls .get_init_context (
4139+ dtype , is_quantized , _is_ds_init_called , allow_all_kernels , distributed_config
4140+ )
41144141
41154142 config = copy .deepcopy (config ) # We do not want to modify the config inplace in from_pretrained.
41164143 with ContextManagers (model_init_context ):
@@ -4241,7 +4268,11 @@ def _load_pretrained_model(
42414268
42424269 error_msgs = []
42434270
4244- if is_deepspeed_zero3_enabled () and not is_quantized :
4271+ # EP + DeepSpeed: skip zero3 loading path. The model was created on meta device
4272+ # (not via zero.Init), so params are not zero3-partitioned. The standard loading
4273+ # path handles EP sharding via shard_and_distribute_module using the EP plan hooks
4274+ # registered by distribute_model(). deepspeed.initialize() wraps the result later.
4275+ if is_deepspeed_zero3_enabled () and not is_quantized and not model .has_ep :
42454276 if state_dict is None :
42464277 merged_state_dict = {}
42474278 for ckpt_file in checkpoint_files :
@@ -4551,7 +4582,8 @@ def _move_missing_keys_from_meta_to_device(
45514582 """
45524583 is_quantized = hf_quantizer is not None
45534584 # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4554- if is_deepspeed_zero3_enabled () and not is_quantized :
4585+ # Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path.
4586+ if is_deepspeed_zero3_enabled () and not is_quantized and not self .has_ep :
45554587 return
45564588
45574589 # In this case we need to move everything back
@@ -4609,7 +4641,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None:
46094641 self ._is_hf_initialized = True
46104642
46114643 # This will only initialize submodules that are not marked as initialized by the line above.
4612- if is_deepspeed_zero3_enabled () and not is_quantized :
4644+ if is_deepspeed_zero3_enabled () and not is_quantized and not self . has_ep :
46134645 import deepspeed
46144646
46154647 # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
0 commit comments