Skip to content

Commit 563b37a

Browse files
authored
fix: preserve Qwen3.5 broadcast weight names (#2690)
* exp: add qwen35 kl debug configs * exp: lower qwen35 wordle kl drift * fix: preserve qwen3.5 broadcast weight names * chore: keep qwen35 debug configs local * chore: clarify qwen3.5 weight naming bypass * fix: use upstream qwen3.5 conversion mapping * chore: use released transformers package * chore: drop qwen3.5 cp patch * chore: remove wordle env packaging
1 parent e751794 commit 563b37a

6 files changed

Lines changed: 308 additions & 269 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dependencies = [
1717
"torchvision",
1818
"torchaudio",
1919
"torchdata>=0.11.0",
20-
"transformers",
20+
"transformers==5.6.2",
2121
"vllm>=0.22.0",
2222
"mooncake-transfer-engine>=0.3.10.post2",
2323
"wandb>=0.26.1",
@@ -128,7 +128,6 @@ dev = [
128128
"ruff>=0.12.1",
129129
]
130130

131-
132131
[tool.uv]
133132
# Enforce a uv version that supports the friendly-duration form
134133
# (`"7 days"`) in the static pyproject parser. Older uvs silently parse
@@ -147,7 +146,7 @@ environments = [
147146
override-dependencies = [
148147
"nvidia-cudnn-cu12>=9.15",
149148
"nvidia-cutlass-dsl>=4.4.1",
150-
"transformers>=5.1.0.dev0",
149+
"transformers==5.6.2",
151150
"torch>=2.9.0",
152151
"openenv-core",
153152
]
@@ -231,7 +230,6 @@ torchvision = { index = "pytorch-cu128" }
231230
torchaudio = { index = "pytorch-cu128" }
232231
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
233232
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
234-
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
235233
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
236234
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.26/vllm_router-0.1.26-cp38-abi3-manylinux_2_28_x86_64.whl" }
237235
vllm = [

skills/training/start-run/SKILL.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ uv run rl @ examples/reverse_text/rl.toml --dry-run
3434
- Config: `RLConfig` (`packages/prime-rl-configs/src/prime_rl/configs/rl.py`)
3535
- Entrypoint: `src/prime_rl/entrypoints/rl.py`
3636
- SLURM: single- and multi-node
37+
- Environment packages: before launching a config with a non-core verifier env id,
38+
verify the package imports under `uv run` (for example
39+
`uv run python -c "import importlib.util; print(importlib.util.find_spec('rlm_swe'))"`).
40+
If a local env exists under `deps/research-environments/environments/` but does not
41+
import, add it to the root `pyproject.toml` env extra, workspace members, and
42+
`[tool.uv.sources]`, then run `uv sync --all-extras`.
3743

3844
## `sft` — SFT training
3945

src/prime_rl/trainer/ckpt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ def save(
420420
f"Converted PrimeRL format to HF format in {time.perf_counter() - start_time:.2f} seconds"
421421
)
422422
else:
423-
# For regular transformers models, revert internal format to original HF hub format
424423
from transformers.core_model_loading import revert_weight_conversion
425424

426425
self.logger.debug("Reverting transformers internal format to HF hub format for weight checkpoint")

src/prime_rl/trainer/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _patch_qwen3_5_moe_conversion_mapping():
7979
incorrectly maps qwen3_5_moe → qwen2_moe, which assumes per-expert 2D checkpoint weights,
8080
causing revert_weight_conversion to produce wrong shapes during weight broadcasting.
8181
82-
Remove once the pinned transformers commit fixes this.
82+
Remove once an official Transformers release fixes this.
8383
"""
8484
from transformers.conversion_mapping import (
8585
get_checkpoint_conversion_mapping,
@@ -99,7 +99,7 @@ def _patch_qwen3_5_text_position_ids():
9999
"""Fix Qwen3.5 passing 3D MRoPE position_ids to decoder layers instead of 2D text_position_ids.
100100
101101
Upstream fix: https://github.com/huggingface/transformers/pull/44399
102-
Remove once the pinned transformers commit includes this fix.
102+
Remove once an official Transformers release includes this fix.
103103
"""
104104
import inspect
105105

src/prime_rl/trainer/rl/broadcast/filesystem.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None:
4646
if isinstance(model, PreTrainedModelPrimeRL) and model.is_prime_state_dict(state_dict):
4747
model.convert_to_hf(state_dict)
4848
else:
49-
# For regular transformers models, revert internal format to original HF hub format
5049
from transformers.core_model_loading import revert_weight_conversion
5150

5251
state_dict = revert_weight_conversion(model, state_dict)

0 commit comments

Comments
 (0)