Skip to content

Commit ffac127

Browse files
committed
feat: emit per-task validation accuracy in GRPO and Distillation
Multi-validation (data.validation as a list of datasets) currently runs correctly but the validation aggregator collapses everything into a single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs math500) is silently lost. task_name is already on every sample (DatumSpec.task_name preserved through rl_collate_fn into val_batch["task_name"]); validate() simply did not read it. This commit teaches both validate() functions to track rewards per task during the loop, then emit accuracy_<task> and num_samples_<task> keys alongside the existing aggregated accuracy. logger.log_metrics plots each as its own metric automatically. The aggregated accuracy key is preserved unchanged for dashboard backwards compatibility. Datasets without task_name are skipped, so single-task and legacy recipes behave identically. DPO already does per-dataset metrics via its dict-of-dataloaders architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377), so it is not touched here. Tests: * test_grpo.py adds test_validate_emits_per_task_accuracy_keys. * test_distillation.py adds the same plus a check that the aggregated accuracy key matches the sample-weighted mean across tasks. Signed-off-by: Minho Ryu <ryumin93@gmail.com>
1 parent 870c987 commit ffac127

4 files changed

Lines changed: 204 additions & 2 deletions

File tree

nemo_rl/algorithms/distillation.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,8 @@ def validate(
976976
total_rewards = [] # Can be any metric. Setted to 'accuracy' by default.
977977
total_lengths = []
978978
all_message_logs = [] # Collect all message logs
979+
# Per-task rewards so multi-validation reports accuracy per dataset.
980+
per_task_rewards: dict[str, list[float]] = {}
979981

980982
max_val_samples = master_config.distillation.get("max_val_samples")
981983
if max_val_samples is None:
@@ -1011,10 +1013,19 @@ def validate(
10111013
greedy=False,
10121014
)
10131015
rewards = val_batch["total_reward"]
1016+
rewards_list = rewards.tolist()
10141017

1015-
total_rewards.extend(rewards.tolist())
1018+
total_rewards.extend(rewards_list)
10161019
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])
10171020

1021+
# Skip samples without task_name (single-task / legacy datasets).
1022+
batch_task_names = val_batch.get("task_name")
1023+
if batch_task_names is not None:
1024+
for r, t in zip(rewards_list, batch_task_names):
1025+
if t is None:
1026+
continue
1027+
per_task_rewards.setdefault(t, []).append(r)
1028+
10181029
# Collect message logs for later display
10191030
to_env = [
10201031
get_keys_from_message_log(
@@ -1037,6 +1048,11 @@ def validate(
10371048
"accuracy": accuracy,
10381049
"avg_length": avg_length,
10391050
}
1051+
for task_name, task_rewards in per_task_rewards.items():
1052+
val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len(
1053+
task_rewards
1054+
)
1055+
val_metrics[f"num_samples_{task_name}"] = len(task_rewards)
10401056

10411057
# Print sample conversations only once at the end of validation
10421058
try:
@@ -1062,6 +1078,14 @@ def validate(
10621078
print(f" • Accuracy: {accuracy:.4f}")
10631079
print(f" • Average response length: {avg_length:.1f} tokens")
10641080
print(f" • Samples processed: {len(total_rewards)}", flush=True)
1081+
if per_task_rewards:
1082+
print(" • Per-task accuracy:")
1083+
for task_name in sorted(per_task_rewards.keys()):
1084+
tr = per_task_rewards[task_name]
1085+
print(
1086+
f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})",
1087+
flush=True,
1088+
)
10651089

10661090
# Print timing information
10671091
print("\n ⏱️ Validation Timing:")

nemo_rl/algorithms/grpo.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2287,6 +2287,8 @@ def validate(
22872287
total_rewards = []
22882288
total_lengths = []
22892289
all_message_logs = [] # Collect all message logs
2290+
# Per-task rewards so multi-validation reports accuracy per dataset.
2291+
per_task_rewards: dict[str, list[float]] = {}
22902292

22912293
max_val_samples = master_config.grpo.get("max_val_samples")
22922294
if max_val_samples is None:
@@ -2337,9 +2339,18 @@ def validate(
23372339
greedy=False,
23382340
)
23392341

2340-
total_rewards.extend(val_batch["total_reward"].tolist())
2342+
rewards_list = val_batch["total_reward"].tolist()
2343+
total_rewards.extend(rewards_list)
23412344
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])
23422345

