-
Notifications
You must be signed in to change notification settings - Fork 452
feat: Wire WeightSynchronizer into algorithm layer, replacing inline refit logic #2467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -134,6 +134,7 @@ def _make_policy(**kwargs): | |
| policy, | ||
| policy_generation, | ||
| cluster, | ||
| weight_sync, | ||
| dataloader, | ||
| val_dataloader, | ||
| loss_fn, | ||
|
|
@@ -199,6 +200,7 @@ def _make_policy(**kwargs): | |
| grpo_save_state=grpo_state, | ||
| master_config=master_config, | ||
| max_trajectory_age_steps=async_config["max_trajectory_age_steps"], | ||
| weight_sync=weight_sync, | ||
| ) | ||
| else: | ||
| # Two parallel synchronous trainers (verl-style — main_ppo.py vs | ||
|
|
@@ -217,6 +219,7 @@ def _make_policy(**kwargs): | |
| checkpointer, | ||
| grpo_state, | ||
| master_config, | ||
| weight_sync=weight_sync, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When Consider either migrating |
||
| ) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,14 +16,13 @@ | |
| from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast | ||
|
|
||
| import numpy as np | ||
| import ray | ||
| import torch | ||
| from pydantic import BaseModel | ||
| from torchdata.stateful_dataloader import StatefulDataLoader | ||
| from transformers import AutoConfig, AutoTokenizer | ||
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | ||
|
|
||
| from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation | ||
| from nemo_rl.algorithms.grpo import _should_use_async_rollouts | ||
| from nemo_rl.algorithms.loss import ( | ||
| DistillationLossConfig, | ||
| DistillationLossDataDict, | ||
|
|
@@ -64,6 +63,7 @@ | |
| ) | ||
| from nemo_rl.utils.nsys import maybe_gpu_profile_step | ||
| from nemo_rl.utils.timer import TimeoutChecker, Timer | ||
| from nemo_rl.weight_sync import WeightSynchronizer, create_weight_synchronizer | ||
|
|
||
| # =============================================================================== | ||
| # Configuration | ||
|
|
@@ -165,6 +165,7 @@ def setup( | |
| ColocatablePolicyInterface, # student_policy | ||
| ColocatablePolicyInterface, # teacher_policy | ||
| Optional[GenerationInterface], # student_generation | ||
| Optional[WeightSynchronizer], # weight_sync | ||
| StatefulDataLoader, | ||
| Optional[StatefulDataLoader], | ||
| DistillationLossFn, | ||
|
|
@@ -176,7 +177,7 @@ def setup( | |
| """Main entry point for distillation algorithm. | ||
|
|
||
| Returns: | ||
| tuple of student_policy, teacher_policy, student_generation, | ||
| tuple of student_policy, teacher_policy, student_generation, weight_sync, | ||
| train_dataloader, val_dataloader, | ||
| loss_fn, logger, checkpointer, distillation_save_state, master_config | ||
| """ | ||
|
|
@@ -461,26 +462,18 @@ def setup( | |
| init_reference_model=False, | ||
| ) | ||
|
|
||
| # Create weight synchronizer and initialize communication channels | ||
| weight_sync: Optional[WeightSynchronizer] = None | ||
| if student_generation is not None: | ||
| state_dict_info = student_policy.prepare_refit_info() | ||
| student_generation.prepare_refit_info(state_dict_info) | ||
|
|
||
| # if it is not colocated inference, initialize collective communication for update weights | ||
| if not colocated_inference: | ||
| ip, port = train_cluster.get_master_address_and_port() | ||
| print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) | ||
| train_world_size = train_cluster.world_size() | ||
| # inference cluster + head node of the train cluster | ||
| world_size = train_world_size + inference_nodes * inference_gpus_per_node | ||
| # init collective | ||
| futures_train = student_policy.init_collective( | ||
| ip, port, world_size, train_world_size=train_world_size | ||
| weight_sync = create_weight_synchronizer( | ||
| policy=student_policy, | ||
| generation=student_generation, | ||
| generation_backend=backend, | ||
| colocated=colocated_inference, | ||
| train_cluster=train_cluster if not colocated_inference else None, | ||
| inference_cluster=inference_cluster if not colocated_inference else None, | ||
| ) | ||
| futures_inference = student_generation.init_collective( | ||
| ip, port, world_size, train_world_size=train_world_size | ||
| ) # type: ignore | ||
| # wait for all futures to complete | ||
| ray.get(futures_train + futures_inference) | ||
| weight_sync.init_communicator() | ||
|
|
||
| loss_fn = DistillationLossFn(loss_config) | ||
|
|
||
|
|
@@ -492,6 +485,7 @@ def setup( | |
| student_policy, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Also, |
||
| teacher_policy, | ||
| student_generation, | ||
| weight_sync, | ||
| dataloader, | ||
| val_dataloader, | ||
| loss_fn, | ||
|
|
@@ -521,6 +515,7 @@ def distillation_train( | |
| checkpointer: CheckpointManager, | ||
| distillation_save_state: DistillationSaveState, | ||
| master_config: MasterConfig, | ||
| weight_sync: Optional[WeightSynchronizer] = None, | ||
| ) -> None: | ||
| """Run Distillation training algorithm.""" | ||
| timer = Timer() | ||
|
|
@@ -530,13 +525,10 @@ def distillation_train( | |
| ) | ||
| timeout.start_iterations() | ||
|
|
||
| NEED_REFIT = True | ||
| # If student_generation is None, use the student_policy as the generation interface (megatron framework backend) | ||
| if student_generation is None: | ||
| student_generation = student_policy # type: ignore | ||
| NEED_REFIT = False | ||
| POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running | ||
| assert student_generation is not None | ||
| assert student_generation is not None # for mypy type check | ||
|
|
||
| # common config/state items | ||
| current_epoch = distillation_save_state["current_epoch"] # current epoch | ||
|
|
@@ -562,11 +554,8 @@ def distillation_train( | |
| # Run validation at the start if configured | ||
| if val_at_start and total_steps == 0: | ||
| print("\n🔍 Running initial validation...", flush=True) | ||
| if NEED_REFIT and POLICY_GENERATION_STALE: | ||
| refit_policy_generation( | ||
| student_policy, student_generation, colocated_inference | ||
| ) | ||
| POLICY_GENERATION_STALE = False | ||
| if weight_sync is not None and weight_sync.is_stale: | ||
| weight_sync.sync_weights() | ||
| else: | ||
| student_generation.prepare_for_generation() | ||
| val_metrics, validation_timings = validate( | ||
|
|
@@ -617,14 +606,8 @@ def distillation_train( | |
| flush=True, | ||
| ) | ||
| with timer.time("prepare_for_generation"): | ||
| if NEED_REFIT and POLICY_GENERATION_STALE: | ||
| refit_policy_generation( | ||
| student_policy, | ||
| student_generation, | ||
| colocated_inference, | ||
| timer=timer, | ||
| ) | ||
| POLICY_GENERATION_STALE = False | ||
| if weight_sync is not None and weight_sync.is_stale: | ||
| weight_sync.sync_weights(timer=timer) | ||
| else: | ||
| student_generation.prepare_for_generation() | ||
|
|
||
|
|
@@ -718,7 +701,8 @@ def distillation_train( | |
| with timer.time("training_prep"): | ||
| teacher_policy.offload_after_refit() | ||
| student_policy.prepare_for_training() # set model train and reload optim to GPU | ||
| POLICY_GENERATION_STALE = True | ||
| if weight_sync is not None: | ||
| weight_sync.mark_stale() | ||
|
|
||
| print("▶ Training policy...", flush=True) | ||
| with timer.time("policy_training"): | ||
|
|
@@ -737,11 +721,8 @@ def distillation_train( | |
| if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( | ||
| val_at_end and is_last_step | ||
| ): | ||
| if NEED_REFIT and POLICY_GENERATION_STALE: | ||
| refit_policy_generation( | ||
| student_policy, student_generation, colocated_inference | ||
| ) | ||
| POLICY_GENERATION_STALE = False | ||
| if weight_sync is not None and weight_sync.is_stale: | ||
| weight_sync.sync_weights() | ||
| else: | ||
| student_generation.prepare_for_generation() | ||
| val_metrics, validation_timings = validate( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run_grpo_nemo_gym.pyweight_syncis correctly unpacked fromsetup()here, butcollect_trajectories()at L85 still directly callsrefit_policy_generation(policy, policy_generation, colocated_inference), which now emits aDeprecationWarningon every trajectory collection run.Consider threading
weight_syncintocollect_trajectories()and usingweight_sync.sync_weights()instead.