Skip to content

Commit df6b623

Browse files
committed
feat: Wire WeightSynchronizer into algorithm layer, replacing inline refit logic
Replace refit_policy_generation() calls and NEED_REFIT/POLICY_GENERATION_STALE flags in grpo.py and distillation.py with WeightSynchronizer method calls (sync_weights, mark_stale, is_stale). The setup() functions now create and initialize the appropriate WeightSynchronizer and return it in the tuple. Key changes: - grpo.py setup(): create WeightSynchronizer via factory, replace inline init_collective/prepare_refit_info with weight_sync.init_communicator() - grpo_train/async_grpo_train: accept weight_sync param, use it for all weight transfer instead of refit_policy_generation() - distillation.py: same treatment for distillation_train() - refit_policy_generation(): kept with deprecation warning for external users - Factory: keep NotImplementedError for non-colocated SGLang (SGLang's update_weights_from_collective() is a no-op, would silently skip transfer) - All example scripts updated to thread weight_sync through Signed-off-by: Saurabh Mishra <sauramishra@nvidia.com>
1 parent 5494d14 commit df6b623

11 files changed

Lines changed: 172 additions & 153 deletions

File tree

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import argparse
1616
import os
1717
import pprint
18+
from typing import Optional
1819

1920
# Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB)
2021
import wandb.util
@@ -35,7 +36,6 @@
3536
TokenizerType,
3637
_should_use_nemo_gym,
3738
grpo_train,
38-
refit_policy_generation,
3939
setup,
4040
)
4141
from nemo_rl.algorithms.utils import get_tokenizer
@@ -54,6 +54,7 @@
5454
register_omegaconf_resolvers,
5555
)
5656
from nemo_rl.utils.logger import get_next_experiment_dir
57+
from nemo_rl.weight_sync import WeightSynchronizer
5758

5859

5960
def parse_args() -> tuple[argparse.Namespace, list[str]]:
@@ -78,11 +79,13 @@ def collect_trajectories(
7879
val_task_to_env: dict[str, EnvironmentInterface],
7980
logger: Logger,
8081
master_config: MasterConfig,
82+
weight_sync: Optional[WeightSynchronizer] = None,
8183
) -> None:
8284
"""Run trajectory collection."""
83-
# common config/state items
84-
colocated_inference = master_config.policy["generation"]["colocated"]["enabled"]
85-
refit_policy_generation(policy, policy_generation, colocated_inference)
85+
if weight_sync is not None and weight_sync.is_stale:
86+
weight_sync.sync_weights()
87+
else:
88+
policy_generation.prepare_for_generation()
8689

8790
log_filename = "trajectory_collection.jsonl"
8891

@@ -196,6 +199,7 @@ def main() -> None:
196199
policy,
197200
policy_generation,
198201
cluster,
202+
weight_sync,
199203
dataloader,
200204
val_dataloader,
201205
loss_fn,
@@ -231,6 +235,7 @@ def main() -> None:
231235
val_task_to_env=val_task_to_env,
232236
logger=logger,
233237
master_config=master_config,
238+
weight_sync=weight_sync,
234239
)
235240
# Check if async mode is enabled
236241
elif "async_grpo" in config.grpo and config.grpo["async_grpo"]["enabled"]:
@@ -282,6 +287,7 @@ def main() -> None:
282287
grpo_save_state=grpo_state,
283288
master_config=master_config,
284289
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
290+
weight_sync=weight_sync,
285291
)
286292
else:
287293
print("🚀 Running synchronous GRPO training")
@@ -300,6 +306,7 @@ def main() -> None:
300306
checkpointer,
301307
grpo_state,
302308
master_config,
309+
weight_sync=weight_sync,
303310
)
304311

305312

examples/run_distillation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def main() -> None:
9090
student_policy,
9191
teacher_policy,
9292
student_generation,
93+
weight_sync,
9394
dataloader,
9495
val_dataloader,
9596
loss_fn,
@@ -105,14 +106,15 @@ def main() -> None:
105106
student_generation,
106107
dataloader,
107108
val_dataloader,
108-
tokenizer, # pass tokenizer parameter
109+
tokenizer,
109110
loss_fn,
110111
task_to_env,
111112
val_task_to_env,
112113
logger,
113114
checkpointer,
114115
distillation_state,
115116
master_config,
117+
weight_sync=weight_sync,
116118
)
117119

