Skip to content

Commit 8b5be01

Browse files
committed
feat: Wire WeightSynchronizer into algorithm layer, replacing inline refit logic
Replace refit_policy_generation() calls and NEED_REFIT/POLICY_GENERATION_STALE flags in grpo.py and distillation.py with WeightSynchronizer method calls (sync_weights, mark_stale, is_stale). The setup() functions now create and initialize the appropriate WeightSynchronizer and return it in the tuple. Key changes: - grpo.py setup(): create WeightSynchronizer via factory, replace inline init_collective/prepare_refit_info with weight_sync.init_communicator() - grpo_train/async_grpo_train: accept weight_sync param, use it for all weight transfer instead of refit_policy_generation() - distillation.py: same treatment for distillation_train() - refit_policy_generation(): kept with deprecation warning for external users - Factory: keep NotImplementedError for non-colocated SGLang (SGLang's update_weights_from_collective() is a no-op, would silently skip transfer) - All example scripts updated to thread weight_sync through
1 parent 29fed73 commit 8b5be01

8 files changed

Lines changed: 101 additions & 104 deletions

File tree

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def main() -> None:
195195
policy,
196196
policy_generation,
197197
cluster,
198+
weight_sync,
198199
dataloader,
199200
val_dataloader,
200201
loss_fn,
@@ -281,6 +282,7 @@ def main() -> None:
281282
grpo_save_state=grpo_state,
282283
master_config=master_config,
283284
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
285+
weight_sync=weight_sync,
284286
)
285287
else:
286288
print("🚀 Running synchronous GRPO training")
@@ -299,6 +301,7 @@ def main() -> None:
299301
checkpointer,
300302
grpo_state,
301303
master_config,
304+
weight_sync=weight_sync,
302305
)
303306

304307

examples/run_grpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def main() -> None:
100100
policy,
101101
policy_generation,
102102
cluster,
103+
weight_sync,
103104
dataloader,
104105
val_dataloader,
105106
loss_fn,
@@ -159,6 +160,7 @@ def main() -> None:
159160
grpo_save_state=grpo_state,
160161
master_config=master_config,
161162
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
163+
weight_sync=weight_sync,
162164
)
163165
else:
164166
print("🚀 Running synchronous GRPO training")
@@ -177,6 +179,7 @@ def main() -> None:
177179
checkpointer,
178180
grpo_state,
179181
master_config,
182+
weight_sync=weight_sync,
180183
)
181184

182185

examples/run_grpo_sliding_puzzle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def main():
253253
policy,
254254
policy_generation,
255255
cluster,
256+
weight_sync,
256257
dataloader,
257258
val_dataloader,
258259
loss_fn,
@@ -275,6 +276,7 @@ def main():
275276
checkpointer,
276277
grpo_state,
277278
master_config,
279+
weight_sync=weight_sync,
278280
)
279281

280282

examples/run_vlm_grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def main() -> None:
107107
policy,
108108
policy_generation,
109109
cluster,
110+
weight_sync,
110111
dataloader,
111112
val_dataloader,
112113
loss_fn,
@@ -129,6 +130,7 @@ def main() -> None:
129130
checkpointer,
130131
grpo_state,
131132
master_config,
133+
weight_sync=weight_sync,
132134
)
133135

134136

nemo_rl/algorithms/distillation.py

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from transformers import AutoConfig, AutoTokenizer
2424
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
2525

