@@ -808,7 +808,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores):
808808 # Log to each configured logger
809809 self .validation_generations_logger .log (self .config .trainer .logger , samples , self .global_steps )
810810
811- def _get_gen_batch (self , batch : DataProto ) -> DataProto :
811+ def _get_gen_batch (self , batch : DataProto , async_rollout_mode : bool | None = None ) -> DataProto :
812812 reward_model_keys = set ({"data_source" , "reward_model" , "extra_info" , "uid" }) & batch .non_tensor_batch .keys ()
813813
814814 # pop those keys for generation
@@ -820,7 +820,10 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto:
820820 )
821821
822822 # For agent loop, we need reward model keys to compute score.
823- if self .async_rollout_mode :
823+ if async_rollout_mode is None :
824+ async_rollout_mode = self .async_rollout_mode
825+
826+ if async_rollout_mode :
824827 gen_batch .non_tensor_batch .update (batch .non_tensor_batch )
825828
826829 return gen_batch
@@ -829,6 +832,40 @@ def _validate(self):
829832 data_source_lst = []
830833 reward_extra_infos_dict : dict [str , list ] = defaultdict (list )
831834
835+ prefer_train_rollout = bool (self .config .trainer .get ("validation_use_train_namespace" , False ))
836+ rollout_ns = self .namespace_specs [self .active_namespace ].rollout_from
837+ if prefer_train_rollout and self .active_namespace in self .rollout_wg_map :
838+ rollout_ns = self .active_namespace
839+
840+ rollout_wg = self .rollout_wg_map .get (rollout_ns )
841+ if rollout_wg is None and rollout_ns != self .namespace_specs [self .active_namespace ].rollout_from :
842+ fallback_ns = self .namespace_specs [self .active_namespace ].rollout_from
843+ rollout_wg = self .rollout_wg_map .get (fallback_ns )
844+ rollout_ns = fallback_ns
845+
846+ if rollout_wg is None :
847+ raise ValueError (
848+ f"Rollout worker for validation namespace '{ rollout_ns } ' (active namespace '{ self .active_namespace } ') is missing"
849+ )
850+
851+ rollout_cfg = self .namespace_specs [rollout_ns ].config .actor_rollout_ref .rollout
852+ async_rollout_mode = rollout_cfg .mode == "async"
853+ async_rollout_manager = None
854+ if async_rollout_mode :
855+ if self .async_rollout_mode and rollout_wg is self .rollout_wg :
856+ async_rollout_manager = self .async_rollout_manager
857+ else :
858+ from verl .experimental .agent_loop import AgentLoopManager
859+
860+ rm_resource_pool = None
861+ if self .config .reward_model .enable and self .config .reward_model .enable_resource_pool :
862+ rm_resource_pool = self .resource_pool_manager .get_resource_pool (Role .RewardModel )
863+ async_rollout_manager = AgentLoopManager (
864+ config = self .namespace_specs [rollout_ns ].config ,
865+ worker_group = rollout_wg ,
866+ rm_resource_pool = rm_resource_pool ,
867+ )
868+
832869 # Lists to collect samples for the table
833870 sample_inputs = []
834871 sample_outputs = []
@@ -850,7 +887,7 @@ def _validate(self):
850887
851888 # repeat test batch
852889 test_batch = test_batch .repeat (
853- repeat_times = self . config . actor_rollout_ref . rollout .val_kwargs .n , interleave = True
890+ repeat_times = rollout_cfg .val_kwargs .n , interleave = True
854891 )
855892
856893 # we only do validation on rule-based rm
@@ -869,28 +906,26 @@ def _validate(self):
869906 ]
870907 sample_gts .extend (ground_truths )
871908
872- test_gen_batch = self ._get_gen_batch (test_batch )
909+ test_gen_batch = self ._get_gen_batch (test_batch , async_rollout_mode = async_rollout_mode )
873910 test_gen_batch .meta_info = {
874911 "eos_token_id" : self .tokenizer .eos_token_id ,
875912 "pad_token_id" : self .tokenizer .pad_token_id ,
876913 "recompute_log_prob" : False ,
877- "do_sample" : self . config . actor_rollout_ref . rollout .val_kwargs .do_sample ,
914+ "do_sample" : rollout_cfg .val_kwargs .do_sample ,
878915 "validate" : True ,
879916 "global_steps" : self .global_steps ,
880917 }
881918 print (f"test_gen_batch meta info: { test_gen_batch .meta_info } " )
882919
883920 # pad to be divisible by dp_size
884921 size_divisor = (
885- self .rollout_wg .world_size
886- if not self .async_rollout_mode
887- else self .config .actor_rollout_ref .rollout .agent .num_workers
922+ rollout_wg .world_size if not async_rollout_mode else rollout_cfg .agent .num_workers
888923 )
889924 test_gen_batch_padded , pad_size = pad_dataproto_to_divisor (test_gen_batch , size_divisor )
890- if not self . async_rollout_mode :
891- test_output_gen_batch_padded = self . rollout_wg .generate_sequences (test_gen_batch_padded )
925+ if not async_rollout_mode :
926+ test_output_gen_batch_padded = rollout_wg .generate_sequences (test_gen_batch_padded )
892927 else :
893- test_output_gen_batch_padded = self . async_rollout_manager .generate_sequences (test_gen_batch_padded )
928+ test_output_gen_batch_padded = async_rollout_manager .generate_sequences (test_gen_batch_padded )
894929
895930 # unpad
896931 test_output_gen_batch = unpad_dataproto (test_output_gen_batch_padded , pad_size = pad_size )
0 commit comments