Skip to content

Commit 616f2b0

Browse files
committed
feat: Wire WeightSynchronizer into algorithm layer, replacing inline refit logic
Replace refit_policy_generation() calls and NEED_REFIT/POLICY_GENERATION_STALE flags in grpo.py and distillation.py with WeightSynchronizer method calls (sync_weights, mark_stale, is_stale). The setup() functions now create and initialize the appropriate WeightSynchronizer and return it in the tuple. Key changes: - grpo.py setup(): create WeightSynchronizer via factory, replace inline init_collective/prepare_refit_info with weight_sync.init_communicator() - grpo_train/async_grpo_train: accept weight_sync param, use it for all weight transfer instead of refit_policy_generation() - distillation.py: same treatment for distillation_train() - refit_policy_generation(): kept with deprecation warning for external users - Factory: keep NotImplementedError for non-colocated SGLang (SGLang's update_weights_from_collective() is a no-op, would silently skip transfer) - All example scripts updated to thread weight_sync through Signed-off-by: Saurabh Mishra <sauramishra@nvidia.com>
1 parent bff7286 commit 616f2b0

8 files changed

Lines changed: 103 additions & 107 deletions

File tree

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def main() -> None:
196196
policy,
197197
policy_generation,
198198
cluster,
199+
weight_sync,
199200
dataloader,
200201
val_dataloader,
201202
loss_fn,
@@ -282,6 +283,7 @@ def main() -> None:
282283
grpo_save_state=grpo_state,
283284
master_config=master_config,
284285
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
286+
weight_sync=weight_sync,
285287
)
286288
else:
287289
print("🚀 Running synchronous GRPO training")
@@ -300,6 +302,7 @@ def main() -> None:
300302
checkpointer,
301303
grpo_state,
302304
master_config,
305+
weight_sync=weight_sync,
303306
)
304307

305308

examples/run_grpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _make_policy(**kwargs):
134134
policy,
135135
policy_generation,
136136
cluster,
137+
weight_sync,
137138
dataloader,
138139
val_dataloader,
139140
loss_fn,
@@ -199,6 +200,7 @@ def _make_policy(**kwargs):
199200
grpo_save_state=grpo_state,
200201
master_config=master_config,
201202
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
203+
weight_sync=weight_sync,
202204
)
203205
else:
204206
# Two parallel synchronous trainers (verl-style — main_ppo.py vs
@@ -217,6 +219,7 @@ def _make_policy(**kwargs):
217219
checkpointer,
218220
grpo_state,
219221
master_config,
222+
weight_sync=weight_sync,
220223
)
221224

222225

examples/run_grpo_sliding_puzzle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def main():
254254
policy,
255255
policy_generation,
256256
cluster,
257+
weight_sync,
257258
dataloader,
258259
val_dataloader,
259260
loss_fn,
@@ -276,6 +277,7 @@ def main():
276277
checkpointer,
277278
grpo_state,
278279
master_config,
280+
weight_sync=weight_sync,
279281
)
280282

281283

examples/run_vlm_grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def main() -> None:
108108
policy,
109109
policy_generation,
110110
cluster,
111+
weight_sync,
111112
dataloader,
112113
val_dataloader,
113114
loss_fn,
@@ -130,6 +131,7 @@ def main() -> None:
130131
checkpointer,
131132
grpo_state,
132133
master_config,
134+
weight_sync=weight_sync,
133135
)
134136

135137

nemo_rl/algorithms/distillation.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast
1717

1818
import numpy as np
19-
import ray
2019
import torch
2120
from pydantic import BaseModel
2221
from torchdata.stateful_dataloader import StatefulDataLoader
2322
from transformers import AutoConfig, AutoTokenizer
2423
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2524

26-
from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation
25+
from nemo_rl.algorithms.grpo import _should_use_async_rollouts
2726
from nemo_rl.algorithms.loss import (
2827
DistillationLossConfig,
2928
DistillationLossDataDict,
@@ -64,6 +63,7 @@
6463
)
6564
from nemo_rl.utils.nsys import maybe_gpu_profile_step
6665
from nemo_rl.utils.timer import TimeoutChecker, Timer
66+
from nemo_rl.weight_sync import WeightSynchronizer, create_weight_synchronizer
6767

