@@ -3680,14 +3680,27 @@ def float(self, *args):
36803680
36813681 @classmethod
36823682 def get_init_context (
3683- cls , dtype : torch .dtype , is_quantized : bool , _is_ds_init_called : bool , allow_all_kernels : bool | None
3683+ cls ,
3684+ dtype : torch .dtype ,
3685+ is_quantized : bool ,
3686+ _is_ds_init_called : bool ,
3687+ allow_all_kernels : bool | None ,
3688+ distributed_config = None ,
36843689 ):
36853690 # Need to instantiate with correct dtype
36863691 init_contexts = [local_torch_dtype (dtype , cls .__name__ ), init .no_tie_weights (), apply_patches ()]
36873692 # Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__
36883693 if allow_all_kernels :
36893694 init_contexts .append (allow_all_hub_kernels ())
3690- if is_deepspeed_zero3_enabled ():
3695+ _has_ep = distributed_config is not None and getattr (distributed_config , "enable_expert_parallel" , False )
3696+ if _has_ep and is_deepspeed_zero3_enabled ():
3697+ # EP + DeepSpeed: use meta device (same as the normal non-DS path).
3698+ # zero.Init is skipped because EP needs to shard experts via distribute_model()
3699+ # hooks, which are incompatible with ZeRO-3 lazy parameters.
3700+ # The standard weight loading path (not zero3) handles EP sharding via
3701+ # shard_and_distribute_module. deepspeed.initialize() wraps the result later.
3702+ init_contexts .extend ([torch .device ("meta" ), init .meta_device_safe_creation_ops ()])
3703+ elif is_deepspeed_zero3_enabled ():
36913704 import deepspeed
36923705
36933706 # We cannot initialize the model on meta device with deepspeed when not quantized
@@ -4095,6 +4108,12 @@ def from_pretrained(
40954108 download_kwargs_with_commit ,
40964109 ** adapter_kwargs ,
40974110 )
4111+ # EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model
4112+ # loads on CPU first. distribute_model() handles GPU placement during EP sharding.
4113+ # Without this, device_map triggers accelerate's dispatch path which breaks shard loading.
4114+ _has_ep = distributed_config is not None and getattr (distributed_config , "enable_expert_parallel" , False )
4115+ if _has_ep and is_deepspeed_zero3_enabled ():
4116+ device_map = None
40984117 device_map = check_and_set_device_map (device_map ) # warn, error and fix the device map
40994118
41004119 user_agent = {"file_type" : "model" , "framework" : "pytorch" , "from_auto_class" : from_auto_class }
@@ -4203,7 +4222,9 @@ def from_pretrained(
42034222
42044223 register_fusion_patches (cls , config , fusion_config )
42054224
4206- model_init_context = cls .get_init_context (dtype , is_quantized , _is_ds_init_called , allow_all_kernels )
4225+ model_init_context = cls .get_init_context (
4226+ dtype , is_quantized , _is_ds_init_called , allow_all_kernels , distributed_config
4227+ )
42074228
42084229 config = copy .deepcopy (config ) # We do not want to modify the config inplace in from_pretrained.
42094230 with ContextManagers (model_init_context ):
@@ -4336,7 +4357,11 @@ def _load_pretrained_model(
43364357
43374358 error_msgs = []
43384359
4339- if is_deepspeed_zero3_enabled () and not is_quantized :
4360+ # EP + DeepSpeed: skip zero3 loading path. The model was created on meta device
4361+ # (not via zero.Init), so params are not zero3-partitioned. The standard loading
4362+ # path handles EP sharding via shard_and_distribute_module using the EP plan hooks
4363+ # registered by distribute_model(). deepspeed.initialize() wraps the result later.
4364+ if is_deepspeed_zero3_enabled () and not is_quantized and not model .has_ep :
43404365 if state_dict is None :
43414366 merged_state_dict = {}
43424367 for ckpt_file in checkpoint_files :
@@ -4655,7 +4680,8 @@ def _move_missing_keys_from_meta_to_device(
46554680 """
46564681 is_quantized = hf_quantizer is not None
46574682 # This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4658- if is_deepspeed_zero3_enabled () and not is_quantized :
4683+ # Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path.
4684+ if is_deepspeed_zero3_enabled () and not is_quantized and not self .has_ep :
46594685 return
46604686
46614687 # Leave parameters on meta on non-rank-0 FSDP ranks (rank-0 broadcast overwrites them); only buffers need real placeholders.
@@ -4710,7 +4736,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None:
47104736 self ._is_hf_initialized = True
47114737
47124738 # This will only initialize submodules that are not marked as initialized by the line above.
4713- if is_deepspeed_zero3_enabled () and not is_quantized :
4739+ if is_deepspeed_zero3_enabled () and not is_quantized and not self . has_ep :
47144740 import deepspeed
47154741
47164742 # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
0 commit comments