|
7 | 7 | from .awq import * |
8 | 8 | from .no_quant import * |
9 | 9 | from lightllm.utils.log_utils import init_logger |
| 10 | +from lightllm.utils.device_utils import is_sm100_gpu |
10 | 11 |
|
11 | 12 | logger = init_logger(__name__) |
12 | 13 |
|
| 14 | +EXPERT_DTYPE_TO_QUANT_TYPE = { |
| 15 | + "fp8": "deepgemm-fp8w8a8-b128", |
| 16 | + "fp4": "deepgemm-fp4fp8-b32", |
| 17 | +} |
| 18 | +SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE) |
| 19 | + |
13 | 20 |
|
14 | 21 | class Quantcfg: |
15 | | - def __init__(self, network_config, quant_type="none", custom_cfg_path=None): |
| 22 | + def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expert_dtype=None): |
16 | 23 | self.layer_num = network_config["n_layer"] |
17 | 24 | self.quant_type = quant_type |
| 25 | + self.expert_dtype = expert_dtype |
18 | 26 | self.network_config_ = network_config |
19 | 27 | self._parse_custom_cfg(custom_cfg_path) |
20 | 28 | self._parse_network_config(network_config) |
| 29 | + self._apply_custom_expert_dtype(expert_dtype) |
| 30 | + |
| 31 | + def _apply_custom_expert_dtype(self, expert_dtype): |
| 32 | + if expert_dtype is None: |
| 33 | + return |
| 34 | + quant_type = self._get_expert_quant_type(expert_dtype, "--expert_dtype") |
| 35 | + for layer_num in range(self.layer_num): |
| 36 | + self.quant_cfg[layer_num]["fused_moe"] = quant_type |
| 37 | + logger.info(f"select fused_moe quant way from --expert_dtype=`{expert_dtype}`: {quant_type}") |
| 38 | + |
| 39 | + def _get_expert_quant_type(self, expert_dtype, source): |
| 40 | + quant_type = EXPERT_DTYPE_TO_QUANT_TYPE.get(expert_dtype) |
| 41 | + if quant_type is None: |
| 42 | + raise ValueError(f"unsupported {source} `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}") |
| 43 | + if expert_dtype == "fp4" and not is_sm100_gpu(): |
| 44 | + raise RuntimeError(f"{source} `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.") |
| 45 | + return quant_type |
21 | 46 |
|
22 | 47 | def _parse_network_config(self, network_config): |
23 | 48 | hf_quantization_config = network_config.get("quantization_config", None) |
@@ -47,18 +72,9 @@ def _mapping_quant_method(self): |
47 | 72 |
|
48 | 73 | # fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度, |
49 | 74 | # 按其值给 fused_moe 选用对应的 deepgemm 量化方法。 |
50 | | - expert_dtype = self.network_config_.get("expert_dtype", None) |
| 75 | + expert_dtype = None if self.expert_dtype is not None else self.network_config_.get("expert_dtype", None) |
51 | 76 | if expert_dtype is not None: |
52 | | - expert_dtype_to_quant_type = { |
53 | | - "fp4": "deepgemm-fp4fp8-b32", |
54 | | - "fp8": "deepgemm-fp8w8a8-b128", |
55 | | - } |
56 | | - target = expert_dtype_to_quant_type.get(expert_dtype) |
57 | | - if target is None: |
58 | | - raise ValueError( |
59 | | - f"unsupported expert_dtype `{expert_dtype}`; " |
60 | | - f"expected one of {sorted(expert_dtype_to_quant_type)}" |
61 | | - ) |
| 77 | + target = self._get_expert_quant_type(expert_dtype, "network config expert_dtype") |
62 | 78 | for layer_num in range(self.layer_num): |
63 | 79 | self.quant_cfg[layer_num].setdefault("fused_moe", target) |
64 | 80 | logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}") |
|
0 commit comments