Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions examples/nemo_gym/run_grpo_nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import os
import pprint
from typing import Optional

# Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB)
import wandb.util
Expand All @@ -35,7 +36,6 @@
TokenizerType,
_should_use_nemo_gym,
grpo_train,
refit_policy_generation,
setup,
)
from nemo_rl.algorithms.utils import get_tokenizer
Expand All @@ -54,6 +54,7 @@
register_omegaconf_resolvers,
)
from nemo_rl.utils.logger import get_next_experiment_dir
from nemo_rl.weight_sync import WeightSynchronizer


def parse_args() -> tuple[argparse.Namespace, list[str]]:
Expand All @@ -78,11 +79,13 @@ def collect_trajectories(
val_task_to_env: dict[str, EnvironmentInterface],
logger: Logger,
master_config: MasterConfig,
weight_sync: Optional[WeightSynchronizer] = None,
) -> None:
"""Run trajectory collection."""
# common config/state items
colocated_inference = master_config.policy["generation"]["colocated"]["enabled"]
refit_policy_generation(policy, policy_generation, colocated_inference)
if weight_sync is not None and weight_sync.is_stale:
weight_sync.sync_weights()
else:
policy_generation.prepare_for_generation()

log_filename = "trajectory_collection.jsonl"

Expand Down Expand Up @@ -196,6 +199,7 @@ def main() -> None:
policy,
policy_generation,
cluster,
weight_sync,
dataloader,
val_dataloader,
loss_fn,
Expand Down Expand Up @@ -240,6 +244,7 @@ def main() -> None:
val_task_to_env=val_task_to_env,
logger=logger,
master_config=master_config,
weight_sync=weight_sync,
)
# Check if async mode is enabled
elif "async_grpo" in config.grpo and config.grpo["async_grpo"]["enabled"]:
Expand Down Expand Up @@ -291,6 +296,7 @@ def main() -> None:
grpo_save_state=grpo_state,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_grpo_nemo_gym.py

weight_sync is correctly unpacked from setup() here, but collect_trajectories() at L85 still directly calls refit_policy_generation(policy, policy_generation, colocated_inference), which now emits a DeprecationWarning on every trajectory collection run.

Consider threading weight_sync into collect_trajectories() and using weight_sync.sync_weights() instead.

master_config=master_config,
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
weight_sync=weight_sync,
)
else:
print("🚀 Running synchronous GRPO training")
Expand All @@ -309,6 +315,7 @@ def main() -> None:
checkpointer,
grpo_state,
master_config,
weight_sync=weight_sync,
)


Expand Down
4 changes: 3 additions & 1 deletion examples/run_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def main() -> None:
student_policy,
teacher_policy,
student_generation,
weight_sync,
dataloader,
val_dataloader,
loss_fn,
Expand All @@ -105,14 +106,15 @@ def main() -> None:
student_generation,
dataloader,
val_dataloader,
tokenizer, # pass tokenizer parameter
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
distillation_state,
master_config,
weight_sync=weight_sync,
)


Expand Down
3 changes: 3 additions & 0 deletions examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _make_policy(**kwargs):
policy,
policy_generation,
cluster,
weight_sync,
dataloader,
val_dataloader,
loss_fn,
Expand Down Expand Up @@ -199,6 +200,7 @@ def _make_policy(**kwargs):
grpo_save_state=grpo_state,
master_config=master_config,
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
weight_sync=weight_sync,
)
else:
# Two parallel synchronous trainers (verl-style — main_ppo.py vs
Expand All @@ -217,6 +219,7 @@ def _make_policy(**kwargs):
checkpointer,
grpo_state,
master_config,
weight_sync=weight_sync,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples/run_grpo.py:222

When data_plane.enabled=True, _select_trainer() returns grpo_train_sync from grpo_sync.py, whose signature has no weight_sync parameter. This will raise TypeError at runtime for the entire TransferQueue/data-plane path.

Consider either migrating grpo_train_sync to accept weight_sync, or conditionally passing the kwarg only when the legacy trainer is selected.

)


Expand Down
2 changes: 2 additions & 0 deletions examples/run_grpo_sliding_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def main():
policy,
policy_generation,
cluster,
weight_sync,
dataloader,
val_dataloader,
loss_fn,
Expand All @@ -276,6 +277,7 @@ def main():
checkpointer,
grpo_state,
master_config,
weight_sync=weight_sync,
)


Expand Down
2 changes: 2 additions & 0 deletions examples/run_vlm_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def main() -> None:
policy,
policy_generation,
cluster,
weight_sync,
dataloader,
val_dataloader,
loss_fn,
Expand All @@ -130,6 +131,7 @@ def main() -> None:
checkpointer,
grpo_state,
master_config,
weight_sync=weight_sync,
)


Expand Down
69 changes: 25 additions & 44 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast

import numpy as np
import ray
import torch
from pydantic import BaseModel
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation
from nemo_rl.algorithms.grpo import _should_use_async_rollouts
from nemo_rl.algorithms.loss import (
DistillationLossConfig,
DistillationLossDataDict,
Expand Down Expand Up @@ -64,6 +63,7 @@
)
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import TimeoutChecker, Timer
from nemo_rl.weight_sync import WeightSynchronizer, create_weight_synchronizer

