2424from paddleformers .transformers import PretrainedModel
2525from paddleformers .utils .log import logger
2626
27+ import fastdeploy
2728from fastdeploy .config import FDConfig
2829from fastdeploy .model_executor .forward_meta import ForwardMeta
2930from fastdeploy .model_executor .graph_optimization .decorator import (
@@ -252,6 +253,14 @@ def forward(
252253 return output
253254
254255
256+ def rms_norm_func (x , weight , eps ):
257+ rms_norm_out = paddle .nn .functional .rms_norm (x , x .shape [- 1 :], weight , eps )
258+ if isinstance (rms_norm_out , (tuple , list )):
259+ return rms_norm_out [0 ].astype (weight .dtype )
260+ else :
261+ return rms_norm_out .astype (weight .dtype )
262+
263+
255264class Glm4MoeDecoderLayer (nn .Layer ):
256265 """ """
257266
@@ -305,8 +314,9 @@ def forward(
305314 residual : paddle .Tensor = None ,
306315 ):
307316 """ """
317+ proxy_rmsnorm = rms_norm_func if fastdeploy .envs .FD_USE_PHI_RMSNORM else None
308318 hidden_states , residual = self .input_layernorm (
309- hidden_states , residual_input = residual , forward_meta = forward_meta
319+ hidden_states , residual_input = residual , forward_meta = forward_meta , proxy_rmsnorm = proxy_rmsnorm
310320 )
311321
312322 hidden_states = self .self_attn (
@@ -315,7 +325,7 @@ def forward(
315325 )
316326
317327 # Fully Connected
318- hidden_states , residual = self .post_attention_layernorm (hidden_states , residual )
328+ hidden_states , residual = self .post_attention_layernorm (hidden_states , residual , proxy_rmsnorm = proxy_rmsnorm )
319329
320330 hidden_states = self .mlp (hidden_states , forward_meta )
321331
0 commit comments