Skip to content

feat: per-dataset validation dataloaders for GRPO and Distillation#2502

Open
bzantium wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
bzantium:feat/per-dataset-validation-dataloaders
Open

feat: per-dataset validation dataloaders for GRPO and Distillation#2502
bzantium wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
bzantium:feat/per-dataset-validation-dataloaders

Conversation

@bzantium
Copy link
Copy Markdown

What does this PR do ?

Closes #2501. Brings GRPO and Distillation validation in line with the DPO architecture so multiple validation datasets each get their own dataloader and their own validation-<name>/ wandb prefix, instead of being concatenated into a single sample-weighted accuracy.

Code changes

File Change
nemo_rl/data/utils.py setup_response_data returns Optional[dict[str, AllTaskProcessedDataset]] for the validation set instead of a single concatenated dataset. Order and task names are preserved so the algorithm can build one dataloader per dataset.
nemo_rl/algorithms/grpo.py setup() builds val_dataloader: dict[str, StatefulDataLoader]. validate() iterates the dict, factoring per-dataset logic into a _validate_one_dataset helper. Per-dataset metrics are logged inside validate() under prefix=validation-<name>; the wrapping log_metrics calls in grpo_train / async_grpo_train were removed to avoid double-prefixing. max_val_samples is now NotRequired and None/absent means "iterate the entire dataloader". validate() also returns an aggregated accuracy macro-mean across datasets so save_state[\"val_reward\"] and checkpointing.metric_name='val:accuracy' keep working.
nemo_rl/algorithms/distillation.py Same shape changes as grpo.py, including the NotRequired widening of max_val_samples and the _validate_one_dataset split. Distillation's truncation switches from ceiling division to floor division so it matches GRPO's behaviour when max_val_samples is set.
examples/run_grpo_sliding_puzzle.py Wraps the puzzle's single iterable val dataset in {task_name: dataset} before passing to setup().
examples/nemo_gym/run_grpo_nemo_gym.py Sums lengths across the val dataset dict for the existing max_val_samples / val_batch_size derivation. collect_trajectories iterates per-dataset dataloaders and writes one trajectory_collection_<name>.jsonl per dataset.
tests/unit/algorithms/test_grpo.py Fixture val_dataloader now wraps as {\"test_dataset\": <MagicMock>}. New tests: test_validate_emits_per_dataset_prefixed_metrics_and_logs verifies per-dataset key emission + wandb prefix; test_validate_iterates_full_dataloader_when_max_val_samples_is_none covers the optional branch.
tests/unit/algorithms/test_distillation.py Same fixture wrap and the same two new tests.

DPO is not touched.

Backwards compatibility

The aggregated validation/accuracy and validation/avg_length wandb keys are no longer emitted; metrics surface only under the per-dataset validation-<name>/ prefix. Dashboards that read the old keys need to be re-pointed at validation-<name>/accuracy and validation-<name>/avg_length.

The internal val_metrics[\"accuracy\"] return value is preserved as the macro-mean across datasets so save_state[\"val_reward\"] and checkpointing.metric_name='val:accuracy' continue to work without configuration changes.

Issues

Closes #2501.

Usage

data:
  validation:
    - dataset_name: gsm8k
      split: test
    - dataset_name: ResponseDataset
      data_path: data/math500.parquet
distillation:
  val_batch_size: 8
  # max_val_samples omitted -> evaluate the entire val dataset

wandb panel after this change:

validation-gsm8k/accuracy
validation-gsm8k/avg_length
validation-ResponseDataset/accuracy
validation-ResponseDataset/avg_length
timing/validation/total_validation_time

Driver log:

📊 Validation Results for `gsm8k`:
    • Accuracy: 0.6580
    • Average response length: 412.1 tokens
    • Samples processed: 1319

📊 Validation Results for `ResponseDataset`:
    • Accuracy: 0.4320
    • Average response length: 651.4 tokens
    • Samples processed: 500

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • Followed the config-conventions skill: optional max_val_samples field expressed via NotRequired, no hidden defaults, exemplar YAMLs keep their explicit recommended values.
  • Single commit; the change is one cohesive refactor and does not split cleanly into reviewable sub-patches.

@bzantium bzantium requested review from a team as code owners May 15, 2026 05:55
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Validation in GRPO and Distillation now mirrors the DPO architecture:
each entry under data.validation gets its own StatefulDataLoader,
validate() iterates them, and each dataset's metrics are logged under
a validation-<dataset_name>/ wandb prefix.

Previously, multiple validation datasets were concatenated into a
single dataloader and the aggregator emitted only one
validation/accuracy and one validation/avg_length, hiding per-task
progress entirely.

Changes
* nemo_rl/data/utils.py: setup_response_data returns
  Optional[dict[str, AllTaskProcessedDataset]] for the validation set
  instead of a single concatenated AllTaskProcessedDataset. Order and
  task names are preserved so the algorithm can build one dataloader
  per dataset.
* nemo_rl/algorithms/grpo.py:
  - setup() builds val_dataloader: dict[str, StatefulDataLoader] and
    grpo_train / async_grpo_train accept the dict type.
  - validate() iterates the dict, calling a new _validate_one_dataset
    helper per dataset. Per-dataset metrics are logged inside validate()
    with prefix=validation-<name>; the wrapping log_metrics calls in
    grpo_train were removed to avoid double-prefixing.
  - max_val_samples is now NotRequired (already typed int | None for
    NeMo-Gym compatibility); when absent or None the full dataloader is
    iterated, applied per dataset.
  - validate() also returns an aggregated accuracy macro-mean across
    datasets so checkpointing.metric_name='val:accuracy' continues to
    work.
* nemo_rl/algorithms/distillation.py: same shape changes as grpo.py,
  including the NotRequired widening of max_val_samples and the per-
  dataset _validate_one_dataset split.
* examples/run_grpo_sliding_puzzle.py: wrap the puzzle val dataset in
  {task_name: dataset} before passing to setup().
* examples/nemo_gym/run_grpo_nemo_gym.py: sum lengths across the val
  dataset dict for the existing max_val_samples / val_batch_size
  derivation; collect_trajectories iterates per-dataset dataloaders
  and writes one trajectory_collection_<name>.jsonl per dataset.
* tests/unit/algorithms/test_grpo.py and test_distillation.py: fixtures
  now wrap the mock dataloader as {"test_dataset": <MagicMock>}; new
  tests cover (a) per-dataset prefixed metric emission and logging and
  (b) the optional max_val_samples branch iterating the full loader.

Backwards compatibility
The aggregated validation/accuracy and validation/avg_length wandb
keys are no longer emitted; metrics surface only under the per-dataset
prefix. Dashboards relying on the old keys need to be re-pointed at
validation-<name>/accuracy and validation-<name>/avg_length. The
internal val_metrics["accuracy"] return value is preserved as the
macro-mean across datasets so save-state and best-checkpoint selection
keep working without configuration changes.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
@bzantium bzantium force-pushed the feat/per-dataset-validation-dataloaders branch from 4c6961d to cdec8a6 Compare May 15, 2026 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GRPO/Distillation validation collapses multi-dataset metrics into a single accuracy

2 participants