Commit ffac127
committed
feat: emit per-task validation accuracy in GRPO and Distillation
Multi-validation (data.validation as a list of datasets) currently runs
correctly but the validation aggregator collapses everything into a
single sample-weighted accuracy. Per-task progress (e.g. gsm8k vs
math500) is silently lost.
task_name is already on every sample (DatumSpec.task_name preserved
through rl_collate_fn into val_batch["task_name"]); validate() simply
did not read it.
This commit teaches both validate() functions to track rewards per
task during the loop, then emit accuracy_<task> and num_samples_<task>
keys alongside the existing aggregated accuracy. logger.log_metrics
plots each as its own metric automatically.
The aggregated accuracy key is preserved unchanged for dashboard
backwards compatibility. Datasets without task_name are skipped, so
single-task and legacy recipes behave identically.
DPO already does per-dataset metrics via its dict-of-dataloaders
architecture (see dpo.validate at nemo_rl/algorithms/dpo.py:332-377),
so it is not touched here.
Tests:
* test_grpo.py adds test_validate_emits_per_task_accuracy_keys.
* test_distillation.py adds the same plus a check that the
aggregated accuracy key matches the sample-weighted mean across
tasks.
Signed-off-by: Minho Ryu <ryumin93@gmail.com>1 parent 870c987 commit ffac127
4 files changed
Lines changed: 204 additions & 2 deletions
File tree
- nemo_rl/algorithms
- tests/unit/algorithms
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
976 | 976 | | |
977 | 977 | | |
978 | 978 | | |
| 979 | + | |
| 980 | + | |
979 | 981 | | |
980 | 982 | | |
981 | 983 | | |
| |||
1011 | 1013 | | |
1012 | 1014 | | |
1013 | 1015 | | |
| 1016 | + | |
1014 | 1017 | | |
1015 | | - | |
| 1018 | + | |
1016 | 1019 | | |
1017 | 1020 | | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
1018 | 1029 | | |
1019 | 1030 | | |
1020 | 1031 | | |
| |||
1037 | 1048 | | |
1038 | 1049 | | |
1039 | 1050 | | |
| 1051 | + | |
| 1052 | + | |
| 1053 | + | |
| 1054 | + | |
| 1055 | + | |
1040 | 1056 | | |
1041 | 1057 | | |
1042 | 1058 | | |
| |||
1062 | 1078 | | |
1063 | 1079 | | |
1064 | 1080 | | |
| 1081 | + | |
| 1082 | + | |
| 1083 | + | |
| 1084 | + | |
| 1085 | + | |
| 1086 | + | |
| 1087 | + | |
| 1088 | + | |
1065 | 1089 | | |
1066 | 1090 | | |
1067 | 1091 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2287 | 2287 | | |
2288 | 2288 | | |
2289 | 2289 | | |
| 2290 | + | |
| 2291 | + | |
2290 | 2292 | | |
2291 | 2293 | | |
2292 | 2294 | | |
| |||
2337 | 2339 | | |
2338 | 2340 | | |
2339 | 2341 | | |
2340 | | - | |
| 2342 | + | |
| 2343 | + | |
2341 | 2344 | | |
2342 | 2345 | | |
| 2346 | + | |
| 2347 | + | |
| 2348 | + | |
| 2349 | + | |
| 2350 | + | |
| 2351 | + | |
| 2352 | + | |
| 2353 | + | |
2343 | 2354 | | |
2344 | 2355 | | |
2345 | 2356 | | |
| |||
2367 | 2378 | | |
2368 | 2379 | | |
2369 | 2380 | | |
| 2381 | + | |
| 2382 | + | |
| 2383 | + | |
| 2384 | + | |
| 2385 | + | |
2370 | 2386 | | |
2371 | 2387 | | |
2372 | 2388 | | |
| |||
2392 | 2408 | | |
2393 | 2409 | | |
2394 | 2410 | | |
| 2411 | + | |
| 2412 | + | |
| 2413 | + | |
| 2414 | + | |
| 2415 | + | |
| 2416 | + | |
| 2417 | + | |
| 2418 | + | |
2395 | 2419 | | |
2396 | 2420 | | |
2397 | 2421 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
329 | 329 | | |
330 | 330 | | |
331 | 331 | | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
332 | 384 | | |
333 | 385 | | |
334 | 386 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2433 | 2433 | | |
2434 | 2434 | | |
2435 | 2435 | | |
| 2436 | + | |
| 2437 | + | |
| 2438 | + | |
| 2439 | + | |
| 2440 | + | |
| 2441 | + | |
| 2442 | + | |
| 2443 | + | |
| 2444 | + | |
| 2445 | + | |
| 2446 | + | |
| 2447 | + | |
| 2448 | + | |
| 2449 | + | |
| 2450 | + | |
| 2451 | + | |
| 2452 | + | |
| 2453 | + | |
| 2454 | + | |
| 2455 | + | |
| 2456 | + | |
| 2457 | + | |
| 2458 | + | |
| 2459 | + | |
| 2460 | + | |
| 2461 | + | |
| 2462 | + | |
| 2463 | + | |
| 2464 | + | |
| 2465 | + | |
| 2466 | + | |
| 2467 | + | |
| 2468 | + | |
| 2469 | + | |
| 2470 | + | |
| 2471 | + | |
| 2472 | + | |
| 2473 | + | |
| 2474 | + | |
| 2475 | + | |
| 2476 | + | |
| 2477 | + | |
| 2478 | + | |
| 2479 | + | |
| 2480 | + | |
| 2481 | + | |
| 2482 | + | |
| 2483 | + | |
| 2484 | + | |
| 2485 | + | |
| 2486 | + | |
| 2487 | + | |
| 2488 | + | |
| 2489 | + | |
| 2490 | + | |
| 2491 | + | |
| 2492 | + | |
| 2493 | + | |
| 2494 | + | |
| 2495 | + | |
| 2496 | + | |
| 2497 | + | |
| 2498 | + | |
| 2499 | + | |
| 2500 | + | |
| 2501 | + | |
| 2502 | + | |
| 2503 | + | |
| 2504 | + | |
| 2505 | + | |
| 2506 | + | |
| 2507 | + | |
| 2508 | + | |
| 2509 | + | |
| 2510 | + | |
| 2511 | + | |
| 2512 | + | |
| 2513 | + | |
| 2514 | + | |
| 2515 | + | |
| 2516 | + | |
| 2517 | + | |
| 2518 | + | |
| 2519 | + | |
| 2520 | + | |
| 2521 | + | |
| 2522 | + | |
| 2523 | + | |
| 2524 | + | |
| 2525 | + | |
| 2526 | + | |
| 2527 | + | |
| 2528 | + | |
| 2529 | + | |
| 2530 | + | |
| 2531 | + | |
| 2532 | + | |
| 2533 | + | |
| 2534 | + | |
| 2535 | + | |
| 2536 | + | |
| 2537 | + | |
2436 | 2538 | | |
2437 | 2539 | | |
2438 | 2540 | | |
| |||
0 commit comments