Skip to content

Commit ae2f9f4

Browse files
authored
[BugFix] Enable moe_gate_fp32 using FD_ENABLE_RL (#7130)
* rl gate fp32 * clean
1 parent 18f0124 commit ae2f9f4

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

fastdeploy/engine/args_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,8 @@ def __post_init__(self):
624624
raise NotImplementedError(
625625
f"not support model_impl: '{self.model_impl}'. " f"Must be one of: {', '.join(valid_model_impls)}"
626626
)
627+
if envs.FD_ENABLE_RL == 1:
628+
self.moe_gate_fp32 = True
627629

628630
self.post_init_all_ports()
629631

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def _validate_split_kv_size(value: int) -> int:
266266
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
267267
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
268268
),
269+
# Whether to align RoPE and moe gate precision with training
270+
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
269271
}
270272

271273

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,7 @@ def __init__(
151151
output_size=fd_config.model_config.n_routed_experts,
152152
with_bias=False,
153153
skip_quant=True,
154-
weight_dtype=(
155-
"float32" if fd_config.load_config.dynamic_load_weight or fd_config.model_config.moe_gate_fp32 else ""
156-
),
154+
weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""),
157155
)
158156
self.gate.e_score_correction_bias = self.create_parameter(
159157
shape=[1, fd_config.model_config.n_routed_experts],

fastdeploy/model_executor/models/qwen3moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def __init__(
7777
output_size=fd_config.model_config.num_experts,
7878
with_bias=False,
7979
skip_quant=True,
80-
weight_dtype=(
81-
"float32" if fd_config.load_config.dynamic_load_weight or fd_config.model_config.moe_gate_fp32 else ""
82-
),
80+
weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""),
8381
)
8482

8583
def forward(self, x, forward_meta):

0 commit comments

Comments
 (0)