@@ -108,6 +108,9 @@ def __init__(self, fd_config: FDConfig) -> None:
108108 else :
109109 self .max_chunk_tokens = self .fd_config .get_max_chunk_tokens (self .model_config .mm_max_tokens_per_item )
110110
111+ self .swa_rope_theta = getattr (self .fd_config .model_config , "swa_rope_theta" , None )
112+ self .swa_rope_emb = None
113+
111114 def init_share_inputs (self ):
112115 max_num_seqs = self .scheduler_config .max_num_seqs
113116
@@ -245,6 +248,14 @@ def init_share_inputs(self):
245248 model_config = self .model_config ,
246249 partial_rotary_factor = self .model_config .partial_rotary_factor ,
247250 )
251+ if self .swa_rope_theta is not None :
252+ self .swa_rope_emb = get_rope (
253+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self .model_config .head_dim ,
254+ position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 )),
255+ base = self .swa_rope_theta ,
256+ model_config = self .model_config ,
257+ partial_rotary_factor = self .model_config .partial_rotary_factor ,
258+ )
248259 if self .is_mm_model :
249260 self .image_features = None
250261 self .image_grid_thws = None
@@ -727,6 +738,14 @@ def reset_share_inputs(self):
727738 model_config = self .model_config ,
728739 partial_rotary_factor = self .model_config .partial_rotary_factor ,
729740 )
741+ if self .swa_rope_theta is not None :
742+ self .swa_rope_emb = get_rope (
743+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self .model_config .head_dim ,
744+ position_ids = paddle .arange (self .model_config .max_model_len ).reshape ((1 , - 1 )),
745+ base = self .swa_rope_theta ,
746+ model_config = self .model_config ,
747+ partial_rotary_factor = self .model_config .partial_rotary_factor ,
748+ )
730749 if self .is_mm_model :
731750 self .image_features = None
732751 self .image_grid_thws = None
@@ -764,6 +783,8 @@ def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) ->
764783 self .cache_config : CacheConfig = fd_config .cache_config
765784 self .speculative_config : SpeculativeConfig = fd_config .speculative_config
766785 self .enable_pd_reorder : bool = False
786+ self .swa_rope_theta = getattr (self .fd_config .model_config , "swa_rope_theta" , None )
787+ self .swa_rope_emb = None
767788
768789 def init_share_inputs (self ):
769790 # share with targe model
@@ -829,6 +850,14 @@ def init_share_inputs(self):
829850 model_config = self .model_config ,
830851 partial_rotary_factor = self .model_config .partial_rotary_factor ,
831852 )
853+ if self .swa_rope_theta is not None :
854+ self .swa_rope_emb = get_rope (
855+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self .model_config .head_dim ,
856+ position_ids = tmp_position_ids ,
857+ base = self .swa_rope_theta ,
858+ model_config = self .model_config ,
859+ partial_rotary_factor = self .model_config .partial_rotary_factor ,
860+ )
832861
833862 # self.caches = self.cache_kvs
834863 # Inherit generation hyperparameters from the main model for consistency
@@ -1059,6 +1088,14 @@ def reset_model_inputs(self) -> None:
10591088 model_config = self .model_config ,
10601089 partial_rotary_factor = self .model_config .partial_rotary_factor ,
10611090 )
1091+ if self .swa_rope_theta is not None :
1092+ self .swa_rope_emb = get_rope (
1093+ rotary_dim = self .rotary_dim if rotary_percent < 1.0 else self .model_config .head_dim ,
1094+ position_ids = tmp_position_ids ,
1095+ base = self .swa_rope_theta ,
1096+ model_config = self .model_config ,
1097+ partial_rotary_factor = self .model_config .partial_rotary_factor ,
1098+ )
10621099
10631100 # Reset generation hyperparameters from the main model
10641101 self .top_p = self .target_model_input_batch ["top_p" ]
0 commit comments