118120

examples/run_grpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _make_policy(**kwargs):
134134
policy,
135135
policy_generation,
136136
cluster,
137+
weight_sync,
137138
dataloader,
138139
val_dataloader,
139140
loss_fn,
@@ -199,6 +200,7 @@ def _make_policy(**kwargs):
199200
grpo_save_state=grpo_state,
200201
master_config=master_config,
201202
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
203+
weight_sync=weight_sync,
202204
)
203205
else:
204206
# Two parallel synchronous trainers (verl-style — main_ppo.py vs
@@ -217,6 +219,7 @@ def _make_policy(**kwargs):
217219
checkpointer,
218220
grpo_state,
219221
master_config,
222+
weight_sync=weight_sync,
220223
)
221224

222225

examples/run_grpo_sliding_puzzle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def main():
254254
policy,
255255
policy_generation,
256256
cluster,
257+
weight_sync,
257258
dataloader,
258259
val_dataloader,
259260
loss_fn,
@@ -276,6 +277,7 @@ def main():
276277
checkpointer,
277278
grpo_state,
278279
master_config,
280+
weight_sync=weight_sync,
279281
)
280282

281283

examples/run_vlm_grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def main() -> None:
108108
policy,
109109
policy_generation,
110110
cluster,
111+
weight_sync,
111112
dataloader,
112113
val_dataloader,
113114
loss_fn,
@@ -130,6 +131,7 @@ def main() -> None:
130131
checkpointer,
131132
grpo_state,
132133
master_config,
134+
weight_sync=weight_sync,
133135
)
134136

135137

nemo_rl/algorithms/distillation.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
1717

1818
import numpy as np
19-
import ray
2019
import torch
2120
from pydantic import BaseModel
2221
from torchdata.stateful_dataloader import StatefulDataLoader
2322
from transformers import AutoConfig, AutoTokenizer
2423
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2524

26-
from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation
25+
from nemo_rl.algorithms.grpo import _should_use_async_rollouts
2726
from nemo_rl.algorithms.loss import (
2827
DistillationLossConfig,
2928
DistillationLossDataDict,
@@ -64,6 +63,7 @@
6463
)
6564
from nemo_rl.utils.nsys import maybe_gpu_profile_step
6665
from nemo_rl.utils.timer import TimeoutChecker, Timer
66+
from nemo_rl.weight_sync import WeightSynchronizer, create_weight_synchronizer
6767