6868
# ===============================================================================
6969
# Configuration
@@ -165,6 +165,7 @@ def setup(
165165
ColocatablePolicyInterface, # student_policy
166166
ColocatablePolicyInterface, # teacher_policy
167167
Optional[GenerationInterface], # student_generation
168+
Optional[WeightSynchronizer], # weight_sync
168169
StatefulDataLoader,
169170
Optional[StatefulDataLoader],
170171
DistillationLossFn,
@@ -176,7 +177,7 @@ def setup(
176177
"""Main entry point for distillation algorithm.
177178
178179
Returns:
179-
tuple of student_policy, teacher_policy, student_generation,
180+
tuple of student_policy, teacher_policy, student_generation, weight_sync,
180181
train_dataloader, val_dataloader,
181182
loss_fn, logger, checkpointer, distillation_save_state, master_config
182183
"""
@@ -455,26 +456,18 @@ def setup(
455456
init_reference_model=False,
456457
)
457458

459+
# Create weight synchronizer and initialize communication channels
460+
weight_sync: Optional[WeightSynchronizer] = None
458461
if student_generation is not None:
459-
state_dict_info = student_policy.prepare_refit_info()
460-
student_generation.prepare_refit_info(state_dict_info)
461-
462-
# if it is not colocated inference, initialize collective communication for update weights
463-
if not colocated_inference:
464-
ip, port = train_cluster.get_master_address_and_port()
465-
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
466-
train_world_size = train_cluster.world_size()
467-
# inference cluster + head node of the train cluster
468-
world_size = train_world_size + inference_nodes * inference_gpus_per_node
469-
# init collective
470-
futures_train = student_policy.init_collective(
471-
ip, port, world_size, train_world_size=train_world_size
462+
weight_sync = create_weight_synchronizer(
463+
policy=student_policy,
464+
generation=student_generation,
465+
generation_backend=backend,
466+
colocated=colocated_inference,
467+
train_cluster=train_cluster if not colocated_inference else None,
468+
inference_cluster=inference_cluster if not colocated_inference else None,
472469
)
473-
futures_inference = student_generation.init_collective(
474-
ip, port, world_size, train_world_size=train_world_size
475-
) # type: ignore
476-
# wait for all futures to complete
477-
ray.get(futures_train + futures_inference)
470+
weight_sync.init_communicator()
478471

479472
loss_fn = DistillationLossFn(loss_config)
480473

@@ -486,6 +479,7 @@ def setup(
486479
student_policy,
487480
teacher_policy,
488481
student_generation,
482+
weight_sync,
489483
dataloader,
490484
val_dataloader,
491485
loss_fn,
@@ -515,6 +509,7 @@ def distillation_train(
515509
checkpointer: CheckpointManager,
516510
distillation_save_state: DistillationSaveState,
517511
master_config: MasterConfig,
512+
weight_sync: Optional[WeightSynchronizer] = None,
518513
) -> None:
519514
"""Run Distillation training algorithm."""
520515
timer = Timer()
@@ -524,13 +519,10 @@ def distillation_train(
524519
)
525520
timeout.start_iterations()
526521

527-
NEED_REFIT = True
528522
# If student_generation is None, use the student_policy as the generation interface (megatron framework backend)
529523
if student_generation is None:
530524
student_generation = student_policy # type: ignore
531-
NEED_REFIT = False
532-
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
533-
assert student_generation is not None
525+
assert student_generation is not None # for mypy type check
534526

535527
# common config/state items
536528
current_epoch = distillation_save_state["current_epoch"] # current epoch
@@ -556,11 +548,8 @@ def distillation_train(
556548
# Run validation at the start if configured
557549
if val_at_start and total_steps == 0:
558550
print("\n🔍 Running initial validation...", flush=True)
559-
if NEED_REFIT and POLICY_GENERATION_STALE:
560-
refit_policy_generation(
561-
student_policy, student_generation, colocated_inference
562-
)
563-
POLICY_GENERATION_STALE = False
551+
if weight_sync is not None and weight_sync.is_stale:
552+
weight_sync.sync_weights()
564553
else:
565554
student_generation.prepare_for_generation()
566555
val_metrics, validation_timings = validate(
@@ -611,14 +600,8 @@ def distillation_train(
611600
flush=True,
612601
)
613602
with timer.time("prepare_for_generation"):
614-
if NEED_REFIT and POLICY_GENERATION_STALE:
615-
refit_policy_generation(
616-
student_policy,
617-
student_generation,
618-
colocated_inference,
619-
timer=timer,
620-
)
621-
POLICY_GENERATION_STALE = False
603+
if weight_sync is not None and weight_sync.is_stale:
604+
weight_sync.sync_weights(timer=timer)
622605
else:
623606
student_generation.prepare_for_generation()
624607

@@ -712,7 +695,8 @@ def distillation_train(
712695
with timer.time("training_prep"):
713696
teacher_policy.offload_after_refit()
714697
student_policy.prepare_for_training() # set model train and reload optim to GPU
715-
POLICY_GENERATION_STALE = True
698+
if weight_sync is not None:
699+
weight_sync.mark_stale()
716700

717701
print("▶ Training policy...", flush=True)
718702
with timer.time("policy_training"):
@@ -731,11 +715,8 @@ def distillation_train(
731715
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
732716
val_at_end and is_last_step
733717
):
734-
if NEED_REFIT and POLICY_GENERATION_STALE:
735-
refit_policy_generation(
736-
student_policy, student_generation, colocated_inference
737-
)
738-
POLICY_GENERATION_STALE = False
718+
if weight_sync is not None and weight_sync.is_stale:
719+
weight_sync.sync_weights()
739720
else:
740721
student_generation.prepare_for_generation()
741722
val_metrics, validation_timings = validate(

0 commit comments

Comments
 (0)