Skip to content

Commit 6ce6a53

Browse files
committed
bug fix: KL estimators mismatch, vram usage optimized, feat: dapo training
1 parent 69a477f commit 6ce6a53

File tree

7 files changed

+212
-138
lines changed

7 files changed

+212
-138
lines changed

recipe/dapo/dapo_ray_trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ class RayDAPOTrainer(RayPPOTrainer):
4747
Note that this trainer runs on the driver process on a single CPU/GPU node.
4848
"""
4949

50+
def fit(self):
51+
# Delegate to the namespace-aware training loop when a topology or multiple namespaces are configured.
52+
if self.topology_schedule or len(self.namespace_specs) > 1 or len(self.training_order) > 1:
53+
return super().fit()
54+
if not self.training_order:
55+
print("No training namespaces configured; skipping training.")
56+
return
57+
return self._fit_single_namespace()
58+
5059
def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict):
5160
batch.batch["response_mask"] = compute_response_mask(batch)
5261

@@ -73,7 +82,7 @@ def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw
7382

7483
return batch
7584

76-
def fit(self):
85+
def _fit_single_namespace(self):
7786
"""
7887
The training loop of PPO.
7988
The driver process only need to call the compute functions of the worker group through RPC

recipe/dapo/main_dapo.py

Lines changed: 122 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import ray
2323
from omegaconf import OmegaConf
2424

25+
from verl.trainer import main_ppo as main_ppo_mod
26+
from verl.trainer.namespace import build_namespace_specs, namespaced_role_key
2527
from verl.trainer.ppo.reward import load_reward_manager
26-
from verl.utils.device import is_cuda_available
28+
from verl.trainer.ppo.utils import Role
2729

2830
from .dapo_ray_trainer import RayDAPOTrainer
2931

@@ -34,39 +36,12 @@ def main(config):
3436

3537

3638
def run_ppo(config) -> None:
37-
if not ray.is_initialized():
38-
# this is for local ray cluster
39-
default_runtime_env = {
40-
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
41-
}
42-
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
43-
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
44-
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
45-
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
46-
print(f"ray init kwargs: {ray_init_kwargs}")
47-
ray.init(**OmegaConf.to_container(ray_init_kwargs))
48-
49-
try:
50-
if (
51-
is_cuda_available
52-
and config.global_profiler.tool == "nsys"
53-
and OmegaConf.select(config.global_profiler, "steps") is not None
54-
and len(OmegaConf.select(config.global_profiler, "steps")) > 0
55-
):
56-
nsight_options = OmegaConf.to_container(
57-
config.global_profiler.global_tool_config.nsys.controller_nsight_options
58-
)
59-
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
60-
else:
61-
runner = TaskRunner.remote()
62-
ray.get(runner.run.remote(config))
63-
finally:
64-
if ray.is_initialized():
65-
ray.shutdown()
39+
"""Entry point for running DAPO with the PPO runner."""
40+
task_runner_cls = ray.remote(num_cpus=1)(TaskRunner) # type: ignore[arg-type]
41+
main_ppo_mod.run_ppo(config, task_runner_class=task_runner_cls)
6642

6743

68-
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
69-
class TaskRunner:
44+
class TaskRunner(main_ppo_mod.TaskRunner):
7045
def run(self, config):
7146
# print initial config
7247
from pprint import pprint
@@ -80,72 +55,70 @@ def run(self, config):
8055
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
8156
OmegaConf.resolve(config)
8257

83-
# download the checkpoint from hdfs
84-
local_path = copy_to_local(config.actor_rollout_ref.model.path)
85-
86-
# instantiate tokenizer
8758
from verl.utils import hf_processor, hf_tokenizer
8859

