Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions examples/nemo_gym/run_grpo_nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,43 +73,42 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
def collect_trajectories(
policy: ColocatablePolicyInterface,
policy_generation: GenerationInterface,
val_dataloader: StatefulDataLoader,
val_dataloader: dict[str, StatefulDataLoader],
tokenizer: TokenizerType,
val_task_to_env: dict[str, EnvironmentInterface],
logger: Logger,
master_config: MasterConfig,
) -> None:
"""Run trajectory collection."""
# common config/state items
"""Run trajectory collection across all validation dataloaders."""
colocated_inference = master_config.policy["generation"]["colocated"]["enabled"]
refit_policy_generation(policy, policy_generation, colocated_inference)

log_filename = "trajectory_collection.jsonl"

print("\n🔍 Running trajectory collection...", flush=True)
generation_config = master_config.policy["generation"]
for val_batch in val_dataloader:
nemo_gym_rollout_result = run_async_nemo_gym_rollout(
policy_generation=policy_generation,
input_batch=val_batch,
tokenizer=tokenizer,
task_to_env=val_task_to_env,
max_seq_len=master_config.policy["max_total_sequence_length"],
generation_config=generation_config,
max_rollout_turns=None,
greedy=False,
)
for dataset_name, dl in val_dataloader.items():
log_filename = f"trajectory_collection_{dataset_name}.jsonl"
for val_batch in dl:
nemo_gym_rollout_result = run_async_nemo_gym_rollout(
policy_generation=policy_generation,
input_batch=val_batch,
tokenizer=tokenizer,
task_to_env=val_task_to_env,
max_seq_len=master_config.policy["max_total_sequence_length"],
generation_config=generation_config,
max_rollout_turns=None,
greedy=False,
)

rows_to_log: list[str] = []
for key, value in nemo_gym_rollout_result.rollout_metrics.items():
if "full_result" not in key:
continue
rows_to_log: list[str] = []
for key, value in nemo_gym_rollout_result.rollout_metrics.items():
if "full_result" not in key:
continue

value: Table
data: list[list[str]] = value.data # (n, 1)
rows_to_log.extend(v[0] for v in data)
value: Table
data: list[list[str]] = value.data # (n, 1)
rows_to_log.extend(v[0] for v in data)

logger.log_string_list_as_jsonl(rows_to_log, log_filename)
logger.log_string_list_as_jsonl(rows_to_log, log_filename)

# TODO: eventually as trajectory collection use cases exceed 4 hours, we can leverage the dataloader save functionality to resume
# And also leverage the TimeoutChecker functionality as well
Expand Down Expand Up @@ -180,10 +179,11 @@ def main() -> None:
)

if val_dataset is not None:
total_val_samples = sum(len(ds) for ds in val_dataset.values())
print(
f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {len(val_dataset)}"
f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {total_val_samples}"
)
config.grpo["max_val_samples"] = len(val_dataset)
config.grpo["max_val_samples"] = total_val_samples
config.grpo["val_batch_size"] = config.grpo["max_val_samples"]

# Print config
Expand Down
8 changes: 7 additions & 1 deletion examples/run_grpo_sliding_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,20 @@ def main():
* config.grpo["num_generations_per_prompt"]
* config.grpo["max_num_steps"]
)
puzzle_task_name = "sliding_puzzle_game"
dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data(
tokenizer=tokenizer,
env_cfg=config.env,
task_name="sliding_puzzle_game",
task_name=puzzle_task_name,
length=ds_length,
val_length=config.grpo["max_val_samples"],
add_system_prompt=config.data["add_system_prompt"],
)
# Algorithm setup expects val_dataset as a name->dataset dict so validate()
# can emit per-dataset metrics. The puzzle generator produces a single
# iterable; wrap it under its task name.
if val_dataset is not None:
val_dataset = {puzzle_task_name: val_dataset}

