Skip to content

Commit d77a859

Browse files
committed
Merge branch 'mergeability-pr-45548' into all-defects
2 parents c038ff2 + d98bad4 commit d77a859

2 files changed

Lines changed: 33 additions & 6 deletions

File tree

src/transformers/configuration_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,7 @@ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
11611161
"ignore_keys_at_rope_validation",
11621162
"base_model_tp_plan",
11631163
"base_model_pp_plan",
1164+
"distributed_config",
11641165
]:
11651166
d.pop(key_to_remove, None)
11661167

src/transformers/modeling_utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3680,14 +3680,27 @@ def float(self, *args):
36803680

36813681
@classmethod
36823682
def get_init_context(
3683-
cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_called: bool, allow_all_kernels: bool | None
3683+
cls,
3684+
dtype: torch.dtype,
3685+
is_quantized: bool,
3686+
_is_ds_init_called: bool,
3687+
allow_all_kernels: bool | None,
3688+
distributed_config=None,
36843689
):
36853690
# Need to instantiate with correct dtype
36863691
init_contexts = [local_torch_dtype(dtype, cls.__name__), init.no_tie_weights(), apply_patches()]
36873692
# Needed as we cannot forward the `allow_all_kernels` arg in the model's __init__
36883693
if allow_all_kernels:
36893694
init_contexts.append(allow_all_hub_kernels())
3690-
if is_deepspeed_zero3_enabled():
3695+
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
3696+
if _has_ep and is_deepspeed_zero3_enabled():
3697+
# EP + DeepSpeed: use meta device (same as the normal non-DS path).
3698+
# zero.Init is skipped because EP needs to shard experts via distribute_model()
3699+
# hooks, which are incompatible with ZeRO-3 lazy parameters.
3700+
# The standard weight loading path (not zero3) handles EP sharding via
3701+
# shard_and_distribute_module. deepspeed.initialize() wraps the result later.
3702+
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])
3703+
elif is_deepspeed_zero3_enabled():
36913704
import deepspeed
36923705

36933706
# We cannot initialize the model on meta device with deepspeed when not quantized
@@ -4095,6 +4108,12 @@ def from_pretrained(
40954108
download_kwargs_with_commit,
40964109
**adapter_kwargs,
40974110
)
4111+
# EP + DeepSpeed: clear device_map (set by initialize_tensor_parallelism) so the model
4112+
# loads on CPU first. distribute_model() handles GPU placement during EP sharding.
4113+
# Without this, device_map triggers accelerate's dispatch path which breaks shard loading.
4114+
_has_ep = distributed_config is not None and getattr(distributed_config, "enable_expert_parallel", False)
4115+
if _has_ep and is_deepspeed_zero3_enabled():
4116+
device_map = None
40984117
device_map = check_and_set_device_map(device_map) # warn, error and fix the device map
40994118

41004119
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -4203,7 +4222,9 @@ def from_pretrained(
42034222

42044223
register_fusion_patches(cls, config, fusion_config)
42054224

4206-
model_init_context = cls.get_init_context(dtype, is_quantized, _is_ds_init_called, allow_all_kernels)
4225+
model_init_context = cls.get_init_context(
4226+
dtype, is_quantized, _is_ds_init_called, allow_all_kernels, distributed_config
4227+
)
42074228

42084229
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
42094230
with ContextManagers(model_init_context):
@@ -4336,7 +4357,11 @@ def _load_pretrained_model(
43364357

43374358
error_msgs = []
43384359

4339-
if is_deepspeed_zero3_enabled() and not is_quantized:
4360+
# EP + DeepSpeed: skip zero3 loading path. The model was created on meta device
4361+
# (not via zero.Init), so params are not zero3-partitioned. The standard loading
4362+
# path handles EP sharding via shard_and_distribute_module using the EP plan hooks
4363+
# registered by distribute_model(). deepspeed.initialize() wraps the result later.
4364+
if is_deepspeed_zero3_enabled() and not is_quantized and not model.has_ep:
43404365
if state_dict is None:
43414366
merged_state_dict = {}
43424367
for ckpt_file in checkpoint_files:
@@ -4655,7 +4680,8 @@ def _move_missing_keys_from_meta_to_device(
46554680
"""
46564681
is_quantized = hf_quantizer is not None
46574682
# This is the only case where we do not initialize the model on meta device, so we don't have to do anything here
4658-
if is_deepspeed_zero3_enabled() and not is_quantized:
4683+
# Exception: EP + DeepSpeed uses meta device (not zero.Init), so it needs the standard move path.
4684+
if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep:
46594685
return
46604686

46614687
# Leave parameters on meta on non-rank-0 FSDP ranks (rank-0 broadcast overwrites them); only buffers need real placeholders.
@@ -4710,7 +4736,7 @@ def _initialize_missing_keys(self, is_quantized: bool) -> None:
47104736
self._is_hf_initialized = True
47114737

47124738
# This will only initialize submodules that are not marked as initialized by the line above.
4713-
if is_deepspeed_zero3_enabled() and not is_quantized:
4739+
if is_deepspeed_zero3_enabled() and not is_quantized and not self.has_ep:
47144740
import deepspeed
47154741

47164742
# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them

0 commit comments

Comments
 (0)