Skip to content

Commit 307fe71

Browse files
authored
Fix QuantSequentialMLP sharded_state_dict (#742)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Bug **Overview:** ? These fixes are needed for Megatron-LM `main` branch due to some changes in `sharded_state_dict`. Qwen3-30B-A3B PTQ and resume fails while EP=4 cannot load a checkpoint generated with PP=4. `singleton_local_shards` must be added to the metadata; otherwise, all experts `amax` are packed to gather and currently the TP `replica_id` for `linear_fc1` is incorrect. **Other Finding:** This limits TP=ETP=1 when EP>1. Otherwise, there will be `sharded_state_dict` access error. There is a potential blind spot of using the default TP group in `ColumnParallelLinear` and `RowParallelLinear` since it can be part of the MoE where the tensor parallelism is controlled by ETP instead. Will need a different PR to fix the parallel_state. **Results:** If calibrate with EP=1, mmlu = 0.80. This can be resumed with EP=4, TP=1, ETP=1 (TP>1 does not work as mentioned above). However if calibrated with EP=4, then mmlu = 0.71 which shows there are some issues with max sync in EP. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 6f18490 commit 307fe71

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

modelopt/torch/opt/plugins/megatron.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@ def _setup(self):
155155
pass
156156

157157
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
158+
"""Overriding the default to support scalar sharding.
159+
160+
Note:
161+
singleton_local_shards needs to be added to the metadata as well as
162+
apply_swiglu_sharded_factory to handle the swiglu case.
163+
"""
164+
if metadata is None:
165+
metadata = {}
166+
metadata["singleton_local_shards"] = True
158167
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
159168
if not self.config.gated_linear_unit:
160169
return sharded_state_dict
@@ -163,6 +172,8 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
163172
re.compile(pattern).match(k) for pattern in self._modelopt_state_keys
164173
):
165174
sharded_state_dict[k] = megatron_mlp.apply_swiglu_sharded_factory(
166-
v, sharded_offsets
175+
v,
176+
sharded_offsets,
177+
metadata["singleton_local_shards"],
167178
)
168179
return sharded_state_dict

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
3434
from megatron.core.utils import get_tensor_model_parallel_group_if_none
3535

36+
from modelopt.torch.opt.dynamic import DynamicModule
3637
from modelopt.torch.opt.plugins.megatron import (
3738
_MegatronMLP,
3839
ensure_metadata_has_dp_cp_group,
@@ -551,8 +552,16 @@ def forward(self, input, *args, **kwargs):
551552

552553

553554
@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
554-
class _MegatronSequentialMLP(_MegatronMLP):
555+
class _MegatronSequentialMLP(DynamicModule):
555556
def _setup(self):
557+
if (
558+
self.config.expert_model_parallel_size > 1
559+
and self.config.tensor_model_parallel_size > 1
560+
):
561+
raise ValueError(
562+
"TP+EP is not supported by QuantSequentialMLP. Set either TP or EP to 1!"
563+
)
564+
556565
if not hasattr(self, "parallel_state") or self.parallel_state is None:
557566
self.parallel_state = ParallelState(
558567
mcore_parallel.get_expert_data_parallel_group(),
@@ -592,6 +601,21 @@ def sync_moe_local_experts_amax(self):
592601
if isinstance(module, TensorQuantizer) and module.amax is not None:
593602
module.amax = amax_dict[name].detach().clone().to(module.amax.device)
594603

604+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
605+
"""Override the default to enable singleton_local_shards.
606+
607+
Note:
608+
singleton_local_shards must be added to the metadata; otherwise, all experts
609+
amax are packed to gather and currently the TP replica_id for linear_fc1
610+
is incorrect. This limits TP=ETP=1 when EP>1. Otherwise, there will be
611+
sharded_state_dict access error.
612+
"""
613+
if metadata is None:
614+
metadata = {}
615+
metadata["singleton_local_shards"] = True
616+
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
617+
return sharded_state_dict
618+
595619

596620
if HAS_TE:
597621

0 commit comments

Comments
 (0)