Skip to content

Commit f74cd24

Browse files
committed
update configs with shared_moe_weight_scale
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 28d5686 commit f74cd24

2 files changed

Lines changed: 11 additions & 8 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@
214214
**_default_disabled_quantizer_cfg,
215215
**_mamba_moe_disabled_quantizer_cfg,
216216
},
217-
"algorithm": "max",
217+
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
218218
}
219219

220220
MAMBA_MOE_FP8_CONSERVATIVE_CFG = {
@@ -226,7 +226,7 @@
226226
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
227227
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
228228
},
229-
"algorithm": "max",
229+
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
230230
}
231231

232232
FP8_PER_CHANNEL_PER_TOKEN_CFG = {
@@ -437,7 +437,7 @@
437437
**_default_disabled_quantizer_cfg,
438438
**_mamba_moe_disabled_quantizer_cfg,
439439
},
440-
"algorithm": "max",
440+
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
441441
}
442442
MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = {
443443
"quant_cfg": {
@@ -458,7 +458,7 @@
458458
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
459459
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
460460
},
461-
"algorithm": "max",
461+
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
462462
}
463463

464464

@@ -1087,6 +1087,12 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
10871087
description="If True, the amax will be synced across the distributed processes.",
10881088
)
10891089

1090+
shared_moe_weight_scale: bool | None = ModeloptField(
1091+
default=True,
1092+
title="Whether to share the weight scale across local experts.",
1093+
description="If True, the weight scale will be shared across local experts.",
1094+
)
1095+
10901096

10911097
class MseCalibConfig(QuantizeAlgorithmConfig):
10921098
"""Configuration for per-tensor MSE calibration.

tests/gpu_megatron/torch/quantization/plugins/test_megatron.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device,
473473

474474
@pytest.mark.parametrize(
475475
"config",
476-
[
477-
NVFP4_GEMM_KV_CFG,
478-
FP8_GEMM_KV_CFG,
479-
],
476+
[NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG],
480477
)
481478
def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config):
482479
"""Test sharded state dict for hybrid Mamba MOE models."""

0 commit comments

Comments
 (0)