Skip to content

Commit f911956

Browse files
author
niushengxiao
committed
feat: add --expert_dtype param
1 parent 05ec73e commit f911956

12 files changed

Lines changed: 84 additions & 13 deletions

File tree

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,14 @@ PD 分离模式参数
464464

465465
示例可以在 test/advanced_config/mixed_quantization/llamacls-mix-down.yaml 中找到。
466466

467+
.. option:: --expert_dtype
468+
469+
EP MoE 专家量化类型,可选值:
470+
471+
* ``deepgemm-fp8w8a8-b128``
472+
* ``deepgemm-fp4fp8-b32``,仅支持 SM100 GPU
473+
* ``None`` (默认)
474+
467475
.. option:: --vit_quant_type
468476

469477
ViT 量化方法,可选值:

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,14 @@ Quantization Parameters
465465

466466
Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.
467467

468+
.. option:: --expert_dtype
469+
470+
Expert quantization dtype for EP MoE, optional values:
471+
472+
* ``deepgemm-fp8w8a8-b128``
473+
* ``deepgemm-fp4fp8-b32``: SM100 GPUs only
474+
* ``None`` (default)
475+
468476
.. option:: --vit_quant_type
469477

470478
ViT quantization method, optional values:

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, kvargs):
8585
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
8686
self.quant_type = kvargs.get("quant_type", "none")
8787
self.quant_cfg_path = kvargs.get("quant_cfg", None)
88+
self.expert_dtype = kvargs.get("expert_dtype", None)
8889
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
8990
self.tp_world_size_ = get_dp_world_size()
9091
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
@@ -156,7 +157,7 @@ def _verify_params(self):
156157
return
157158

158159
def _init_quant(self):
159-
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path)
160+
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path, self.expert_dtype)
160161
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")
161162