2346+
# Skip samples without task_name (single-task / legacy datasets).
2347+
batch_task_names = val_batch.get("task_name")
2348+
if batch_task_names is not None:
2349+
for r, t in zip(rewards_list, batch_task_names):
2350+
if t is None:
2351+
continue
2352+
per_task_rewards.setdefault(t, []).append(r)
2353+
23432354
# Collect message logs for later display
23442355
to_env = [
23452356
get_keys_from_message_log(
@@ -2367,6 +2378,11 @@ def validate(
23672378
"avg_length": avg_length,
23682379
**additional_metrics_to_report,
23692380
}
2381+
for task_name, task_rewards in per_task_rewards.items():
2382+
val_metrics[f"accuracy_{task_name}"] = sum(task_rewards) / len(
2383+
task_rewards
2384+
)
2385+
val_metrics[f"num_samples_{task_name}"] = len(task_rewards)
23702386

23712387
# Print sample conversations only once at the end of validation
23722388
try:
@@ -2392,6 +2408,14 @@ def validate(
23922408
print(f" • Accuracy: {accuracy:.4f}")
23932409
print(f" • Average response length: {avg_length:.1f} tokens")
23942410
print(f" • Samples processed: {len(total_rewards)}", flush=True)
2411+
if per_task_rewards:
2412+
print(" • Per-task accuracy:")
2413+
for task_name in sorted(per_task_rewards.keys()):
2414+
tr = per_task_rewards[task_name]
2415+
print(
2416+
f" - {task_name}: {sum(tr) / len(tr):.4f} (n={len(tr)})",
2417+
flush=True,
2418+
)
23952419

23962420
# Print timing information
23972421
print("\n ⏱️ Validation Timing:")

tests/unit/algorithms/test_distillation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,58 @@ def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(
329329
assert mock_rollout.call_count == expected_batches
330330

331331

332+
def test_validate_emits_per_task_accuracy_keys(mock_components):
333+
"""Multi-task validation produces accuracy_<task> and num_samples_<task> entries."""
334+
config = mock_components["master_config"]
335+
config.distillation["max_val_samples"] = 4
336+
config.distillation["val_batch_size"] = 1
337+
338+
# Two batches per task; reward 1.0 for gsm8k, 0.0 for math500.
339+
task_sequence = ["gsm8k", "math500", "gsm8k", "math500"]
340+
reward_sequence = [1.0, 0.0, 1.0, 0.0]
341+
342+
def make_batch(task, reward):
343+
batch = BatchedDataDict[DatumSpec](
344+
{
345+
"message_log": [
346+
[
347+
{"token_ids": torch.tensor([1, 2]), "role": "user", "content": "q"},
348+
{"token_ids": torch.tensor([3, 4]), "role": "assistant", "content": "a"},
349+
]
350+
],
351+
"loss_multiplier": torch.tensor([1.0]),
352+
"task_name": [task],
353+
"extra_env_info": [{}],
354+
"length": torch.tensor([4]),
355+
"idx": torch.tensor([0]),
356+
"total_reward": torch.tensor([reward]),
357+
}
358+
)
359+
return batch
360+
361+
rollout_batches = [make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)]
362+
363+
with patch.object(distil_mod, "run_multi_turn_rollout") as mock_rollout:
364+
mock_rollout.side_effect = [
365+
(b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches
366+
]
367+
val_metrics, _ = validate(
368+
mock_components["student_generation"],
369+
mock_components["val_dataloader"],
370+
mock_components["tokenizer"],
371+
mock_components["val_task_to_env"],
372+
step=0,
373+
master_config=config,
374+
)
375+
376+
assert val_metrics["accuracy_gsm8k"] == 1.0
377+
assert val_metrics["accuracy_math500"] == 0.0
378+
assert val_metrics["num_samples_gsm8k"] == 2
379+
assert val_metrics["num_samples_math500"] == 2
380+
# Aggregated key is preserved with the same definition (sample-weighted mean).
381+
assert val_metrics["accuracy"] == pytest.approx(0.5)
382+
383+
332384
def test_validate_floor_divides_max_val_samples_by_val_batch_size(mock_components):
333385
"""When max_val_samples is set, validate truncates with floor division (matches GRPO)."""
334386
config = mock_components["master_config"]

tests/unit/algorithms/test_grpo.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,6 +2433,108 @@ def test_validate_iterates_full_dataloader_when_max_val_samples_is_none(self):
24332433

24342434
assert mock_rollout.call_count == num_batches
24352435

2436+
def test_validate_emits_per_task_accuracy_keys(self):
2437+
"""Multi-task validation produces accuracy_<task> and num_samples_<task> keys."""
2438+
mock_policy_gen = MagicMock()
2439+
mock_tokenizer = MagicMock()
2440+
mock_tokenizer.pad_token_id = 0
2441+
2442+
def make_batch(task, reward):
2443+
return BatchedDataDict[DatumSpec](
2444+
{
2445+
"message_log": [
2446+
[
2447+
{
2448+
"role": "user",
2449+
"content": "q",
2450+
"token_ids": torch.tensor([1, 2]),
2451+
},
2452+
{
2453+
"role": "assistant",
2454+
"content": "a",
2455+
"token_ids": torch.tensor([3, 4]),
2456+
},
2457+
],
2458+
],
2459+
"task_name": [task],
2460+
"extra_env_info": [{}],
2461+
"loss_multiplier": torch.tensor([1.0]),
2462+
"idx": torch.tensor([0]),
2463+
"length": torch.tensor([4]),
2464+
"total_reward": torch.tensor([reward]),
2465+
}
2466+
)
2467+
2468+
# Two batches per task; gsm8k all correct, math500 all wrong.
2469+
task_sequence = ["gsm8k", "math500", "gsm8k", "math500"]
2470+
reward_sequence = [1.0, 0.0, 1.0, 0.0]
2471+
rollout_batches = [
2472+
make_batch(t, r) for t, r in zip(task_sequence, reward_sequence)
2473+
]
2474+
2475+
mock_dataloader = MagicMock(spec=StatefulDataLoader)
2476+
mock_dataloader.__iter__ = MagicMock(
2477+
return_value=iter([make_batch("placeholder", 0.0)] * len(rollout_batches))
2478+
)
2479+
mock_dataloader.__len__ = MagicMock(return_value=len(rollout_batches))
2480+
2481+
mock_env = MagicMock(spec=EnvironmentInterface)
2482+
mock_env.global_post_process_and_metrics.return_value = (
2483+
rollout_batches[0],
2484+
{},
2485+
)
2486+
2487+
mock_config = MasterConfig.model_construct(
2488+
**{
2489+
"grpo": {
2490+
"max_val_samples": 4,
2491+
"val_batch_size": 1,
2492+
"max_rollout_turns": 1,
2493+
},
2494+
"policy": {
2495+
"max_total_sequence_length": 2048,
2496+
"generation": {
2497+
"temperature": 1.0,
2498+
"top_p": 1.0,
2499+
"top_k": None,
2500+
"backend": "vllm",
2501+
"colocated": {"enabled": True},
2502+
"vllm_cfg": {"async_engine": False},
2503+
},
2504+
},
2505+
"logger": {"num_val_samples_to_print": 1},
2506+
}
2507+
)
2508+
2509+
with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout:
2510+
mock_rollout.side_effect = [
2511+
(b, {"mean_gen_tokens_per_sample": 4.0}) for b in rollout_batches
2512+
]
2513+
with patch(
2514+
"nemo_rl.algorithms.grpo._should_use_nemo_gym", return_value=False
2515+
):
2516+
with patch(
2517+
"nemo_rl.algorithms.grpo._should_use_async_rollouts",
2518+
return_value=False,
2519+
):
2520+
with patch("nemo_rl.algorithms.grpo.print_message_log_samples"):
2521+
val_metrics, _ = validate(
2522+
mock_policy_gen,
2523+
mock_dataloader,
2524+
mock_tokenizer,
2525+
{"math": mock_env},
2526+
step=0,
2527+
master_config=mock_config,
2528+
logger=None,
2529+
)
2530+
2531+
assert val_metrics["accuracy_gsm8k"] == 1.0
2532+
assert val_metrics["accuracy_math500"] == 0.0
2533+
assert val_metrics["num_samples_gsm8k"] == 2
2534+
assert val_metrics["num_samples_math500"] == 2
2535+
# Aggregated key is preserved with the same definition.
2536+
assert val_metrics["accuracy"] == pytest.approx(0.5)
2537+
24362538

24372539
# ============================================================================
24382540
# Tests for compute_and_apply_seq_logprob_error_masking function

0 commit comments

Comments
 (0)