8960
trust_remote_code = config.data.get("trust_remote_code", False)
90-
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
91-
# used for multimodal LLM, could be none
92-
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
93-
94-
from verl.single_controller.ray import RayWorkerGroup
95-
96-
# define worker classes
97-
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
98-
assert config.critic.strategy in {"fsdp", "fsdp2"}
99-
100-
from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker
101-
102-
ray_worker_group_cls = RayWorkerGroup
103-
104-
elif config.actor_rollout_ref.actor.strategy == "megatron":
105-
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
106-
from verl.workers.megatron_workers import AsyncActorRolloutRefWorker, CriticWorker
107-
108-
ray_worker_group_cls = RayWorkerGroup
109-
110-
else:
111-
raise NotImplementedError
112-
113-
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
114-
115-
role_worker_mapping = {
116-
Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),
117-
Role.Critic: ray.remote(CriticWorker),
118-
}
119-
120-
global_pool_id = "global_pool"
121-
resource_pool_spec = {
122-
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
123-
}
124-
mapping = {
125-
Role.ActorRollout: global_pool_id,
126-
Role.Critic: global_pool_id,
127-
}
128-
129-
# we should adopt a multi-source reward function here
130-
# - for rule-based rm, we directly call a reward score
131-
# - for model-based rm, we call a model
132-
# - for code related prompt, we send to a sandbox if there are test cases
133-
# - finally, we combine all the rewards together
134-
# - The reward type depends on the tag of the data
61+
namespace_specs = build_namespace_specs(config)
62+
63+
# instantiate tokenizer/processor per namespace
64+
tokenizers = {}
65+
processors = {}
66+
for name, spec in namespace_specs.items():
67+
local_path = copy_to_local(
68+
spec.config.actor_rollout_ref.model.path,
69+
use_shm=spec.config.actor_rollout_ref.model.get("use_shm", False),
70+
)
71+
tokenizers[name] = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
72+
processors[name] = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
73+
74+
active_namespace = config.trainer.get("namespace", "default")
75+
tokenizer = tokenizers[active_namespace]
76+
processor = processors.get(active_namespace)
77+
78+
self.role_worker_mapping = {}
79+
self.mapping = {}
80+
81+
# Register actor-like workers and collect the worker group class.
82+
ray_worker_group_cls = None
83+
for spec in namespace_specs.values():
84+
if not spec.spawn_roles:
85+
continue
86+
actor_cls, rg_cls = self._select_actor_worker_impl(spec.config)
87+
ray_worker_group_cls = rg_cls if ray_worker_group_cls is None else ray_worker_group_cls
88+
if rg_cls is not None and ray_worker_group_cls != rg_cls:
89+
raise ValueError("All namespaces must share the same RayWorkerGroup class")
90+
91+
critic_cls = self._select_critic_worker_impl(spec.config)
92+
93+
for role in spec.spawn_roles:
94+
key = namespaced_role_key(spec.name, role)
95+
if role == Role.Critic:
96+
self.role_worker_mapping[key] = ray.remote(critic_cls)
97+
else:
98+
self.role_worker_mapping[key] = ray.remote(actor_cls)
99+
self.mapping[key] = spec.resource_pool
100+
101+
# reward model
135102
if config.reward_model.enable:
136-
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
137-
from verl.workers.fsdp_workers import RewardModelWorker
138-
elif config.reward_model.strategy == "megatron":
139-
from verl.workers.megatron_workers import RewardModelWorker
103+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
104+
if use_legacy_worker_impl in ["auto", "enable"]:
105+
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
106+
from verl.workers.fsdp_workers import RewardModelWorker
107+
elif config.reward_model.strategy == "megatron":
108+
from verl.workers.megatron_workers import RewardModelWorker
109+
else:
110+
raise NotImplementedError
111+
elif use_legacy_worker_impl == "disable":
112+
from verl.workers.roles import RewardModelWorker
113+
114+
print("Using new worker implementation")
140115
else:
141-
raise NotImplementedError
142-
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
143-
mapping[Role.RewardModel] = global_pool_id
116+
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
144117

145-
# reference model
146-
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
147-
role_worker_mapping[Role.RefPolicy] = ray.remote(AsyncActorRolloutRefWorker)
148-
mapping[Role.RefPolicy] = global_pool_id
118+
available_pools = [spec.resource_pool for spec in namespace_specs.values() if spec.spawn_roles]
119+
reward_pool = "reward_pool" if config.reward_model.enable_resource_pool else available_pools[0]
120+
self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
121+
self.mapping[Role.RewardModel] = reward_pool
149122

