Skip to content

Commit 8f41b05

Browse files
wdl339weidongliang.339gemini-code-assist[bot]
authored
[ckpt] fix: FSDP save ckpt after validation (verl-project#4799)
### What does this PR do? This PR fixes a bug in the `save_checkpoint` function for FSDPEngine. In the original logic, if the model engine is used (`use_legacy_worker_impl=disable`), the `wake_up` function in `verl/workers/engine_workers.py` will be invoked during the rollout phase of each step, which will offload the model to CPU. Under normal circumstances, the `compute_log_prob` function called during the training phase can load the model back to GPU. However, the training process is not executed during the validation phase, leaving the model on the CPU. If a checkpoint is saved immediately after validation, it will trigger the following error: `AssertionError: Expects tensor to be on the compute device cuda:0, was on cpu.` <details> <summary>Details</summary> Script: ``` set -x python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k python -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/geo3k/train.parquet \ data.val_files=$HOME/data/geo3k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=1024 \ data.max_response_length=2048 \ data.filter_overlong_prompts=True \ data.truncation='error' \ data.image_key=images \ actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=128 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.01 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.enable_chunked_prefill=False \ actor_rollout_ref.rollout.enforce_eager=False \ actor_rollout_ref.rollout.free_cache_engine=False \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=False \ algorithm.use_kl_in_reward=False \ trainer.use_legacy_worker_impl=disable \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ trainer.project_name='verl_ci_grpo_example_geo3k' \ trainer.experiment_name='qwen2_5_vl_3b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.log_val_generations=20 \ trainer.save_freq=5 \ trainer.test_freq=5 \ trainer.total_epochs=15 ``` Error: ``` (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ERROR:2026-01-05 07:35:49,128:Got error when executing task. (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) Traceback (most recent call last): (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "python/ray/_raylet.pyx", line 1890, in ray._raylet.execute_task (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "python/ray/_raylet.pyx", line 1998, in ray._raylet.execute_task (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "python/ray/_raylet.pyx", line 1897, in ray._raylet.execute_task (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "python/ray/_raylet.pyx", line 1825, in ray._raylet.execute_task.function_executor (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "python/ray/_raylet.pyx", line 4651, in ray._raylet.CoreWorker.run_async_func_or_coro_in_event_loop (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return self.__get_result() (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) raise self._exception (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "python/ray/_raylet.pyx", line 4638, in async_func (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/ray/_private/async_compat.py", line 50, in wrapper (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/ray/_private/function_manager.py", line 691, in actor_method_executor (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return method(__ray_actor, *args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return method(self, *_args, **_kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/single_controller/ray/base.py", line 841, in func (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return getattr(self.worker_dict[key], name)(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/single_controller/base/decorator.py", line 456, in inner (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/utils/transferqueue_utils.py", line 314, in dummy_inner (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) output = func(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/workers/engine_workers.py", line 541, in save_checkpoint (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) self.actor.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/single_controller/base/decorator.py", line 456, in inner (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/utils/transferqueue_utils.py", line 314, in dummy_inner (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) output = func(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/workers/engine_workers.py", line 343, in save_checkpoint (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/workers/engine/fsdp/transformer_impl.py", line 607, in save_checkpoint (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) self.checkpoint_manager.save_checkpoint( (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/opt/tiger/open_verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py", line 238, in save_checkpoint (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) model_state_dict = self.model.state_dict() (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2256, in state_dict (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) hook(self, prefix, keep_vars) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return func(*args, **kwargs) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 777, in _pre_state_dict_hook (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 517, in _sharded_pre_state_dict_hook (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _common_unshard_pre_state_dict_hook( (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 161, in _common_unshard_pre_state_dict_hook (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _enter_unshard_params_ctx( (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_state_dict_utils.py", line 125, in _enter_unshard_params_ctx (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) fsdp_state._unshard_params_ctx[module].__enter__() (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/lib/python3.12/contextlib.py", line 137, in __enter__ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) return next(self.gen) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 199, in _unshard_fsdp_state_params (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _unshard(state, handle, computation_stream, computation_stream) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 290, in _unshard (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ran_pre_unshard = handle.pre_unshard() (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) ^^^^^^^^^^^^^^^^^^^^ (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_flat_param.py", line 1303, in pre_unshard (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) self._check_on_compute_device(self.flat_param) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/fsdp/_flat_param.py", line 2582, in _check_on_compute_device (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) _p_assert( (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) File "/usr/local/lib/python3.12/dist-packages/torch/distributed/utils.py", line 159, in _p_assert (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) raise AssertionError(s) (WorkerDict pid=42417, ip=[fdbd:dccd:cdd2:2207::30f]) AssertionError: Expects tensor to be on the compute device cuda:0, was on cpu ``` </details> To fix this bug, this PR checks whether the model is located on the CPU before saving the checkpoint and loads it onto the GPU if that is the case. The same bug also exists in Megatron, which requires further fixes. --------- Co-authored-by: weidongliang.339 <weidongliang.339@bytedance.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent b2205c2 commit 8f41b05

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

verl/workers/engine/fsdp/transformer_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,8 @@ def save_checkpoint(
601601
"""
602602
Save FSDP checkpoint, handling parameter offload as needed.
603603
"""
604-
if self._is_offload_param:
604+
origin_module_device = next(self.module.parameters()).device.type
605+
if self._is_offload_param or origin_module_device == "cpu":
605606
load_fsdp_model_to_gpu(self.module)
606607

607608
self.checkpoint_manager.save_checkpoint(

0 commit comments

Comments
 (0)