From 870c987218d5e216e0a174c3506ec30c322054b9 Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Fri, 15 May 2026 13:33:21 +0900 Subject: [PATCH 1/2] feat: make max_val_samples optional, unify GRPO/Distillation truncation Two related changes to the validation truncation logic. 1. Make max_val_samples optional. When the field is absent or set to None in the recipe, validate() now iterates the entire val_dataloader. * GRPO already typed it as `int | None # None for NeMo-Gym compatibility` but the main validation path crashed when reading None. Patch the read site so the main path matches the type. * Distillation widens the TypedDict from `int` to `NotRequired[int]` and applies the same read-site change. The exemplar YAMLs (examples/configs/grpo_math_1B.yaml and examples/configs/distillation_math.yaml) keep their explicit values so the recommended default is still documented. 2. Unify Distillation truncation with GRPO. GRPO uses floor division (max_val_samples // val_batch_size); Distillation used ceiling division ((max_val_samples + val_batch_size - 1) // val_batch_size). With the new None-handling branch already in place, switch Distillation to floor division so the two algorithms behave identically when the field is set. Behaviour impact for existing recipes: only Distillation runs whose max_val_samples is not divisible by val_batch_size see fewer samples evaluated by one partial batch. Recipes in examples/configs/recipes/llm all use values that divide cleanly (256/8, 512/8 etc.), so no recipe under examples/ is affected. Recipes that previously set an integer that divides cleanly remain identical; recipes that previously omitted the field could not run at all and now do. Tests: * tests/unit/algorithms/test_grpo.py adds test_validate_iterates_full_dataloader_when_max_val_samples_is_none * tests/unit/algorithms/test_distillation.py adds the same plus test_validate_floor_divides_max_val_samples_by_val_batch_size to guard the GRPO/Distillation parity. Signed-off-by: Minho Ryu --- nemo_rl/algorithms/distillation.py | 14 ++-- nemo_rl/algorithms/grpo.py | 9 ++- tests/unit/algorithms/test_distillation.py | 55 ++++++++++++++ tests/unit/algorithms/test_grpo.py | 87 ++++++++++++++++++++++ 4 files changed, 155 insertions(+), 10 deletions(-) diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index bbc54ab0c8..a9386e100c 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -83,7 +83,7 @@ class DistillationConfig(TypedDict): # Whether to run validation on the last training step. Setting this to True ensures the # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). val_at_end: bool - max_val_samples: int + max_val_samples: NotRequired[int] topk_logits_k: int seed: int @@ -977,11 +977,13 @@ def validate( total_lengths = [] all_message_logs = [] # Collect all message logs - max_batches = ( - master_config.distillation["max_val_samples"] - + master_config.distillation["val_batch_size"] - - 1 - ) // master_config.distillation["val_batch_size"] + max_val_samples = master_config.distillation.get("max_val_samples") + if max_val_samples is None: + max_batches = len(val_dataloader) + else: + max_batches = ( + max_val_samples // master_config.distillation["val_batch_size"] + ) for batch_idx, val_batch in enumerate(val_dataloader): if batch_idx >= max_batches: break diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d26fb0bcae..9cdc1a845e 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2288,10 +2288,11 @@ def validate( total_lengths = [] all_message_logs = [] # Collect all message logs - max_batches = ( - master_config.grpo["max_val_samples"] - // master_config.grpo["val_batch_size"] - ) + max_val_samples = master_config.grpo.get("max_val_samples") + if max_val_samples is None: + max_batches = len(val_dataloader) + else: + max_batches = max_val_samples // master_config.grpo["val_batch_size"] for batch_idx, val_batch in enumerate(val_dataloader): if batch_idx >= max_batches: break diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index 780e3459f0..3c15e18ca0 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -297,6 +297,61 @@ def test_validate_function(mock_components): # Note: validate() function itself doesn't call logger.log_metrics - that's done by the caller +def _rollout_return_for(batch): + enriched = BatchedDataDict[DatumSpec](dict(batch)) + enriched["total_reward"] = torch.tensor([1.0]) + return (enriched, {"mean_gen_tokens_per_sample": 10.0}) + + +def test_validate_iterates_full_dataloader_when_max_val_samples_is_none( + mock_components, +): + """When max_val_samples is omitted, validate iterates the entire val_dataloader.""" + config = mock_components["master_config"] + # Simulate a recipe that does not specify max_val_samples. + config.distillation.pop("max_val_samples", None) + config.distillation["val_batch_size"] = 1 + + expected_batches = mock_components["val_dataloader"].__len__.return_value + sample_batch = next(iter(mock_components["val_dataloader"])) + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = _rollout_return_for(sample_batch) + validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + assert mock_rollout.call_count == expected_batches + + +def test_validate_floor_divides_max_val_samples_by_val_batch_size(mock_components): + """When max_val_samples is set, validate truncates with floor division (matches GRPO).""" + config = mock_components["master_config"] + config.distillation["max_val_samples"] = 7 + config.distillation["val_batch_size"] = 2 + + sample_batch = next(iter(mock_components["val_dataloader"])) + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = _rollout_return_for(sample_batch) + validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + # 7 // 2 == 3 batches; previous ceiling-division behaviour would have been 4. + assert mock_rollout.call_count == 3 + + def test_check_vocab_equality_pass(monkeypatch): student_tokenizer = MagicMock() student_tokenizer.get_vocab.return_value = {"a": 0, "b": 1} diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index a6193f9f85..b0159aef99 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -2346,6 +2346,93 @@ def test_validate_returns_empty_when_no_dataloader(self): assert val_metrics == {} assert timing == {} + def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(self): + """When max_val_samples is None, validate iterates the entire val_dataloader.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "assistant", + "content": "response", + "token_ids": torch.tensor([4, 5, 6]), + }, + ], + ], + "task_name": ["math"], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([6]), + "total_reward": torch.tensor([1.0]), + } + ) + + num_batches = 3 + mock_dataloader = MagicMock(spec=StatefulDataLoader) + mock_dataloader.__iter__ = MagicMock( + return_value=iter([mock_batch] * num_batches) + ) + mock_dataloader.__len__ = MagicMock(return_value=num_batches) + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = (mock_batch, {}) + + mock_config = MasterConfig.model_construct( + **{ + "grpo": { + "max_val_samples": None, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": False}, + }, + }, + "logger": {"num_val_samples_to_print": 1}, + } + ) + + mock_rollout_metrics = {"mean_gen_tokens_per_sample": 10.0} + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = (mock_batch, mock_rollout_metrics) + with patch( + "nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False + ): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=False, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + validate( + mock_policy_gen, + mock_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=0, + master_config=mock_config, + logger=None, + ) + + assert mock_rollout.call_count == num_batches + # ============================================================================ # Tests for compute_and_apply_seq_logprob_error_masking function From ffac12759508afb0630cce2da8ba0467ceccecbf Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Fri, 15 May 2026 13:57:36 +0900 Subject: [PATCH 2/2] feat: emit per-task validation accuracy in GRPO and Distillation Multi-validation (data.validation as a list of datasets) currently runs correctly but the validation aggregator collapses everything into a single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs math500) is silently lost. task_name is already on every sample (DatumSpec.task_name preserved through rl_collate_fn into val_batch["task_name"]); validate() simply did not read it. This commit teaches both validate() functions to track rewards per task during the loop, then emit accuracy_ and num_samples_ keys alongside the existing aggregated accuracy. logger.log_metrics plots each as its own metric automatically. The aggregated accuracy key is preserved unchanged for dashboard backwards compatibility. Datasets without task_name are skipped, so single-task and legacy recipes behave identically. DPO already does per-dataset metrics via its dict-of-dataloaders architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377), so it is not touched here. Tests: * test_grpo.py adds test_validate_emits_per_task_accuracy_keys. * test_distillation.py adds the same plus a check that the aggregated accuracy key matches the sample-weighted mean across tasks. Signed-off-by: Minho Ryu --- nemo_rl/algorithms/distillation.py | 26 +++++- nemo_rl/algorithms/grpo.py | 26 +++++- tests/unit/algorithms/test_distillation.py | 52 +++++++++++ tests/unit/algorithms/test_grpo.py | 102 +++++++++++++++++++++ 4 files changed, 204 insertions(+), 2 deletions(-) diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index a9386e100c..4a11a6b7f0 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -976,6 +976,8 @@ def validate( total_rewards = [] # Can be any metric. Setted to 'accuracy' by default. total_lengths = [] all_message_logs = [] # Collect all message logs + # Per-task rewards so multi-validation reports accuracy per dataset. + per_task_rewards: dict[str, list[float]] = {} max_val_samples = master_config.distillation.get("max_val_samples") if max_val_samples is None: @@ -1011,10 +1013,19 @@ def validate( greedy=False, ) rewards = val_batch["total_reward"] + rewards_list = rewards.tolist() - total_rewards.extend(rewards.tolist()) + total_rewards.extend(rewards_list) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + # Skip samples without task_name (single-task / legacy datasets). + batch_task_names = val_batch.get("task_name") + if batch_task_names is not None: + for r, t in zip(rewards_list, batch_task_names): + if t is None: + continue + per_task_rewards.setdefault(t, []).append(r) + # Collect message logs for later display to_env = [ get_keys_from_message_log( @@ -1037,6 +1048,11 @@ def validate( "accuracy": accuracy, "avg_length": avg_length, } + for task_name, task_rewards in per_task_rewards.items(): + val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len( + task_rewards + ) + val_metrics[f"num_samples_{task_name}"] = len(task_rewards) # Print sample conversations only once at the end of validation try: @@ -1062,6 +1078,14 @@ def validate( print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") print(f" • Samples processed: {len(total_rewards)}", flush=True) + if per_task_rewards: + print(" • Per-task accuracy:") + for task_name in sorted(per_task_rewards.keys()): + tr = per_task_rewards[task_name] + print( + f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})", + flush=True, + ) # Print timing information print("\n ⏱️ Validation Timing:") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 9cdc1a845e..9ba7c0913b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2287,6 +2287,8 @@ def validate( total_rewards = [] total_lengths = [] all_message_logs = [] # Collect all message logs + # Per-task rewards so multi-validation reports accuracy per dataset. + per_task_rewards: dict[str, list[float]] = {} max_val_samples = master_config.grpo.get("max_val_samples") if max_val_samples is None: @@ -2337,9 +2339,18 @@ def validate( greedy=False, ) - total_rewards.extend(val_batch["total_reward"].tolist()) + rewards_list = val_batch["total_reward"].tolist() + total_rewards.extend(rewards_list) total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + # Skip samples without task_name (single-task / legacy datasets). + batch_task_names = val_batch.get("task_name") + if batch_task_names is not None: + for r, t in zip(rewards_list, batch_task_names): + if t is None: + continue + per_task_rewards.setdefault(t, []).append(r) + # Collect message logs for later display to_env = [ get_keys_from_message_log( @@ -2367,6 +2378,11 @@ def validate( "avg_length": avg_length, **additional_metrics_to_report, } + for task_name, task_rewards in per_task_rewards.items(): + val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len( + task_rewards + ) + val_metrics[f"num_samples_{task_name}"] = len(task_rewards) # Print sample conversations only once at the end of validation try: @@ -2392,6 +2408,14 @@ def validate( print(f" • Accuracy: {accuracy:.4f}") print(f" • Average response length: {avg_length:.1f} tokens") print(f" • Samples processed: {len(total_rewards)}", flush=True) + if per_task_rewards: + print(" • Per-task accuracy:") + for task_name in sorted(per_task_rewards.keys()): + tr = per_task_rewards[task_name] + print( + f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})", + flush=True, + ) # Print timing information print("\n ⏱️ Validation Timing:") diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index 3c15e18ca0..dc42a77f63 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -329,6 +329,58 @@ def test_validate_iterates_full_dataloader_when_max_val_samples_is_none( assert mock_rollout.call_count == expected_batches +def test_validate_emits_per_task_accuracy_keys(mock_components): + """Multi-task validation produces accuracy_ and num_samples_ entries.""" + config = mock_components["master_config"] + config.distillation["max_val_samples"] = 4 + config.distillation["val_batch_size"] = 1 + + # Two batches per task; reward 1.0 for gsm8k, 0.0 for math500. + task_sequence = ["gsm8k", "math500", "gsm8k", "math500"] + reward_sequence = [1.0, 0.0, 1.0, 0.0] + + def make_batch(task, reward): + batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + {"token_ids": torch.tensor([1, 2]), "role": "user", "content": "q"}, + {"token_ids": torch.tensor([3, 4]), "role": "assistant", "content": "a"}, + ] + ], + "loss_multiplier": torch.tensor([1.0]), + "task_name": [task], + "extra_env_info": [{}], + "length": torch.tensor([4]), + "idx": torch.tensor([0]), + "total_reward": torch.tensor([reward]), + } + ) + return batch + + rollout_batches = [make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)] + + with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout: + mock_rollout.side_effect = [ + (b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches + ] + val_metrics, _ = validate( + mock_components["student_generation"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["val_task_to_env"], + step=0, + master_config=config, + ) + + assert val_metrics["accuracy_gsm8k"] == 1.0 + assert val_metrics["accuracy_math500"] == 0.0 + assert val_metrics["num_samples_gsm8k"] == 2 + assert val_metrics["num_samples_math500"] == 2 + # Aggregated key is preserved with the same definition (sample-weighted mean). + assert val_metrics["accuracy"] == pytest.approx(0.5) + + def test_validate_floor_divides_max_val_samples_by_val_batch_size(mock_components): """When max_val_samples is set, validate truncates with floor division (matches GRPO).""" config = mock_components["master_config"] diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index b0159aef99..4b3959b21a 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -2433,6 +2433,108 @@ def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(self): assert mock_rollout.call_count == num_batches + def test_validate_emits_per_task_accuracy_keys(self): + """Multi-task validation produces accuracy_ and num_samples_ keys.""" + mock_policy_gen = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token_id = 0 + + def make_batch(task, reward): + return BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "q", + "token_ids": torch.tensor([1, 2]), + }, + { + "role": "assistant", + "content": "a", + "token_ids": torch.tensor([3, 4]), + }, + ], + ], + "task_name": [task], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([4]), + "total_reward": torch.tensor([reward]), + } + ) + + # Two batches per task; gsm8k all correct, math500 all wrong. + task_sequence = ["gsm8k", "math500", "gsm8k", "math500"] + reward_sequence = [1.0, 0.0, 1.0, 0.0] + rollout_batches = [ + make_batch(t, r) for t, r in zip(task_sequence, reward_sequence) + ] + + mock_dataloader = MagicMock(spec=StatefulDataLoader) + mock_dataloader.__iter__ = MagicMock( + return_value=iter([make_batch("placeholder", 0.0)] * len(rollout_batches)) + ) + mock_dataloader.__len__ = MagicMock(return_value=len(rollout_batches)) + + mock_env = MagicMock(spec=EnvironmentInterface) + mock_env.global_post_process_and_metrics.return_value = ( + rollout_batches[0], + {}, + ) + + mock_config = MasterConfig.model_construct( + **{ + "grpo": { + "max_val_samples": 4, + "val_batch_size": 1, + "max_rollout_turns": 1, + }, + "policy": { + "max_total_sequence_length": 2048, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": False}, + }, + }, + "logger": {"num_val_samples_to_print": 1}, + } + ) + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.side_effect = [ + (b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches + ] + with patch( + "nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False + ): + with patch( + "nemo_rl.algorithms.grpo._should_use_async_rollouts", + return_value=False, + ): + with patch("nemo_rl.algorithms.grpo.print_message_log_samples"): + val_metrics, _ = validate( + mock_policy_gen, + mock_dataloader, + mock_tokenizer, + {"math": mock_env}, + step=0, + master_config=mock_config, + logger=None, + ) + + assert val_metrics["accuracy_gsm8k"] == 1.0 + assert val_metrics["accuracy_math500"] == 0.0 + assert val_metrics["num_samples_gsm8k"] == 2 + assert val_metrics["num_samples_math500"] == 2 + # Aggregated key is preserved with the same definition. + assert val_metrics["accuracy"] == pytest.approx(0.5) + # ============================================================================ # Tests for compute_and_apply_seq_logprob_error_masking function