Skip to content

Commit f5fa7da

Browse files
committed
feat: emit per-task validation accuracy in GRPO and Distillation
Multi-validation (data.validation as a list of datasets) currently runs correctly but the validation aggregator collapses everything into a single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs math500) is silently lost. task_name is already on every sample (DatumSpec.task_name preserved through rl_collate_fn into val_batch["task_name"]); validate() simply did not read it. This commit teaches both validate() functions to track rewards per task during the loop, then emit accuracy_<task> and num_samples_<task> keys alongside the existing aggregated accuracy. logger.log_metrics plots each as its own metric automatically. The aggregated accuracy key is preserved unchanged for dashboard backwards compatibility. Datasets without task_name are skipped, so single-task and legacy recipes behave identically. DPO already does per-dataset metrics via its dict-of-dataloaders architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377), so it is not touched here. Tests: * test_grpo.py adds test_validate_emits_per_task_accuracy_keys. * test_distillation.py adds the same plus a check that the aggregated accuracy key matches the sample-weighted mean across tasks. Signed-off-by: Minho Ryu <ryumin93@gmail.com>
1 parent 870c987 commit f5fa7da

4 files changed

Lines changed: 212 additions & 2 deletions

File tree

nemo_rl/algorithms/distillation.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,10 @@ def validate(
976976
total_rewards = [] # Can be any metric. Setted to 'accuracy' by default.
977977
total_lengths = []
978978
all_message_logs = [] # Collect all message logs
979+
# Per-task accuracy breakdown. Multi-validation concatenates several
980+
# datasets into one dataloader; without this split the aggregated
981+
# `accuracy` hides per-task progress (e.g. gsm8k vs math500).
982+
per_task_rewards: dict[str, list[float]] = {}
979983

980984
max_val_samples = master_config.distillation.get("max_val_samples")
981985
if max_val_samples is None:
@@ -1011,10 +1015,21 @@ def validate(
10111015
greedy=False,
10121016
)
10131017
rewards = val_batch["total_reward"]
1018+
rewards_list = rewards.tolist()
10141019

1015-
total_rewards.extend(rewards.tolist())
1020+
total_rewards.extend(rewards_list)
10161021
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])
10171022

1023+
# task_name is set per sample by rl_collate_fn from
1024+
# DatumSpec.task_name. Skip entries that lack it for backwards
1025+
# compatibility with single-task or legacy datasets.
1026+
batch_task_names = val_batch.get("task_name")
1027+
if batch_task_names is not None:
1028+
for r, t in zip(rewards_list, batch_task_names):
1029+
if t is None:
1030+
continue
1031+
per_task_rewards.setdefault(t, []).append(r)
1032+
10181033
# Collect message logs for later display
10191034
to_env = [
10201035
get_keys_from_message_log(
@@ -1037,6 +1052,11 @@ def validate(
10371052
"accuracy": accuracy,
10381053
"avg_length": avg_length,
10391054
}
1055+
for task_name, task_rewards in per_task_rewards.items():
1056+
val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len(
1057+
task_rewards
1058+
)
1059+
val_metrics[f"num_samples_{task_name}"] = len(task_rewards)
10401060

10411061
# Print sample conversations only once at the end of validation
10421062
try:
@@ -1062,6 +1082,14 @@ def validate(
10621082
print(f" • Accuracy: {accuracy:.4f}")
10631083
print(f" • Average response length: {avg_length:.1f} tokens")
10641084
print(f" • Samples processed: {len(total_rewards)}", flush=True)
1085+
if per_task_rewards:
1086+
print(" • Per-task accuracy:")
1087+
for task_name in sorted(per_task_rewards.keys()):
1088+
tr = per_task_rewards[task_name]
1089+
print(
1090+
f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})",
1091+
flush=True,
1092+
)
10651093

10661094
# Print timing information
10671095
print("\n ⏱️ Validation Timing:")

nemo_rl/algorithms/grpo.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2287,6 +2287,10 @@ def validate(
22872287
total_rewards = []
22882288
total_lengths = []
22892289
all_message_logs = [] # Collect all message logs
2290+
# Per-task accuracy breakdown. Multi-validation concatenates several
2291+
# datasets into one dataloader; without this split the aggregated
2292+
# `accuracy` hides per-task progress (e.g. gsm8k vs math500).
2293+
per_task_rewards: dict[str, list[float]] = {}
22902294

22912295
max_val_samples = master_config.grpo.get("max_val_samples")
22922296
if max_val_samples is None:
@@ -2337,9 +2341,20 @@ def validate(
23372341
greedy=False,
23382342
)
23392343

2340-
total_rewards.extend(val_batch["total_reward"].tolist())
2344+
rewards_list = val_batch["total_reward"].tolist()
2345+
total_rewards.extend(rewards_list)
23412346
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])
23422347