6868
# ===============================================================================
6969
# Configuration
@@ -165,6 +165,7 @@ def setup(
165165
ColocatablePolicyInterface, # student_policy
166166
ColocatablePolicyInterface, # teacher_policy
167167
Optional[GenerationInterface], # student_generation
168+
Optional[WeightSynchronizer], # weight_sync
168169
StatefulDataLoader,
169170
Optional[StatefulDataLoader],
170171
DistillationLossFn,
@@ -176,7 +177,7 @@ def setup(
176177
"""Main entry point for distillation algorithm.
177178
178179
Returns:
179-
tuple of student_policy, teacher_policy, student_generation,
180+
tuple of student_policy, teacher_policy, student_generation, weight_sync,
180181
train_dataloader, val_dataloader,
181182
loss_fn, logger, checkpointer, distillation_save_state, master_config
182183
"""
@@ -455,26 +456,18 @@ def setup(
455456
init_reference_model=False,
456457
)
457458

459+
# Create weight synchronizer and initialize communication channels
460+
weight_sync: Optional[WeightSynchronizer] = None
458461
if student_generation is not None:
459-
state_dict_info = student_policy.prepare_refit_info()
460-
student_generation.prepare_refit_info(state_dict_info)
461-
462-
# if it is not colocated inference, initialize collective communication for update weights
463-
if not colocated_inference:
464-
ip, port = train_cluster.get_master_address_and_port()
465-
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
466-
train_world_size = train_cluster.world_size()
467-
# inference cluster + head node of the train cluster
468-
world_size = train_world_size + inference_nodes * inference_gpus_per_node
469-
# init collective
470-
futures_train = student_policy.init_collective(
471-
ip, port, world_size, train_world_size=train_world_size
462+
weight_sync = create_weight_synchronizer(
463+
policy=student_policy,
464+
generation=student_generation,
465+
generation_backend=backend,
466+
colocated=colocated_inference,
467+
train_cluster=train_cluster if not colocated_inference else None,
468+
inference_cluster=inference_cluster if not colocated_inference else None,
472469
)
473-
futures_inference = student_generation.init_collective(
474-
ip, port, world_size, train_world_size=train_world_size
475-
) # type: ignore
476-
# wait for all futures to complete
477-
ray.get(futures_train + futures_inference)
470+
weight_sync.init_communicator()
478471

479472
loss_fn = DistillationLossFn(loss_config)
480473

@@ -486,6 +479,7 @@ def setup(
486479
student_policy,
487480
teacher_policy,
488481
student_generation,
482+
weight_sync,
489483
dataloader,
490484
val_dataloader,
491485
loss_fn,
@@ -515,6 +509,7 @@ def distillation_train(
515509
checkpointer: CheckpointManager,
516510
distillation_save_state: DistillationSaveState,
517511
master_config: MasterConfig,
512+
weight_sync: Optional[WeightSynchronizer] = None,
518513
) -> None:
519514
"""Run Distillation training algorithm."""
520515
timer = Timer()
@@ -524,13 +519,10 @@ def distillation_train(
524519
)
525520
timeout.start_iterations()
526521

527-
NEED_REFIT = True
528522
# If student_generation is None, use the student_policy as the generation interface (megatron framework backend)
529523
if student_generation is None:
530524
student_generation = student_policy # type: ignore
531-
NEED_REFIT = False
532-
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
533-
assert student_generation is not None
525+
assert student_generation is not None # for mypy type check
534526

535527
# common config/state items
536528
current_epoch = distillation_save_state["current_epoch"] # current epoch
@@ -556,11 +548,8 @@ def distillation_train(
556548
# Run validation at the start if configured
557549
if val_at_start and total_steps == 0:
558550
print("\n🔍 Running initial validation...", flush=True)
559-
if NEED_REFIT and POLICY_GENERATION_STALE:
560-
refit_policy_generation(
561-
student_policy, student_generation, colocated_inference
562-
)
563-
POLICY_GENERATION_STALE = False
551+
if weight_sync is not None and weight_sync.is_stale:
552+
weight_sync.sync_weights()
564553
else:
565554
student_generation.prepare_for_generation()
566555
val_metrics, validation_timings = validate(
@@ -611,14 +600,8 @@ def distillation_train(
611600
flush=True,
612601
)
613602
with timer.time("prepare_for_generation"):
614-
if NEED_REFIT and POLICY_GENERATION_STALE:
615-
refit_policy_generation(
616-
student_policy,
617-
student_generation,
618-
colocated_inference,
619-
timer=timer,
620-
)
621-
POLICY_GENERATION_STALE = False
603+
if weight_sync is not None and weight_sync.is_stale:
604+
weight_sync.sync_weights(timer=timer)
622605
else:
623606
student_generation.prepare_for_generation()
624607

@@ -712,7 +695,8 @@ def distillation_train(
712695
with timer.time("training_prep"):
713696
teacher_policy.offload_after_refit()
714697
student_policy.prepare_for_training() # set model train and reload optim to GPU
715-
POLICY_GENERATION_STALE = True
698+
if weight_sync is not None:
699+
weight_sync.mark_stale()
716700

717701
print("▶ Training policy...", flush=True)
718702
with timer.time("policy_training"):
@@ -731,11 +715,8 @@ def distillation_train(
731715
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
732716
val_at_end and is_last_step
733717
):
734-
if NEED_REFIT and POLICY_GENERATION_STALE:
735-
refit_policy_generation(
736-
student_policy, student_generation, colocated_inference
737-
)
738-
POLICY_GENERATION_STALE = False
718+
if weight_sync is not None and weight_sync.is_stale:
719+
weight_sync.sync_weights()
739720
else:
740721
student_generation.prepare_for_generation()
741722
val_metrics, validation_timings = validate(

0 commit comments

Comments
 (0)