Commit f5fa7da
committed
feat: emit per-task validation accuracy in GRPO and Distillation
Multi-validation (data.validation as a list of datasets) currently runs
correctly but the validation aggregator collapses everything into a
single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs
math500) is silently lost.
task_name is already on every sample (DatumSpec.task_name preserved
through rl_collate_fn into val_batch["task_name"]); validate() simply
did not read it.
This commit teaches both validate() functions to track rewards per
task during the loop, then emit accuracy_<task> and num_samples_<task>
keys alongside the existing aggregated accuracy. logger.log_metrics
plots each as its own metric automatically.
The aggregated accuracy key is preserved unchanged for dashboard
backwards compatibility. Datasets without task_name are skipped, so
single-task and legacy recipes behave identically.
DPO already does per-dataset metrics via its dict-of-dataloaders
architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377),
so it is not touched here.
Tests:
* test_grpo.py adds test_validate_emits_per_task_accuracy_keys.
* test_distillation.py adds the same plus a check that the
aggregated accuracy key matches the sample-weighted mean across
tasks.
Signed-off-by: Minho Ryu <ryumin93@gmail.com>1 parent 870c987 commit f5fa7da
4 files changed
Lines changed: 212 additions & 2 deletions
File tree
- nemo_rl/algorithms
- tests/unit/algorithms
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
976 | 976 | | |
977 | 977 | | |
978 | 978 | | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
979 | 983 | | |
980 | 984 | | |
981 | 985 | | |
| |||
1011 | 1015 | | |
1012 | 1016 | | |
1013 | 1017 | | |
| 1018 | + | |
1014 | 1019 | | |
1015 | | - | |
| 1020 | + | |
1016 | 1021 | | |
1017 | 1022 | | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
1018 | 1033 | | |
1019 | 1034 | | |
1020 | 1035 | | |
| |||
1037 | 1052 | | |
1038 | 1053 | | |
1039 | 1054 | | |
| 1055 | + | |
| 1056 | + | |
| 1057 | + | |
| 1058 | + | |
| 1059 | + | |
1040 | 1060 | | |
1041 | 1061 | | |
1042 | 1062 | | |
| |||
1062 | 1082 | | |
1063 | 1083 | | |
1064 | 1084 | | |
| 1085 | + | |
| 1086 | + | |
| 1087 | + | |
| 1088 | + | |
| 1089 | + | |
| 1090 | + | |
| 1091 | + | |
| 1092 | + | |
1065 | 1093 | | |
1066 | 1094 | | |
1067 | 1095 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2287 | 2287 | | |
2288 | 2288 | | |
2289 | 2289 | | |
| 2290 | + | |
| 2291 | + | |
| 2292 | + | |
| 2293 | + | |
2290 | 2294 | | |
2291 | 2295 | | |
2292 | 2296 | | |
| |||
2337 | 2341 | | |
2338 | 2342 | | |
2339 | 2343 | | |
2340 | | - | |
| 2344 | + | |
| 2345 | + | |
2341 | 2346 | | |
2342 | 2347 | | |
| 2348 | + | |
| 2349 | + | |
| 2350 | + | |
| 2351 | + | |
| 2352 | + | |
| 2353 | + | |
| 2354 | + | |
| 2355 | + | |
| 2356 | + | |
| 2357 | + | |
2343 | 2358 | | |
2344 | 2359 | | |
2345 | 2360 | | |
| |||
2367 | 2382 | | |
2368 | 2383 | | |
2369 | 2384 | | |
| 2385 | + | |
| 2386 | + | |
| 2387 | + | |
| 2388 | + | |
| 2389 | + | |
2370 | 2390 | | |
2371 | 2391 | | |
2372 | 2392 | | |
| |||
2392 | 2412 | | |
2393 | 2413 | | |
2394 | 2414 | | |
| 2415 | + | |
| 2416 | + | |
| 2417 | + | |
| 2418 | + | |
| 2419 | + | |
| 2420 | + | |
| 2421 | + | |
| 2422 | + | |
2395 | 2423 | | |
2396 | 2424 | | |
2397 | 2425 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
329 | 329 | | |
330 | 330 | | |
331 | 331 | | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
332 | 384 | | |
333 | 385 | | |
334 | 386 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2433 | 2433 | | |
2434 | 2434 | | |
2435 | 2435 | | |
| 2436 | + | |
| 2437 | + | |
| 2438 | + | |
| 2439 | + | |
| 2440 | + | |
| 2441 | + | |
| 2442 | + | |
| 2443 | + | |
| 2444 | + | |
| 2445 | + | |
| 2446 | + | |
| 2447 | + | |
| 2448 | + | |
| 2449 | + | |
| 2450 | + | |
| 2451 | + | |
| 2452 | + | |
| 2453 | + | |
| 2454 | + | |
| 2455 | + | |
| 2456 | + | |
| 2457 | + | |
| 2458 | + | |
| 2459 | + | |
| 2460 | + | |
| 2461 | + | |
| 2462 | + | |
| 2463 | + | |
| 2464 | + | |
| 2465 | + | |
| 2466 | + | |
| 2467 | + | |
| 2468 | + | |
| 2469 | + | |
| 2470 | + | |
| 2471 | + | |
| 2472 | + | |
| 2473 | + | |
| 2474 | + | |
| 2475 | + | |
| 2476 | + | |
| 2477 | + | |
| 2478 | + | |
| 2479 | + | |
| 2480 | + | |
| 2481 | + | |
| 2482 | + | |
| 2483 | + | |
| 2484 | + | |
| 2485 | + | |
| 2486 | + | |
| 2487 | + | |
| 2488 | + | |
| 2489 | + | |
| 2490 | + | |
| 2491 | + | |
| 2492 | + | |
| 2493 | + | |
| 2494 | + | |
| 2495 | + | |
| 2496 | + | |
| 2497 | + | |
| 2498 | + | |
| 2499 | + | |
| 2500 | + | |
| 2501 | + | |
| 2502 | + | |
| 2503 | + | |
| 2504 | + | |
| 2505 | + | |
| 2506 | + | |
| 2507 | + | |
| 2508 | + | |
| 2509 | + | |
| 2510 | + | |
| 2511 | + | |
| 2512 | + | |
| 2513 | + | |
| 2514 | + | |
| 2515 | + | |
| 2516 | + | |
| 2517 | + | |
| 2518 | + | |
| 2519 | + | |
| 2520 | + | |
| 2521 | + | |
| 2522 | + | |
| 2523 | + | |
| 2524 | + | |
| 2525 | + | |
| 2526 | + | |
| 2527 | + | |
| 2528 | + | |
| 2529 | + | |
| 2530 | + | |
| 2531 | + | |
| 2532 | + | |
| 2533 | + | |
| 2534 | + | |
| 2535 | + | |
| 2536 | + | |
| 2537 | + | |
2436 | 2538 | | |
2437 | 2539 | | |
2438 | 2540 | | |
| |||
0 commit comments