Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DistillationConfig(TypedDict):
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
max_val_samples: int
max_val_samples: NotRequired[int]
topk_logits_k: int
seed: int

Expand Down Expand Up @@ -976,12 +976,16 @@ def validate(
total_rewards = [] # Can be any metric. Setted to 'accuracy' by default.
total_lengths = []
all_message_logs = [] # Collect all message logs
# Per-task rewards so multi-validation reports accuracy per dataset.
per_task_rewards: dict[str, list[float]] = {}

max_batches = (
master_config.distillation["max_val_samples"]
+ master_config.distillation["val_batch_size"]
- 1
) // master_config.distillation["val_batch_size"]
max_val_samples = master_config.distillation.get("max_val_samples")
if max_val_samples is None:
max_batches = len(val_dataloader)
else:
max_batches = (
max_val_samples // master_config.distillation["val_batch_size"]
)
for batch_idx, val_batch in enumerate(val_dataloader):
if batch_idx >= max_batches:
break
Expand Down Expand Up @@ -1009,10 +1013,19 @@ def validate(
greedy=False,
)
rewards = val_batch["total_reward"]
rewards_list = rewards.tolist()

total_rewards.extend(rewards.tolist())
total_rewards.extend(rewards_list)
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])

# Skip samples without task_name (single-task / legacy datasets).
batch_task_names = val_batch.get("task_name")
if batch_task_names is not None:
for r, t in zip(rewards_list, batch_task_names):
if t is None:
continue
per_task_rewards.setdefault(t, []).append(r)

# Collect message logs for later display
to_env = [
get_keys_from_message_log(
Expand All @@ -1035,6 +1048,11 @@ def validate(
"accuracy": accuracy,
"avg_length": avg_length,
}
for task_name, task_rewards in per_task_rewards.items():
val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len(
task_rewards
)
val_metrics[f"num_samples_{task_name}"] = len(task_rewards)

# Print sample conversations only once at the end of validation
try:
Expand All @@ -1060,6 +1078,14 @@ def validate(
print(f" • Accuracy: {accuracy:.4f}")
print(f" • Average response length: {avg_length:.1f} tokens")
print(f" • Samples processed: {len(total_rewards)}", flush=True)
if per_task_rewards:
print(" • Per-task accuracy:")
for task_name in sorted(per_task_rewards.keys()):
tr = per_task_rewards[task_name]
print(
f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})",
flush=True,
)

# Print timing information
print("\n ⏱️ Validation Timing:")
Expand Down
35 changes: 30 additions & 5 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2287,11 +2287,14 @@ def validate(
total_rewards = []
total_lengths = []
all_message_logs = [] # Collect all message logs
# Per-task rewards so multi-validation reports accuracy per dataset.
per_task_rewards: dict[str, list[float]] = {}

max_batches = (
master_config.grpo["max_val_samples"]
// master_config.grpo["val_batch_size"]
)
max_val_samples = master_config.grpo.get("max_val_samples")
if max_val_samples is None:
max_batches = len(val_dataloader)
else:
max_batches = max_val_samples // master_config.grpo["val_batch_size"]
for batch_idx, val_batch in enumerate(val_dataloader):
if batch_idx >= max_batches:
break
Expand Down Expand Up @@ -2336,9 +2339,18 @@ def validate(
greedy=False,
)

total_rewards.extend(val_batch["total_reward"].tolist())
rewards_list = val_batch["total_reward"].tolist()
total_rewards.extend(rewards_list)
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])

# Skip samples without task_name (single-task / legacy datasets).
batch_task_names = val_batch.get("task_name")
if batch_task_names is not None:
for r, t in zip(rewards_list, batch_task_names):
if t is None:
continue
per_task_rewards.setdefault(t, []).append(r)

