diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index 17f78bcf15..ea08cf8c62 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -73,43 +73,42 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: def collect_trajectories( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, - val_dataloader: StatefulDataLoader, + val_dataloader: dict[str, StatefulDataLoader], tokenizer: TokenizerType, val_task_to_env: dict[str, EnvironmentInterface], logger: Logger, master_config: MasterConfig, ) -> None: - """Run trajectory collection.""" - # common config/state items + """Run trajectory collection across all validation dataloaders.""" colocated_inference = master_config.policy["generation"]["colocated"]["enabled"] refit_policy_generation(policy, policy_generation, colocated_inference) - log_filename = "trajectory_collection.jsonl" - print("\nšŸ” Running trajectory collection...", flush=True) generation_config = master_config.policy["generation"] - for val_batch in val_dataloader: - nemo_gym_rollout_result = run_async_nemo_gym_rollout( - policy_generation=policy_generation, - input_batch=val_batch, - tokenizer=tokenizer, - task_to_env=val_task_to_env, - max_seq_len=master_config.policy["max_total_sequence_length"], - generation_config=generation_config, - max_rollout_turns=None, - greedy=False, - ) + for dataset_name, dl in val_dataloader.items(): + log_filename = f"trajectory_collection_{dataset_name}.jsonl" + for val_batch in dl: + nemo_gym_rollout_result = run_async_nemo_gym_rollout( + policy_generation=policy_generation, + input_batch=val_batch, + tokenizer=tokenizer, + task_to_env=val_task_to_env, + max_seq_len=master_config.policy["max_total_sequence_length"], + generation_config=generation_config, + max_rollout_turns=None, + greedy=False, + ) - rows_to_log: list[str] = [] - for key, value in nemo_gym_rollout_result.rollout_metrics.items(): - if "full_result" not in key: - continue + rows_to_log: list[str] = [] + for key, value in nemo_gym_rollout_result.rollout_metrics.items(): + if "full_result" not in key: + continue - value: Table - data: list[list[str]] = value.data # (n, 1) - rows_to_log.extend(v[0] for v in data) + value: Table + data: list[list[str]] = value.data # (n, 1) + rows_to_log.extend(v[0] for v in data) - logger.log_string_list_as_jsonl(rows_to_log, log_filename) + logger.log_string_list_as_jsonl(rows_to_log, log_filename) # TODO: eventually as trajectory collection use cases exceed 4 hours, we can leverage the dataloader save functionality to resume # And also leverage the TimeoutChecker functionality as well @@ -180,10 +179,11 @@ def main() -> None: ) if val_dataset is not None: + total_val_samples = sum(len(ds) for ds in val_dataset.values()) print( - f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {len(val_dataset)}" + f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {total_val_samples}" ) - config.grpo["max_val_samples"] = len(val_dataset) + config.grpo["max_val_samples"] = total_val_samples config.grpo["val_batch_size"] = config.grpo["max_val_samples"] # Print config diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py index 6aed9e8044..b1be15e683 100644 --- a/examples/run_grpo_sliding_puzzle.py +++ b/examples/run_grpo_sliding_puzzle.py @@ -241,14 +241,20 @@ def main(): * config.grpo["num_generations_per_prompt"] * config.grpo["max_num_steps"] ) + puzzle_task_name = "sliding_puzzle_game" dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data( tokenizer=tokenizer, env_cfg=config.env, - task_name="sliding_puzzle_game", + task_name=puzzle_task_name, length=ds_length, val_length=config.grpo["max_val_samples"], add_system_prompt=config.data["add_system_prompt"], ) + # Algorithm setup expects val_dataset as a name->dataset dict so validate() + # can emit per-dataset metrics. The puzzle generator produces a single + # iterable; wrap it under its task name. + if val_dataset is not None: + val_dataset = {puzzle_task_name: val_dataset} ( policy, diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index bbc54ab0c8..e73d7441f7 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -83,7 +83,7 @@ class DistillationConfig(TypedDict): # Whether to run validation on the last training step. Setting this to True ensures the # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). val_at_end: bool - max_val_samples: int + max_val_samples: NotRequired[int] topk_logits_k: int seed: int @@ -159,13 +159,13 @@ def setup( master_config: MasterConfig, tokenizer: TokenizerType, train_dataset: AllTaskProcessedDataset, - val_dataset: Optional[AllTaskProcessedDataset], + val_dataset: Optional[dict[str, AllTaskProcessedDataset]], ) -> tuple[ ColocatablePolicyInterface, # student_policy ColocatablePolicyInterface, # teacher_policy Optional[GenerationInterface], # student_generation StatefulDataLoader, - Optional[StatefulDataLoader], + Optional[dict[str, StatefulDataLoader]], DistillationLossFn, Logger, CheckpointManager, @@ -262,9 +262,9 @@ def setup( f" āœ“ Training dataloader loaded with {len(train_dataset)} samples", flush=True ) - # Load validation dataset if provided - val_dataloader: Optional[StatefulDataLoader] = None - # If validation is enabled, load the validation dataloader + # Load validation dataset if provided. One StatefulDataLoader per dataset + # so validate() can emit per-dataset metrics; mirrors the DPO pattern. + val_dataloader: Optional[dict[str, StatefulDataLoader]] = None if ( distillation_config["val_period"] > 0 or distillation_config["val_at_start"] @@ -273,14 +273,19 @@ def setup( assert val_dataset is not None, ( "Validation dataset is required if validation is enabled" ) - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=distillation_config["val_batch_size"], - shuffle=False, - collate_fn=rl_collate_fn, - ) + val_dataloader = { + name: StatefulDataLoader( + ds, + batch_size=distillation_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + ) + for name, ds in val_dataset.items() + } + total_samples = sum(len(ds) for ds in val_dataset.values()) print( - f" āœ“ Validation dataloader loaded with {len(val_dataset)} samples", + f" āœ“ Validation dataloader loaded with {total_samples} samples" + f" across {len(val_dataset)} dataset(s)", flush=True, ) @@ -508,7 +513,7 @@ def distillation_train( teacher_policy: ColocatablePolicyInterface, student_generation: Optional[GenerationInterface], dataloader: StatefulDataLoader, - val_dataloader: Optional[StatefulDataLoader], + val_dataloader: Optional[dict[str, StatefulDataLoader]], tokenizer: TokenizerType, loss_fn: DistillationLossFn, task_to_env: dict[str, EnvironmentInterface], @@ -572,10 +577,9 @@ def distillation_train( val_task_to_env, step=total_steps, master_config=master_config, + logger=logger, ) student_generation.finish_generation() - logger.log_metrics(val_metrics, total_steps, prefix="validation") - logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") # Run distillation training (multi-epoch until reaching max_num_steps or max_num_epochs) batch: BatchedDataDict[DatumSpec] @@ -747,14 +751,9 @@ def distillation_train( val_task_to_env, step=total_steps + 1, master_config=master_config, + logger=logger, ) student_generation.finish_generation() - logger.log_metrics( - validation_timings, total_steps + 1, prefix="timing/validation" - ) - logger.log_metrics( - val_metrics, total_steps + 1, prefix="validation" - ) metrics = { "loss": train_results["loss"].numpy(), @@ -951,14 +950,22 @@ def distillation_train( def validate( policy_generation: GenerationInterface, - val_dataloader: Optional[StatefulDataLoader], + val_dataloader: Optional[dict[str, StatefulDataLoader]], tokenizer, val_task_to_env: Optional[dict[str, EnvironmentInterface]], step: int, master_config: MasterConfig, + logger: Optional[Logger] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: - """Run validation on the validation dataset.""" - if val_dataloader is None: + """Run validation across one or more datasets. + + Each dataset gets its own dataloader and is logged under a + `validation-/` prefix, mirroring the DPO setup. The returned + `val_metrics` dict carries the per-dataset entries alongside an + `accuracy` macro-mean across datasets so save-state and best-checkpoint + selection keep working. + """ + if not val_dataloader: print(" āš ļø No validation dataloader provided, skipping validation", flush=True) return {}, {} @@ -969,25 +976,84 @@ def validate( ) return {}, {} + val_metrics: dict[str, Any] = {} + validation_timings: dict[str, Any] = {} + per_dataset_accuracy: list[float] = [] + + for dataset_name, dl in val_dataloader.items(): + k_metrics, k_timings = validate_one_dataset( + policy_generation, + dl, + tokenizer, + val_task_to_env, + step=step, + master_config=master_config, + dataset_name=dataset_name, + ) + prefix = f"validation-{dataset_name}" + if logger is not None: + logger.log_metrics(k_metrics, step, prefix=prefix) + logger.log_metrics(k_timings, step, prefix=f"timing/{prefix}") + for metric_name, value in k_metrics.items(): + val_metrics[f"{prefix}_{metric_name}"] = value + for timing_name, value in k_timings.items(): + validation_timings[f"{prefix}_{timing_name}"] = value + if "accuracy" in k_metrics: + per_dataset_accuracy.append(k_metrics["accuracy"]) + + if per_dataset_accuracy: + # Macro-mean across datasets so checkpointing.metric_name='val:accuracy' + # keeps working under multi-validation. + val_metrics["accuracy"] = sum(per_dataset_accuracy) / len(per_dataset_accuracy) + + if validation_timings: + total_validation_time = sum( + v for k, v in validation_timings.items() if k.endswith("total_validation_time") + ) + if logger is not None: + logger.log_metrics( + {"total_validation_time": total_validation_time}, + step, + prefix="timing/validation", + ) + validation_timings["total_validation_time"] = total_validation_time + + return val_metrics, validation_timings + + +def validate_one_dataset( + policy_generation: GenerationInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + val_task_to_env: dict[str, EnvironmentInterface], + step: int, + master_config: MasterConfig, + dataset_name: str, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run validation on one validation dataset and return its metrics + timings.""" timer = Timer() with timer.time("total_validation_time"): - print(f"ā–¶ Starting validation at step {step}...", flush=True) + print( + f"ā–¶ Starting validation at step {step} for `{dataset_name}`...", + flush=True, + ) + + total_rewards: list[float] = [] + total_lengths: list[float] = [] + all_message_logs = [] - total_rewards = [] # Can be any metric. Setted to 'accuracy' by default. - total_lengths = [] - all_message_logs = [] # Collect all message logs + max_val_samples = master_config.distillation.get("max_val_samples") + if max_val_samples is None: + max_batches = len(val_dataloader) + else: + max_batches = ( + max_val_samples // master_config.distillation["val_batch_size"] + ) - max_batches = ( - master_config.distillation["max_val_samples"] - + master_config.distillation["val_batch_size"] - - 1 - ) // master_config.distillation["val_batch_size"] for batch_idx, val_batch in enumerate(val_dataloader): if batch_idx >= max_batches: break - # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) - # Use async rollouts if vLLM async engine is enabled if _should_use_async_rollouts(master_config): val_batch, gen_metrics = run_async_multi_turn_rollout( policy_generation, @@ -1013,17 +1079,14 @@ def validate( total_rewards.extend(rewards.tolist()) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) - # Collect message logs for later display to_env = [ get_keys_from_message_log( val_batch["message_log"][i], ["role", "content"] ) for i in range(len(val_batch["message_log"])) ] - all_message_logs.extend(to_env) - # Calculate validation metrics accuracy = ( sum(total_rewards) / len(total_rewards) if len(total_rewards) > 0 else 0 ) @@ -1036,7 +1099,6 @@ def validate( "avg_length": avg_length, } - # Print sample conversations only once at the end of validation try: print_message_log_samples( all_message_logs, @@ -1051,22 +1113,15 @@ def validate( print(f"\n āš ļø Error displaying message samples: {str(e)}") print(" āš ļø Continuing validation without displaying samples...", flush=True) - # Get timing metrics timing_metrics = timer.get_timing_metrics(reduction_op="sum") validation_time = timing_metrics.get("total_validation_time", 0) - # Print summary of validation results - print("\nšŸ“Š Validation Results:") + print(f"\nšŸ“Š Validation Results for `{dataset_name}`:") print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") print(f" • Samples processed: {len(total_rewards)}", flush=True) - - # Print timing information print("\n ā±ļø Validation Timing:") - validation_time = timing_metrics.get("total_validation_time", 0) print(f" • Total validation time: {validation_time:.2f}s", flush=True) - # Make sure to reset the timer after validation timer.reset() - return val_metrics, timing_metrics diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d26fb0bcae..d835a32d21 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -218,14 +218,14 @@ def setup( master_config: MasterConfig, tokenizer: TokenizerType, dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset], - val_dataset: Optional[AllTaskProcessedDataset], + val_dataset: Optional[dict[str, AllTaskProcessedDataset]], processor: Optional[AutoProcessor] = None, ) -> tuple[ ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], StatefulDataLoader | MultipleDataloaderWrapper, - Optional[StatefulDataLoader], + Optional[dict[str, StatefulDataLoader]], ClippedPGLossFn, Logger, CheckpointManager, @@ -355,9 +355,9 @@ def init_train_dataloader(dataset, suffix: str = ""): flush=True, ) - # Load validation dataset if provided - val_dataloader: Optional[StatefulDataLoader] = None - # If validation is enabled, load the validation dataloader + # Load validation dataset if provided. One StatefulDataLoader per dataset + # so validate() can emit per-dataset metrics; mirrors the DPO pattern. + val_dataloader: Optional[dict[str, StatefulDataLoader]] = None if ( grpo_config["val_period"] > 0 or grpo_config["val_at_start"] @@ -366,15 +366,20 @@ def init_train_dataloader(dataset, suffix: str = ""): assert val_dataset is not None, ( "Validation dataset is required if validation is enabled" ) - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=grpo_config["val_batch_size"], - shuffle=False, - collate_fn=rl_collate_fn, - num_workers=data_config["num_workers"], - ) + val_dataloader = { + name: StatefulDataLoader( + ds, + batch_size=grpo_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + num_workers=data_config["num_workers"], + ) + for name, ds in val_dataset.items() + } + total_samples = sum(len(ds) for ds in val_dataset.values()) print( - f" āœ“ Validation dataloader loaded with {len(val_dataset)} samples", + f" āœ“ Validation dataloader loaded with {total_samples} samples" + f" across {len(val_dataset)} dataset(s)", flush=True, ) @@ -1336,7 +1341,7 @@ def grpo_train( policy: ColocatablePolicyInterface, policy_generation: Optional[GenerationInterface], wrapped_dataloader: StatefulDataLoader | MultipleDataloaderWrapper, - val_dataloader: Optional[StatefulDataLoader], + val_dataloader: Optional[dict[str, StatefulDataLoader]], tokenizer: TokenizerType, loss_fn: LossFunction, task_to_env: dict[str, EnvironmentInterface], @@ -1414,8 +1419,6 @@ def grpo_train( logger=logger, ) policy_generation.finish_generation() - logger.log_metrics(val_metrics, current_step, prefix="validation") - logger.log_metrics(validation_timings, current_step, prefix="timing/validation") if master_config.data["use_multiple_dataloader"]: warnings.warn( @@ -1912,12 +1915,6 @@ def grpo_train( logger=logger, ) policy_generation.finish_generation() - logger.log_metrics( - validation_timings, total_steps + 1, prefix="timing/validation" - ) - logger.log_metrics( - val_metrics, total_steps + 1, prefix="validation" - ) # Get flat advantages and token mask for masked metrics computation flat_advantages = train_data["advantages"] @@ -2265,40 +2262,111 @@ def grpo_train( def validate( policy_generation: GenerationInterface, - val_dataloader: Optional[StatefulDataLoader], + val_dataloader: Optional[dict[str, StatefulDataLoader]], tokenizer, val_task_to_env: Optional[dict[str, EnvironmentInterface]], step: int, master_config: MasterConfig, logger: Optional[Logger] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: - """Run validation on the validation dataset.""" - if val_dataloader is None: - assert val_dataloader is not None or master_config.grpo["val_period"] == 0, ( - "val_dataloader is None, so grpo.val_period must be 0" + """Run validation across one or more datasets. + + Each dataset gets its own dataloader and is logged under a + `validation-/` prefix, mirroring the DPO setup. The returned + `val_metrics` dict carries the per-dataset entries alongside an + `accuracy` macro-mean across datasets so save-state and best-checkpoint + selection keep working. + """ + if not val_dataloader: + assert master_config.grpo["val_period"] == 0, ( + "val_dataloader is empty, so grpo.val_period must be 0" ) print(" āš ļø No validation dataloader provided, skipping validation", flush=True) return {}, {} + val_metrics: dict[str, Any] = {} + validation_timings: dict[str, Any] = {} + per_dataset_accuracy: list[float] = [] + + for dataset_name, dl in val_dataloader.items(): + k_metrics, k_timings = validate_one_dataset( + policy_generation, + dl, + tokenizer, + val_task_to_env, + step=step, + master_config=master_config, + dataset_name=dataset_name, + logger=logger, + ) + prefix = f"validation-{dataset_name}" + if logger is not None: + logger.log_metrics(k_metrics, step, prefix=prefix) + logger.log_metrics(k_timings, step, prefix=f"timing/{prefix}") + for metric_name, value in k_metrics.items(): + val_metrics[f"{prefix}_{metric_name}"] = value + for timing_name, value in k_timings.items(): + validation_timings[f"{prefix}_{timing_name}"] = value + if "accuracy" in k_metrics: + per_dataset_accuracy.append(k_metrics["accuracy"]) + + if per_dataset_accuracy: + # Macro-mean across datasets so checkpointing.metric_name='val:accuracy' + # keeps working under multi-validation. + val_metrics["accuracy"] = sum(per_dataset_accuracy) / len(per_dataset_accuracy) + + if validation_timings: + total_validation_time = sum( + v for k, v in validation_timings.items() if k.endswith("total_validation_time") + ) + if logger is not None: + logger.log_metrics( + {"total_validation_time": total_validation_time}, + step, + prefix="timing/validation", + ) + validation_timings["total_validation_time"] = total_validation_time + + # Explicit GPU memory cleanup after validation + gc.collect() + torch.cuda.empty_cache() + + return val_metrics, validation_timings + + +def validate_one_dataset( + policy_generation: GenerationInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + step: int, + master_config: MasterConfig, + dataset_name: str, + logger: Optional[Logger] = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run validation on one validation dataset and return its metrics + timings.""" timer = Timer() with timer.time("total_validation_time"): - print(f"ā–¶ Starting validation at step {step}...", flush=True) + print( + f"ā–¶ Starting validation at step {step} for `{dataset_name}`...", + flush=True, + ) total_rewards = [] total_lengths = [] - all_message_logs = [] # Collect all message logs + all_message_logs = [] - max_batches = ( - master_config.grpo["max_val_samples"] - // master_config.grpo["val_batch_size"] - ) + max_val_samples = master_config.grpo.get("max_val_samples") + if max_val_samples is None: + max_batches = len(val_dataloader) + else: + max_batches = max_val_samples // master_config.grpo["val_batch_size"] + + additional_metrics_to_report: dict[str, Any] = {} for batch_idx, val_batch in enumerate(val_dataloader): if batch_idx >= max_batches: break - additional_metrics_to_report = dict() - # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) - # Use async rollouts if vLLM async engine is enabled # We cascade NeMo-Gym first since NeMo-Gym also uses async rollouts. if _should_use_nemo_gym(master_config): generation_config = master_config.policy["generation"] @@ -2339,17 +2407,14 @@ def validate( total_rewards.extend(val_batch["total_reward"].tolist()) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) - # Collect message logs for later display to_env = [ get_keys_from_message_log( val_batch["message_log"][i], ["role", "content"] ) for i in range(len(val_batch["message_log"])) ] - all_message_logs.extend(to_env) - # Calculate validation metrics num_samples = len(total_rewards) if num_samples > 0: rewards_t = torch.tensor(total_rewards, dtype=torch.float32) @@ -2367,7 +2432,6 @@ def validate( **additional_metrics_to_report, } - # Print sample conversations only once at the end of validation try: print_message_log_samples( all_message_logs, @@ -2382,36 +2446,26 @@ def validate( print(f"\n āš ļø Error displaying message samples: {str(e)}") print(" āš ļø Continuing validation without displaying samples...", flush=True) - # Get timing metrics timing_metrics = timer.get_timing_metrics(reduction_op="sum") validation_time = timing_metrics.get("total_validation_time", 0) - # Print summary of validation results - print("\nšŸ“Š Validation Results:") + print(f"\nšŸ“Š Validation Results for `{dataset_name}`:") print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") - print(f" • Samples processed: {len(total_rewards)}", flush=True) - - # Print timing information + print(f" • Samples processed: {num_samples}", flush=True) print("\n ā±ļø Validation Timing:") - validation_time = timing_metrics.get("total_validation_time", 0) print(f" • Total validation time: {validation_time:.2f}s", flush=True) - # Log validation data to JSONL file if logger is not None: val_log_data = { "content": all_message_logs, "rewards": total_rewards, } - logger.log_batched_dict_as_jsonl(val_log_data, f"val_data_step{step}.jsonl") + logger.log_batched_dict_as_jsonl( + val_log_data, f"val_data_{dataset_name}_step{step}.jsonl" + ) - # Make sure to reset the timer after validation timer.reset() - - # Explicit GPU memory cleanup after validation - gc.collect() - torch.cuda.empty_cache() - return val_metrics, timing_metrics @@ -2419,7 +2473,7 @@ def async_grpo_train( policy: ColocatablePolicyInterface, policy_generation: Optional[GenerationInterface], dataloader: StatefulDataLoader, - val_dataloader: Optional[StatefulDataLoader], + val_dataloader: Optional[dict[str, StatefulDataLoader]], tokenizer: TokenizerType, loss_fn: LossFunction, task_to_env: dict[str, EnvironmentInterface], @@ -2642,8 +2696,6 @@ def async_grpo_train( logger=logger, ) policy_generation.finish_generation() - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") print("āœ… Initial validation completed successfully") except Exception as e: print(f"āŒ Initial validation failed: {e}") @@ -2999,10 +3051,6 @@ def async_grpo_train( logger=logger, ) policy_generation.finish_generation() - logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" - ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") # Explicit GPU memory cleanup after validation in async mode import gc diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 2819e27582..8a09ffe717 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -152,12 +152,15 @@ def setup_response_data( val_task_data_processors = {} val_task_data_preprocessors = {} val_task_to_env = {} - val_data_list = [] + # (task_name, raw_dataset) pairs preserved in load order so each appears + # as its own entry in the returned dict and gets per-dataset metrics + # downstream. Mirrors the pattern DPO uses in its setup. + val_data_pairs: list[tuple[str, Any]] = [] # validation dataset from train dataset (when train dataset's split_validation_size > 0) for data in data_list: if hasattr(data, "val_dataset") and data.val_dataset is not None: - val_data_list.append(data.val_dataset) + val_data_pairs.append((data.task_name, data.val_dataset)) print( f" - Loaded validation dataset {data.task_name} with {len(data.val_dataset)} samples." ) @@ -181,7 +184,7 @@ def setup_response_data( if "default" in data_config and data_config["default"] is not None: update_single_dataset_config(cfg, data_config["default"]) val_data = load_response_dataset(cfg) - val_data_list.append(val_data.dataset) + val_data_pairs.append((val_data.task_name, val_data.dataset)) print( f" - Loaded validation dataset {val_data.task_name} with {len(val_data.dataset)} samples." ) @@ -196,19 +199,36 @@ def setup_response_data( if has_envs: val_task_to_env[task_name] = envs[cfg["env_name"]] - # merge datasets - val_dataset = None - if len(val_data_list) > 0: - merged_val_data = concatenate_datasets(val_data_list) - val_dataset = AllTaskProcessedDataset( - merged_val_data, - tokenizer, - None, - val_task_data_processors, - task_data_preprocessors=val_task_data_preprocessors, - max_seq_length=data_config["max_input_seq_length"], + # Build a dict[name -> AllTaskProcessedDataset] so the algorithm setup + # can produce one StatefulDataLoader per dataset and validate() can emit + # per-dataset metrics under a `validation-/` prefix. + val_dataset: Optional[dict[str, AllTaskProcessedDataset]] = None + if len(val_data_pairs) > 0: + seen_names: set[str] = set() + for name, _ in val_data_pairs: + if name in seen_names: + raise ValueError( + f"Duplicate validation task_name '{name}'. Each entry under " + "data.validation must produce a unique task_name so that " + "per-dataset metrics can be reported under a distinct " + "validation-/ wandb prefix." + ) + seen_names.add(name) + val_dataset = { + name: AllTaskProcessedDataset( + raw, + tokenizer, + None, + val_task_data_processors, + task_data_preprocessors=val_task_data_preprocessors, + max_seq_length=data_config["max_input_seq_length"], + ) + for name, raw in val_data_pairs + } + total = sum(len(ds) for ds in val_dataset.values()) + print( + f" āœ“ Validation datasets loaded: {total} samples across {len(val_dataset)} dataset(s)." ) - print(f" āœ“ Validation dataset loaded with {len(val_dataset)} samples.") if has_envs: return dataset, val_dataset, task_to_env, val_task_to_env diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index 780e3459f0..0e2b276fec 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -95,13 +95,15 @@ def train_iter(self): train_dataloader.__iter__ = train_iter train_dataloader.__len__ = MagicMock(return_value=10) - val_dataloader = MagicMock(spec=StatefulDataLoader) + val_dataloader_single = MagicMock(spec=StatefulDataLoader) def val_iter(self): return iter([mock_batch] * 10) - val_dataloader.__iter__ = val_iter - val_dataloader.__len__ = MagicMock(return_value=10) + val_dataloader_single.__iter__ = val_iter + val_dataloader_single.__len__ = MagicMock(return_value=10) + # validate() now takes a dict[name -> StatefulDataLoader] for per-dataset metrics. + val_dataloader = {"test_dataset": val_dataloader_single} tokenizer = MagicMock() tokenizer.pad_token_id = 0 @@ -294,7 +296,99 @@ def test_validate_function(mock_components): assert isinstance(validation_timings, dict) # For distillation, we don't need environment interaction since max_rollout_turns=0 # The validation focuses on generation and teacher-student knowledge transfer - # Note: validate() function itself doesn't call logger.log_metrics - that's done by the caller + + +def _build_single_batch_loader(reward: float): + batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + {"token_ids": torch.tensor([1, 2]), "role": "user", "content": "q"}, + {"token_ids": torch.tensor([3, 4]), "role": "assistant", "content": "a"}, + ] + ], + "loss_multiplier": torch.tensor([1.0]), + "task_name": ["dummy"], + "extra_env_info": [{}], + "length": torch.tensor([4]), + "idx": torch.tensor([0]), + "total_reward": torch.tensor([reward]), + } + ) + dl = MagicMock(spec=StatefulDataLoader) + dl.__iter__ = MagicMock(side_effect=lambda: iter([batch])) + dl.__len__ = MagicMock(return_value=1) + return dl, batch + + +def test_validate_emits_per_dataset_prefixed_metrics_and_logs(mock_components): + """validate() splits metrics per dataset and logs each under validation-.""" + config = mock_components["master_config"] + config.distillation["val_batch_size"] = 1 + config.distillation["max_val_samples"] = 1 + + dl_a, batch_a = _build_single_batch_loader(1.0) + dl_b, batch_b = _build_single_batch_loader(0.0) + val_dataloader = {"gsm8k": dl_a, "math500": dl_b} + + rollout_returns = iter( + [ + (batch_a, {"mean_gen_tokens_per_sample": 10.0}), + (batch_b, {"mean_gen_tokens_per_sample": 12.0}), + ] + ) + logger = MagicMock() + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.side_effect = lambda *a, **k: next(rollout_returns) + val_metrics, validation_timings = validate( + mock_components["student_generation"], + val_dataloader, + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + logger=logger, + ) + + # Per-dataset metric keys present in the returned dict. + assert val_metrics["validation-gsm8k_accuracy"] == 1.0 + assert val_metrics["validation-math500_accuracy"] == 0.0 + # Aggregated key is the macro-mean across datasets (preserved for save_state). + assert val_metrics["accuracy"] == pytest.approx(0.5) + # Per-dataset wandb prefixes were used. + logged_prefixes = {call.kwargs.get("prefix") for call in logger.log_metrics.call_args_list} + assert "validation-gsm8k" in logged_prefixes + assert "validation-math500" in logged_prefixes + # Total validation time is also logged once. + assert "timing/validation" in logged_prefixes + assert "total_validation_time" in validation_timings + + +def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(mock_components): + """When max_val_samples is omitted, validate iterates the entire val_dataloader.""" + config = mock_components["master_config"] + config.distillation.pop("max_val_samples", None) + config.distillation["val_batch_size"] = 1 + + expected_batches = mock_components["val_dataloader"]["test_dataset"].__len__.return_value + sample_batch = next(iter(mock_components["val_dataloader"]["test_dataset"])) + + enriched = BatchedDataDict[DatumSpec](dict(sample_batch)) + enriched["total_reward"] = torch.tensor([1.0]) + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = (enriched, {"mean_gen_tokens_per_sample": 10.0}) + validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + assert mock_rollout.call_count == expected_batches def test_check_vocab_equality_pass(monkeypatch): diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index a6193f9f85..d945622c82 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -219,13 +219,15 @@ def train_iter(self): train_dataloader.__iter__ = train_iter train_dataloader.__len__ = MagicMock(return_value=10) - val_dataloader = MagicMock(spec=StatefulDataLoader) + val_dataloader_single = MagicMock(spec=StatefulDataLoader) def val_iter(self): return iter([mock_batch] * 10) - val_dataloader.__iter__ = val_iter - val_dataloader.__len__ = MagicMock(return_value=10) + val_dataloader_single.__iter__ = val_iter + val_dataloader_single.__len__ = MagicMock(return_value=10) + # validate() now takes a dict[name -> StatefulDataLoader] for per-dataset metrics. + val_dataloader = {"test_dataset": val_dataloader_single} tokenizer = MagicMock() tokenizer.pad_token_id = 0 @@ -2154,9 +2156,11 @@ def test_validate_logs_data_when_logger_provided(self, tmp_path): } ) - # Create mock dataloader that yields mock_batch - mock_dataloader = MagicMock(spec=StatefulDataLoader) - mock_dataloader.__iter__ = MagicMock(return_value=iter([mock_batch])) + # Create mock dataloader that yields mock_batch. validate() now takes a dict. + single_dl = MagicMock(spec=StatefulDataLoader) + single_dl.__iter__ = MagicMock(return_value=iter([mock_batch])) + single_dl.__len__ = MagicMock(return_value=1) + mock_dataloader = {"math": single_dl} # Create mock environment mock_env = MagicMock(spec=EnvironmentInterface) @@ -2262,9 +2266,11 @@ def test_validate_works_without_logger(self): } ) - # Create mock dataloader - mock_dataloader = MagicMock(spec=StatefulDataLoader) - mock_dataloader.__iter__ = MagicMock(return_value=iter([mock_batch])) + # Create mock dataloader. validate() now takes a dict. + single_dl = MagicMock(spec=StatefulDataLoader) + single_dl.__iter__ = MagicMock(return_value=iter([mock_batch])) + single_dl.__len__ = MagicMock(return_value=1) + mock_dataloader = {"math": single_dl} # Create mock environment mock_env = MagicMock(spec=EnvironmentInterface) @@ -2346,6 +2352,177 @@ def test_validate_returns_empty_when_no_dataloader(self): assert val_metrics == {} assert timing == {} + def test_validate_emits_per_dataset_prefixed_metrics_and_logs(self): + """validate() splits metrics per dataset and logs each under validation-.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + def make_batch(reward: float): + return BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + {"role": "user", "content": "q", "token_ids": torch.tensor([1, 2])}, + {"role": "assistant", "content": "a", "token_ids": torch.tensor([3, 4])}, + ], + ], + "task_name": ["dummy"], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([4]), + "total_reward": torch.tensor([reward]), + } + ) + + def make_loader(batch): + dl = MagicMock(spec=StatefulDataLoader) + dl.__iter__ = MagicMock(side_effect=lambda: iter([batch])) + dl.__len__ = MagicMock(return_value=1) + return dl + + batch_a = make_batch(1.0) + batch_b = make_batch(0.0) + val_dataloader = {"gsm8k": make_loader(batch_a), "math500": make_loader(batch_b)} + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = (batch_a, {}) + + mock_config = MasterConfig.model_construct( + **{ + "grpo": { + "max_val_samples": 1, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": False}, + }, + }, + "logger": {"num_val_samples_to_print": 1}, + } + ) + + rollout_returns = iter( + [ + (batch_a, {"mean_gen_tokens_per_sample": 10.0}), + (batch_b, {"mean_gen_tokens_per_sample": 12.0}), + ] + ) + logger = MagicMock() + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.side_effect = lambda *a, **k: next(rollout_returns) + with patch("nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=False, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + val_metrics, validation_timings = validate( + mock_policy_gen, + val_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=0, + master_config=mock_config, + logger=logger, + ) + + assert val_metrics["validation-gsm8k_accuracy"] == 1.0 + assert val_metrics["validation-math500_accuracy"] == 0.0 + # Aggregated key is the macro-mean across datasets (preserved for save_state). + assert val_metrics["accuracy"] == pytest.approx(0.5) + logged_prefixes = { + call.kwargs.get("prefix") for call in logger.log_metrics.call_args_list + } + assert "validation-gsm8k" in logged_prefixes + assert "validation-math500" in logged_prefixes + assert "timing/validation" in logged_prefixes + assert "total_validation_time" in validation_timings + + def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(self): + """When max_val_samples is None, validate iterates the entire val_dataloader.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + {"role": "user", "content": "q", "token_ids": torch.tensor([1, 2])}, + {"role": "assistant", "content": "a", "token_ids": torch.tensor([3, 4])}, + ], + ], + "task_name": ["dummy"], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([4]), + "total_reward": torch.tensor([1.0]), + } + ) + + num_batches = 3 + single_dl = MagicMock(spec=StatefulDataLoader) + single_dl.__iter__ = MagicMock(return_value=iter([mock_batch] * num_batches)) + single_dl.__len__ = MagicMock(return_value=num_batches) + mock_dataloader = {"only": single_dl} + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = (mock_batch, {}) + + mock_config = MasterConfig.model_construct( + **{ + "grpo": { + "max_val_samples": None, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": False}, + }, + }, + "logger": {"num_val_samples_to_print": 1}, + } + ) + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = (mock_batch, {"mean_gen_tokens_per_sample": 10.0}) + with patch("nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=False, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + validate( + mock_policy_gen, + mock_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=0, + master_config=mock_config, + logger=None, + ) + + assert mock_rollout.call_count == num_batches + # ============================================================================ # Tests for compute_and_apply_seq_logprob_error_masking function