Skip to content

Commit 02fa362

Browse files
authored
Sync MOE layer input quantizer only (#903)
## What does this PR do? **Type of change:** Bug fix **Overview:** in MOE layer we currently sync both the weight and input quantizers so that all experts have the same weight amaxes and activation amaxes. VLLM/TRTLLM actually support non-uniform weight amaxes in MOE so we only need to sync the activation amaxes. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved input quantizer synchronization for Mixture of Experts models to ensure correct amax value handling across local experts. * **Documentation** * Fixed typos and clarified wording in quantization documentation. * **Tests** * Added test coverage for Mixture of Experts quantizer synchronization functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 70ffb6f commit 02fa362

File tree

3 files changed

+93
-15
lines changed

3 files changed

+93
-15
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,18 @@ def _check_moe_calibration_complete(quantizer, parallel_state):
9696

9797

9898
@torch.no_grad()
99-
def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True):
99+
def max_calibrate(
100+
model: nn.Module,
101+
forward_loop: ForwardLoop | None = None,
102+
distributed_sync=True,
103+
):
100104
"""Calibrate the model using max.
101105
102106
Args:
103107
model: Model to be calibrated.
104108
forward_loop: A callable which takes the model as argument and
105109
forwards calibration data through the model.
110+
distributed_sync: Whether to sync input_quantizer amax across distributed processes.
106111
107112
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
108113
details on the remaining arguments.
@@ -114,7 +119,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
114119
forward_loop(model)
115120
finish_stats_collection(model)
116121

117-
# Sync amax across local experts within each rank (for SequentialMLP)
122+
# Sync input_quantizer amax across local experts within each rank (for SequentialMLP)
118123
for name, module in model.named_modules():
119124
if hasattr(module, "layer_sync_moe_local_experts_amax"):
120125
module.layer_sync_moe_local_experts_amax()

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -575,26 +575,30 @@ def _setup(self):
575575
expert.linear_fc2.parallel_state = self.parallel_state
576576

577577
def layer_sync_moe_local_experts_amax(self):
578-
"""Sync amax across local experts in a SequentialMLP.
578+
"""Sync input quantizer amax across local experts in a SequentialMLP.
579579
580-
Synchronize the amax values across local experts in a lyaer such that all local experts will
581-
share the same amax. This function operates on a single rank and does not require distributed sync.
580+
Ensures all experts have the same input quantizer amax.This function operates
581+
on a single rank and does not require distributed sync.
582582
583583
Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
584584
This function should be called before the distributed sync to ensure the amax values
585585
are synchronized across the layer first.
586586
587587
Note:
588588
Because there are logic which calls collective communication based on whether amax is not None,
589-
We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
590-
when synchroizing over EP since some ranks may have amax None and not calling the collective
589+
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
590+
when synchronizing over EP since some ranks may have amax None and not calling the collective
591591
communication.
592592
"""
593593
# Collect amax from all local experts
594594
amax_dict = {}
595595
for expert in self.local_experts:
596596
for name, module in expert.named_modules():
597-
if isinstance(module, TensorQuantizer) and module.amax is not None:
597+
if (
598+
isinstance(module, TensorQuantizer)
599+
and module.amax is not None
600+
and "input_quantizer" in name
601+
):
598602
stored_amax = amax_dict.get(name)
599603
amax_tensor = module.amax.detach().clone()
600604
amax_dict[name] = (

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device,
469469

470470
@pytest.mark.parametrize(
471471
"config",
472-
[
473-
NVFP4_GEMM_KV_CFG,
474-
FP8_GEMM_KV_CFG,
475-
],
472+
[NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG],
476473
)
477474
def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config):
478475
"""Test sharded state dict for hybrid Mamba MOE models."""
@@ -731,6 +728,81 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
731728
)
732729

733730

731+
@pytest.mark.parametrize("ep_size", [1, 2])
732+
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
733+
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
734+
"""Test expert model parallel synchronization."""
735+
size = torch.cuda.device_count()
736+
if size < ep_size:
737+
pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test")
738+
739+
spawn_multiprocess_job(
740+
size=size,
741+
job=partial(
742+
_test_layer_sync_moe_local_experts_amax,
743+
ep_size,
744+
moe_grouped_gemm,
745+
),
746+
backend="nccl",
747+
)
748+
749+
750+
def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size):
751+
initialize_for_megatron(
752+
tensor_model_parallel_size=1,
753+
pipeline_model_parallel_size=1,
754+
expert_model_parallel_size=ep_size,
755+
expert_tensor_parallel_size=1,
756+
seed=SEED,
757+
)
758+
model = _gpt_model_provider(
759+
tp_size=1,
760+
ep_size=ep_size,
761+
etp_size=1,
762+
hidden_size=256,
763+
moe_grouped_gemm=moe_grouped_gemm,
764+
use_te=moe_grouped_gemm,
765+
num_moe_experts=8,
766+
transformer_impl="modelopt",
767+
)
768+
quant_cfg = mtq.FP8_DEFAULT_CFG
769+
model = mtq.quantize(model, quant_cfg, get_forward(model))
770+
771+
for layer in model.decoder.layers:
772+
layer.mlp.experts.layer_sync_moe_local_experts_amax()
773+
774+
for layer in model.decoder.layers:
775+
# Check input quantizer amax is synced across local experts
776+
fc1_amax = None
777+
fc2_amax = None
778+
for expert in layer.mlp.experts.local_experts:
779+
assert expert.linear_fc1.input_quantizer.amax is not None
780+
assert expert.linear_fc2.input_quantizer.amax is not None
781+
if fc1_amax is None:
782+
fc1_amax = expert.linear_fc1.input_quantizer.amax
783+
else:
784+
assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax)
785+
if fc2_amax is None:
786+
fc2_amax = expert.linear_fc2.input_quantizer.amax
787+
else:
788+
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)
789+
790+
# Check weight quantizer amax is different across local experts
791+
fc1_amax = None
792+
fc2_amax = None
793+
for expert in layer.mlp.experts.local_experts:
794+
assert expert.linear_fc1.weight_quantizer.amax is not None
795+
assert expert.linear_fc2.weight_quantizer.amax is not None
796+
if fc1_amax is None:
797+
fc1_amax = expert.linear_fc1.weight_quantizer.amax
798+
else:
799+
assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
800+
if fc2_amax is None:
801+
fc2_amax = expert.linear_fc2.weight_quantizer.amax
802+
else:
803+
assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
804+
805+
734806
def _test_expert_model_parallel_amax_sync(
735807
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size
736808
):
@@ -811,9 +883,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
811883
if size < ep_size * etp_size:
812884
pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
813885

814-
if moe_grouped_gemm:
815-
pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
816-
817886
spawn_multiprocess_job(
818887
size=size,
819888
job=partial(

0 commit comments

Comments
 (0)