150123
reward_fn = load_reward_manager(
151124
config,
@@ -163,17 +136,72 @@ def run(self, config):
163136
max_resp_len=config.data.max_response_length,
164137
overlong_buffer_cfg=config.reward_model.overlong_buffer,
165138
)
166-
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
139+
reward_fn_map = {spec.name: reward_fn for spec in namespace_specs.values()}
140+
val_reward_fn_map = {spec.name: val_reward_fn for spec in namespace_specs.values()}
141+
142+
base_rm_cfg = OmegaConf.to_container(config.reward_model, resolve=True)
143+
base_custom_cfg = OmegaConf.to_container(config.custom_reward_function, resolve=True)
144+
for spec in namespace_specs.values():
145+
rm_cfg = OmegaConf.to_container(spec.config.reward_model, resolve=True)
146+
custom_cfg = OmegaConf.to_container(spec.config.custom_reward_function, resolve=True)
147+
if rm_cfg != base_rm_cfg or custom_cfg != base_custom_cfg:
148+
reward_fn_map[spec.name] = load_reward_manager(
149+
spec.config,
150+
tokenizer,
151+
0,
152+
max_resp_len=spec.config.data.max_response_length,
153+
overlong_buffer_cfg=spec.config.reward_model.overlong_buffer,
154+
)
155+
val_reward_fn_map[spec.name] = load_reward_manager(
156+
spec.config,
157+
tokenizer,
158+
1,
159+
max_resp_len=spec.config.data.max_response_length,
160+
overlong_buffer_cfg=spec.config.reward_model.overlong_buffer,
161+
)
162+
163+
resource_pool_manager = self.init_resource_pool_mgr(config, namespace_specs=namespace_specs)
164+
165+
from verl.utils.dataset.rl_dataset import collate_fn
166+
# Create training/validation datasets when only one namespace is present.
167+
train_dataset = val_dataset = train_sampler = None
168+
if len(namespace_specs) == 1:
169+
train_dataset = main_ppo_mod.create_rl_dataset(
170+
config.data.train_files,
171+
config.data,
172+
tokenizer,
173+
processor,
174+
is_train=True,
175+
max_samples=config.data.get("train_max_samples", -1),
176+
)
177+
val_dataset = main_ppo_mod.create_rl_dataset(
178+
config.data.val_files,
179+
config.data,
180+
tokenizer,
181+
processor,
182+
is_train=False,
183+
max_samples=config.data.get("val_max_samples", -1),
184+
)
185+
train_sampler = main_ppo_mod.create_rl_sampler(config.data, train_dataset)
167186

168187
trainer = RayDAPOTrainer(
169188
config=config,
170189
tokenizer=tokenizer,
171190
processor=processor,
172-
role_worker_mapping=role_worker_mapping,
191+
role_worker_mapping=self.role_worker_mapping,
173192
resource_pool_manager=resource_pool_manager,
174193
ray_worker_group_cls=ray_worker_group_cls,
175194
reward_fn=reward_fn,
176195
val_reward_fn=val_reward_fn,
196+
reward_fn_map=reward_fn_map,
197+
val_reward_fn_map=val_reward_fn_map,
198+
train_dataset=train_dataset,
199+
val_dataset=val_dataset,
200+
collate_fn=collate_fn,
201+
train_sampler=train_sampler,
202+
namespace_specs=namespace_specs,
203+
tokenizers_by_namespace=tokenizers,
204+
processors_by_namespace=processors,
177205
)
178206
trainer.init_workers()
179207
trainer.fit()

