Skip to content

Commit 1efb8c1

Browse files
committed
fix test
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent cfff760 commit 1efb8c1

2 files changed

Lines changed: 46 additions & 18 deletions

File tree

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,11 @@ def layer_sync_moe_local_experts_amax(self):
594594
amax_dict = {}
595595
for expert in self.local_experts:
596596
for name, module in expert.named_modules():
597+
print(name, module)
597598
if (
598599
isinstance(module, TensorQuantizer)
599600
and module.amax is not None
600-
and name == "input_quantizer"
601+
and "input_quantizer" in name
601602
):
602603
stored_amax = amax_dict.get(name)
603604
amax_tensor = module.amax.detach().clone()
@@ -613,7 +614,7 @@ def layer_sync_moe_local_experts_amax(self):
613614
if (
614615
isinstance(module, TensorQuantizer)
615616
and name in amax_dict
616-
and name == "input_quantizer"
617+
and "input_quantizer" in name
617618
):
618619
module.amax = amax_dict[name].detach().clone()
619620

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -735,37 +735,63 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
735735
)
736736

737737

738-
def test_layer_sync_moe_local_experts_amax(moe_grouped_gemm):
738+
@pytest.mark.parametrize("ep_size", [1, 2])
739+
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
740+
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm):
741+
"""Test expert model parallel synchronization."""
742+
size = torch.cuda.device_count()
743+
if size < ep_size:
744+
pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test")
745+
746+
spawn_multiprocess_job(
747+
size=size,
748+
job=partial(
749+
_test_layer_sync_moe_local_experts_amax,
750+
ep_size,
751+
moe_grouped_gemm,
752+
),
753+
backend="nccl",
754+
)
755+
756+
757+
def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size):
739758
initialize_for_megatron(
740759
tensor_model_parallel_size=1,
741760
pipeline_model_parallel_size=1,
742-
expert_model_parallel_size=2,
761+
expert_model_parallel_size=ep_size,
743762
expert_tensor_parallel_size=1,
744763
seed=SEED,
745764
)
746765
model = _gpt_model_provider(
747766
tp_size=1,
748-
ep_size=2,
767+
ep_size=ep_size,
749768
etp_size=1,
750769
hidden_size=256,
751770
moe_grouped_gemm=moe_grouped_gemm,
752771
use_te=moe_grouped_gemm,
753772
num_moe_experts=8,
754773
transformer_impl="modelopt",
755774
)
756-
# model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model))
757-
forward = get_forward(model)
758-
forward()
759-
print(model)
760-
761-
model.layer_sync_moe_local_experts_amax()
762-
prev_amax = None
763-
for expert in model.local_experts:
764-
assert expert.input_quantizer.amax is not None
765-
if prev_amax is None:
766-
prev_amax = expert.input_quantizer.amax
767-
else:
768-
assert torch.allclose(prev_amax, expert.input_quantizer.amax)
775+
model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model))
776+
777+
# Sync amax across local experts in each layer
778+
for layer in model.decoder.layers:
779+
layer.mlp.experts.layer_sync_moe_local_experts_amax()
780+
781+
for layer in model.decoder.layers:
782+
fc1_amax = None
783+
fc2_amax = None
784+
for expert in layer.mlp.experts.local_experts:
785+
assert expert.linear_fc1.input_quantizer.amax is not None
786+
assert expert.linear_fc2.input_quantizer.amax is not None
787+
if fc1_amax is None:
788+
fc1_amax = expert.linear_fc1.input_quantizer.amax
789+
else:
790+
assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax)
791+
if fc2_amax is None:
792+
fc2_amax = expert.linear_fc2.input_quantizer.amax
793+
else:
794+
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)
769795

770796

771797
def _test_expert_model_parallel_amax_sync(
@@ -815,6 +841,7 @@ def _test_expert_model_parallel_amax_sync(
815841

816842
# quantize the model
817843
model = mtq.quantize(model, config, forward)
844+
818845
# Check initial sync status
819846
initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model)
820847
assert initial_sync, (

0 commit comments

Comments
 (0)