26-
from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation
26+
from nemo_rl.algorithms.grpo import _should_use_async_rollouts
27+
from nemo_rl.weight_sync import WeightSynchronizer, create_weight_synchronizer
2728
from nemo_rl.algorithms.loss import (
2829
DistillationLossConfig,
2930
DistillationLossDataDict,
@@ -164,6 +165,7 @@ def setup(
164165
ColocatablePolicyInterface, # student_policy
165166
ColocatablePolicyInterface, # teacher_policy
166167
Optional[GenerationInterface], # student_generation
168+
Optional[WeightSynchronizer], # weight_sync
167169
StatefulDataLoader,
168170
Optional[StatefulDataLoader],
169171
DistillationLossFn,
@@ -175,7 +177,7 @@ def setup(
175177
"""Main entry point for distillation algorithm.
176178
177179
Returns:
178-
tuple of student_policy, teacher_policy, student_generation,
180+
tuple of student_policy, teacher_policy, student_generation, weight_sync,
179181
train_dataloader, val_dataloader,
180182
loss_fn, logger, checkpointer, distillation_save_state, master_config
181183
"""
@@ -457,26 +459,18 @@ def setup(
457459
init_reference_model=False,
458460
)
459461

462+
# Create weight synchronizer and initialize communication channels
463+
weight_sync: Optional[WeightSynchronizer] = None
460464
if student_generation is not None:
461-
state_dict_info = student_policy.prepare_refit_info()
462-
student_generation.prepare_refit_info(state_dict_info)
463-
464-
# if it is not colocated inference, initialize collective communication for update weights
465-
if not colocated_inference:
466-
ip, port = train_cluster.get_master_address_and_port()
467-
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
468-
train_world_size = train_cluster.world_size()
469-
# inference cluster + head node of the train cluster
470-
world_size = train_world_size + inference_nodes * inference_gpus_per_node
471-
# init collective
472-
futures_train = student_policy.init_collective(
473-
ip, port, world_size, train_world_size=train_world_size
465+
weight_sync = create_weight_synchronizer(
466+
policy=student_policy,
467+
generation=student_generation,
468+
generation_backend=backend,
469+
colocated=colocated_inference,
470+
train_cluster=train_cluster if not colocated_inference else None,
471+
inference_cluster=inference_cluster if not colocated_inference else None,
474472
)
475-
futures_inference = student_generation.init_collective(
476-
ip, port, world_size, train_world_size=train_world_size
477-
) # type: ignore
478-
# wait for all futures to complete
479-
ray.get(futures_train + futures_inference)
473+
weight_sync.init_communicator()
480474

481475
loss_fn = DistillationLossFn(loss_config)
482476

@@ -488,6 +482,7 @@ def setup(
488482
student_policy,
489483
teacher_policy,
490484
student_generation,
485+
weight_sync,
491486
dataloader,
492487
val_dataloader,
493488
loss_fn,
@@ -517,6 +512,7 @@ def distillation_train(
517512
checkpointer: CheckpointManager,
518513
distillation_save_state: DistillationSaveState,
519514
master_config: MasterConfig,
515+
weight_sync: Optional[WeightSynchronizer] = None,
520516
) -> None:
521517
"""Run Distillation training algorithm."""
522518
timer = Timer()
@@ -526,12 +522,9 @@ def distillation_train(
526522
)
527523
timeout.start_iterations()
528524

529-
NEED_REFIT = True
530525
# If student_generation is None, use the student_policy as the generation interface (megatron framework backend)
531526
if student_generation is None:
532527
student_generation = student_policy # type: ignore
533-
NEED_REFIT = False
534-
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
535528
assert student_generation is not None # for mypy type check
536529

537530
# common config/state items
@@ -558,11 +551,8 @@ def distillation_train(
558551
# Run validation at the start if configured
559552
if val_at_start and total_steps == 0:
560553
print("\n🔍 Running initial validation...", flush=True)
561-
if NEED_REFIT and POLICY_GENERATION_STALE:
562-
refit_policy_generation(
563-
student_policy, student_generation, colocated_inference
564-
)
565-
POLICY_GENERATION_STALE = False
554+
if weight_sync is not None and weight_sync.is_stale:
555+
weight_sync.sync_weights()
566556
else:
567557
student_generation.prepare_for_generation()
568558
val_metrics, validation_timings = validate(
@@ -613,14 +603,8 @@ def distillation_train(
613603
flush=True,
614604
)
615605
with timer.time("prepare_for_generation"):
616-
if NEED_REFIT and POLICY_GENERATION_STALE:
617-
refit_policy_generation(
618-
student_policy,
619-
student_generation,
620-
colocated_inference,
621-
timer=timer,
622-
)
623-
POLICY_GENERATION_STALE = False
606+
if weight_sync is not None and weight_sync.is_stale:
607+
weight_sync.sync_weights(timer=timer)
624608
else:
625609
student_generation.prepare_for_generation()
626610

@@ -714,7 +698,8 @@ def distillation_train(
714698
with timer.time("training_prep"):
715699
teacher_policy.offload_after_refit()
716700
student_policy.prepare_for_training() # set model train and reload optim to GPU
717-
POLICY_GENERATION_STALE = True
701+
if weight_sync is not None:
702+
weight_sync.mark_stale()
718703

719704
print("▶ Training policy...", flush=True)
720705
with timer.time("policy_training"):
@@ -733,11 +718,8 @@ def distillation_train(
733718
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
734719
val_at_end and is_last_step
735720
):
736-
if NEED_REFIT and POLICY_GENERATION_STALE:
737-
refit_policy_generation(
738-
student_policy, student_generation, colocated_inference
739-
)
740-
POLICY_GENERATION_STALE = False
721+
if weight_sync is not None and weight_sync.is_stale:
722+
weight_sync.sync_weights()
741723
else:
742724
student_generation.prepare_for_generation()
743725
val_metrics, validation_timings = validate(

0 commit comments

Comments
 (0)