Skip to content

Commit 966c8ac

Browse files
author
niushengxiao
committed
feat: add --expert_dtype param
1 parent 05ec73e commit 966c8ac

12 files changed

Lines changed: 87 additions & 24 deletions

File tree

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,14 @@ PD 分离模式参数
464464

465465
示例可以在 test/advanced_config/mixed_quantization/llamacls-mix-down.yaml 中找到。
466466

467+
.. option:: --expert_dtype
468+
469+
EP MoE 专家量化类型,可选值:
470+
471+
* ``fp8``
472+
* ``fp4``,仅支持 SM100 GPU
473+
* ``None`` (默认)
474+
467475
.. option:: --vit_quant_type
468476

469477
ViT 量化方法,可选值:

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,14 @@ Quantization Parameters
465465

466466
Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.
467467

468+
.. option:: --expert_dtype
469+
470+
Expert quantization dtype for EP MoE, optional values:
471+
472+
* ``fp8``
473+
* ``fp4``: SM100 GPUs only
474+
* ``None`` (default)
475+
468476
.. option:: --vit_quant_type
469477

470478
ViT quantization method, optional values:

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, kvargs):
8585
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
8686
self.quant_type = kvargs.get("quant_type", "none")
8787
self.quant_cfg_path = kvargs.get("quant_cfg", None)
88+
self.expert_dtype = kvargs.get("expert_dtype", None)
8889
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
8990
self.tp_world_size_ = get_dp_world_size()
9091
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
@@ -156,7 +157,7 @@ def _verify_params(self):
156157
return
157158

158159
def _init_quant(self):
159-
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path)
160+
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path, self.expert_dtype)
160161
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")
161162

