@@ -236,8 +236,10 @@ def init_share_inputs(self):
236236
237237 # Initialize rotary position embedding
238238 if not self .enable_mm :
239+ rotary_percent = getattr (self .model_config , "rotary_percent" , 1.0 )
240+ self .rotary_dim = int (self .model_config .head_dim * rotary_percent )
239241 self .rope_emb = get_rope (
240- rotary_dim = self .model_config .head_dim ,
242+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self . model_config .head_dim ,
241243 position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 )),
242244 base = self .model_config .rope_theta ,
243245 model_config = self .model_config ,
@@ -716,8 +718,10 @@ def reset_share_inputs(self):
716718 fill_paddle_tensor (self , "attn_mask_offsets_full" , - 1 )
717719 else :
718720 # Reset non-multimodal rope_emb
721+ rotary_percent = getattr (self .model_config , "rotary_percent" , 1.0 )
722+ self .rotary_dim = int (self .model_config .head_dim * rotary_percent )
719723 self .rope_emb = get_rope (
720- rotary_dim = self .model_config .head_dim ,
724+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self . model_config .head_dim ,
721725 position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 )),
722726 base = self .model_config .rope_theta ,
723727 model_config = self .model_config ,
@@ -816,8 +820,10 @@ def init_share_inputs(self):
816820
817821 tmp_position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 ))
818822
823+ rotary_percent = getattr (self .model_config , "rotary_percent" , 1.0 )
824+ self .rotary_dim = int (self .model_config .head_dim * rotary_percent )
819825 self .rope_emb = get_rope (
820- rotary_dim = self .model_config .head_dim ,
826+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self . model_config .head_dim ,
821827 position_ids = tmp_position_ids ,
822828 base = self .model_config .rope_theta ,
823829 model_config = self .model_config ,
@@ -1040,8 +1046,10 @@ def reset_model_inputs(self) -> None:
10401046
10411047 # Reset rope embedding by recreating with default position_ids
10421048 tmp_position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 ))
1049+ rotary_percent = getattr (self .model_config , "rotary_percent" , 1.0 )
1050+ self .rotary_dim = int (self .model_config .head_dim * rotary_percent )
10431051 self .rope_emb = get_rope (
1044- rotary_dim = self .model_config .head_dim ,
1052+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self . model_config .head_dim ,
10451053 position_ids = tmp_position_ids ,
10461054 base = self .model_config .rope_theta ,
10471055 model_config = self .model_config ,
0 commit comments