Skip to content

Commit 69a477f

Browse files
committed
bug fix: validation ckpt
1 parent a70abbd commit 69a477f

File tree

6 files changed

+65
-12
lines changed

6 files changed

+65
-12
lines changed

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ trainer:
595595
resume_from_path: null
596596
del_local_ckpt_after_load: false
597597
val_before_train: true
598+
validation_use_train_namespace: true
598599
test_freq: -1
599600
critic_warmup: 0
600601
default_hdfs_dir: null

verl/trainer/config/_generated_ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ trainer:
511511
log_val_generations: 0
512512
rollout_data_dir: null
513513
validation_data_dir: null
514+
validation_use_train_namespace: true
514515
nnodes: 1
515516
n_gpus_per_node: 8
516517
save_freq: -1

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ trainer:
147147
resume_from_path: null
148148
del_local_ckpt_after_load: False
149149
val_before_train: True
150+
# Whether to run validation using each training namespace's own policy (instead of the rollout_from provider)
151+
validation_use_train_namespace: True
150152
test_freq: -1
151153
critic_warmup: 0
152154
default_hdfs_dir: null

verl/trainer/config/ppo_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ trainer:
183183
# Directory for logging validation data; no dump if null
184184
validation_data_dir: null
185185

186+
# Whether to run validation using each training namespace's own policy (instead of the rollout_from provider)
187+
validation_use_train_namespace: True
188+
186189
# Number of nodes used in the training
187190
nnodes: 1
188191

verl/trainer/namespace.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,18 @@ def _maybe_override_namespace_fields(cfg: DictConfig, name: str, rollout_from: s
105105

106106

107107
def _compute_spawn_roles(spec: NamespaceSpec, rollout_dependents: Iterable[str], ref_dependents: Iterable[str]):
108-
provides_rollout = spec.rollout_from == spec.name or len(list(rollout_dependents)) > 0
108+
prefer_val_rollout = bool(spec.config.trainer.get("validation_use_train_namespace", False))
109+
wants_validation = bool(
110+
spec.train
111+
and prefer_val_rollout
112+
and (
113+
spec.config.trainer.get("val_only", False)
114+
or spec.config.trainer.get("val_before_train", True)
115+
or spec.config.trainer.get("test_freq", -1) > 0
116+
)
117+
)
118+
119+
provides_rollout = spec.rollout_from == spec.name or len(list(rollout_dependents)) > 0 or wants_validation
109120
provides_ref = (spec.needs_ref and spec.ref_from == spec.name) or len(list(ref_dependents)) > 0
110121

111122
spec.provides_rollout = provides_rollout

verl/trainer/ppo/ray_trainer.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores):
808808
# Log to each configured logger
809809
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
810810

811-
def _get_gen_batch(self, batch: DataProto) -> DataProto:
811+
def _get_gen_batch(self, batch: DataProto, async_rollout_mode: bool | None = None) -> DataProto:
812812
reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys()
813813

814814
# pop those keys for generation
@@ -820,7 +820,10 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto:
820820
)
821821

822822
# For agent loop, we need reward model keys to compute score.
823-
if self.async_rollout_mode:
823+
if async_rollout_mode is None:
824+
async_rollout_mode = self.async_rollout_mode
825+
826+
if async_rollout_mode:
824827
gen_batch.non_tensor_batch.update(batch.non_tensor_batch)
825828

826829
return gen_batch
@@ -829,6 +832,40 @@ def _validate(self):
829832
data_source_lst = []
830833
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
831834

835+
prefer_train_rollout = bool(self.config.trainer.get("validation_use_train_namespace", False))
836+
rollout_ns = self.namespace_specs[self.active_namespace].rollout_from
837+
if prefer_train_rollout and self.active_namespace in self.rollout_wg_map:
838+
rollout_ns = self.active_namespace
839+
840+
rollout_wg = self.rollout_wg_map.get(rollout_ns)
841+
if rollout_wg is None and rollout_ns != self.namespace_specs[self.active_namespace].rollout_from:
842+
fallback_ns = self.namespace_specs[self.active_namespace].rollout_from
843+
rollout_wg = self.rollout_wg_map.get(fallback_ns)
844+
rollout_ns = fallback_ns
845+
846+
if rollout_wg is None:
847+
raise ValueError(
848+
f"Rollout worker for validation namespace '{rollout_ns}' (active namespace '{self.active_namespace}') is missing"
849+
)
850+
851+
rollout_cfg = self.namespace_specs[rollout_ns].config.actor_rollout_ref.rollout
852+
async_rollout_mode = rollout_cfg.mode == "async"
853+
async_rollout_manager = None
854+
if async_rollout_mode:
855+
if self.async_rollout_mode and rollout_wg is self.rollout_wg:
856+
async_rollout_manager = self.async_rollout_manager
857+
else:
858+
from verl.experimental.agent_loop import AgentLoopManager
859+
860+
rm_resource_pool = None
861+
if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool:
862+
rm_resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
863+
async_rollout_manager = AgentLoopManager(
864+
config=self.namespace_specs[rollout_ns].config,
865+
worker_group=rollout_wg,
866+
rm_resource_pool=rm_resource_pool,
867+
)
868+
832869
# Lists to collect samples for the table
833870
sample_inputs = []
834871
sample_outputs = []
@@ -850,7 +887,7 @@ def _validate(self):
850887

851888
# repeat test batch
852889
test_batch = test_batch.repeat(
853-
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
890+
repeat_times=rollout_cfg.val_kwargs.n, interleave=True
854891
)
855892

856893
# we only do validation on rule-based rm
@@ -869,28 +906,26 @@ def _validate(self):
869906
]
870907
sample_gts.extend(ground_truths)
871908

872-
test_gen_batch = self._get_gen_batch(test_batch)
909+
test_gen_batch = self._get_gen_batch(test_batch, async_rollout_mode=async_rollout_mode)
873910
test_gen_batch.meta_info = {
874911
"eos_token_id": self.tokenizer.eos_token_id,
875912
"pad_token_id": self.tokenizer.pad_token_id,
876913
"recompute_log_prob": False,
877-
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
914+
"do_sample": rollout_cfg.val_kwargs.do_sample,
878915
"validate": True,
879916
"global_steps": self.global_steps,
880917
}
881918
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
882919

883920
# pad to be divisible by dp_size
884921
size_divisor = (
885-
self.rollout_wg.world_size
886-
if not self.async_rollout_mode
887-
else self.config.actor_rollout_ref.rollout.agent.num_workers
922+
rollout_wg.world_size if not async_rollout_mode else rollout_cfg.agent.num_workers
888923
)
889924
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
890-
if not self.async_rollout_mode:
891-
test_output_gen_batch_padded = self.rollout_wg.generate_sequences(test_gen_batch_padded)
925+
if not async_rollout_mode:
926+
test_output_gen_batch_padded = rollout_wg.generate_sequences(test_gen_batch_padded)
892927
else:
893-
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
928+
test_output_gen_batch_padded = async_rollout_manager.generate_sequences(test_gen_batch_padded)
894929

895930
# unpad
896931
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)

0 commit comments

Comments
 (0)