Skip to content

Commit cfff760

Browse files
committed
sync moe input quantizer only
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 9e38041 commit cfff760

2 files changed

Lines changed: 48 additions & 7 deletions

File tree

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -575,26 +575,30 @@ def _setup(self):
575575
expert.linear_fc2.parallel_state = self.parallel_state
576576

577577
def layer_sync_moe_local_experts_amax(self):
578-
"""Sync amax across local experts in a SequentialMLP.
578+
"""Sync input quantizer amax across local experts in a SequentialMLP.
579579
580-
Synchronize the amax values across local experts in a lyaer such that all local experts will
581-
share the same amax. This function operates on a single rank and does not require distributed sync.
580+
Ensures all experts have the same input quantizer amax.This function operates
581+
on a single rank and does not require distributed sync.
582582
583583
Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
584584
This function should be called before the distributed sync to ensure the amax values
585585
are synchronized across the layer first.
586586
587587
Note:
588588
Because there are logic which calls collective communication based on whether amax is not None,
589-
We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
590-
when synchroizing over EP since some ranks may have amax None and not calling the collective
589+
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
590+
when synchronizing over EP since some ranks may have amax None and not calling the collective
591591
communication.
592592
"""
593593
# Collect amax from all local experts
594594
amax_dict = {}
595595
for expert in self.local_experts:
596596
for name, module in expert.named_modules():
597-
if isinstance(module, TensorQuantizer) and module.amax is not None:
597+
if (
598+
isinstance(module, TensorQuantizer)
599+
and module.amax is not None
600+
and name == "input_quantizer"
601+
):
598602
stored_amax = amax_dict.get(name)
599603
amax_tensor = module.amax.detach().clone()
600604
amax_dict[name] = (
@@ -606,7 +610,11 @@ def layer_sync_moe_local_experts_amax(self):
606610
# Apply synchronized amax values back to all local experts
607611
for expert in self.local_experts:
608612
for name, module in expert.named_modules():
609-
if isinstance(module, TensorQuantizer) and name in amax_dict:
613+
if (
614+
isinstance(module, TensorQuantizer)
615+
and name in amax_dict
616+
and name == "input_quantizer"
617+
):
610618
module.amax = amax_dict[name].detach().clone()
611619

612620
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,39 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
735735
)
736736

737737

738+
def test_layer_sync_moe_local_experts_amax(moe_grouped_gemm):
739+
initialize_for_megatron(
740+
tensor_model_parallel_size=1,
741+
pipeline_model_parallel_size=1,
742+
expert_model_parallel_size=2,
743+
expert_tensor_parallel_size=1,
744+
seed=SEED,
745+
)
746+
model = _gpt_model_provider(
747+
tp_size=1,
748+
ep_size=2,
749+
etp_size=1,
750+
hidden_size=256,
751+
moe_grouped_gemm=moe_grouped_gemm,
752+
use_te=moe_grouped_gemm,
753+
num_moe_experts=8,
754+
transformer_impl="modelopt",
755+
)
756+
# model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model))
757+
forward = get_forward(model)
758+
forward()
759+
print(model)
760+
761+
model.layer_sync_moe_local_experts_amax()
762+
prev_amax = None
763+
for expert in model.local_experts:
764+
assert expert.input_quantizer.amax is not None
765+
if prev_amax is None:
766+
prev_amax = expert.input_quantizer.amax
767+
else:
768+
assert torch.allclose(prev_amax, expert.input_quantizer.amax)
769+
770+
738771
def _test_expert_model_parallel_amax_sync(
739772
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size
740773
):

0 commit comments

Comments
 (0)