2348+
# task_name is set per sample by rl_collate_fn from
2349+
# DatumSpec.task_name. Skip entries that lack it for backwards
2350+
# compatibility with single-task or legacy datasets.
2351+
batch_task_names = val_batch.get("task_name")
2352+
if batch_task_names is not None:
2353+
for r, t in zip(rewards_list, batch_task_names):
2354+
if t is None:
2355+
continue
2356+
per_task_rewards.setdefault(t, []).append(r)
2357+
23432358
# Collect message logs for later display
23442359
to_env = [
23452360
get_keys_from_message_log(
@@ -2367,6 +2382,11 @@ def validate(
23672382
"avg_length": avg_length,
23682383
**additional_metrics_to_report,
23692384
}
2385+
for task_name, task_rewards in per_task_rewards.items():
2386+
val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len(
2387+
task_rewards
2388+
)
2389+
val_metrics[f"num_samples_{task_name}"] = len(task_rewards)
23702390

23712391
# Print sample conversations only once at the end of validation
23722392
try:
@@ -2392,6 +2412,14 @@ def validate(
23922412
print(f" • Accuracy: {accuracy:.4f}")
23932413
print(f" • Average response length: {avg_length:.1f} tokens")
23942414
print(f" • Samples processed: {len(total_rewards)}", flush=True)
2415+
if per_task_rewards:
2416+
print(" • Per-task accuracy:")
2417+
for task_name in sorted(per_task_rewards.keys()):
2418+
tr = per_task_rewards[task_name]
2419+
print(
2420+
f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})",
2421+
flush=True,
2422+
)
23952423

23962424
# Print timing information
23972425
print("\n ⏱️ Validation Timing:")

tests/unit/algorithms/test_distillation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,58 @@ def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(
329329
assert mock_rollout.call_count == expected_batches
330330

331331

332+
def test_validate_emits_per_task_accuracy_keys(mock_components):
333+
"""Multi-task validation produces accuracy_<task> and num_samples_<task> entries."""
334+
config = mock_components["master_config"]
335+
config.distillation["max_val_samples"] = 4
336+
config.distillation["val_batch_size"] = 1
337+
338+
# Two batches per task; reward 1.0 for gsm8k, 0.0 for math500.
339+
task_sequence = ["gsm8k", "math500", "gsm8k", "math500"]
340+
reward_sequence = [1.0, 0.0, 1.0, 0.0]
341+
342+
def make_batch(task, reward):
343+
batch = BatchedDataDict[DatumSpec](
344+
{
345+
"message_log": [
346+
[
347+
{"token_ids": torch.tensor([1, 2]), "role": "user", "content": "q"},
348+
{"token_ids": torch.tensor([3, 4]), "role": "assistant", "content": "a"},
349+
]
350+
],
351+
"loss_multiplier": torch.tensor([1.0]),
352+
"task_name": [task],
353+
"extra_env_info": [{}],
354+
"length": torch.tensor([4]),
355+
"idx": torch.tensor([0]),
356+
"total_reward": torch.tensor([reward]),
357+
}
358+
)
359+
return batch
360+
361+
rollout_batches = [make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)]
362+
363+
with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout:
364+
mock_rollout.side_effect = [
365+
(b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches
366+
]
367+
val_metrics, _ = validate(
368+
mock_components["student_generation"],
369+
mock_components["val_dataloader"],
370+
mock_components["tokenizer"],
371+
mock_components["val_task_to_env"],
372+
step=0,
373+
master_config=config,
374+
)
375+
376+
assert val_metrics["accuracy_gsm8k"] == 1.0
377+
assert val_metrics["accuracy_math500"] == 0.0
378+
assert val_metrics["num_samples_gsm8k"] == 2
379+
assert val_metrics["num_samples_math500"] == 2
380+
# Aggregated key is preserved with the same definition (sample-weighted mean).
381+
assert val_metrics["accuracy"] == pytest.approx(0.5)
382+
383+
332384
def test_validate_floor_divides_max_val_samples_by_val_batch_size(mock_components):
333385
"""When max_val_samples is set, validate truncates with floor division (matches GRPO)."""
334386
config = mock_components["master_config"]

tests/unit/algorithms/test_grpo.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,6 +2433,108 @@ def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(self):
24332433

24342434
assert mock_rollout.call_count == num_batches
24352435

2436+
def test_validate_emits_per_task_accuracy_keys(self):
2437+
"""Multi-task validation produces accuracy_<task> and num_samples_<task> keys."""
2438+
mock_policy_gen = MagicMock()
2439+
mock_tokenizer = MagicMock()
2440+
mock_tokenizer.pad_token_id = 0
2441+
2442+
def make_batch(task, reward):
2443+
return BatchedDataDict[DatumSpec](
2444+
{
2445+
"message_log": [
2446+
[
2447+
{
2448+
"role": "user",
2449+
"content": "q",
2450+
"token_ids": torch.tensor([1, 2]),
2451+
},
2452+
{
2453+
"role": "assistant",
2454+
"content": "a",
2455+
"token_ids": torch.tensor([3, 4]),
2456+
},
2457+
],
2458+
],
2459+
"task_name": [task],
2460+
"extra_env_info": [{}],
2461+
"loss_multiplier": torch.tensor([1.0]),
2462+
"idx": torch.tensor([0]),
2463+
"length": torch.tensor([4]),
2464+
"total_reward": torch.tensor([reward]),
2465+
}
2466+
)
2467+
2468+
# Two batches per task; gsm8k all correct, math500 all wrong.
2469+
task_sequence = ["gsm8k", "math500", "gsm8k", "math500"]
2470+
reward_sequence = [1.0, 0.0, 1.0, 0.0]
2471+
rollout_batches = [
2472+
make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)
2473+
]
2474+
2475+
mock_dataloader = MagicMock(spec=StatefulDataLoader)
2476+
mock_dataloader.__iter__ = MagicMock(
2477+
return_value=iter([make_batch("placeholder", 0.0)] * len(rollout_batches))
2478+
)
2479+
mock_dataloader.__len__ = MagicMock(return_value=len(rollout_batches))
2480+
2481+
mock_env = MagicMock(spec=EnvironmentInterface)
2482+
mock_env.global_post_process_and_metrics.return_value = (
2483+
rollout_batches[0],
2484+
{},
2485+
)
2486+
2487+
mock_config = MasterConfig.model_construct(
2488+
**{
2489+
"grpo": {
2490+
"max_val_samples": 4,
2491+
"val_batch_size": 1,
2492+
"max_rollout_turns": 1,
2493+
},
2494+
"policy": {
2495+
"max_total_sequence_length": 2048,
2496+
"generation": {
2497+
"temperature": 1.0,
2498+
"top_p": 1.0,
2499+
"top_k": None,
2500+
"backend": "vllm",
2501+
"colocated": {"enabled": True},
2502+
"vllm_cfg": {"async_engine": False},
2503+
},
2504+
},
2505+
"logger": {"num_val_samples_to_print": 1},
2506+
}
2507+
)
2508+
2509+
with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout:
2510+
mock_rollout.side_effect = [
2511+
(b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches
2512+
]
2513+
with patch(
2514+
"nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False
2515+
):
2516+
with patch(
2517+
"nemo_rl.algorithms.grpo._should_use_async_rollouts",
2518+
return_value=False,
2519+
):
2520+
with patch("nemo_rl.algorithms.grpo.print_message_log_samples"):
2521+
val_metrics, _ = validate(
2522+
mock_policy_gen,
2523+
mock_dataloader,
2524+
mock_tokenizer,
2525+
{"math": mock_env},
2526+
step=0,
2527+
master_config=mock_config,
2528+
logger=None,
2529+
)
2530+
2531+
assert val_metrics["accuracy_gsm8k"] == 1.0
2532+
assert val_metrics["accuracy_math500"] == 0.0
2533+
assert val_metrics["num_samples_gsm8k"] == 2
2534+
assert val_metrics["num_samples_math500"] == 2
2535+
# Aggregated key is preserved with the same definition.
2536+
assert val_metrics["accuracy"] == pytest.approx(0.5)
2537+
24362538

24372539
# ============================================================================
24382540
# Tests for compute_and_apply_seq_logprob_error_masking function

0 commit comments

Comments
 (0)