diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index bbc54ab0c8..4a11a6b7f0 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -83,7 +83,7 @@ class DistillationConfig(TypedDict): # Whether to run validation on the last training step. Setting this to True ensures the # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). val_at_end: bool - max_val_samples: int + max_val_samples: NotRequired[int] topk_logits_k: int seed: int @@ -976,12 +976,16 @@ def validate( total_rewards = [] # Can be any metric. Setted to 'accuracy' by default. total_lengths = [] all_message_logs = [] # Collect all message logs + # Per-task rewards so multi-validation reports accuracy per dataset. + per_task_rewards: dict[str, list[float]] = {} - max_batches = ( - master_config.distillation["max_val_samples"] - + master_config.distillation["val_batch_size"] - - 1 - ) // master_config.distillation["val_batch_size"] + max_val_samples = master_config.distillation.get("max_val_samples") + if max_val_samples is None: + max_batches = len(val_dataloader) + else: + max_batches = ( + max_val_samples // master_config.distillation["val_batch_size"] + ) for batch_idx, val_batch in enumerate(val_dataloader): if batch_idx >= max_batches: break @@ -1009,10 +1013,19 @@ def validate( greedy=False, ) rewards = val_batch["total_reward"] + rewards_list = rewards.tolist() - total_rewards.extend(rewards.tolist()) + total_rewards.extend(rewards_list) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + # Skip samples without task_name (single-task / legacy datasets). + batch_task_names = val_batch.get("task_name") + if batch_task_names is not None: + for r, t in zip(rewards_list, batch_task_names): + if t is None: + continue + per_task_rewards.setdefault(t, []).append(r) + # Collect message logs for later display to_env = [ get_keys_from_message_log( @@ -1035,6 +1048,11 @@ def validate( "accuracy": accuracy, "avg_length": avg_length, } + for task_name, task_rewards in per_task_rewards.items(): + val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len( + task_rewards + ) + val_metrics[f"num_samples_{task_name}"] = len(task_rewards) # Print sample conversations only once at the end of validation try: @@ -1060,6 +1078,14 @@ def validate( print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") print(f" • Samples processed: {len(total_rewards)}", flush=True) + if per_task_rewards: + print(" • Per-task accuracy:") + for task_name in sorted(per_task_rewards.keys()): + tr = per_task_rewards[task_name] + print( + f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})", + flush=True, + ) # Print timing information print("\n ⏱️ Validation Timing:") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d26fb0bcae..9ba7c0913b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2287,11 +2287,14 @@ def validate( total_rewards = [] total_lengths = [] all_message_logs = [] # Collect all message logs + # Per-task rewards so multi-validation reports accuracy per dataset. + per_task_rewards: dict[str, list[float]] = {} - max_batches = ( - master_config.grpo["max_val_samples"] - // master_config.grpo["val_batch_size"] - ) + max_val_samples = master_config.grpo.get("max_val_samples") + if max_val_samples is None: + max_batches = len(val_dataloader) + else: + max_batches = max_val_samples // master_config.grpo["val_batch_size"] for batch_idx, val_batch in enumerate(val_dataloader): if batch_idx >= max_batches: break @@ -2336,9 +2339,18 @@ def validate( greedy=False, ) - total_rewards.extend(val_batch["total_reward"].tolist()) + rewards_list = val_batch["total_reward"].tolist() + total_rewards.extend(rewards_list) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + # Skip samples without task_name (single-task / legacy datasets). + batch_task_names = val_batch.get("task_name") + if batch_task_names is not None: + for r, t in zip(rewards_list, batch_task_names): + if t is None: + continue + per_task_rewards.setdefault(t, []).append(r) + # Collect message logs for later display to_env = [ get_keys_from_message_log( @@ -2366,6 +2378,11 @@ def validate( "avg_length": avg_length, **additional_metrics_to_report, } + for task_name, task_rewards in per_task_rewards.items(): + val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len( + task_rewards + ) + val_metrics[f"num_samples_{task_name}"] = len(task_rewards) # Print sample conversations only once at the end of validation try: @@ -2391,6 +2408,14 @@ def validate( print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") print(f" • Samples processed: {len(total_rewards)}", flush=True) + if per_task_rewards: + print(" • Per-task accuracy:") + for task_name in sorted(per_task_rewards.keys()): + tr = per_task_rewards[task_name] + print( + f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})", + flush=True, + ) # Print timing information print("\n ⏱️ Validation Timing:") diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index 780e3459f0..dc42a77f63 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -297,6 +297,113 @@ def test_validate_function(mock_components): # Note: validate() function itself doesn't call logger.log_metrics - that's done by the caller +def _rollout_return_for(batch): + enriched = BatchedDataDict[DatumSpec](dict(batch)) + enriched["total_reward"] = torch.tensor([1.0]) + return (enriched, {"mean_gen_tokens_per_sample": 10.0}) + + +def test_validate_iterates_full_dataloader_when_max_val_samples_is_none( + mock_components, +): + """When max_val_samples is omitted, validate iterates the entire val_dataloader.""" + config = mock_components["master_config"] + # Simulate a recipe that does not specify max_val_samples. + config.distillation.pop("max_val_samples", None) + config.distillation["val_batch_size"] = 1 + + expected_batches = mock_components["val_dataloader"].__len__.return_value + sample_batch = next(iter(mock_components["val_dataloader"])) + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = _rollout_return_for(sample_batch) + validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + assert mock_rollout.call_count == expected_batches + + +def test_validate_emits_per_task_accuracy_keys(mock_components): + """Multi-task validation produces accuracy_ and num_samples_ entries.""" + config = mock_components["master_config"] + config.distillation["max_val_samples"] = 4 + config.distillation["val_batch_size"] = 1 + + # Two batches per task; reward 1.0 for gsm8k, 0.0 for math500. + task_sequence = ["gsm8k", "math500", "gsm8k", "math500"] + reward_sequence = [1.0, 0.0, 1.0, 0.0] + + def make_batch(task, reward): + batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + {"token_ids": torch.tensor([1, 2]), "role": "user", "content": "q"}, + {"token_ids": torch.tensor([3, 4]), "role": "assistant", "content": "a"}, + ] + ], + "loss_multiplier": torch.tensor([1.0]), + "task_name": [task], + "extra_env_info": [{}], + "length": torch.tensor([4]), + "idx": torch.tensor([0]), + "total_reward": torch.tensor([reward]), + } + ) + return batch + + rollout_batches = [make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)] + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.side_effect = [ + (b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches + ] + val_metrics, _ = validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + assert val_metrics["accuracy_gsm8k"] == 1.0 + assert val_metrics["accuracy_math500"] == 0.0 + assert val_metrics["num_samples_gsm8k"] == 2 + assert val_metrics["num_samples_math500"] == 2 + # Aggregated key is preserved with the same definition (sample-weighted mean). + assert val_metrics["accuracy"] == pytest.approx(0.5) + + +def test_validate_floor_divides_max_val_samples_by_val_batch_size(mock_components): + """When max_val_samples is set, validate truncates with floor division (matches GRPO).""" + config = mock_components["master_config"] + config.distillation["max_val_samples"] = 7 + config.distillation["val_batch_size"] = 2 + + sample_batch = next(iter(mock_components["val_dataloader"])) + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = _rollout_return_for(sample_batch) + validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + # 7 // 2 == 3 batches; previous ceiling-division behaviour would have been 4. + assert mock_rollout.call_count == 3 + + def test_check_vocab_equality_pass(monkeypatch): student_tokenizer = MagicMock() student_tokenizer.get_vocab.return_value = {"a": 0, "b": 1} diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index a6193f9f85..4b3959b21a 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -2346,6 +2346,195 @@ def test_validate_returns_empty_when_no_dataloader(self): assert val_metrics == {} assert timing == {} + def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(self): + """When max_val_samples is None, validate iterates the entire val_dataloader.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "response", + "token_ids": torch.tensor([4, 5, 6]), + }, + ], + ], + "task_name": ["math"], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([6]), + "total_reward": torch.tensor([1.0]), + } + ) + + num_batches = 3 + mock_dataloader = MagicMock(spec=StatefulDataLoader) + mock_dataloader.__iter__ = MagicMock( + return_value=iter([mock_batch] * num_batches) + ) + mock_dataloader.__len__ = MagicMock(return_value=num_batches) + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = (mock_batch, {}) + + mock_config = MasterConfig.model_construct( + **{ + "grpo": { + "max_val_samples": None, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": False}, + }, + }, + "logger": {"num_val_samples_to_print": 1}, + } + ) + + mock_rollout_metrics = {"mean_gen_tokens_per_sample": 10.0} + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = (mock_batch, mock_rollout_metrics) + with patch( + "nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False + ): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=False, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + validate( + mock_policy_gen, + mock_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=0, + master_config=mock_config, + logger=None, + ) + + assert mock_rollout.call_count == num_batches + + def test_validate_emits_per_task_accuracy_keys(self): + """Multi-task validation produces accuracy_ and num_samples_ keys.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + def make_batch(task, reward): + return BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "q", + "token_ids": torch.tensor([1, 2]), + }, + { + "role": "assistant", + "content": "a", + "token_ids": torch.tensor([3, 4]), + }, + ], + ], + "task_name": [task], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([4]), + "total_reward": torch.tensor([reward]), + } + ) + + # Two batches per task; gsm8k all correct, math500 all wrong. + task_sequence = ["gsm8k", "math500", "gsm8k", "math500"] + reward_sequence = [1.0, 0.0, 1.0, 0.0] + rollout_batches = [ + make_batch(t, r) for t, r in zip(task_sequence, reward_sequence) + ] + + mock_dataloader = MagicMock(spec=StatefulDataLoader) + mock_dataloader.__iter__ = MagicMock( + return_value=iter([make_batch("placeholder", 0.0)] * len(rollout_batches)) + ) + mock_dataloader.__len__ = MagicMock(return_value=len(rollout_batches)) + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = ( + rollout_batches[0], + {}, + ) + + mock_config = MasterConfig.model_construct( + **{ + "grpo": { + "max_val_samples": 4, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": False}, + }, + }, + "logger": {"num_val_samples_to_print": 1}, + } + ) + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.side_effect = [ + (b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches + ] + with patch( + "nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False + ): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=False, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + val_metrics, _ = validate( + mock_policy_gen, + mock_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=0, + master_config=mock_config, + logger=None, + ) + + assert val_metrics["accuracy_gsm8k"] == 1.0 + assert val_metrics["accuracy_math500"] == 0.0 + assert val_metrics["num_samples_gsm8k"] == 2 + assert val_metrics["num_samples_math500"] == 2 + # Aggregated key is preserved with the same definition. + assert val_metrics["accuracy"] == pytest.approx(0.5) + # ============================================================================ # Tests for compute_and_apply_seq_logprob_error_masking function