1616from typing import Any , NotRequired , Optional , TypedDict , TypeVar , cast
1717
1818import numpy as np
19- import ray
2019import torch
2120from pydantic import BaseModel
2221from torchdata .stateful_dataloader import StatefulDataLoader
2322from transformers import AutoConfig , AutoTokenizer
2423from transformers .tokenization_utils_base import PreTrainedTokenizerBase
2524
26- from nemo_rl .algorithms .grpo import _should_use_async_rollouts , refit_policy_generation
25+ from nemo_rl .algorithms .grpo import _should_use_async_rollouts
2726from nemo_rl .algorithms .loss import (
2827 DistillationLossConfig ,
2928 DistillationLossDataDict ,
6463)
6564from nemo_rl .utils .nsys import maybe_gpu_profile_step
6665from nemo_rl .utils .timer import TimeoutChecker , Timer
66+ from nemo_rl .weight_sync import WeightSynchronizer , create_weight_synchronizer
6767
6868# ===============================================================================
6969# Configuration
@@ -165,6 +165,7 @@ def setup(
165165 ColocatablePolicyInterface , # student_policy
166166 ColocatablePolicyInterface , # teacher_policy
167167 Optional [GenerationInterface ], # student_generation
168+ Optional [WeightSynchronizer ], # weight_sync
168169 StatefulDataLoader ,
169170 Optional [StatefulDataLoader ],
170171 DistillationLossFn ,
@@ -176,7 +177,7 @@ def setup(
176177 """Main entry point for distillation algorithm.
177178
178179 Returns:
179- tuple of student_policy, teacher_policy, student_generation,
180+ tuple of student_policy, teacher_policy, student_generation, weight_sync,
180181 train_dataloader, val_dataloader,
181182 loss_fn, logger, checkpointer, distillation_save_state, master_config
182183 """
@@ -455,26 +456,18 @@ def setup(
455456 init_reference_model = False ,
456457 )
457458
459+ # Create weight synchronizer and initialize communication channels
460+ weight_sync : Optional [WeightSynchronizer ] = None
458461 if student_generation is not None :
459- state_dict_info = student_policy .prepare_refit_info ()
460- student_generation .prepare_refit_info (state_dict_info )
461-
462- # if it is not colocated inference, initialize collective communication for update weights
463- if not colocated_inference :
464- ip , port = train_cluster .get_master_address_and_port ()
465- print (f"Using ip: { ip } , port: { port } for collective communication" , flush = True )
466- train_world_size = train_cluster .world_size ()
467- # inference cluster + head node of the train cluster
468- world_size = train_world_size + inference_nodes * inference_gpus_per_node
469- # init collective
470- futures_train = student_policy .init_collective (
471- ip , port , world_size , train_world_size = train_world_size
462+ weight_sync = create_weight_synchronizer (
463+ policy = student_policy ,
464+ generation = student_generation ,
465+ generation_backend = backend ,
466+ colocated = colocated_inference ,
467+ train_cluster = train_cluster if not colocated_inference else None ,
468+ inference_cluster = inference_cluster if not colocated_inference else None ,
472469 )
473- futures_inference = student_generation .init_collective (
474- ip , port , world_size , train_world_size = train_world_size
475- ) # type: ignore
476- # wait for all futures to complete
477- ray .get (futures_train + futures_inference )
470+ weight_sync .init_communicator ()
478471
479472 loss_fn = DistillationLossFn (loss_config )
480473
@@ -486,6 +479,7 @@ def setup(
486479 student_policy ,
487480 teacher_policy ,
488481 student_generation ,
482+ weight_sync ,
489483 dataloader ,
490484 val_dataloader ,
491485 loss_fn ,
@@ -515,6 +509,7 @@ def distillation_train(
515509 checkpointer : CheckpointManager ,
516510 distillation_save_state : DistillationSaveState ,
517511 master_config : MasterConfig ,
512+ weight_sync : Optional [WeightSynchronizer ] = None ,
518513) -> None :
519514 """Run Distillation training algorithm."""
520515 timer = Timer ()
@@ -524,13 +519,10 @@ def distillation_train(
524519 )
525520 timeout .start_iterations ()
526521
527- NEED_REFIT = True
528522 # If student_generation is None, use the student_policy as the generation interface (megatron framework backend)
529523 if student_generation is None :
530524 student_generation = student_policy # type: ignore
531- NEED_REFIT = False
532- POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
533- assert student_generation is not None
525+ assert student_generation is not None # for mypy type check
534526
535527 # common config/state items
536528 current_epoch = distillation_save_state ["current_epoch" ] # current epoch
@@ -556,11 +548,8 @@ def distillation_train(
556548 # Run validation at the start if configured
557549 if val_at_start and total_steps == 0 :
558550 print ("\n 🔍 Running initial validation..." , flush = True )
559- if NEED_REFIT and POLICY_GENERATION_STALE :
560- refit_policy_generation (
561- student_policy , student_generation , colocated_inference
562- )
563- POLICY_GENERATION_STALE = False
551+ if weight_sync is not None and weight_sync .is_stale :
552+ weight_sync .sync_weights ()
564553 else :
565554 student_generation .prepare_for_generation ()
566555 val_metrics , validation_timings = validate (
@@ -611,14 +600,8 @@ def distillation_train(
611600 flush = True ,
612601 )
613602 with timer .time ("prepare_for_generation" ):
614- if NEED_REFIT and POLICY_GENERATION_STALE :
615- refit_policy_generation (
616- student_policy ,
617- student_generation ,
618- colocated_inference ,
619- timer = timer ,
620- )
621- POLICY_GENERATION_STALE = False
603+ if weight_sync is not None and weight_sync .is_stale :
604+ weight_sync .sync_weights (timer = timer )
622605 else :
623606 student_generation .prepare_for_generation ()
624607
@@ -712,7 +695,8 @@ def distillation_train(
712695 with timer .time ("training_prep" ):
713696 teacher_policy .offload_after_refit ()
714697 student_policy .prepare_for_training () # set model train and reload optim to GPU
715- POLICY_GENERATION_STALE = True
698+ if weight_sync is not None :
699+ weight_sync .mark_stale ()
716700
717701 print ("▶ Training policy..." , flush = True )
718702 with timer .time ("policy_training" ):
@@ -731,11 +715,8 @@ def distillation_train(
731715 if (val_period > 0 and (total_steps + 1 ) % val_period == 0 ) or (
732716 val_at_end and is_last_step
733717 ):
734- if NEED_REFIT and POLICY_GENERATION_STALE :
735- refit_policy_generation (
736- student_policy , student_generation , colocated_inference
737- )
738- POLICY_GENERATION_STALE = False
718+ if weight_sync is not None and weight_sync .is_stale :
719+ weight_sync .sync_weights ()
739720 else :
740721 student_generation .prepare_for_generation ()
741722 val_metrics , validation_timings = validate (
0 commit comments