Skip to content

Commit f000576

Browse files
authored
[Cherry-Pick] change rms norm for glm #7269 (#7275)
* change rms_norm * refine code * refine code * refine code * refine code
1 parent 1e0ab31 commit f000576

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@
197197
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
198198
# Whether to use phi MOE permute,if 1,use paddle op.
199199
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
200+
# Whether to use phi rms_norm,if 1,use paddle op.
201+
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
200202
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
201203
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
202204
# Reserve output blocks for decoding requests when schedule new prefill requests

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from paddleformers.transformers import PretrainedModel
2525
from paddleformers.utils.log import logger
2626

27+
import fastdeploy
2728
from fastdeploy.config import FDConfig
2829
from fastdeploy.model_executor.forward_meta import ForwardMeta
2930
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -252,6 +253,14 @@ def forward(
252253
return output
253254

254255

256+
def rms_norm_func(x, weight, eps):
257+
rms_norm_out = paddle.nn.functional.rms_norm(x, x.shape[-1:], weight, eps)
258+
if isinstance(rms_norm_out, (tuple, list)):
259+
return rms_norm_out[0].astype(weight.dtype)
260+
else:
261+
return rms_norm_out.astype(weight.dtype)
262+
263+
255264
class Glm4MoeDecoderLayer(nn.Layer):
256265
""" """
257266

@@ -305,8 +314,9 @@ def forward(
305314
residual: paddle.Tensor = None,
306315
):
307316
""" """
317+
proxy_rmsnorm = rms_norm_func if fastdeploy.envs.FD_USE_PHI_RMSNORM else None
308318
hidden_states, residual = self.input_layernorm(
309-
hidden_states, residual_input=residual, forward_meta=forward_meta
319+
hidden_states, residual_input=residual, forward_meta=forward_meta, proxy_rmsnorm=proxy_rmsnorm
310320
)
311321

312322
hidden_states = self.self_attn(
@@ -315,7 +325,7 @@ def forward(
315325
)
316326

317327
# Fully Connected
318-
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
328+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual, proxy_rmsnorm=proxy_rmsnorm)
319329

320330
hidden_states = self.mlp(hidden_states, forward_meta)
321331

0 commit comments

Comments
 (0)