Skip to content

Commit 610bfe6

Browse files
authored
fix gptoss rope config (#1445)
1 parent 996e2fb commit 610bfe6

4 files changed

Lines changed: 42 additions & 8 deletions

File tree

tests/model/test_gpt_oss_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from xtuner.v1.model.moe.moe import SequenceContext
1414
from xtuner.v1.model.moe.gpt_oss import GptOss21BA3P6Config
1515
from xtuner.v1.config import FSDPConfig
16-
from xtuner.v1.utils.compile import maybe_compile
1716
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem
1817

1918
GPT_OSS_MINI_PATH = os.environ["GPT_OSS_MINI_PATH"]
@@ -45,7 +44,6 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
4544
self.create_pg(device)
4645

4746
hf_config = AutoConfig.from_pretrained(GPT_OSS_MINI_PATH)
48-
hf_config.rope_scaling = None
4947

5048
hf_model = AutoModelForCausalLM.from_pretrained(
5149
GPT_OSS_MINI_PATH,
@@ -108,7 +106,6 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
108106
self.create_pg(device)
109107

110108
hf_config = AutoConfig.from_pretrained(GPT_OSS_MINI_PATH)
111-
hf_config.rope_scaling = None
112109
hf_model = AutoModelForCausalLM.from_pretrained(
113110
GPT_OSS_MINI_PATH,
114111
dtype=torch.bfloat16,

xtuner/v1/model/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ class TransformerConfig(XTunerBaseModelConfig):
124124
"default"
125125
)
126126

127+
@computed_field # type: ignore[misc]
128+
@property
129+
def rope_scaling(self) -> dict | None:
130+
if self.rope_scaling_cfg is not None:
131+
return self.rope_scaling_cfg.model_dump()
132+
return None
133+
127134
@computed_field
128135
def num_attention_heads(self) -> int:
129136
return self.attention.num_attention_heads

xtuner/v1/model/moe/gpt_oss.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig
1111
from xtuner.v1.module.attention import MHAConfig
1212
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig
13+
from xtuner.v1.module.rope import RopeScalingConfig
1314
from xtuner.v1.module.router.greedy import GreedyRouterConfig
1415

1516
from .moe import MoE
@@ -123,6 +124,9 @@ class GptOssConfig(MoEConfig):
123124
tie_word_embeddings: bool = False
124125
n_shared_experts: int = 0
125126
moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig(act_type="clipped_swiglu", clip_alpha=1.702, clip_limit=7)
127+
rope_scaling_cfg: RopeScalingConfig = RopeScalingConfig(
128+
type="yarn", beta_fast=32.0, beta_slow=1.0, factor=32.0, original_max_position_embeddings=4096, truncate=False
129+
)
126130

127131
@computed_field
128132
def layers_type(self) -> list[Literal["full_attention", "sliding_attention"]]:
@@ -138,7 +142,6 @@ def from_hf(cls, hf_path: str | Path) -> Self:
138142
assert isinstance(cfg, HFGptOssConfig)
139143

140144
config = cls(
141-
hf_config=cfg,
142145
vocab_size=cfg.vocab_size,
143146
max_position_embeddings=cfg.max_position_embeddings,
144147
pad_token_id=cfg.pad_token_id,
@@ -168,8 +171,17 @@ def from_hf(cls, hf_path: str | Path) -> Self:
168171
norm_topk_prob=True,
169172
router_scaling_factor=1.0,
170173
),
174+
rope_scaling_cfg=RopeScalingConfig(
175+
type=cfg.rope_scaling.get("rope_type", "yarn"),
176+
beta_fast=cfg.rope_scaling.get("beta_fast", 32.0),
177+
beta_slow=cfg.rope_scaling.get("beta_slow", 1.0),
178+
factor=cfg.rope_scaling.get("factor", 32.0),
179+
original_max_position_embeddings=cfg.rope_scaling.get("original_max_position_embeddings", 4096),
180+
truncate=cfg.rope_scaling.get("truncate", False),
181+
)
182+
if cfg.rope_scaling is not None
183+
else None,
171184
)
172-
173185
return config
174186

175187
@property
@@ -201,6 +213,16 @@ def hf_config(self) -> HFGptOssConfig:
201213
o_bias=True,
202214
dtype=torch.bfloat16,
203215
swiglu_limit=self.moe_act_fn_cfg.clip_limit,
216+
rope_scaling={
217+
"rope_type": self.rope_scaling_cfg.type,
218+
"beta_fast": self.rope_scaling_cfg.beta_fast,
219+
"beta_slow": self.rope_scaling_cfg.beta_slow,
220+
"factor": self.rope_scaling_cfg.factor,
221+
"original_max_position_embeddings": self.rope_scaling_cfg.original_max_position_embeddings,
222+
"truncate": self.rope_scaling_cfg.truncate,
223+
}
224+
if self.rope_scaling_cfg is not None
225+
else None,
204226
)
205227

206228

xtuner/v1/module/rope/rope.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ class RopeScalingConfig(BaseModel):
2020
model_config = ConfigDict(extra="forbid")
2121
type: Literal["default", "linear", "dynamic", "yarn", "longrope", "llama3", "qwen3_vl"] = "default"
2222

23-
max_position_embeddings: int | None = None # TODO: 无用参数考虑删除
24-
original_max_position_embeddings: int | None = None # TODO: 无用参数考虑删除
23+
max_position_embeddings: int | None = None
24+
original_max_position_embeddings: int | None = None
2525

2626
# For Qwen3VL
2727
mrope_section: list[int] | None = None # e.g. [24, 20, 20]
2828

29-
# For inference
3029
factor: float | None = None
3130
beta_fast: float | None = None
3231
beta_slow: float | None = None
@@ -36,6 +35,7 @@ class RopeScalingConfig(BaseModel):
3635
high_freq_factor: float | None = None
3736
mscale: float | None = None
3837
mscale_all_dim: float | None = None
38+
truncate: bool = False
3939

4040
# For FoPE
4141
fope_init_factor: float | None = None
@@ -73,6 +73,14 @@ def __init__(self, config, device=None):
7373
self.original_max_seq_len = config.max_position_embeddings
7474
self.rope_type = "default"
7575
self.config = config
76+
77+
rope_scaling_cfg = config.rope_scaling_cfg
78+
if rope_scaling_cfg is not None:
79+
self.rope_type = rope_scaling_cfg.type
80+
assert self.rope_type in ["default", "linear", "yarn", "llama3"], (
81+
f"Unsupported rope_type: {self.rope_type}. Supported types are: 'default', 'linear', 'yarn', 'llama3'."
82+
)
83+
7684
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
7785

7886
inv_freq: torch.Tensor

0 commit comments

Comments
 (0)