|
15 | 15 |
|
16 | 16 | """Utilities for Mixture-of-Experts (MoE) model export.""" |
17 | 17 |
|
| 18 | +import copy |
| 19 | +import warnings |
18 | 20 | from pathlib import Path |
19 | 21 |
|
| 22 | +import torch |
20 | 23 | import torch.nn as nn |
21 | 24 |
|
22 | 25 |
|
| 26 | +def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: |
| 27 | + """Split fused MoE expert weights and export per-expert quantization scales. |
| 28 | +
|
| 29 | + Works with any module wrapped by ``_QuantFusedExperts`` — i.e. any HF |
| 30 | + transformers 5.0+ fused expert container that stores ``gate_up_proj`` and |
| 31 | + ``down_proj`` as 3-D ``nn.Parameter`` tensors with per-expert quantizer |
| 32 | + ``nn.ModuleList`` s. |
| 33 | +
|
| 34 | + Steps: |
| 35 | +
|
| 36 | + 1. Handle amax fallback for uncalibrated expert input quantizers. |
| 37 | + 2. Split fused 3-D weights into per-expert 2-D projections |
| 38 | + (``gate_proj``, ``up_proj``, ``down_proj``). |
| 39 | + 3. Call ``_export_quantized_weight`` on each projection. |
| 40 | + 4. Register results under the standard naming convention:: |
| 41 | +
|
| 42 | + {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... |
| 43 | + {E}.up_proj.weight, {E}.up_proj.weight_scale, ... |
| 44 | + {E}.down_proj.weight, {E}.down_proj.weight_scale, ... |
| 45 | + """ |
| 46 | + from modelopt.torch.export.unified_export_hf import _export_quantized_weight |
| 47 | + from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim |
| 48 | + |
| 49 | + n = module.num_experts |
| 50 | + expert_dim = _get_fused_expert_intermediate_dim(module) |
| 51 | + |
| 52 | + # 1. Shared input quantizers — one per projection type, shared across all experts. |
| 53 | + gate_up_input_q = module.gate_up_proj_input_quantizer |
| 54 | + down_input_q = module.down_proj_input_quantizer |
| 55 | + |
| 56 | + gate_up = module.gate_up_proj.data |
| 57 | + down = module.down_proj.data |
| 58 | + |
| 59 | + # 2-3. Split + export each per-expert projection. |
| 60 | + fused_dim0 = gate_up.shape[1] # 2 * expert_dim |
| 61 | + |
| 62 | + for idx in range(n): |
| 63 | + expert = nn.Module() |
| 64 | + |
| 65 | + projections = [ |
| 66 | + ("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True), |
| 67 | + ("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True), |
| 68 | + ("down_proj", down[idx], 0, down.shape[1], False), |
| 69 | + ] |
| 70 | + |
| 71 | + for proj_name, weight_slice, fused_start, fused_total, is_gate_up in projections: |
| 72 | + w_quantizer_src = ( |
| 73 | + module.gate_up_proj_weight_quantizers[idx] |
| 74 | + if is_gate_up |
| 75 | + else module.down_proj_weight_quantizers[idx] |
| 76 | + ) |
| 77 | + i_quantizer = gate_up_input_q if is_gate_up else down_input_q |
| 78 | + |
| 79 | + # gate/up share a weight quantizer — clone so each gets independent amax. |
| 80 | + w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src |
| 81 | + |
| 82 | + # For per-channel amax (dim >= 1), proportionally slice dim-0 |
| 83 | + # to match the split weight. |
| 84 | + if ( |
| 85 | + hasattr(w_quantizer, "_amax") |
| 86 | + and w_quantizer._amax is not None |
| 87 | + and w_quantizer._amax.dim() >= 1 |
| 88 | + ): |
| 89 | + amax = w_quantizer._amax |
| 90 | + amax_dim0 = amax.shape[0] |
| 91 | + if fused_total % amax_dim0 == 0: |
| 92 | + slice_start = fused_start * amax_dim0 // fused_total |
| 93 | + slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total |
| 94 | + w_quantizer.amax = amax[slice_start:slice_end].contiguous() |
| 95 | + else: |
| 96 | + warnings.warn( |
| 97 | + f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " |
| 98 | + f"evenly divide fused_total ({fused_total}). Skipping amax slicing, " |
| 99 | + f"which may produce incorrect quantization scales.", |
| 100 | + stacklevel=2, |
| 101 | + ) |
| 102 | + |
| 103 | + # If the weight quantizer was never calibrated, compute amax from weights. |
| 104 | + if ( |
| 105 | + hasattr(w_quantizer, "is_enabled") |
| 106 | + and w_quantizer.is_enabled |
| 107 | + and ( |
| 108 | + not hasattr(w_quantizer, "_amax") |
| 109 | + or w_quantizer._amax is None |
| 110 | + or torch.all(w_quantizer._amax == 0) |
| 111 | + ) |
| 112 | + ): |
| 113 | + w_quantizer.amax = weight_slice.abs().amax().to(torch.float32) |
| 114 | + warnings.warn( |
| 115 | + f"Expert {idx} {proj_name} weight quantizer was not calibrated " |
| 116 | + f"(amax missing or zero). Using weight-derived amax as fallback. " |
| 117 | + f"Consider using more calibration data to activate all experts.", |
| 118 | + stacklevel=2, |
| 119 | + ) |
| 120 | + |
| 121 | + wrapper = nn.Module() |
| 122 | + wrapper.weight = nn.Parameter(weight_slice.contiguous(), requires_grad=False) |
| 123 | + wrapper.weight_quantizer = w_quantizer |
| 124 | + wrapper.input_quantizer = i_quantizer |
| 125 | + |
| 126 | + _export_quantized_weight(wrapper, dtype) |
| 127 | + |
| 128 | + proj = nn.Module() |
| 129 | + proj.weight = wrapper.weight |
| 130 | + for attr in ("weight_scale", "weight_scale_2", "input_scale"): |
| 131 | + if hasattr(wrapper, attr): |
| 132 | + proj.register_buffer(attr, getattr(wrapper, attr)) |
| 133 | + |
| 134 | + expert.add_module(proj_name, proj) |
| 135 | + |
| 136 | + module.add_module(str(idx), expert) |
| 137 | + |
| 138 | + # 4. Remove fused params and quantizer lists — replaced by per-expert submodules |
| 139 | + for attr in ( |
| 140 | + "gate_up_proj", |
| 141 | + "down_proj", |
| 142 | + "gate_up_proj_weight_quantizers", |
| 143 | + "gate_up_proj_input_quantizer", |
| 144 | + "down_proj_weight_quantizers", |
| 145 | + "down_proj_input_quantizer", |
| 146 | + ): |
| 147 | + if hasattr(module, attr): |
| 148 | + delattr(module, attr) |
| 149 | + |
| 150 | + |
23 | 151 | def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): |
24 | 152 | """Collect expert_token_count from all quantized MoE layers and save as an HTML table. |
25 | 153 |
|
|
0 commit comments