Commit 8f41b05
[ckpt] fix: FSDP save ckpt after validation (verl-project#4799)
### What does this PR do?
This PR fixes a bug in the `save_checkpoint` function for FSDPEngine.
In the original logic, if the model engine is used
(`use_legacy_worker_impl=disable`), the `wake_up` function in
`verl/workers/engine_workers.py` will be invoked during the rollout
phase of each step, which will offload the model to CPU.
Under normal circumstances, the `compute_log_prob` function called
during the training phase can load the model back to GPU. However, the
training process is not executed during the validation phase, leaving
the model on the CPU. If a checkpoint is saved immediately after
validation, it will trigger the following error: `AssertionError:
Expects tensor to be on the compute device cuda:0, was on cpu.`
<details>
<summary>Details</summary>
Script:
```
set -x
python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k
python -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/geo3k/train.parquet \
data.val_files=$HOME/data/geo3k/test.parquet \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.image_key=images \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=False \
algorithm.use_kl_in_reward=False \
trainer.use_legacy_worker_impl=disable \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_ci_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_3b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.log_val_generations=20 \
trainer.save_freq=5 \
trainer.test_freq=5 \
trainer.total_epochs=15
```
Error:
```
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ERROR:2026-01-05
07:35:49,128:Got error when executing task.
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) Traceback (most
recent call last):
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"python/ray/_raylet.pyx", line 1890, in ray._raylet.execute_task
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"python/ray/_raylet.pyx", line 1998, in ray._raylet.execute_task
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"python/ray/_raylet.pyx", line 1897, in ray._raylet.execute_task
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"python/ray/_raylet.pyx", line 1825, in
ray._raylet.execute_task.function_executor
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"python/ray/_raylet.pyx", line 4651, in
ray._raylet.CoreWorker.run_async_func_or_coro_in_event_loop
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return
self.__get_result()
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in
__get_result
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) raise
self._exception
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"python/ray/_raylet.pyx", line 4638, in async_func
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/ray/_private/async_compat.py",
line 50, in wrapper
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args,
**kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/ray/_private/function_manager.py",
line 691, in actor_method_executor
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return
method(__ray_actor, *args, **kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/ray/util/tracing/tracing_helper.py",
line 463, in _resume_span
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return
method(self, *_args, **_kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/single_controller/ray/base.py", line 841, in
func
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return
getattr(self.worker_dict[key], name)(*args, **kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/single_controller/base/decorator.py", line
456, in inner
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args,
**kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/utils/transferqueue_utils.py", line 314, in
dummy_inner
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) output =
func(*args, **kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/workers/engine_workers.py", line 541, in
save_checkpoint
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
self.actor.save_checkpoint(local_path, hdfs_path, global_step,
max_ckpt_to_keep)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/single_controller/base/decorator.py", line
456, in inner
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args,
**kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/utils/transferqueue_utils.py", line 314, in
dummy_inner
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) output =
func(*args, **kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/workers/engine_workers.py", line 343, in
save_checkpoint
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return
self.engine.save_checkpoint(local_path, hdfs_path, global_step,
max_ckpt_to_keep)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/workers/engine/fsdp/transformer_impl.py",
line 607, in save_checkpoint
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
self.checkpoint_manager.save_checkpoint(
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/opt/tiger/open_verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py",
line 238, in save_checkpoint
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) model_state_dict =
self.model.state_dict()
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py",
line 2256, in state_dict
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) hook(self, prefix,
keep_vars)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py",
line 120, in decorate_context
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args,
**kwargs)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py",
line 777, in _pre_state_dict_hook
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
_pre_state_dict_hook_fn[fsdp_state._state_dict_type](
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py",
line 517, in _sharded_pre_state_dict_hook
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
_common_unshard_pre_state_dict_hook(
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py",
line 161, in _common_unshard_pre_state_dict_hook
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
_enter_unshard_params_ctx(
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py",
line 125, in _enter_unshard_params_ctx
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
fsdp_state._unshard_params_ctx[module].__enter__()
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/lib/python3.12/contextlib.py", line 137, in __enter__
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return
next(self.gen)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_unshard_param_utils.py",
line 199, in _unshard_fsdp_state_params
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _unshard(state,
handle, computation_stream, computation_stream)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_runtime_utils.py",
line 290, in _unshard
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ran_pre_unshard =
handle.pre_unshard()
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
^^^^^^^^^^^^^^^^^^^^
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_flat_param.py",
line 1303, in pre_unshard
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f])
self._check_on_compute_device(self.flat_param)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_flat_param.py",
line 2582, in _check_on_compute_device
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _p_assert(
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File
"/usr/local/lib/python3.12/dist-packages/torch/distributed/utils.py",
line 159, in _p_assert
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) raise
AssertionError(s)
(WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) AssertionError:
Expects tensor to be on the compute device cuda:0, was on cpu
```
</details>
To fix this bug, this PR checks whether the model is located on the CPU
before saving the checkpoint and loads it onto the GPU if that is the
case. The same bug also exists in Megatron, which requires further
fixes.
---------
Co-authored-by: weidongliang.339 <weidongliang.339@bytedance.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>1 parent b2205c2 commit 8f41b05
1 file changed
Lines changed: 2 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
601 | 601 | | |
602 | 602 | | |
603 | 603 | | |
604 | | - | |
| 604 | + | |
| 605 | + | |
605 | 606 | | |
606 | 607 | | |
607 | 608 | | |
| |||
0 commit comments