Skip to content

Commit ae548bf

Browse files
committed
Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch
Route EP through the standard (non-zero3) loading path when both EP and is_deepspeed_zero3_enabled() are active, then let deepspeed.initialize() wrap the EP-sharded model afterwards. - Add PreTrainedModel.has_ep property; use it in tp_plan - get_init_context: meta device for EP+DS (not zero.Init) - from_pretrained: clear device_map for EP+DS - _load_pretrained_model: skip zero3 path for EP+DS, pass model.tp_plan - _move_missing_keys_from_meta_to_device: do not early-return for EP+DS - _initialize_missing_keys: standard init (no GatheredParameters) for EP+DS - configuration_utils: strip distributed_config from serialized config
1 parent 9dff7ca commit ae548bf

2 files changed

Lines changed: 40 additions & 7 deletions

File tree

src/transformers/configuration_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,7 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
11541154
"ignore_keys_at_rope_validation",
11551155
"base_model_tp_plan",
11561156
"base_model_pp_plan",
1157+
"distributed_config",
11571158
]:
11581159
d.pop(key_to_remove, None)
11591160

src/transformers/modeling_utils.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,12 +1330,18 @@ def post_init(self):
13301330
self.init_weights()
13311331
self._backward_compatibility_gradient_checkpointing()
13321332

1333+
@property
1334+
def has_ep(self) -> bool:
1335+
"""Whether expert parallelism is enabled for this model."""
1336+
distributed_config = getattr(getattr(self, "config", None), "distributed_config", None)
1337+
return distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
1338+
13331339
@property
13341340
def tp_plan(self) -> dict[str, str]:
13351341
"""
13361342
The full tp plan for the model's modules
13371343
"""
1338-
if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
1344+
if self.has_ep:
13391345
return self._ep_plan
13401346
return self._tp_plan
13411347

@@ -3599,14 +3605,27 @@ def float(self, *args):
35993605

36003606
@classmethod
36013607
def get_init_context(
3602-
cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool, allow_all_kernels: bool | None
3608+
cls,
3609+
dtype: torch.dtype,
3610+
is_quantized: bool,
3611+
_is_ds_init_called: bool,
3612+
allow_all_kernels: bool | None,
3613+
distributed_config=None,
36033614
):
36043615
# Need to instantiate with correct dtype
36053616
init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights(), apply_patches()]
36063617
# Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__
36073618
if allow_all_kernels:
36083619
init_contexts.append(allow_all_hub_kernels())
3609-
if is_deepspeed_zero3_enabled():
3620+
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
3621+
if _has_ep and is_deepspeed_zero3_enabled():
3622+
# EP + DeepSpeed: use meta device (same as the normal non-DS path).
3623+
# zero.Init is skipped because EP needs to shard experts via distribute_model()
3624+
# hooks, which are incompatible with ZeRO-3 lazy parameters.
3625+
# The standard weight loading path (not zero3) handles EP sharding via
3626+
# shard_and_distribute_module. deepspeed.initialize() wraps the result later.
3627+
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])
3628+
elif is_deepspeed_zero3_enabled():
36103629
import deepspeed
36113630

36123631
# We cannot initialize the model on meta device with deepspeed when not quantized
@@ -4007,6 +4026,12 @@ def from_pretrained(
40074026
download_kwargs_with_commit,
40084027
**adapter_kwargs,
40094028
)
4029+
# EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model
4030+
# loads on CPU first. distribute_model() handles GPU placement during EP sharding.
4031+
# Without this, device_map triggers accelerate's dispatch path which breaks shard loading.
4032+
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
4033+
if _has_ep and is_deepspeed_zero3_enabled():
4034+
device_map = None
40104035
device_map = check_and_set_device_map(device_map) # warn, error and fix the device map
40114036

40124037
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -4110,7 +4135,9 @@ def from_pretrained(
41104135

41114136
register_fusion_patches(cls, config, fusion_config)
41124137

4113-
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels)
4138+
model_init_context = cls.get_init_context(
4139+
dtype, is_quantized, _is_ds_init_called, allow_all_kernels, distributed_config
4140+
)
41144141

41154142
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
41164143
with ContextManagers(model_init_context):
@@ -4241,7 +4268,11 @@ def _load_pretrained_model(
42414268

42424269
error_msgs = []
42434270

4244-
if is_deepspeed_zero3_enabled() and not is_quantized:
4271+
# EP + DeepSpeed: skip zero3 loading path. The model was created on meta device
4272+
# (not via zero.Init), so params are not zero3-partitioned. The standard loading
4273+
# path handles EP sharding via shard_and_distribute_module using the EP plan hooks
4274+
# registered by distribute_model(). deepspeed.initialize() wraps the result later.
4275+
if is_deepspeed_zero3_enabled() and not is_quantized and not model.has_ep:
42454276
if state_dict is None:
42464277
merged_state_dict = {}
42474278
for ckpt_file in checkpoint_files:
@@ -4551,7 +4582,8 @@ def _move_missing_keys_from_meta_to_device(
45514582
"""
45524583
is_quantized = hf_quantizer is not None
45534584
# This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4554-
if is_deepspeed_zero3_enabled() and not is_quantized:
4585+
# Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path.
4586+
if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep:
45554587
return
45564588

45574589
# In this case we need to move everything back
@@ -4609,7 +4641,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None:
46094641
self._is_hf_initialized = True
46104642

46114643
# This will only initialize submodules that are not marked as initialized by the line above.
4612-
if is_deepspeed_zero3_enabled() and not is_quantized:
4644+
if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep:
46134645
import deepspeed
46144646

46154647
# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them

0 commit comments

Comments
 (0)