Skip to content

Commit ad84a0d

Browse files
jenchen13danielkorzekwa
authored andcommitted
Mamba MOE Quant Configs + Fix Export Bug (#882)
## What does this PR do? **Type of change:** ? Bug fix **Overview:** ? - Fix a bug in MCore export `exclude_modules` where the layers had an extra period at the end - Add custom quant configs for mamba moes ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added four new Mamba MOE quantization configurations: aggressive and conservative variants for both FP8 and NVFP4 quantization schemes, providing enhanced flexibility in quantization options for different use cases. * **Bug Fixes** * Improved quantization export module exclusion pattern handling to properly normalize trailing dots from exclude patterns during export. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jennifer Chen <jennifchen@nvidia.com> Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
1 parent 0bd4313 commit ad84a0d

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

modelopt/torch/quantization/calib/max.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def collect(self, x):
6666
if x.device.type == "meta":
6767
self._calib_amax = local_amax
6868
return
69+
assert not torch.any(torch.isnan(local_amax)), (
70+
f"detected nan values in amax. nan in original tensor: {torch.any(torch.isnan(x))}"
71+
)
6972
assert torch.all(local_amax >= 0), (
7073
"detected negative values after abs, could be torch or cuda bug"
7174
)
7275
assert not torch.any(torch.isinf(local_amax)), (
7376
f"detected inf values in amax. inf in original tensor: {torch.any(torch.isinf(x))}"
7477
)
75-
assert not torch.any(torch.isnan(local_amax)), (
76-
f"detected nan values in amax. nan in original tensor: {torch.any(torch.isnan(x))}"
77-
)
7878
if self._calib_amax is None:
7979
self._calib_amax = local_amax
8080
else:

modelopt/torch/quantization/config.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,21 @@
156156
"*mlp.gate.*": {"enable": False}, # Skip the MOE router
157157
"*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router
158158
"*linear_attn.conv1d*": {"enable": False},
159-
"*mixer.conv1d*": {"enable": False},
159+
"*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d
160160
"*output_layer*": {"enable": False},
161161
"output.*": {"enable": False},
162162
"default": {"enable": False},
163163
}
164164

165+
_mamba_moe_disabled_quantizer_cfg = {
166+
"*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE
167+
"*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE
168+
"*q_proj*": {"enable": False}, # Skip QKV Linear
169+
"*k_proj*": {"enable": False}, # Skip QKV Linear
170+
"*v_proj*": {"enable": False}, # Skip QKV Linear
171+
"*o_proj*": {"enable": False}, # Skip QKV Output Projection
172+
}
173+
165174
INT8_DEFAULT_CFG = {
166175
"quant_cfg": {
167176
"*weight_quantizer": {"num_bits": 8, "axis": 0},
@@ -198,6 +207,28 @@
198207
"algorithm": "max",
199208
}
200209

210+
MAMBA_MOE_FP8_AGGRESSIVE_CFG = {
211+
"quant_cfg": {
212+
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
213+
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
214+
**_default_disabled_quantizer_cfg,
215+
**_mamba_moe_disabled_quantizer_cfg,
216+
},
217+
"algorithm": "max",
218+
}
219+
220+
MAMBA_MOE_FP8_CONSERVATIVE_CFG = {
221+
"quant_cfg": {
222+
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
223+
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
224+
**_default_disabled_quantizer_cfg,
225+
**_mamba_moe_disabled_quantizer_cfg,
226+
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
227+
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
228+
},
229+
"algorithm": "max",
230+
}
231+
201232
FP8_PER_CHANNEL_PER_TOKEN_CFG = {
202233
"quant_cfg": {
203234
"*weight_quantizer": {"num_bits": (4, 3), "axis": 0},
@@ -388,6 +419,49 @@
388419
"algorithm": "max",
389420
}
390421

422+
423+
MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = {
424+
"quant_cfg": {
425+
"*weight_quantizer": {
426+
"num_bits": (2, 1),
427+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
428+
"axis": None,
429+
"enable": True,
430+
},
431+
"*input_quantizer": {
432+
"num_bits": (2, 1),
433+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
434+
"axis": None,
435+
"enable": True,
436+
},
437+
**_default_disabled_quantizer_cfg,
438+
**_mamba_moe_disabled_quantizer_cfg,
439+
},
440+
"algorithm": "max",
441+
}
442+
MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = {
443+
"quant_cfg": {
444+
"*weight_quantizer": {
445+
"num_bits": (2, 1),
446+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
447+
"axis": None,
448+
"enable": True,
449+
},
450+
"*input_quantizer": {
451+
"num_bits": (2, 1),
452+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
453+
"axis": None,
454+
"enable": True,
455+
},
456+
**_default_disabled_quantizer_cfg,
457+
**_mamba_moe_disabled_quantizer_cfg,
458+
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
459+
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
460+
},
461+
"algorithm": "max",
462+
}
463+
464+
391465
NVFP4_AWQ_LITE_CFG = {
392466
"quant_cfg": {
393467
"*weight_quantizer": {
@@ -652,6 +726,10 @@
652726
"NVFP4_MLP_WEIGHT_ONLY_CFG",
653727
"MXFP4_MLP_WEIGHT_ONLY_CFG",
654728
"NVFP4_MLP_ONLY_CFG",
729+
"MAMBA_MOE_NVFP4_CONSERVATIVE_CFG",
730+
"MAMBA_MOE_NVFP4_AGGRESSIVE_CFG",
731+
"MAMBA_MOE_FP8_CONSERVATIVE_CFG",
732+
"MAMBA_MOE_FP8_AGGRESSIVE_CFG",
655733
}
656734

657735
BiasType = Literal["static", "dynamic"]

0 commit comments

Comments
 (0)