feat: per-dataset validation dataloaders for GRPO and Distillation#2502
Open
bzantium wants to merge 1 commit into
Open
feat: per-dataset validation dataloaders for GRPO and Distillation#2502bzantium wants to merge 1 commit into
bzantium wants to merge 1 commit into
Conversation
Validation in GRPO and Distillation now mirrors the DPO architecture:
each entry under data.validation gets its own StatefulDataLoader,
validate() iterates them, and each dataset's metrics are logged under
a validation-<dataset_name>/ wandb prefix.
Previously, multiple validation datasets were concatenated into a
single dataloader and the aggregator emitted only one
validation/accuracy and one validation/avg_length, hiding per-task
progress entirely.
Changes
* nemo_rl/data/utils.py: setup_response_data returns
Optional[dict[str, AllTaskProcessedDataset]] for the validation set
instead of a single concatenated AllTaskProcessedDataset. Order and
task names are preserved so the algorithm can build one dataloader
per dataset.
* nemo_rl/algorithms/grpo.py:
- setup() builds val_dataloader: dict[str, StatefulDataLoader] and
grpo_train / async_grpo_train accept the dict type.
- validate() iterates the dict, calling a new _validate_one_dataset
helper per dataset. Per-dataset metrics are logged inside validate()
with prefix=validation-<name>; the wrapping log_metrics calls in
grpo_train were removed to avoid double-prefixing.
- max_val_samples is now NotRequired (already typed int | None for
NeMo-Gym compatibility); when absent or None the full dataloader is
iterated, applied per dataset.
- validate() also returns an aggregated accuracy macro-mean across
datasets so checkpointing.metric_name='val:accuracy' continues to
work.
* nemo_rl/algorithms/distillation.py: same shape changes as grpo.py,
including the NotRequired widening of max_val_samples and the per-
dataset _validate_one_dataset split.
* examples/run_grpo_sliding_puzzle.py: wrap the puzzle val dataset in
{task_name: dataset} before passing to setup().
* examples/nemo_gym/run_grpo_nemo_gym.py: sum lengths across the val
dataset dict for the existing max_val_samples / val_batch_size
derivation; collect_trajectories iterates per-dataset dataloaders
and writes one trajectory_collection_<name>.jsonl per dataset.
* tests/unit/algorithms/test_grpo.py and test_distillation.py: fixtures
now wrap the mock dataloader as {"test_dataset": <MagicMock>}; new
tests cover (a) per-dataset prefixed metric emission and logging and
(b) the optional max_val_samples branch iterating the full loader.
Backwards compatibility
The aggregated validation/accuracy and validation/avg_length wandb
keys are no longer emitted; metrics surface only under the per-dataset
prefix. Dashboards relying on the old keys need to be re-pointed at
validation-<name>/accuracy and validation-<name>/avg_length. The
internal val_metrics["accuracy"] return value is preserved as the
macro-mean across datasets so save-state and best-checkpoint selection
keep working without configuration changes.
Signed-off-by: Minho Ryu <ryumin93@gmail.com>
4c6961d to
cdec8a6
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Closes #2501. Brings GRPO and Distillation validation in line with the DPO architecture so multiple validation datasets each get their own dataloader and their own
validation-<name>/wandb prefix, instead of being concatenated into a single sample-weightedaccuracy.Code changes
nemo_rl/data/utils.pysetup_response_datareturnsOptional[dict[str, AllTaskProcessedDataset]]for the validation set instead of a single concatenated dataset. Order and task names are preserved so the algorithm can build one dataloader per dataset.nemo_rl/algorithms/grpo.pysetup()buildsval_dataloader: dict[str, StatefulDataLoader].validate()iterates the dict, factoring per-dataset logic into a_validate_one_datasethelper. Per-dataset metrics are logged insidevalidate()underprefix=validation-<name>; the wrappinglog_metricscalls ingrpo_train/async_grpo_trainwere removed to avoid double-prefixing.max_val_samplesis nowNotRequiredandNone/absent means "iterate the entire dataloader".validate()also returns an aggregatedaccuracymacro-mean across datasets sosave_state[\"val_reward\"]andcheckpointing.metric_name='val:accuracy'keep working.nemo_rl/algorithms/distillation.pygrpo.py, including theNotRequiredwidening ofmax_val_samplesand the_validate_one_datasetsplit. Distillation's truncation switches from ceiling division to floor division so it matches GRPO's behaviour whenmax_val_samplesis set.examples/run_grpo_sliding_puzzle.py{task_name: dataset}before passing tosetup().examples/nemo_gym/run_grpo_nemo_gym.pymax_val_samples/val_batch_sizederivation.collect_trajectoriesiterates per-dataset dataloaders and writes onetrajectory_collection_<name>.jsonlper dataset.tests/unit/algorithms/test_grpo.pyval_dataloadernow wraps as{\"test_dataset\": <MagicMock>}. New tests:test_validate_emits_per_dataset_prefixed_metrics_and_logsverifies per-dataset key emission + wandb prefix;test_validate_iterates_full_dataloader_when_max_val_samples_is_nonecovers the optional branch.tests/unit/algorithms/test_distillation.pyDPO is not touched.
Backwards compatibility
The aggregated
validation/accuracyandvalidation/avg_lengthwandb keys are no longer emitted; metrics surface only under the per-datasetvalidation-<name>/prefix. Dashboards that read the old keys need to be re-pointed atvalidation-<name>/accuracyandvalidation-<name>/avg_length.The internal
val_metrics[\"accuracy\"]return value is preserved as the macro-mean across datasets sosave_state[\"val_reward\"]andcheckpointing.metric_name='val:accuracy'continue to work without configuration changes.Issues
Closes #2501.
Usage
wandb panel after this change:
Driver log:
Before your PR is "Ready for review"
Pre checks:
Additional Information
config-conventionsskill: optionalmax_val_samplesfield expressed viaNotRequired, no hidden defaults, exemplar YAMLs keep their explicit recommended values.