162163
def _init_weights(self, start_layer_index=0):

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
get_ep_num_sms,
1414
masked_group_gemm,
1515
deepgemm_grouped_fp8_nt_contiguous,
16-
use_sm100_fp4_moe,
16+
use_sm100_mega_moe,
1717
)
1818
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import (
1919
per_token_group_quant_fp8,
@@ -153,7 +153,7 @@ def select_experts_and_quant_input(
153153
scoring_func=scoring_func,
154154
)
155155
w13_weight, w13_scale = w13.weight, w13.weight_scale
156-
if use_sm100_fp4_moe(self.quant_method):
156+
if use_sm100_mega_moe(self.quant_method):
157157
from deep_gemm.utils import per_token_cast_to_fp8
158158

159159
qinput_tensor = per_token_cast_to_fp8(

lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
logger = init_logger(__name__)
2424
_MEGA_MOE_STATES: Dict[Tuple[int, int, int, int], Dict[str, Any]] = {}
25+
SUPPORTED_EP_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32")
2526

2627
try:
2728
from deep_ep import Buffer, EventOverlap
@@ -37,10 +38,25 @@ def get_ep_num_sms() -> int:
3738
return getattr(dist_group_manager, "ep_num_sms", None) or 0
3839

3940

40-
def use_sm100_fp4_moe(quant_method: Any) -> bool:
41+
def use_sm100_mega_moe(quant_method: Any) -> bool:
4142
return is_sm100_gpu() and quant_method.method_name == "deepgemm-fp4fp8-b32"
4243

4344

45+
def check_ep_expert_dtype(quant_method: Any):
46+
expert_dtype = getattr(quant_method, "method_name", None)
47+
if expert_dtype not in SUPPORTED_EP_EXPERT_DTYPES:
48+
raise ValueError(
49+
"EP MoE requires --expert_dtype to be one of ['fp8', 'fp4'], "
50+
f"but the resolved fused_moe quant method is `{expert_dtype}`. "
51+
"Please start with --expert_dtype fp8 or --expert_dtype fp4. "
52+
"Note that --expert_dtype fp4 is only supported on SM100 GPUs."
53+
)
54+
if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu():
55+
raise RuntimeError(
56+
"--expert_dtype fp4 requires an SM100 GPU for EP MoE; " "please use --expert_dtype fp8 on non-SM100 GPUs."
57+
)
58+
59+
4460
def masked_group_gemm(
4561
recv_x: Tuple[torch.Tensor, torch.Tensor],
4662
masked_m: torch.Tensor,
@@ -155,10 +171,10 @@ def do_fused_experts(
155171
is_prefill: Optional[bool],
156172
previous_event: Optional[Any] = None,
157173
):
158-
if use_sm100_fp4_moe(quant_method):
174+
check_ep_expert_dtype(quant_method)
175+
if use_sm100_mega_moe(quant_method):
159176
return mega_moe_impl(hidden_states, w13, w2, topk_weights, topk_idx, quant_method)
160177

161-
use_fp8_w8a8 = quant_method.method_name != "none"
162178
buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer
163179
return fused_experts_impl(
164180
hidden_states=hidden_states,
@@ -169,8 +185,8 @@ def do_fused_experts(
169185
num_experts=num_experts,
170186
buffer=buffer,
171187
is_prefill=is_prefill,
172-
use_fp8_w8a8=use_fp8_w8a8,
173-
use_fp8_all2all=use_fp8_w8a8,
188+
use_fp8_w8a8=True,
189+
use_fp8_all2all=True,
174190
use_int8_w8a16=False,
175191
w1_scale=w13.weight_scale,
176192
w2_scale=w2.weight_scale,

lightllm/common/quantization/__init__.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,42 @@
77
from .awq import *
88
from .no_quant import *
99
from lightllm.utils.log_utils import init_logger
10+
from lightllm.utils.device_utils import is_sm100_gpu
1011

1112
logger = init_logger(__name__)
1213

14+
EXPERT_DTYPE_TO_QUANT_TYPE = {
15+
"fp8": "deepgemm-fp8w8a8-b128",
16+
"fp4": "deepgemm-fp4fp8-b32",
17+
}
18+
SUPPORTED_EXPERT_DTYPES = tuple(EXPERT_DTYPE_TO_QUANT_TYPE)
19+
1320

1421
class Quantcfg:
15-
def __init__(self, network_config, quant_type="none", custom_cfg_path=None):
22+
def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expert_dtype=None):
1623
self.layer_num = network_config["n_layer"]
1724
self.quant_type = quant_type
25+
self.expert_dtype = expert_dtype
1826
self.network_config_ = network_config
1927
self._parse_custom_cfg(custom_cfg_path)
2028
self._parse_network_config(network_config)
29+
self._apply_custom_expert_dtype(expert_dtype)
30+
31+
def _apply_custom_expert_dtype(self, expert_dtype):
32+
if expert_dtype is None:
33+
return
34+
quant_type = self._get_expert_quant_type(expert_dtype, "--expert_dtype")
35+
for layer_num in range(self.layer_num):
36+
self.quant_cfg[layer_num]["fused_moe"] = quant_type
37+
logger.info(f"select fused_moe quant way from --expert_dtype=`{expert_dtype}`: {quant_type}")
38+
39+
def _get_expert_quant_type(self, expert_dtype, source):
40+
quant_type = EXPERT_DTYPE_TO_QUANT_TYPE.get(expert_dtype)
41+
if quant_type is None:
42+
raise ValueError(f"unsupported {source} `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}")
43+
if expert_dtype == "fp4" and not is_sm100_gpu():
44+
raise RuntimeError(f"{source} `fp4` requires an SM100 GPU; please use `fp8` on non-SM100 GPUs.")
45+
return quant_type
2146

2247
def _parse_network_config(self, network_config):
2348
hf_quantization_config = network_config.get("quantization_config", None)
@@ -47,18 +72,9 @@ def _mapping_quant_method(self):
4772

4873
# fp8 量化下,部分 MoE 模型(如 DeepSeek-V4),可以单独声明 expert 权重精度,
4974
# 按其值给 fused_moe 选用对应的 deepgemm 量化方法。
50-
expert_dtype = self.network_config_.get("expert_dtype", None)
75+
expert_dtype = None if self.expert_dtype is not None else self.network_config_.get("expert_dtype", None)
5176
if expert_dtype is not None:
52-
expert_dtype_to_quant_type = {
53-
"fp4": "deepgemm-fp4fp8-b32",
54-
"fp8": "deepgemm-fp8w8a8-b128",
55-
}
56-
target = expert_dtype_to_quant_type.get(expert_dtype)
57-
if target is None:
58-
raise ValueError(
59-
f"unsupported expert_dtype `{expert_dtype}`; "
60-
f"expected one of {sorted(expert_dtype_to_quant_type)}"
61-
)
77+
target = self._get_expert_quant_type(expert_dtype, "network config expert_dtype")
6278
for layer_num in range(self.layer_num):
6379
self.quant_cfg[layer_num].setdefault("fused_moe", target)
6480
logger.info(f"select fused_moe quant way from expert_dtype=`{expert_dtype}`: {target}")

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
88
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
99
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
10+
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe
1011
from functools import partial
1112
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
1213
from lightllm.utils.envs_utils import get_env_start_args
@@ -295,7 +296,7 @@ def overlap_tpsp_token_forward(
295296
infer_state1: Deepseek2InferStateInfo,
296297
layer_weight: Deepseek2TransformerLayerWeight,
297298
):
298-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
299+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
299300
return super().overlap_tpsp_token_forward(
300301
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
301302
)
@@ -421,7 +422,7 @@ def overlap_tpsp_context_forward(
421422
infer_state1: Deepseek2InferStateInfo,
422423
layer_weight: Deepseek2TransformerLayerWeight,
423424
):
424-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
425+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
425426
return super().overlap_tpsp_context_forward(
426427
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
427428
)

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
77
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
88
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
9+
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe
910
from lightllm.utils.dist_utils import get_global_world_size
1011
from lightllm.utils.envs_utils import get_env_start_args
1112

@@ -133,7 +134,7 @@ def overlap_tpsp_token_forward(
133134
infer_state1: LlamaInferStateInfo,
134135
layer_weight: Qwen3MOETransformerLayerWeight,
135136
):
136-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
137+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
137138
return super().overlap_tpsp_token_forward(
138139
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
139140
)
@@ -245,7 +246,7 @@ def overlap_tpsp_context_forward(
245246
infer_state1: LlamaInferStateInfo,
246247
layer_weight: Qwen3MOETransformerLayerWeight,
247248
):
248-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
249+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
249250
return super().overlap_tpsp_context_forward(
250251
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
251252
)

lightllm/server/api_cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
620620
help="""Path of quantization config. It can be used for mixed quantization.
621621
Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.""",
622622
)
623+
parser.add_argument(
624+
"--expert_dtype",
625+
type=str,
626+
default=None,
627+
choices=["fp8", "fp4"],
628+
help="""Expert quantization dtype for EP MoE. Supported values are
629+
fp8 and fp4. Note that fp4 is only supported on SM100 GPUs.""",
630+
)
623631
parser.add_argument(
624632
"--vit_quant_type",
625633
type=str,

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class StartArgs:
133133
graph_max_len_in_batch: int = field(default=0)
134134
quant_type: Optional[str] = field(default=None)
135135
quant_cfg: Optional[str] = field(default=None)
136+
expert_dtype: Optional[str] = field(default=None, metadata={"choices": ["fp8", "fp4"]})
136137
vit_quant_type: Optional[str] = field(default=None)
137138
vit_quant_cfg: Optional[str] = field(default=None)
138139
llm_prefill_att_backend: List[str] = field(

0 commit comments

Comments
 (0)