# ===============================================================================
# Configuration
Expand Down Expand Up @@ -165,6 +165,7 @@ def setup(
ColocatablePolicyInterface, # student_policy
ColocatablePolicyInterface, # teacher_policy
Optional[GenerationInterface], # student_generation
Optional[WeightSynchronizer], # weight_sync
StatefulDataLoader,
Optional[StatefulDataLoader],
DistillationLossFn,
Expand All @@ -176,7 +177,7 @@ def setup(
"""Main entry point for distillation algorithm.

Returns:
tuple of student_policy, teacher_policy, student_generation,
tuple of student_policy, teacher_policy, student_generation, weight_sync,
train_dataloader, val_dataloader,
loss_fn, logger, checkpointer, distillation_save_state, master_config
"""
Expand Down Expand Up @@ -461,26 +462,18 @@ def setup(
init_reference_model=False,
)

# Create weight synchronizer and initialize communication channels
weight_sync: Optional[WeightSynchronizer] = None
if student_generation is not None:
state_dict_info = student_policy.prepare_refit_info()
student_generation.prepare_refit_info(state_dict_info)

# if it is not colocated inference, initialize collective communication for update weights
if not colocated_inference:
ip, port = train_cluster.get_master_address_and_port()
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
train_world_size = train_cluster.world_size()
# inference cluster + head node of the train cluster
world_size = train_world_size + inference_nodes * inference_gpus_per_node
# init collective
futures_train = student_policy.init_collective(
ip, port, world_size, train_world_size=train_world_size
weight_sync = create_weight_synchronizer(
policy=student_policy,
generation=student_generation,
generation_backend=backend,
colocated=colocated_inference,
train_cluster=train_cluster if not colocated_inference else None,
inference_cluster=inference_cluster if not colocated_inference else None,
)
futures_inference = student_generation.init_collective(
ip, port, world_size, train_world_size=train_world_size
) # type: ignore
# wait for all futures to complete
ray.get(futures_train + futures_inference)
weight_sync.init_communicator()

loss_fn = DistillationLossFn(loss_config)

Expand All @@ -492,6 +485,7 @@ def setup(
student_policy,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distillation.py:479

setup() now returns weight_sync at position 4, but examples/run_distillation.py:89-100 still unpacks only 10 values. This will crash with ValueError: too many values to unpack.

Also, distillation_train() needs weight_sync=weight_sync passed through. See suggested fix in the review summary.

teacher_policy,
student_generation,
weight_sync,
dataloader,
val_dataloader,
loss_fn,
Expand Down Expand Up @@ -521,6 +515,7 @@ def distillation_train(
checkpointer: CheckpointManager,
distillation_save_state: DistillationSaveState,
master_config: MasterConfig,
weight_sync: Optional[WeightSynchronizer] = None,
) -> None:
"""Run Distillation training algorithm."""
timer = Timer()
Expand All @@ -530,13 +525,10 @@ def distillation_train(
)
timeout.start_iterations()

NEED_REFIT = True
# If student_generation is None, use the student_policy as the generation interface (megatron framework backend)
if student_generation is None:
student_generation = student_policy # type: ignore
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
assert student_generation is not None
assert student_generation is not None # for mypy type check

# common config/state items
current_epoch = distillation_save_state["current_epoch"] # current epoch
Expand All @@ -562,11 +554,8 @@ def distillation_train(
# Run validation at the start if configured
if val_at_start and total_steps == 0:
print("\n🔍 Running initial validation...", flush=True)
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
student_policy, student_generation, colocated_inference
)
POLICY_GENERATION_STALE = False
if weight_sync is not None and weight_sync.is_stale:
weight_sync.sync_weights()
else:
student_generation.prepare_for_generation()
val_metrics, validation_timings = validate(
Expand Down Expand Up @@ -617,14 +606,8 @@ def distillation_train(
flush=True,
)
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
student_policy,
student_generation,
colocated_inference,
timer=timer,
)
POLICY_GENERATION_STALE = False
if weight_sync is not None and weight_sync.is_stale:
weight_sync.sync_weights(timer=timer)
else:
student_generation.prepare_for_generation()

Expand Down Expand Up @@ -718,7 +701,8 @@ def distillation_train(
with timer.time("training_prep"):
teacher_policy.offload_after_refit()
student_policy.prepare_for_training() # set model train and reload optim to GPU
POLICY_GENERATION_STALE = True
if weight_sync is not None:
weight_sync.mark_stale()

print("▶ Training policy...", flush=True)
with timer.time("policy_training"):
Expand All @@ -737,11 +721,8 @@ def distillation_train(
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
val_at_end and is_last_step
):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
student_policy, student_generation, colocated_inference
)
POLICY_GENERATION_STALE = False
if weight_sync is not None and weight_sync.is_stale:
weight_sync.sync_weights()
else:
student_generation.prepare_for_generation()
val_metrics, validation_timings = validate(
Expand Down
Loading
Loading