2222import ray
2323from omegaconf import OmegaConf
2424
25+ from verl .trainer import main_ppo as main_ppo_mod
26+ from verl .trainer .namespace import build_namespace_specs , namespaced_role_key
2527from verl .trainer .ppo .reward import load_reward_manager
26- from verl .utils . device import is_cuda_available
28+ from verl .trainer . ppo . utils import Role
2729
2830from .dapo_ray_trainer import RayDAPOTrainer
2931
@@ -34,39 +36,12 @@ def main(config):
3436
3537
3638def run_ppo (config ) -> None :
37- if not ray .is_initialized ():
38- # this is for local ray cluster
39- default_runtime_env = {
40- "env_vars" : {"TOKENIZERS_PARALLELISM" : "true" , "NCCL_DEBUG" : "WARN" , "VLLM_LOGGING_LEVEL" : "WARN" }
41- }
42- ray_init_kwargs = config .ray_kwargs .get ("ray_init" , {})
43- runtime_env_kwargs = ray_init_kwargs .get ("runtime_env" , {})
44- runtime_env = OmegaConf .merge (default_runtime_env , runtime_env_kwargs )
45- ray_init_kwargs = OmegaConf .create ({** ray_init_kwargs , "runtime_env" : runtime_env })
46- print (f"ray init kwargs: { ray_init_kwargs } " )
47- ray .init (** OmegaConf .to_container (ray_init_kwargs ))
48-
49- try :
50- if (
51- is_cuda_available
52- and config .global_profiler .tool == "nsys"
53- and OmegaConf .select (config .global_profiler , "steps" ) is not None
54- and len (OmegaConf .select (config .global_profiler , "steps" )) > 0
55- ):
56- nsight_options = OmegaConf .to_container (
57- config .global_profiler .global_tool_config .nsys .controller_nsight_options
58- )
59- runner = TaskRunner .options (runtime_env = {"nsight" : nsight_options }).remote ()
60- else :
61- runner = TaskRunner .remote ()
62- ray .get (runner .run .remote (config ))
63- finally :
64- if ray .is_initialized ():
65- ray .shutdown ()
39+ """Entry point for running DAPO with the PPO runner."""
40+ task_runner_cls = ray .remote (num_cpus = 1 )(TaskRunner ) # type: ignore[arg-type]
41+ main_ppo_mod .run_ppo (config , task_runner_class = task_runner_cls )
6642
6743
68- @ray .remote (num_cpus = 1 ) # please make sure main_task is not scheduled on head
69- class TaskRunner :
44+ class TaskRunner (main_ppo_mod .TaskRunner ):
7045 def run (self , config ):
7146 # print initial config
7247 from pprint import pprint
@@ -80,72 +55,70 @@ def run(self, config):
8055 pprint (OmegaConf .to_container (config , resolve = True )) # resolve=True will eval symbol values
8156 OmegaConf .resolve (config )
8257
83- # download the checkpoint from hdfs
84- local_path = copy_to_local (config .actor_rollout_ref .model .path )
85-
86- # instantiate tokenizer
8758 from verl .utils import hf_processor , hf_tokenizer
8859
8960 trust_remote_code = config .data .get ("trust_remote_code" , False )
90- tokenizer = hf_tokenizer (local_path , trust_remote_code = trust_remote_code )
91- # used for multimodal LLM, could be none
92- processor = hf_processor (local_path , trust_remote_code = trust_remote_code , use_fast = True )
93-
94- from verl .single_controller .ray import RayWorkerGroup
95-
96- # define worker classes
97- if config .actor_rollout_ref .actor .strategy in {"fsdp" , "fsdp2" }:
98- assert config .critic .strategy in {"fsdp" , "fsdp2" }
99-
100- from verl .workers .fsdp_workers import AsyncActorRolloutRefWorker , CriticWorker
101-
102- ray_worker_group_cls = RayWorkerGroup
103-
104- elif config .actor_rollout_ref .actor .strategy == "megatron" :
105- assert config .actor_rollout_ref .actor .strategy == config .critic .strategy
106- from verl .workers .megatron_workers import AsyncActorRolloutRefWorker , CriticWorker
107-
108- ray_worker_group_cls = RayWorkerGroup
109-
110- else :
111- raise NotImplementedError
112-
113- from verl .trainer .ppo .ray_trainer import ResourcePoolManager , Role
114-
115- role_worker_mapping = {
116- Role .ActorRollout : ray .remote (AsyncActorRolloutRefWorker ),
117- Role .Critic : ray .remote (CriticWorker ),
118- }
119-
120- global_pool_id = "global_pool"
121- resource_pool_spec = {
122- global_pool_id : [config .trainer .n_gpus_per_node ] * config .trainer .nnodes ,
123- }
124- mapping = {
125- Role .ActorRollout : global_pool_id ,
126- Role .Critic : global_pool_id ,
127- }
128-
129- # we should adopt a multi-source reward function here
130- # - for rule-based rm, we directly call a reward score
131- # - for model-based rm, we call a model
132- # - for code related prompt, we send to a sandbox if there are test cases
133- # - finally, we combine all the rewards together
134- # - The reward type depends on the tag of the data
61+ namespace_specs = build_namespace_specs (config )
62+
63+ # instantiate tokenizer/processor per namespace
64+ tokenizers = {}
65+ processors = {}
66+ for name , spec in namespace_specs .items ():
67+ local_path = copy_to_local (
68+ spec .config .actor_rollout_ref .model .path ,
69+ use_shm = spec .config .actor_rollout_ref .model .get ("use_shm" , False ),
70+ )
71+ tokenizers [name ] = hf_tokenizer (local_path , trust_remote_code = trust_remote_code )
72+ processors [name ] = hf_processor (local_path , trust_remote_code = trust_remote_code , use_fast = True )
73+
74+ active_namespace = config .trainer .get ("namespace" , "default" )
75+ tokenizer = tokenizers [active_namespace ]
76+ processor = processors .get (active_namespace )
77+
78+ self .role_worker_mapping = {}
79+ self .mapping = {}
80+
81+ # Register actor-like workers and collect the worker group class.
82+ ray_worker_group_cls = None
83+ for spec in namespace_specs .values ():
84+ if not spec .spawn_roles :
85+ continue
86+ actor_cls , rg_cls = self ._select_actor_worker_impl (spec .config )
87+ ray_worker_group_cls = rg_cls if ray_worker_group_cls is None else ray_worker_group_cls
88+ if rg_cls is not None and ray_worker_group_cls != rg_cls :
89+ raise ValueError ("All namespaces must share the same RayWorkerGroup class" )
90+
91+ critic_cls = self ._select_critic_worker_impl (spec .config )
92+
93+ for role in spec .spawn_roles :
94+ key = namespaced_role_key (spec .name , role )
95+ if role == Role .Critic :
96+ self .role_worker_mapping [key ] = ray .remote (critic_cls )
97+ else :
98+ self .role_worker_mapping [key ] = ray .remote (actor_cls )
99+ self .mapping [key ] = spec .resource_pool
100+
101+ # reward model
135102 if config .reward_model .enable :
136- if config .reward_model .strategy in {"fsdp" , "fsdp2" }:
137- from verl .workers .fsdp_workers import RewardModelWorker
138- elif config .reward_model .strategy == "megatron" :
139- from verl .workers .megatron_workers import RewardModelWorker
103+ use_legacy_worker_impl = config .trainer .get ("use_legacy_worker_impl" , "auto" )
104+ if use_legacy_worker_impl in ["auto" , "enable" ]:
105+ if config .reward_model .strategy in {"fsdp" , "fsdp2" }:
106+ from verl .workers .fsdp_workers import RewardModelWorker
107+ elif config .reward_model .strategy == "megatron" :
108+ from verl .workers .megatron_workers import RewardModelWorker
109+ else :
110+ raise NotImplementedError
111+ elif use_legacy_worker_impl == "disable" :
112+ from verl .workers .roles import RewardModelWorker
113+
114+ print ("Using new worker implementation" )
140115 else :
141- raise NotImplementedError
142- role_worker_mapping [Role .RewardModel ] = ray .remote (RewardModelWorker )
143- mapping [Role .RewardModel ] = global_pool_id
116+ raise ValueError (f"Invalid use_legacy_worker_impl: { use_legacy_worker_impl } " )
144117
145- # reference model
146- if config .algorithm . use_kl_in_reward or config . actor_rollout_ref . actor . use_kl_loss :
147- role_worker_mapping [Role .RefPolicy ] = ray .remote (AsyncActorRolloutRefWorker )
148- mapping [Role .RefPolicy ] = global_pool_id
118+ available_pools = [ spec . resource_pool for spec in namespace_specs . values () if spec . spawn_roles ]
119+ reward_pool = "reward_pool" if config .reward_model . enable_resource_pool else available_pools [ 0 ]
120+ self . role_worker_mapping [Role .RewardModel ] = ray .remote (RewardModelWorker )
121+ self . mapping [Role .RewardModel ] = reward_pool
149122
150123 reward_fn = load_reward_manager (
151124 config ,
@@ -163,17 +136,72 @@ def run(self, config):
163136 max_resp_len = config .data .max_response_length ,
164137 overlong_buffer_cfg = config .reward_model .overlong_buffer ,
165138 )
166- resource_pool_manager = ResourcePoolManager (resource_pool_spec = resource_pool_spec , mapping = mapping )
139+ reward_fn_map = {spec .name : reward_fn for spec in namespace_specs .values ()}
140+ val_reward_fn_map = {spec .name : val_reward_fn for spec in namespace_specs .values ()}
141+
142+ base_rm_cfg = OmegaConf .to_container (config .reward_model , resolve = True )
143+ base_custom_cfg = OmegaConf .to_container (config .custom_reward_function , resolve = True )
144+ for spec in namespace_specs .values ():
145+ rm_cfg = OmegaConf .to_container (spec .config .reward_model , resolve = True )
146+ custom_cfg = OmegaConf .to_container (spec .config .custom_reward_function , resolve = True )
147+ if rm_cfg != base_rm_cfg or custom_cfg != base_custom_cfg :
148+ reward_fn_map [spec .name ] = load_reward_manager (
149+ spec .config ,
150+ tokenizer ,
151+ 0 ,
152+ max_resp_len = spec .config .data .max_response_length ,
153+ overlong_buffer_cfg = spec .config .reward_model .overlong_buffer ,
154+ )
155+ val_reward_fn_map [spec .name ] = load_reward_manager (
156+ spec .config ,
157+ tokenizer ,
158+ 1 ,
159+ max_resp_len = spec .config .data .max_response_length ,
160+ overlong_buffer_cfg = spec .config .reward_model .overlong_buffer ,
161+ )
162+
163+ resource_pool_manager = self .init_resource_pool_mgr (config , namespace_specs = namespace_specs )
164+
165+ from verl .utils .dataset .rl_dataset import collate_fn
166+ # Create training/validation datasets when only one namespace is present.
167+ train_dataset = val_dataset = train_sampler = None
168+ if len (namespace_specs ) == 1 :
169+ train_dataset = main_ppo_mod .create_rl_dataset (
170+ config .data .train_files ,
171+ config .data ,
172+ tokenizer ,
173+ processor ,
174+ is_train = True ,
175+ max_samples = config .data .get ("train_max_samples" , - 1 ),
176+ )
177+ val_dataset = main_ppo_mod .create_rl_dataset (
178+ config .data .val_files ,
179+ config .data ,
180+ tokenizer ,
181+ processor ,
182+ is_train = False ,
183+ max_samples = config .data .get ("val_max_samples" , - 1 ),
184+ )
185+ train_sampler = main_ppo_mod .create_rl_sampler (config .data , train_dataset )
167186
168187 trainer = RayDAPOTrainer (
169188 config = config ,
170189 tokenizer = tokenizer ,
171190 processor = processor ,
172- role_worker_mapping = role_worker_mapping ,
191+ role_worker_mapping = self . role_worker_mapping ,
173192 resource_pool_manager = resource_pool_manager ,
174193 ray_worker_group_cls = ray_worker_group_cls ,
175194 reward_fn = reward_fn ,
176195 val_reward_fn = val_reward_fn ,
196+ reward_fn_map = reward_fn_map ,
197+ val_reward_fn_map = val_reward_fn_map ,
198+ train_dataset = train_dataset ,
199+ val_dataset = val_dataset ,
200+ collate_fn = collate_fn ,
201+ train_sampler = train_sampler ,
202+ namespace_specs = namespace_specs ,
203+ tokenizers_by_namespace = tokenizers ,
204+ processors_by_namespace = processors ,
177205 )
178206 trainer .init_workers ()
179207 trainer .fit ()
0 commit comments