1010from xtuner .v1 .model .moe .moe import BalancingLossConfig , MoEConfig
1111from xtuner .v1 .module .attention import MHAConfig
1212from xtuner .v1 .module .decoder_layer .moe_decoder_layer import MoEActFnConfig
13+ from xtuner .v1 .module .rope import RopeScalingConfig
1314from xtuner .v1 .module .router .greedy import GreedyRouterConfig
1415
1516from .moe import MoE
@@ -123,6 +124,9 @@ class GptOssConfig(MoEConfig):
123124 tie_word_embeddings : bool = False
124125 n_shared_experts : int = 0
125126 moe_act_fn_cfg : MoEActFnConfig = MoEActFnConfig (act_type = "clipped_swiglu" , clip_alpha = 1.702 , clip_limit = 7 )
127+ rope_scaling_cfg : RopeScalingConfig = RopeScalingConfig (
128+ type = "yarn" , beta_fast = 32.0 , beta_slow = 1.0 , factor = 32.0 , original_max_position_embeddings = 4096 , truncate = False
129+ )
126130
127131 @computed_field
128132 def layers_type (self ) -> list [Literal ["full_attention" , "sliding_attention" ]]:
@@ -138,7 +142,6 @@ def from_hf(cls, hf_path: str | Path) -> Self:
138142 assert isinstance (cfg , HFGptOssConfig )
139143
140144 config = cls (
141- hf_config = cfg ,
142145 vocab_size = cfg .vocab_size ,
143146 max_position_embeddings = cfg .max_position_embeddings ,
144147 pad_token_id = cfg .pad_token_id ,
@@ -168,8 +171,17 @@ def from_hf(cls, hf_path: str | Path) -> Self:
168171 norm_topk_prob = True ,
169172 router_scaling_factor = 1.0 ,
170173 ),
174+ rope_scaling_cfg = RopeScalingConfig (
175+ type = cfg .rope_scaling .get ("rope_type" , "yarn" ),
176+ beta_fast = cfg .rope_scaling .get ("beta_fast" , 32.0 ),
177+ beta_slow = cfg .rope_scaling .get ("beta_slow" , 1.0 ),
178+ factor = cfg .rope_scaling .get ("factor" , 32.0 ),
179+ original_max_position_embeddings = cfg .rope_scaling .get ("original_max_position_embeddings" , 4096 ),
180+ truncate = cfg .rope_scaling .get ("truncate" , False ),
181+ )
182+ if cfg .rope_scaling is not None
183+ else None ,
171184 )
172-
173185 return config
174186
175187 @property
@@ -201,6 +213,16 @@ def hf_config(self) -> HFGptOssConfig:
201213 o_bias = True ,
202214 dtype = torch .bfloat16 ,
203215 swiglu_limit = self .moe_act_fn_cfg .clip_limit ,
216+ rope_scaling = {
217+ "rope_type" : self .rope_scaling_cfg .type ,
218+ "beta_fast" : self .rope_scaling_cfg .beta_fast ,
219+ "beta_slow" : self .rope_scaling_cfg .beta_slow ,
220+ "factor" : self .rope_scaling_cfg .factor ,
221+ "original_max_position_embeddings" : self .rope_scaling_cfg .original_max_position_embeddings ,
222+ "truncate" : self .rope_scaling_cfg .truncate ,
223+ }
224+ if self .rope_scaling_cfg is not None
225+ else None ,
204226 )
205227
206228
0 commit comments