Skip to content

Commit 2a08622

Browse files
authored
Fix moe amax remedy for dsr1 and remove global barrier in quantization megatron plugins (#808)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Bug **Overview:** ? This PR fix 2 bugs which impact DeepSeek calibration as well as PP forward of MoE models. 1. The WAR in `MoELayer` that change the `topk` to `num_experts` only works if no group-topk (a.k.a group routing) is used. Only changing topk will lead to out-of-range error since `topk` can never be `num_experts` when `group_topk != None`. Currently only `DeepSeek-V3` uses `group_topk` and DeepSeek-V3 does not have difficulty to calibrate all experts. As a result, we disable the WAR when detecting `group_topk`. 2. A previous PR inserted global barrier in `quantization.plugin.megatron` 6ef9954#diff-0fa2ba4ecc36c5ff031be9f9a5af080e7aa3afa331c438f02f501b9432ec6d6aL228-R515 This leads to dead lock when using PP since PP rank will never be able to sync during pipeline forward. For MoE, this can be even worse if the barrier is only visited by some EP/PP rank. Using collective communication over the global world (a.k.a global comm) in megatron plugin should be prohibited. Using collective on sub communication group should avoid using `megatron.core.parallel_state` (a.k.a `mpu`) in the future. Instead, use the local `pg_collection` from each module. Any usage of collective communication must be inspected carefully with test as PP, TP, and EP. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent b44c60a commit 2a08622

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ def sync_moe_local_experts_amax(self):
581581
This function is called to synchronize the amax values across local experts s.t. all localexperts will
582582
share the same amax.
583583
"""
584-
torch.distributed.barrier()
585584
# Collect amax from all local experts
586585
amax_dict = {}
587586
for expert in self.local_experts:
@@ -754,8 +753,11 @@ def _setup(self):
754753

755754
def forward(self, hidden_states):
756755
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
757-
original_top_k = self.router.topk
758-
self.router.topk = self.router.num_experts
759-
super().forward(hidden_states)
760-
self.router.topk = original_top_k
756+
if self.config.moe_router_num_groups is None:
757+
original_top_k = self.router.topk
758+
self.router.topk = self.router.num_experts
759+
super().forward(hidden_states)
760+
self.router.topk = original_top_k
761+
else:
762+
super().forward(hidden_states)
761763
return super().forward(hidden_states)

0 commit comments

Comments
 (0)