Skip to content

Commit 35d0f52

Browse files
ChenhanYudanielkorzekwa
authored andcommitted
Fix Sequential MLP amax sync deadlock (#862)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Bug fix **Overview:** ? After `QuantMoELayer`, we rely on `layer_sync_moe_local_experts_amax` to first perform local sync. This is supposed to create `input_quantizer.amax` for all experts but the current logic will only update experts that already have `amax`. This results in some experts are still missing `amax`. With the fact above, `sync_quantizer_amax_across_dp_ep` will actually deadlock seems the collective is called based on whether `quantizer._amax is None`. Any expert with `None` amax will not call collective hence will never arrive the collective and cause a deadlock. We fix `layer_sync_moe_local_experts_amax` such that even if an expert does not have `amax`, we will overwrite it with a clone of the global amax. The post condition should be all experts have `amax` and the pre condition of `sync_quantizer_amax_across_dp_ep` should be the same. **Note:** we found that `_check_moe_calibration_complete` actually didn't raise any error even some experts have no amax. Didn't look into this problem. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved synchronization of quantization parameters for Mixture of Experts (MoE) models with more flexible configuration support. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent b11d49b commit 35d0f52

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
115115

116116
# Sync amax across local experts within each rank (for SequentialMLP)
117117
for name, module in model.named_modules():
118-
if hasattr(module, "sync_moe_local_experts_amax"):
119-
module.sync_moe_local_experts_amax()
118+
if hasattr(module, "layer_sync_moe_local_experts_amax"):
119+
module.layer_sync_moe_local_experts_amax()
120120

121121
if not distributed_sync:
122122
return

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,12 @@ def layer_sync_moe_local_experts_amax(self):
583583
Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
584584
This function should be called before the distributed sync to ensure the amax values
585585
are synchronized across the layer first.
586+
587+
Note:
588+
Because there are logic which calls collective communication based on whether amax is not None,
589+
We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
590+
when synchroizing over EP since some ranks may have amax None and not calling the collective
591+
communication.
586592
"""
587593
# Collect amax from all local experts
588594
amax_dict = {}
@@ -600,8 +606,8 @@ def layer_sync_moe_local_experts_amax(self):
600606
# Apply synchronized amax values back to all local experts
601607
for expert in self.local_experts:
602608
for name, module in expert.named_modules():
603-
if isinstance(module, TensorQuantizer) and module.amax is not None:
604-
module.amax = amax_dict[name].detach().clone().to(module.amax.device)
609+
if isinstance(module, TensorQuantizer) and name in amax_dict:
610+
module.amax = amax_dict[name].detach().clone()
605611

606612
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
607613
"""Override the default to enable singleton_local_shards.

0 commit comments

Comments
 (0)