Skip to content

Commit 2a33e19

Browse files
committed
remove shared_moe flag
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 023b0a3 commit 2a33e19

3 files changed

Lines changed: 16 additions & 36 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@
214214
**_default_disabled_quantizer_cfg,
215215
**_mamba_moe_disabled_quantizer_cfg,
216216
},
217-
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
218217
}
219218

220219
MAMBA_MOE_FP8_CONSERVATIVE_CFG = {
@@ -226,7 +225,6 @@
226225
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
227226
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
228227
},
229-
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
230228
}
231229

232230
FP8_PER_CHANNEL_PER_TOKEN_CFG = {
@@ -437,7 +435,6 @@
437435
**_default_disabled_quantizer_cfg,
438436
**_mamba_moe_disabled_quantizer_cfg,
439437
},
440-
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
441438
}
442439
MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = {
443440
"quant_cfg": {
@@ -458,7 +455,6 @@
458455
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
459456
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
460457
},
461-
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
462458
}
463459

464460

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def _setup(self):
574574
expert.linear_fc1.parallel_state = self.parallel_state
575575
expert.linear_fc2.parallel_state = self.parallel_state
576576

577-
def layer_sync_moe_local_experts_amax(self, shared_moe_weight_scale=True):
578-
"""Sync input quantizer amax across local experts in a SequentialMLP, and optionally weight scale.
577+
def layer_sync_moe_local_experts_amax(self):
578+
"""Sync input quantizer amax across local experts in a SequentialMLP.
579579
580580
Ensures all experts have the same input quantizer amax.This function operates
581581
on a single rank and does not require distributed sync.
@@ -589,24 +589,19 @@ def layer_sync_moe_local_experts_amax(self, shared_moe_weight_scale=True):
589589
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
590590
when synchronizing over EP since some ranks may have amax None and not calling the collective
591591
communication.
592-
593-
Args:
594-
shared_moe_weight_scale: Whether to share the weight scale across local experts.
595592
"""
596593
# Collect amax from all local experts
597594
amax_dict = {}
598595
for expert in self.local_experts:
599596
for name, module in expert.named_modules():
600597
if isinstance(module, TensorQuantizer) and module.amax is not None:
601-
if shared_moe_weight_scale or ("weight_quantizer" not in name):
602-
# Sync both quantizers or only sync input quantizer
603-
stored_amax = amax_dict.get(name)
604-
amax_tensor = module.amax.detach().clone()
605-
amax_dict[name] = (
606-
amax_tensor
607-
if stored_amax is None
608-
else torch.maximum(stored_amax, amax_tensor)
609-
)
598+
stored_amax = amax_dict.get(name)
599+
amax_tensor = module.amax.detach().clone()
600+
amax_dict[name] = (
601+
amax_tensor
602+
if stored_amax is None
603+
else torch.maximum(stored_amax, amax_tensor)
604+
)
610605

611606
# Apply synchronized amax values back to all local experts
612607
for expert in self.local_experts:

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,7 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
734734

735735
@pytest.mark.parametrize("ep_size", [1, 2])
736736
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
737-
@pytest.mark.parametrize("shared_moe_weight_scale", [True, False])
738-
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe_weight_scale):
737+
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
739738
"""Test expert model parallel synchronization."""
740739
size = torch.cuda.device_count()
741740
if size < ep_size:
@@ -747,15 +746,12 @@ def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe
747746
_test_layer_sync_moe_local_experts_amax,
748747
ep_size,
749748
moe_grouped_gemm,
750-
shared_moe_weight_scale,
751749
),
752750
backend="nccl",
753751
)
754752

755753

756-
def _test_layer_sync_moe_local_experts_amax(
757-
ep_size, moe_grouped_gemm, shared_moe_weight_scale, rank, size
758-
):
754+
def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size):
759755
initialize_for_megatron(
760756
tensor_model_parallel_size=1,
761757
pipeline_model_parallel_size=1,
@@ -774,16 +770,14 @@ def _test_layer_sync_moe_local_experts_amax(
774770
transformer_impl="modelopt",
775771
)
776772
quant_cfg = mtq.FP8_DEFAULT_CFG
777-
if not shared_moe_weight_scale:
778-
quant_cfg = copy.deepcopy(quant_cfg)
779-
quant_cfg["algorithm"] = {"method": "max", "shared_moe_weight_scale": False}
780773
model = mtq.quantize(model, quant_cfg, get_forward(model))
781774

782775
# does layer_sync_moe_local_experts_amax happens in mtq.quantize if EP=1?
783776
for layer in model.decoder.layers:
784-
layer.mlp.experts.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)
777+
layer.mlp.experts.layer_sync_moe_local_experts_amax()
785778

786779
for layer in model.decoder.layers:
780+
# Check input quantizer amax is synced across local experts
787781
fc1_amax = None
788782
fc2_amax = None
789783
for expert in layer.mlp.experts.local_experts:
@@ -798,25 +792,20 @@ def _test_layer_sync_moe_local_experts_amax(
798792
else:
799793
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)
800794

801-
for layer in model.decoder.layers:
795+
# Check weight quantizer amax is different across local experts
802796
fc1_amax = None
803797
fc2_amax = None
804798
for expert in layer.mlp.experts.local_experts:
805799
assert expert.linear_fc1.weight_quantizer.amax is not None
806800
assert expert.linear_fc2.weight_quantizer.amax is not None
807801
if fc1_amax is None:
808802
fc1_amax = expert.linear_fc1.weight_quantizer.amax
809-
elif shared_moe_weight_scale:
810-
assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
811803
else:
812804
assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
813-
fc1_amax = expert.linear_fc1.weight_quantizer.amax # update most recent amax
814-
815805
if fc2_amax is None:
816806
fc2_amax = expert.linear_fc2.weight_quantizer.amax
817-
elif shared_moe_weight_scale:
818-
assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
819-
# FC2 amaxes are the same since the input to the layer is all the same
807+
else:
808+
assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
820809

821810

822811
def _test_expert_model_parallel_amax_sync(

0 commit comments

Comments
 (0)