162163
def _init_weights(self, start_layer_index=0):

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
get_ep_num_sms,
1414
masked_group_gemm,
1515
deepgemm_grouped_fp8_nt_contiguous,
16-
use_sm100_fp4_moe,
16+
use_sm100_mega_moe,
1717
)
1818
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import (
1919
per_token_group_quant_fp8,
@@ -153,7 +153,7 @@ def select_experts_and_quant_input(
153153
scoring_func=scoring_func,
154154
)
155155
w13_weight, w13_scale = w13.weight, w13.weight_scale
156-
if use_sm100_fp4_moe(self.quant_method):
156+
if use_sm100_mega_moe(self.quant_method):
157157
from deep_gemm.utils import per_token_cast_to_fp8
158158

159159
qinput_tensor = per_token_cast_to_fp8(

lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
logger = init_logger(__name__)
2424
_MEGA_MOE_STATES: Dict[Tuple[int, int, int, int], Dict[str, Any]] = {}
25+
SUPPORTED_EP_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32")
2526

2627
try:
2728
from deep_ep import Buffer, EventOverlap
@@ -37,10 +38,27 @@ def get_ep_num_sms() -> int:
3738
return getattr(dist_group_manager, "ep_num_sms", None) or 0
3839

3940

40-
def use_sm100_fp4_moe(quant_method: Any) -> bool:
41+
def use_sm100_mega_moe(quant_method: Any) -> bool:
4142
return is_sm100_gpu() and quant_method.method_name == "deepgemm-fp4fp8-b32"
4243

4344

45+
def check_ep_expert_dtype(quant_method: Any):
46+
expert_dtype = getattr(quant_method, "method_name", None)
47+
if expert_dtype not in SUPPORTED_EP_EXPERT_DTYPES:
48+
raise ValueError(
49+
"EP MoE requires --expert_dtype to be one of "
50+
f"{list(SUPPORTED_EP_EXPERT_DTYPES)}, but got `{expert_dtype}`. "
51+
"Please start with --expert_dtype deepgemm-fp8w8a8-b128 or "
52+
"--expert_dtype deepgemm-fp4fp8-b32. Note that deepgemm-fp4fp8-b32 "
53+
"is only supported on SM100 GPUs."
54+
)
55+
if expert_dtype == "deepgemm-fp4fp8-b32" and not is_sm100_gpu():
56+
raise RuntimeError(
57+
"--expert_dtype deepgemm-fp4fp8-b32 requires an SM100 GPU for EP MoE; "
58+
"please use --expert_dtype deepgemm-fp8w8a8-b128 on non-SM100 GPUs."
59+
)
60+
61+
4462
def masked_group_gemm(
4563
recv_x: Tuple[torch.Tensor, torch.Tensor],
4664
masked_m: torch.Tensor,
@@ -155,10 +173,10 @@ def do_fused_experts(
155173
is_prefill: Optional[bool],
156174
previous_event: Optional[Any] = None,
157175
):
158-
if use_sm100_fp4_moe(quant_method):
176+
check_ep_expert_dtype(quant_method)
177+
if use_sm100_mega_moe(quant_method):
159178
return mega_moe_impl(hidden_states, w13, w2, topk_weights, topk_idx, quant_method)
160179

161-
use_fp8_w8a8 = quant_method.method_name != "none"
162180
buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer
163181
return fused_experts_impl(
164182
hidden_states=hidden_states,
@@ -169,8 +187,8 @@ def do_fused_experts(
169187
num_experts=num_experts,
170188
buffer=buffer,
171189
is_prefill=is_prefill,
172-
use_fp8_w8a8=use_fp8_w8a8,
173-
use_fp8_all2all=use_fp8_w8a8,
190+
use_fp8_w8a8=True,
191+
use_fp8_all2all=True,
174192
use_int8_w8a16=False,
175193
w1_scale=w13.weight_scale,
176194
w2_scale=w2.weight_scale,

lightllm/common/quantization/__init__.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,36 @@
77
from .awq import *
88
from .no_quant import *
99
from lightllm.utils.log_utils import init_logger
10+
from lightllm.utils.device_utils import is_sm100_gpu
1011

1112
logger = init_logger(__name__)
1213

14+
SUPPORTED_EXPERT_DTYPES = ("deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32")
15+
1316

1417
class Quantcfg:
15-
def __init__(self, network_config, quant_type="none", custom_cfg_path=None):
18+
def __init__(self, network_config, quant_type="none", custom_cfg_path=None, expert_dtype=None):
1619
self.layer_num = network_config["n_layer"]
1720
self.quant_type = quant_type
1821
self.network_config_ = network_config
1922
self._parse_custom_cfg(custom_cfg_path)
2023
self._parse_network_config(network_config)
24+
self._apply_custom_expert_dtype(expert_dtype)
25+
26+
def _apply_custom_expert_dtype(self, expert_dtype):
27+
if expert_dtype is None:
28+
return
29+
if expert_dtype not in SUPPORTED_EXPERT_DTYPES:
30+
raise ValueError(
31+
f"unsupported --expert_dtype `{expert_dtype}`; expected one of {list(SUPPORTED_EXPERT_DTYPES)}"
32+
)
33+
if not is_sm100_gpu() and expert_dtype == "deepgemm-fp4fp8-b32":
34+
raise RuntimeError(
35+
f"deepgemm-fp4fp8-b32 requires an SM100 GPU; " "please use deepgemm-fp8w8a8-b128 on non-SM100 GPUs."
36+
)
37+
for layer_num in range(self.layer_num):
38+
self.quant_cfg[layer_num]["fused_moe"] = expert_dtype
39+
logger.info(f"select fused_moe quant way from --expert_dtype: {expert_dtype}")
2140

2241
def _parse_network_config(self, network_config):
2342
hf_quantization_config = network_config.get("quantization_config", None)

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
88
from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd
99
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
10+
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe
1011
from functools import partial
1112
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
1213
from lightllm.utils.envs_utils import get_env_start_args
@@ -295,7 +296,7 @@ def overlap_tpsp_token_forward(
295296
infer_state1: Deepseek2InferStateInfo,
296297
layer_weight: Deepseek2TransformerLayerWeight,
297298
):
298-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
299+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
299300
return super().overlap_tpsp_token_forward(
300301
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
301302
)
@@ -421,7 +422,7 @@ def overlap_tpsp_context_forward(
421422
infer_state1: Deepseek2InferStateInfo,
422423
layer_weight: Deepseek2TransformerLayerWeight,
423424
):
424-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
425+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
425426
return super().overlap_tpsp_context_forward(
426427
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
427428
)

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
77
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
88
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
9+
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import use_sm100_mega_moe
910
from lightllm.utils.dist_utils import get_global_world_size
1011
from lightllm.utils.envs_utils import get_env_start_args
1112

@@ -133,7 +134,7 @@ def overlap_tpsp_token_forward(
133134
infer_state1: LlamaInferStateInfo,
134135
layer_weight: Qwen3MOETransformerLayerWeight,
135136
):
136-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
137+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
137138
return super().overlap_tpsp_token_forward(
138139
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
139140
)
@@ -245,7 +246,7 @@ def overlap_tpsp_context_forward(
245246
infer_state1: LlamaInferStateInfo,
246247
layer_weight: Qwen3MOETransformerLayerWeight,
247248
):
248-
if not self.is_moe or layer_weight.experts.use_sm100_mega_moe():
249+
if not self.is_moe or use_sm100_mega_moe(layer_weight.experts.quant_method):
249250
return super().overlap_tpsp_context_forward(
250251
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
251252
)

lightllm/server/api_cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,15 @@ def make_argument_parser() -> argparse.ArgumentParser:
620620
help="""Path of quantization config. It can be used for mixed quantization.
621621
Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.""",
622622
)
623+
parser.add_argument(
624+
"--expert_dtype",
625+
type=str,
626+
default=None,
627+
choices=["deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32"],
628+
help="""Expert quantization dtype for EP MoE. Supported values are
629+
deepgemm-fp8w8a8-b128 and deepgemm-fp4fp8-b32. Note that
630+
deepgemm-fp4fp8-b32 is only supported on SM100 GPUs.""",
631+
)
623632
parser.add_argument(
624633
"--vit_quant_type",
625634
type=str,

lightllm/server/core/objs/start_args_type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ class StartArgs:
133133
graph_max_len_in_batch: int = field(default=0)
134134
quant_type: Optional[str] = field(default=None)
135135
quant_cfg: Optional[str] = field(default=None)
136+
expert_dtype: Optional[str] = field(
137+
default=None, metadata={"choices": ["deepgemm-fp8w8a8-b128", "deepgemm-fp4fp8-b32"]}
138+
)
136139
vit_quant_type: Optional[str] = field(default=None)
137140
vit_quant_cfg: Optional[str] = field(default=None)
138141
llm_prefill_att_backend: List[str] = field(

0 commit comments

Comments
 (0)