Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion megatron/core/models/mimo/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def load_state_dict(self, state_dict: Dict):

for sub_sd, inner_opt in _iter_optimizer_sub_dicts(module_sd, info.optimizer):
_restore_param_groups(sub_sd, inner_opt, name)
_restore_param_state_sharding_type(sub_sd)
_restore_grad_scaler(sub_sd)

info.optimizer.load_state_dict(module_sd)
Expand All @@ -175,6 +176,7 @@ def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False,
):
suffix = f'.{idx}' if idx > 0 else ''
_extract_param_groups(sub_sd, name, suffix, replica_id)
_extract_param_state_sharding_type(sub_sd, name, suffix, replica_id)
_extract_grad_scaler(sub_sd, name, suffix, replica_id)

sharded_state[name] = module_sd
Expand Down Expand Up @@ -218,6 +220,8 @@ def _extract_param_groups(sub_sd, module_name, suffix, replica_id):
replica_id=replica_id,
)
del opt_sub['param_groups']
if not opt_sub:
del sub_sd['optimizer']


def _extract_grad_scaler(sub_sd, module_name, suffix, replica_id):
Expand All @@ -232,6 +236,18 @@ def _extract_grad_scaler(sub_sd, module_name, suffix, replica_id):
)


def _extract_param_state_sharding_type(sub_sd, module_name, suffix, replica_id):
"""Save: extract param_state_sharding_type into a ShardedObject."""
if 'param_state_sharding_type' in sub_sd:
sub_sd[f'_mimo_param_state_sharding_type{suffix}'] = ShardedObject(
f'optimizer.mimo.{module_name}{suffix}.param_state_sharding_type',
sub_sd.pop('param_state_sharding_type'),
(1,),
(0,),
replica_id=replica_id,
)


def _restore_param_groups(sub_sd, inner_optimizer, module_name):
"""Load: restore param_groups with current param IDs from the inner optimizer."""
# Find the _mimo_param_groups key (may have a suffix for chained optimizers)
Expand All @@ -253,7 +269,21 @@ def _restore_param_groups(sub_sd, inner_optimizer, module_name):
)
for loaded_g, current_g in zip(loaded_pg, current_pg):
loaded_g['params'] = current_g['params']
sub_sd['optimizer']['param_groups'] = loaded_pg
# In MIMO, rank-local module optimizer metadata is not common across ranks
# (for example, non-colocated rank 0 may own language while rank 1 owns
# vision). Distributed checkpoint load can therefore return the sharded
# tensor state plus the extracted MIMO param groups, but without the
# original nested "optimizer" common-state wrapper. Recreate it here; the
# inner optimizer load path only needs param_groups from this wrapper.
sub_sd.setdefault('optimizer', {})['param_groups'] = loaded_pg


def _restore_param_state_sharding_type(sub_sd):
"""Load: restore param_state_sharding_type from ShardedObject key."""
for k in list(sub_sd.keys()):
if k.startswith('_mimo_param_state_sharding_type'):
sub_sd['param_state_sharding_type'] = sub_sd.pop(k)
break


def _restore_grad_scaler(sub_sd):
Expand Down