Skip to content

Commit 8b0a564

Browse files
committed
shared_moe_weight_scale parameter
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent e504526 commit 8b0a564

3 files changed

Lines changed: 54 additions & 25 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,20 @@ def _check_moe_calibration_complete(quantizer, parallel_state):
9595

9696

9797
@torch.no_grad()
98-
def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True):
98+
def max_calibrate(
99+
model: nn.Module,
100+
forward_loop: ForwardLoop | None = None,
101+
distributed_sync=True,
102+
shared_moe_weight_scale=True,
103+
):
99104
"""Calibrate the model using max.
100105
101106
Args:
102107
model: Model to be calibrated.
103108
forward_loop: A callable which takes the model as argument and
104109
forwards calibration data through the model.
110+
distributed_sync: Whether to sync amax across distributed processes.
111+
shared_moe_weight_scale: Whether to share the weight scale across local experts.
105112
106113
See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
107114
details on the remaining arguments.
@@ -116,7 +123,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
116123
# Sync amax across local experts within each rank (for SequentialMLP)
117124
for name, module in model.named_modules():
118125
if hasattr(module, "layer_sync_moe_local_experts_amax"):
119-
module.layer_sync_moe_local_experts_amax()
126+
module.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)
120127

121128
if not distributed_sync:
122129
return

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def _setup(self):
574574
expert.linear_fc1.parallel_state = self.parallel_state
575575
expert.linear_fc2.parallel_state = self.parallel_state
576576

577-
def layer_sync_moe_local_experts_amax(self):
578-
"""Sync input quantizer amax across local experts in a SequentialMLP.
577+
def layer_sync_moe_local_experts_amax(self, shared_moe_weight_scale=True):
578+
"""Sync input quantizer amax across local experts in a SequentialMLP, and optionally weight scale.
579579
580580
Ensures all experts have the same input quantizer amax.This function operates
581581
on a single rank and does not require distributed sync.
@@ -589,33 +589,35 @@ def layer_sync_moe_local_experts_amax(self):
589589
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
590590
when synchronizing over EP since some ranks may have amax None and not calling the collective
591591
communication.
592+
593+
Args:
594+
shared_moe_weight_scale: Whether to share the weight scale across local experts.
592595
"""
593596
# Collect amax from all local experts
594597
amax_dict = {}
595598
for expert in self.local_experts:
596599
for name, module in expert.named_modules():
597-
if (
598-
isinstance(module, TensorQuantizer)
599-
and module.amax is not None
600-
and "input_quantizer" in name
601-
):
602-
stored_amax = amax_dict.get(name)
603-
amax_tensor = module.amax.detach().clone()
604-
amax_dict[name] = (
605-
amax_tensor
606-
if stored_amax is None
607-
else torch.maximum(stored_amax, amax_tensor)
608-
)
600+
if isinstance(module, TensorQuantizer) and module.amax is not None:
601+
if shared_moe_weight_scale or (
602+
not shared_moe_weight_scale and "input_quantizer" in name
603+
):
604+
# Sync both quantizers or only sync input quantizer
605+
stored_amax = amax_dict.get(name)
606+
amax_tensor = module.amax.detach().clone()
607+
amax_dict[name] = (
608+
amax_tensor
609+
if stored_amax is None
610+
else torch.maximum(stored_amax, amax_tensor)
611+
)
609612

610613
# Apply synchronized amax values back to all local experts
611614
for expert in self.local_experts:
612615
for name, module in expert.named_modules():
613-
if (
614-
isinstance(module, TensorQuantizer)
615-
and name in amax_dict
616-
and "input_quantizer" in name
617-
):
618-
module.amax = amax_dict[name].detach().clone()
616+
if isinstance(module, TensorQuantizer) and name in amax_dict:
617+
if shared_moe_weight_scale or (
618+
not shared_moe_weight_scale and "input_quantizer" in name
619+
):
620+
module.amax = amax_dict[name].detach().clone()
619621

620622
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
621623
"""Override the default to enable singleton_local_shards.

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,8 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
737737

738738
@pytest.mark.parametrize("ep_size", [1, 2])
739739
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
740-
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
740+
@pytest.mark.parametrize("shared_moe_weight_scale", [True, False])
741+
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe_weight_scale):
741742
"""Test expert model parallel synchronization."""
742743
size = torch.cuda.device_count()
743744
if size < ep_size:
@@ -749,12 +750,15 @@ def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
749750
_test_layer_sync_moe_local_experts_amax,
750751
ep_size,
751752
moe_grouped_gemm,
753+
shared_moe_weight_scale,
752754
),
753755
backend="nccl",
754756
)
755757

756758

757-
def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size):
759+
def _test_layer_sync_moe_local_experts_amax(
760+
ep_size, moe_grouped_gemm, shared_moe_weight_scale, rank, size
761+
):
758762
initialize_for_megatron(
759763
tensor_model_parallel_size=1,
760764
pipeline_model_parallel_size=1,
@@ -776,7 +780,7 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz
776780

777781
# Sync amax across local experts in each layer
778782
for layer in model.decoder.layers:
779-
layer.mlp.experts.layer_sync_moe_local_experts_amax()
783+
layer.mlp.experts.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)
780784

781785
for layer in model.decoder.layers:
782786
fc1_amax = None
@@ -793,6 +797,22 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz
793797
else:
794798
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)
795799

800+
if shared_moe_weight_scale:
801+
for layer in model.decoder.layers:
802+
fc1_amax = None
803+
fc2_amax = None
804+
for expert in layer.mlp.experts.local_experts:
805+
assert expert.linear_fc1.weight_quantizer.amax is not None
806+
assert expert.linear_fc2.weight_quantizer.amax is not None
807+
if fc1_amax is None:
808+
fc1_amax = expert.linear_fc1.weight_quantizer.amax
809+
else:
810+
assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
811+
if fc2_amax is None:
812+
fc2_amax = expert.linear_fc2.weight_quantizer.amax
813+
else:
814+
assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
815+
796816

797817
def _test_expert_model_parallel_amax_sync(
798818
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size

0 commit comments

Comments
 (0)