verl/trainer/ppo/ray_trainer.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import json
2222
import os
23+
import time
2324
import uuid
2425
from collections import defaultdict
2526
from copy import deepcopy
@@ -1334,43 +1335,43 @@ def _run_single_step(self, runtime: NamespaceRuntime, batch_dict, logger, progre
13341335

13351336
next_step = self.global_steps + 1
13361337
is_last_step = next_step >= runtime.total_training_steps
1337-
with marked_timer("step", timing_raw):
1338-
with marked_timer("gen", timing_raw, color="red"):
1339-
if not self.async_rollout_mode:
1340-
gen_batch_output = self.rollout_wg.generate_sequences(gen_batch_output)
1341-
else:
1342-
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
1338+
step_start = time.perf_counter()
1339+
with marked_timer("gen", timing_raw, color="red"):
1340+
if not self.async_rollout_mode:
1341+
gen_batch_output = self.rollout_wg.generate_sequences(gen_batch_output)
1342+
else:
1343+
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
13431344

1344-
timing_raw.update(gen_batch_output.meta_info.get("timing", {}))
1345-
gen_batch_output.meta_info.pop("timing", None)
1345+
timing_raw.update(gen_batch_output.meta_info.get("timing", {}))
1346+
gen_batch_output.meta_info.pop("timing", None)
13461347

1347-
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
1348-
if self.reward_fn is None:
1349-
raise ValueError("A reward_fn is required for REMAX advantage estimation.")
1348+
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
1349+
if self.reward_fn is None:
1350+
raise ValueError("A reward_fn is required for REMAX advantage estimation.")
13501351

1351-
with marked_timer("gen_max", timing_raw, color="purple"):
1352-
gen_baseline_batch = deepcopy(gen_batch)
1353-
gen_baseline_batch.meta_info["do_sample"] = False
1354-
if not self.async_rollout_mode:
1355-
gen_baseline_output = self.rollout_wg.generate_sequences(gen_baseline_batch)
1356-
else:
1357-
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
1358-
batch = batch.union(gen_baseline_output)
1359-
rm_scores = None
1360-
if self.use_rm and "rm_scores" not in batch.batch.keys():
1361-
rm_scores = self.rm_wg.compute_rm_score(batch)
1362-
batch = batch.union(rm_scores)
1363-
reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)
1364-
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
1365-
1366-
keys_to_pop = set(gen_baseline_output.batch.keys())
1367-
if rm_scores is not None:
1368-
keys_to_pop.update(rm_scores.batch.keys())
1369-
batch.pop(batch_keys=list(keys_to_pop))
1370-
1371-
batch.batch["reward_baselines"] = reward_baseline_tensor
1372-
1373-
del rm_scores, gen_baseline_batch, gen_baseline_output
1352+
with marked_timer("gen_max", timing_raw, color="purple"):
1353+
gen_baseline_batch = deepcopy(gen_batch)
1354+
gen_baseline_batch.meta_info["do_sample"] = False
1355+
if not self.async_rollout_mode:
1356+
gen_baseline_output = self.rollout_wg.generate_sequences(gen_baseline_batch)
1357+
else:
1358+
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
1359+
batch = batch.union(gen_baseline_output)
1360+
rm_scores = None
1361+
if self.use_rm and "rm_scores" not in batch.batch.keys():
1362+
rm_scores = self.rm_wg.compute_rm_score(batch)
1363+
batch = batch.union(rm_scores)
1364+
reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)
1365+
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
1366+
1367+
keys_to_pop = set(gen_baseline_output.batch.keys())
1368+
if rm_scores is not None:
1369+
keys_to_pop.update(rm_scores.batch.keys())
1370+
batch.pop(batch_keys=list(keys_to_pop))
1371+
1372+
batch.batch["reward_baselines"] = reward_baseline_tensor
1373+
1374+
del rm_scores, gen_baseline_batch, gen_baseline_output
13741375
# repeat to align with repeated responses in rollout
13751376
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
13761377
batch = batch.union(gen_batch_output)
@@ -1586,6 +1587,8 @@ def _cached_logprob(ns: str):
15861587
if rollout_data_dir:
15871588
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
15881589

1590+
timing_raw["step"] = time.perf_counter() - step_start
1591+
15891592
next_step = self.global_steps + 1
15901593
if (
15911594
self.val_reward_fn is not None
@@ -1616,7 +1619,14 @@ def _cached_logprob(ns: str):
16161619
metrics.update(compute_data_metrics(batch=metrics_source, use_critic=self.use_critic))
16171620
metrics.update(compute_timing_metrics(batch=metrics_source, timing_raw=timing_raw))
16181621
n_gpus = self.resource_pool_manager.get_n_gpus()
1619-
metrics.update(compute_throughout_metrics(batch=metrics_source, timing_raw=timing_raw, n_gpus=n_gpus))
1622+
perf_metrics = compute_throughout_metrics(batch=metrics_source, timing_raw=timing_raw, n_gpus=n_gpus)
1623+
if getattr(self, "topology_step", None):
1624+
topo_name = getattr(self.topology_step, "name", None)
1625+
if topo_name:
1626+
perf_metrics.update(
1627+
{f"perf-{topo_name}/{k.split('/', 1)[1]}": v for k, v in perf_metrics.items() if "/" in k}
1628+
)
1629+
metrics.update(perf_metrics)
16201630

16211631
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
16221632
self.train_dataloader.sampler.update(batch=metrics_source)

0 commit comments

Comments
 (0)