# Collect message logs for later display
to_env = [
get_keys_from_message_log(
Expand Down Expand Up @@ -2366,6 +2378,11 @@ def validate(
"avg_length": avg_length,
**additional_metrics_to_report,
}
for task_name, task_rewards in per_task_rewards.items():
val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len(
task_rewards
)
val_metrics[f"num_samples_{task_name}"] = len(task_rewards)

# Print sample conversations only once at the end of validation
try:
Expand All @@ -2391,6 +2408,14 @@ def validate(
print(f" • Accuracy: {accuracy:.4f}")
print(f" • Average response length: {avg_length:.1f} tokens")
print(f" • Samples processed: {len(total_rewards)}", flush=True)
if per_task_rewards:
print(" • Per-task accuracy:")
for task_name in sorted(per_task_rewards.keys()):
tr = per_task_rewards[task_name]
print(
f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})",
flush=True,
)

# Print timing information
print("\n ⏱️ Validation Timing:")
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/algorithms/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,113 @@ def test_validate_function(mock_components):
# Note: validate() function itself doesn't call logger.log_metrics - that's done by the caller


def _rollout_return_for(batch):
enriched = BatchedDataDict[DatumSpec](dict(batch))
enriched["total_reward"] = torch.tensor([1.0])
return (enriched, {"mean_gen_tokens_per_sample": 10.0})


def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(
mock_components,
):
"""When max_val_samples is omitted, validate iterates the entire val_dataloader."""
config = mock_components["master_config"]
# Simulate a recipe that does not specify max_val_samples.
config.distillation.pop("max_val_samples", None)
config.distillation["val_batch_size"] = 1

expected_batches = mock_components["val_dataloader"].__len__.return_value
sample_batch = next(iter(mock_components["val_dataloader"]))

with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout:
mock_rollout.return_value = _rollout_return_for(sample_batch)
validate(
mock_components["student_generation"],
mock_components["val_dataloader"],
mock_components["tokenizer"],
mock_components["val_task_to_env"],
step=0,
master_config=config,
)

assert mock_rollout.call_count == expected_batches


def test_validate_emits_per_task_accuracy_keys(mock_components):
"""Multi-task validation produces accuracy_<task> and num_samples_<task> entries."""
config = mock_components["master_config"]
config.distillation["max_val_samples"] = 4
config.distillation["val_batch_size"] = 1

# Two batches per task; reward 1.0 for gsm8k, 0.0 for math500.
task_sequence = ["gsm8k", "math500", "gsm8k", "math500"]
reward_sequence = [1.0, 0.0, 1.0, 0.0]

def make_batch(task, reward):
batch = BatchedDataDict[DatumSpec](
{
"message_log": [
[
{"token_ids": torch.tensor([1, 2]), "role": "user", "content": "q"},
{"token_ids": torch.tensor([3, 4]), "role": "assistant", "content": "a"},
]
],
"loss_multiplier": torch.tensor([1.0]),
"task_name": [task],
"extra_env_info": [{}],
"length": torch.tensor([4]),
"idx": torch.tensor([0]),
"total_reward": torch.tensor([reward]),
}
)
return batch

rollout_batches = [make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)]

with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout:
mock_rollout.side_effect = [
(b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches
]
val_metrics, _ = validate(
mock_components["student_generation"],
mock_components["val_dataloader"],
mock_components["tokenizer"],
mock_components["val_task_to_env"],
step=0,
master_config=config,
)

assert val_metrics["accuracy_gsm8k"] == 1.0
assert val_metrics["accuracy_math500"] == 0.0
assert val_metrics["num_samples_gsm8k"] == 2
assert val_metrics["num_samples_math500"] == 2
# Aggregated key is preserved with the same definition (sample-weighted mean).
assert val_metrics["accuracy"] == pytest.approx(0.5)


def test_validate_floor_divides_max_val_samples_by_val_batch_size(mock_components):
"""When max_val_samples is set, validate truncates with floor division (matches GRPO)."""
config = mock_components["master_config"]
config.distillation["max_val_samples"] = 7
config.distillation["val_batch_size"] = 2

sample_batch = next(iter(mock_components["val_dataloader"]))

with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout:
mock_rollout.return_value = _rollout_return_for(sample_batch)
validate(
mock_components["student_generation"],
mock_components["val_dataloader"],
mock_components["tokenizer"],
mock_components["val_task_to_env"],
step=0,
master_config=config,
)

# 7 // 2 == 3 batches; previous ceiling-division behaviour would have been 4.
assert mock_rollout.call_count == 3


def test_check_vocab_equality_pass(monkeypatch):
student_tokenizer = MagicMock()
student_tokenizer.get_vocab.return_value = {"a": 0, "b": 1}
Expand Down
Loading
Loading