diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 4acd9791fb5..a35bc4d15f9 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -19,13 +19,10 @@ """ import asyncio -import json import logging import math -import os import uuid from collections import defaultdict -from dataclasses import dataclass, field from pprint import pprint from typing import Any, Optional @@ -51,7 +48,6 @@ from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.single_controller.ray import ( RayClassWithInitArgs, - RayResourcePool, RayWorkerGroup, ) from verl.single_controller.ray.base import create_colocated_worker_cls @@ -64,16 +60,13 @@ compute_timing_metrics, process_validation_metrics, ) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager from verl.trainer.ppo.reward import compute_reward, compute_reward_async from verl.trainer.ppo.utils import ( Role, WorkerType, - need_critic, - need_reference_policy, - need_reward_model, ) from verl.utils.checkpoint.checkpoint_manager import ( - find_latest_ckpt_path, should_save_ckpt_esi, ) from verl.utils.config import omega_conf_to_dataclass @@ -85,7 +78,6 @@ log_seqlen_unbalance, ) from verl.utils.torch_functional import masked_mean -from verl.utils.tracking import ValidationGenerationsLogger from verl.utils.transferqueue_utils import ( create_transferqueue_client, get_transferqueue_client, @@ -94,63 +86,6 @@ ) -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - """Create Ray resource pools for distributed training. - - Initializes resource pools based on the resource pool specification, - with each pool managing GPU resources across multiple nodes. - For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. - For Megatron backend, uses max_colocate_count>1 for different models. - """ - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 - # that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray._private.state.available_resources_per_node() - node_available_gpus = { - node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) - for node, node_info in node_available_resources.items() - } - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - @tqbridge(put_data=False) def compute_reward_decorated(data, reward_fn): return compute_reward(data, reward_fn) @@ -331,7 +266,7 @@ def compute_val_reward_decorated(reward_fn, data, return_dict): return reward_fn(data, return_dict) -class RayPPOTrainer: +class TransferQueueRayPPOTrainer(RayPPOTrainer): """Distributed PPO trainer using Ray for scalable reinforcement learning. This trainer orchestrates distributed PPO training across multiple nodes and GPUs, @@ -376,40 +311,21 @@ def __init__( train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. """ - - # Store the tokenizer for text processing - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = need_reference_policy(self.role_worker_mapping) - self.use_rm = need_reward_model(self.role_worker_mapping) - self.use_critic = need_critic(self.config) - self.ray_worker_group_cls = ray_worker_group_cls - self.device_name = device_name if device_name else self.config.trainer.device - self.validation_generations_logger = ValidationGenerationsLogger( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + processor, + reward_fn, + val_reward_fn, + train_dataset, + val_dataset, + collate_fn, + train_sampler, + device_name, ) - - # if ref_in_actor is True, the reference policy will be actor without lora applied - self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 - - # define in-reward KL control - # kl loss control currently not suppoorted - if self.config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) self.data_system_client = self._initialize_train_data_system( @@ -592,34 +508,6 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl except Exception as e: print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): - """Dump rollout/validation samples as JSONL.""" - os.makedirs(dump_path, exist_ok=True) - filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") - - n = len(inputs) - base_data = { - "input": inputs, - "output": outputs, - "gts": gts, - "score": scores, - "step": [self.global_steps] * n, - } - - for k, v in reward_extra_infos_dict.items(): - if len(v) == n: - base_data[k] = v - - lines = [] - for i in range(n): - entry = {k: v[i] for k, v in base_data.items()} - lines.append(json.dumps(entry, ensure_ascii=False)) - - with open(filename, "w") as f: - f.write("\n".join(lines) + "\n") - - print(f"Dumped generations to {filename}") - def _log_rollout_data( self, log_rollout_meta: BatchMeta, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str ): @@ -656,47 +544,6 @@ def _log_rollout_data( dump_path=rollout_data_dir, ) - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.log_val_generations - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores, strict=True)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - - def _get_gen_batch(self, batch: DataProto) -> DataProto: - reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() - - # pop those keys for generation - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), - ) - - # For agent loop, we need reward model keys to compute score. - if self.async_rollout_mode: - gen_batch.non_tensor_batch.update(batch.non_tensor_batch) - - return gen_batch - def _validate(self): data_source_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) @@ -748,47 +595,15 @@ def _validate(self): ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})] sample_gts.extend(ground_truths) - if not self.async_rollout_mode: - test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - ], - batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, - task_name="generate_sequences", - ) - ) - else: - test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - "raw_prompt", - "reward_model", - "data_source", - ], - batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, - task_name="async_generate_sequences", - ) + test_gen_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=list(test_batch.keys()), + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="generate_sequences", ) + ) test_gen_meta.extra_info = { "eos_token_id": self.tokenizer.eos_token_id, @@ -1060,143 +875,6 @@ def init_workers(self): self.val_data_system_controller_infos, self.val_data_system_storage_unit_infos, role="val" ) - def _save_checkpoint(self): - from verl.utils.fs import local_mkdir_safe - - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") - ) - - remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) - if remove_previous_ckpt_in_save: - print( - "Warning: remove_previous_ckpt_in_save is deprecated," - + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" - ) - max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - - self.actor_rollout_wg.save_checkpoint( - actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") - ) - self.critic_wg.save_checkpoint( - critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep - ) - - # save dataloader - local_mkdir_safe(local_global_step_folder) - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") - - def _start_profiling(self, do_profile: bool) -> None: - """Start profiling for all worker groups if profiling is enabled.""" - if do_profile: - self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) - if self.use_reference_policy: - self.ref_policy_wg.start_profile(profile_step=self.global_steps) - if self.use_critic: - self.critic_wg.start_profile(profile_step=self.global_steps) - if self.use_rm: - self.rm_wg.start_profile(profile_step=self.global_steps) - - def _stop_profiling(self, do_profile: bool) -> None: - """Stop profiling for all worker groups if profiling is enabled.""" - if do_profile: - self.actor_rollout_wg.stop_profile() - if self.use_reference_policy: - self.ref_policy_wg.stop_profile() - if self.use_critic: - self.critic_wg.stop_profile() - if self.use_rm: - self.rm_wg.stop_profile() - def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" data = asyncio.run(data_system_client.async_get_data(batch)) @@ -1367,43 +1045,15 @@ def fit(self): ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) - if not self.async_rollout_mode: - gen_meta = asyncio.run( - self.data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - ], - task_name="generate_sequences", - **base_get_meta_kwargs, - ) - ) - else: - gen_meta = asyncio.run( - self.data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - "raw_prompt", - "reward_model", - "data_source", - ], - task_name="async_generate_sequences", - **base_get_meta_kwargs, - ) + + gen_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=list(batch.keys()), + task_name="generate_sequences", + **base_get_meta_kwargs, ) + ) + # pass global_steps to trace gen_meta.set_extra_info("global_steps", self.global_steps)