Skip to content

Commit 29bfe93

Browse files
authored
[megatron] feat: load dist checkpoint with customized prefix for state dict keys. (verl-project#4139)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: https://github.com/search?q=repo%3Avolcengine%2Fverl+dist+checkpoint+prefix&type=pullrequests - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. For Megatron dist checkpoint using customized prefix for state dict keys, e.g. NeMo2 use the prefix`module.` for the keys in the state dict, user can add the spec `dist_checkpointing_prefix` to the corresponding role to load that dist ckpt. For instance, the snippet below shows an example for actor model: ```python actor_rollout_ref.actor.megatron.dist_checkpointing_prefix='module.' ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: can be covered by the existing tests. - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent d951de4 commit 29bfe93

6 files changed

Lines changed: 31 additions & 6 deletions

File tree

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ actor_rollout_ref:
4141
use_distributed_optimizer: true
4242
use_dist_checkpointing: false
4343
dist_checkpointing_path: null
44+
dist_checkpointing_prefix: ''
4445
seed: 42
4546
override_ddp_config: {}
4647
override_transformer_config:
@@ -165,6 +166,7 @@ actor_rollout_ref:
165166
use_distributed_optimizer: true
166167
use_dist_checkpointing: false
167168
dist_checkpointing_path: null
169+
dist_checkpointing_prefix: ''
168170
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
169171
override_ddp_config: {}
170172
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
@@ -356,6 +358,7 @@ critic:
356358
use_distributed_optimizer: true
357359
use_dist_checkpointing: false
358360
dist_checkpointing_path: null
361+
dist_checkpointing_prefix: ''
359362
seed: 42
360363
override_ddp_config: {}
361364
override_transformer_config:
@@ -473,6 +476,7 @@ reward_model:
473476
use_distributed_optimizer: false
474477
use_dist_checkpointing: false
475478
dist_checkpointing_path: null
479+
dist_checkpointing_prefix: ''
476480
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
477481
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
478482
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}

verl/trainer/config/engine/megatron.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ use_dist_checkpointing: False
4040
# distributed checkpointing path
4141
dist_checkpointing_path: null
4242

43+
# distributed checkpointing prefix, e.g. Nemo2 will append prefix 'module.' to the state dict keys
44+
dist_checkpointing_prefix: ''
45+
4346
# oc.select: default val for ref.megatron.seed
4447
seed: 42
4548

verl/trainer/config/reward_model/megatron_reward_model.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ megatron:
5252
# Path for distributed checkpoints
5353
dist_checkpointing_path: null
5454

55+
# distributed checkpointing prefix, e.g. Nemo2 will append prefix 'module.' to the state dict keys
56+
dist_checkpointing_prefix: ''
57+
5558
# RNG seed for megatron
5659
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
5760

verl/utils/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc
531531
return unpad_tokens, cu_seqlens, max_seqlen_in_batch
532532

533533

534-
def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):
534+
def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False, prefix=""):
535535
from megatron.core import dist_checkpointing
536536
from megatron.core.dist_checkpointing.serialization import StrictHandling
537537

@@ -540,7 +540,7 @@ def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=Fal
540540
# strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED
541541
strict = StrictHandling.ASSUME_OK_UNEXPECTED
542542
for model in parallel_model:
543-
ssd = unwrap_model(model).sharded_state_dict()
543+
ssd = unwrap_model(model).sharded_state_dict(prefix=prefix)
544544
if is_value_model:
545545
for k in list(ssd.keys()):
546546
if "output_layer" in k:

verl/workers/config/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class McoreEngineConfig(BaseConfig):
6565
use_distributed_optimizer: bool = True
6666
use_dist_checkpointing: bool = False
6767
dist_checkpointing_path: Optional[str] = None
68+
dist_checkpointing_prefix: str = ""
6869
seed: int = 42
6970
override_ddp_config: dict[str, Any] = field(default_factory=dict)
7071
override_transformer_config: dict[str, Any] = field(default_factory=dict)

verl/workers/megatron_workers.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,10 @@ def _build_model_optimizer(
333333
if self.config.actor.load_weight:
334334
if self.config.actor.megatron.use_dist_checkpointing:
335335
load_mcore_dist_weights(
336-
actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False
336+
actor_module,
337+
self.config.actor.megatron.dist_checkpointing_path,
338+
is_value_model=False,
339+
prefix=self.config.actor.megatron.dist_checkpointing_prefix,
337340
)
338341
else:
339342
if self.bridge is not None:
@@ -366,7 +369,10 @@ def _build_model_optimizer(
366369
print("load ref weight start")
367370
if self.config.ref.megatron.use_dist_checkpointing:
368371
load_mcore_dist_weights(
369-
ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False
372+
ref_module,
373+
self.config.ref.megatron.dist_checkpointing_path,
374+
is_value_model=False,
375+
prefix=self.config.ref.megatron.dist_checkpointing_prefix,
370376
)
371377
else:
372378
if self.bridge is not None:
@@ -971,7 +977,10 @@ def _build_critic_model_optimizer(
971977
t0 = time.time()
972978
if self.config.megatron.use_dist_checkpointing:
973979
load_mcore_dist_weights(
974-
critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True
980+
critic_module,
981+
self.config.megatron.dist_checkpointing_path,
982+
is_value_model=True,
983+
prefix=self.config.megatron.dist_checkpointing_prefix,
975984
)
976985
else:
977986
if self.bridge is not None:
@@ -1233,7 +1242,12 @@ def _build_rm_model(self, model_path, tokenizer, override_model_config, override
12331242

12341243
if self.config.load_weight:
12351244
if self.config.megatron.use_dist_checkpointing:
1236-
load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True)
1245+
load_mcore_dist_weights(
1246+
reward_model,
1247+
self.config.megatron.dist_checkpointing_path,
1248+
is_value_model=True,
1249+
prefix=self.config.megatron.dist_checkpointing_prefix,
1250+
)
12371251
else:
12381252
if self.bridge is not None:
12391253
local_model_path = get_hf_model_path(self.config)

0 commit comments

Comments
 (0)