Skip to content

Commit 28d5686

Browse files
committed
sync moe input quantizer only
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 9e38041 commit 28d5686

3 files changed

Lines changed: 106 additions & 18 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,20 @@ def _check_moe_calibration_complete(quantizer, parallel_state):
9595

9696

9797
@torch.no_grad()
98-
def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True):
98+
def max_calibrate(
99+
model: nn.Module,
100+
forward_loop: ForwardLoop | None = None,
101+
distributed_sync=True,
102+
shared_moe_weight_scale=True,
103+
):
99104
"""Calibrate the model using max.
100105
101106
Args:
102107
model: Model to be calibrated.
103108
forward_loop: A callable which takes the model as argument and
104109
forwards calibration data through the model.
110+
distributed_sync: Whether to sync amax across distributed processes.
111+
shared_moe_weight_scale: Whether to share the weight scale across local experts.
105112
106113
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
107114
details on the remaining arguments.
@@ -116,7 +123,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
116123
# Sync amax across local experts within each rank (for SequentialMLP)
117124
for name, module in model.named_modules():
118125
if hasattr(module, "layer_sync_moe_local_experts_amax"):
119-
module.layer_sync_moe_local_experts_amax()
126+
module.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)
120127

121128
if not distributed_sync:
122129
return

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -574,34 +574,39 @@ def _setup(self):
574574
expert.linear_fc1.parallel_state = self.parallel_state
575575
expert.linear_fc2.parallel_state = self.parallel_state
576576

577-
def layer_sync_moe_local_experts_amax(self):
578-
"""Sync amax across local experts in a SequentialMLP.
577+
def layer_sync_moe_local_experts_amax(self, shared_moe_weight_scale=True):
578+
"""Sync input quantizer amax across local experts in a SequentialMLP, and optionally weight scale.
579579
580-
Synchronize the amax values across local experts in a lyaer such that all local experts will
581-
share the same amax. This function operates on a single rank and does not require distributed sync.
580+
Ensures all experts have the same input quantizer amax.This function operates
581+
on a single rank and does not require distributed sync.
582582
583583
Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
584584
This function should be called before the distributed sync to ensure the amax values
585585
are synchronized across the layer first.
586586
587587
Note:
588588
Because there are logic which calls collective communication based on whether amax is not None,
589-
We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
590-
when synchroizing over EP since some ranks may have amax None and not calling the collective
589+
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
590+
when synchronizing over EP since some ranks may have amax None and not calling the collective
591591
communication.
592+
593+
Args:
594+
shared_moe_weight_scale: Whether to share the weight scale across local experts.
592595
"""
593596
# Collect amax from all local experts
594597
amax_dict = {}
595598
for expert in self.local_experts:
596599
for name, module in expert.named_modules():
597600
if isinstance(module, TensorQuantizer) and module.amax is not None:
598-
stored_amax = amax_dict.get(name)
599-
amax_tensor = module.amax.detach().clone()
600-
amax_dict[name] = (
601-
amax_tensor
602-
if stored_amax is None
603-
else torch.maximum(stored_amax, amax_tensor)
604-
)
601+
if shared_moe_weight_scale or ("weight_quantizer" not in name):
602+
# Sync both quantizers or only sync input quantizer
603+
stored_amax = amax_dict.get(name)
604+
amax_tensor = module.amax.detach().clone()
605+
amax_dict[name] = (
606+
amax_tensor
607+
if stored_amax is None
608+
else torch.maximum(stored_amax, amax_tensor)
609+
)
605610

606611
# Apply synchronized amax values back to all local experts
607612
for expert in self.local_experts:

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,85 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
735735
)
736736

737737

738+
@pytest.mark.parametrize("ep_size", [1, 2])
739+
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
740+
@pytest.mark.parametrize("shared_moe_weight_scale", [True, False])
741+
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe_weight_scale):
742+
"""Test expert model parallel synchronization."""
743+
size = torch.cuda.device_count()
744+
if size < ep_size:
745+
pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test")
746+
747+
spawn_multiprocess_job(
748+
size=size,
749+
job=partial(
750+
_test_layer_sync_moe_local_experts_amax,
751+
ep_size,
752+
moe_grouped_gemm,
753+
shared_moe_weight_scale,
754+
),
755+
backend="nccl",
756+
)
757+
758+
759+
def _test_layer_sync_moe_local_experts_amax(
760+
ep_size, moe_grouped_gemm, shared_moe_weight_scale, rank, size
761+
):
762+
initialize_for_megatron(
763+
tensor_model_parallel_size=1,
764+
pipeline_model_parallel_size=1,
765+
expert_model_parallel_size=ep_size,
766+
expert_tensor_parallel_size=1,
767+
seed=SEED,
768+
)
769+
model = _gpt_model_provider(
770+
tp_size=1,
771+
ep_size=ep_size,
772+
etp_size=1,
773+
hidden_size=256,
774+
moe_grouped_gemm=moe_grouped_gemm,
775+
use_te=moe_grouped_gemm,
776+
num_moe_experts=8,
777+
transformer_impl="modelopt",
778+
)
779+
model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model))
780+
781+
# Sync amax across local experts in each layer
782+
for layer in model.decoder.layers:
783+
layer.mlp.experts.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)
784+
785+
for layer in model.decoder.layers:
786+
fc1_amax = None
787+
fc2_amax = None
788+
for expert in layer.mlp.experts.local_experts:
789+
assert expert.linear_fc1.input_quantizer.amax is not None
790+
assert expert.linear_fc2.input_quantizer.amax is not None
791+
if fc1_amax is None:
792+
fc1_amax = expert.linear_fc1.input_quantizer.amax
793+
else:
794+
assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax)
795+
if fc2_amax is None:
796+
fc2_amax = expert.linear_fc2.input_quantizer.amax
797+
else:
798+
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)
799+
800+
if shared_moe_weight_scale:
801+
for layer in model.decoder.layers:
802+
fc1_amax = None
803+
fc2_amax = None
804+
for expert in layer.mlp.experts.local_experts:
805+
assert expert.linear_fc1.weight_quantizer.amax is not None
806+
assert expert.linear_fc2.weight_quantizer.amax is not None
807+
if fc1_amax is None:
808+
fc1_amax = expert.linear_fc1.weight_quantizer.amax
809+
else:
810+
assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
811+
if fc2_amax is None:
812+
fc2_amax = expert.linear_fc2.weight_quantizer.amax
813+
else:
814+
assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
815+
816+
738817
def _test_expert_model_parallel_amax_sync(
739818
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size
740819
):
@@ -815,9 +894,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
815894
if size < ep_size * etp_size:
816895
pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")
817896

818-
if moe_grouped_gemm:
819-
pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
820-
821897
spawn_multiprocess_job(
822898
size=size,
823899
job=partial(

0 commit comments

Comments
 (0)