(
policy,
Expand Down
151 changes: 103 additions & 48 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DistillationConfig(TypedDict):
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
max_val_samples: int
max_val_samples: NotRequired[int]
topk_logits_k: int
seed: int

Expand Down Expand Up @@ -159,13 +159,13 @@ def setup(
master_config: MasterConfig,
tokenizer: TokenizerType,
train_dataset: AllTaskProcessedDataset,
val_dataset: Optional[AllTaskProcessedDataset],
val_dataset: Optional[dict[str, AllTaskProcessedDataset]],
) -> tuple[
ColocatablePolicyInterface, # student_policy
ColocatablePolicyInterface, # teacher_policy
Optional[GenerationInterface], # student_generation
StatefulDataLoader,
Optional[StatefulDataLoader],
Optional[dict[str, StatefulDataLoader]],
DistillationLossFn,
Logger,
CheckpointManager,
Expand Down Expand Up @@ -262,9 +262,9 @@ def setup(
f" ✓ Training dataloader loaded with {len(train_dataset)} samples", flush=True
)

# Load validation dataset if provided
val_dataloader: Optional[StatefulDataLoader] = None
# If validation is enabled, load the validation dataloader
# Load validation dataset if provided. One StatefulDataLoader per dataset
# so validate() can emit per-dataset metrics; mirrors the DPO pattern.
val_dataloader: Optional[dict[str, StatefulDataLoader]] = None
if (
distillation_config["val_period"] > 0
or distillation_config["val_at_start"]
Expand All @@ -273,14 +273,19 @@ def setup(
assert val_dataset is not None, (
"Validation dataset is required if validation is enabled"
)
val_dataloader = StatefulDataLoader(
val_dataset,
batch_size=distillation_config["val_batch_size"],
shuffle=False,
collate_fn=rl_collate_fn,
)
val_dataloader = {
name: StatefulDataLoader(
ds,
batch_size=distillation_config["val_batch_size"],
shuffle=False,
collate_fn=rl_collate_fn,
)
for name, ds in val_dataset.items()
}
total_samples = sum(len(ds) for ds in val_dataset.values())
print(
f" ✓ Validation dataloader loaded with {len(val_dataset)} samples",
f" ✓ Validation dataloader loaded with {total_samples} samples"
f" across {len(val_dataset)} dataset(s)",
flush=True,
)

Expand Down Expand Up @@ -508,7 +513,7 @@ def distillation_train(
teacher_policy: ColocatablePolicyInterface,
student_generation: Optional[GenerationInterface],
dataloader: StatefulDataLoader,
val_dataloader: Optional[StatefulDataLoader],
val_dataloader: Optional[dict[str, StatefulDataLoader]],
tokenizer: TokenizerType,
loss_fn: DistillationLossFn,
task_to_env: dict[str, EnvironmentInterface],
Expand Down Expand Up @@ -572,10 +577,9 @@ def distillation_train(
val_task_to_env,
step=total_steps,
master_config=master_config,
logger=logger,
)
student_generation.finish_generation()
logger.log_metrics(val_metrics, total_steps, prefix="validation")
logger.log_metrics(validation_timings, total_steps, prefix="timing/validation")

# Run distillation training (multi-epoch until reaching max_num_steps or max_num_epochs)
batch: BatchedDataDict[DatumSpec]
Expand Down Expand Up @@ -747,14 +751,9 @@ def distillation_train(
val_task_to_env,
step=total_steps + 1,
master_config=master_config,
logger=logger,
)
student_generation.finish_generation()
logger.log_metrics(
validation_timings, total_steps + 1, prefix="timing/validation"
)
logger.log_metrics(
val_metrics, total_steps + 1, prefix="validation"
)

metrics = {
"loss": train_results["loss"].numpy(),
Expand Down Expand Up @@ -951,14 +950,22 @@ def distillation_train(

def validate(
policy_generation: GenerationInterface,
val_dataloader: Optional[StatefulDataLoader],
val_dataloader: Optional[dict[str, StatefulDataLoader]],
tokenizer,
val_task_to_env: Optional[dict[str, EnvironmentInterface]],
step: int,
master_config: MasterConfig,
logger: Optional[Logger] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Run validation on the validation dataset."""
if val_dataloader is None:
"""Run validation across one or more datasets.

Each dataset gets its own dataloader and is logged under a
`validation-<dataset_name>/` prefix, mirroring the DPO setup. The returned
`val_metrics` dict carries the per-dataset entries alongside an
`accuracy` macro-mean across datasets so save-state and best-checkpoint
selection keep working.
"""
if not val_dataloader:
print(" ⚠️ No validation dataloader provided, skipping validation", flush=True)
return {}, {}

Expand All @@ -969,25 +976,84 @@ def validate(
)
return {}, {}

val_metrics: dict[str, Any] = {}
validation_timings: dict[str, Any] = {}
per_dataset_accuracy: list[float] = []

for dataset_name, dl in val_dataloader.items():
k_metrics, k_timings = validate_one_dataset(
policy_generation,
dl,
tokenizer,
val_task_to_env,
step=step,
master_config=master_config,
dataset_name=dataset_name,
)
prefix = f"validation-{dataset_name}"
if logger is not None:
logger.log_metrics(k_metrics, step, prefix=prefix)
logger.log_metrics(k_timings, step, prefix=f"timing/{prefix}")
for metric_name, value in k_metrics.items():
val_metrics[f"{prefix}_{metric_name}"] = value
for timing_name, value in k_timings.items():
validation_timings[f"{prefix}_{timing_name}"] = value
if "accuracy" in k_metrics:
per_dataset_accuracy.append(k_metrics["accuracy"])

if per_dataset_accuracy:
# Macro-mean across datasets so checkpointing.metric_name='val:accuracy'
# keeps working under multi-validation.
val_metrics["accuracy"] = sum(per_dataset_accuracy) / len(per_dataset_accuracy)

if validation_timings:
total_validation_time = sum(
v for k, v in validation_timings.items() if k.endswith("total_validation_time")
)
if logger is not None:
logger.log_metrics(
{"total_validation_time": total_validation_time},
step,
prefix="timing/validation",
)
validation_timings["total_validation_time"] = total_validation_time

return val_metrics, validation_timings


def validate_one_dataset(
policy_generation: GenerationInterface,
val_dataloader: StatefulDataLoader,
tokenizer,
val_task_to_env: dict[str, EnvironmentInterface],
step: int,
master_config: MasterConfig,
dataset_name: str,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Run validation on one validation dataset and return its metrics + timings."""
timer = Timer()
with timer.time("total_validation_time"):
print(f"▶ Starting validation at step {step}...", flush=True)
print(
f"▶ Starting validation at step {step} for `{dataset_name}`...",
flush=True,
)

total_rewards: list[float] = []
total_lengths: list[float] = []
all_message_logs = []

total_rewards = [] # Can be any metric. Setted to 'accuracy' by default.
total_lengths = []
all_message_logs = [] # Collect all message logs
max_val_samples = master_config.distillation.get("max_val_samples")
if max_val_samples is None:
max_batches = len(val_dataloader)
else:
max_batches = (
max_val_samples // master_config.distillation["val_batch_size"]
)

max_batches = (
master_config.distillation["max_val_samples"]
+ master_config.distillation["val_batch_size"]
- 1
) // master_config.distillation["val_batch_size"]
for batch_idx, val_batch in enumerate(val_dataloader):
if batch_idx >= max_batches:
break

# Generate responses (updates the LLMMessageLogType in batch_with_msg_logs)
# Use async rollouts if vLLM async engine is enabled
if _should_use_async_rollouts(master_config):
val_batch, gen_metrics = run_async_multi_turn_rollout(
policy_generation,
Expand All @@ -1013,17 +1079,14 @@ def validate(
total_rewards.extend(rewards.tolist())
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])

# Collect message logs for later display
to_env = [
get_keys_from_message_log(
val_batch["message_log"][i], ["role", "content"]
)
for i in range(len(val_batch["message_log"]))
]

all_message_logs.extend(to_env)

# Calculate validation metrics
accuracy = (
sum(total_rewards) / len(total_rewards) if len(total_rewards) > 0 else 0
)
Expand All @@ -1036,7 +1099,6 @@ def validate(
"avg_length": avg_length,
}

# Print sample conversations only once at the end of validation
try:
print_message_log_samples(
all_message_logs,
Expand All @@ -1051,22 +1113,15 @@ def validate(
print(f"\n ⚠️ Error displaying message samples: {str(e)}")
print(" ⚠️ Continuing validation without displaying samples...", flush=True)

# Get timing metrics
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
validation_time = timing_metrics.get("total_validation_time", 0)

# Print summary of validation results
print("\n📊 Validation Results:")
print(f"\n📊 Validation Results for `{dataset_name}`:")
print(f" • Accuracy: {accuracy:.4f}")
print(f" • Average response length: {avg_length:.1f} tokens")
print(f" • Samples processed: {len(total_rewards)}", flush=True)

# Print timing information
print("\n ⏱️ Validation Timing:")
validation_time = timing_metrics.get("total_validation_time", 0)
print(f" • Total validation time: {validation_time:.2f}s", flush=True)

# Make sure to reset the timer after validation
timer.reset()

return val_metrics, timing_metrics
Loading
Loading