2323from transformers import AutoConfig , AutoTokenizer
2424from transformers .tokenization_utils_base import PreTrainedTokenizerBase
2525
26- from nemo_rl .algorithms .grpo import _should_use_async_rollouts , refit_policy_generation
26+ from nemo_rl .algorithms .grpo import _should_use_async_rollouts
27+ from nemo_rl .weight_sync import WeightSynchronizer , create_weight_synchronizer
2728from nemo_rl .algorithms .loss import (
2829 DistillationLossConfig ,
2930 DistillationLossDataDict ,
@@ -164,6 +165,7 @@ def setup(
164165 ColocatablePolicyInterface , # student_policy
165166 ColocatablePolicyInterface , # teacher_policy
166167 Optional [GenerationInterface ], # student_generation
168+ Optional [WeightSynchronizer ], # weight_sync
167169 StatefulDataLoader ,
168170 Optional [StatefulDataLoader ],
169171 DistillationLossFn ,
@@ -175,7 +177,7 @@ def setup(
175177 """Main entry point for distillation algorithm.
176178
177179 Returns:
178- tuple of student_policy, teacher_policy, student_generation,
180+ tuple of student_policy, teacher_policy, student_generation, weight_sync,
179181 train_dataloader, val_dataloader,
180182 loss_fn, logger, checkpointer, distillation_save_state, master_config
181183 """
@@ -457,26 +459,18 @@ def setup(
457459 init_reference_model = False ,
458460 )
459461
462+ # Create weight synchronizer and initialize communication channels
463+ weight_sync : Optional [WeightSynchronizer ] = None
460464 if student_generation is not None :
461- state_dict_info = student_policy .prepare_refit_info ()
462- student_generation .prepare_refit_info (state_dict_info )
463-
464- # if it is not colocated inference, initialize collective communication for update weights
465- if not colocated_inference :
466- ip , port = train_cluster .get_master_address_and_port ()
467- print (f"Using ip: { ip } , port: { port } for collective communication" , flush = True )
468- train_world_size = train_cluster .world_size ()
469- # inference cluster + head node of the train cluster
470- world_size = train_world_size + inference_nodes * inference_gpus_per_node
471- # init collective
472- futures_train = student_policy .init_collective (
473- ip , port , world_size , train_world_size = train_world_size
465+ weight_sync = create_weight_synchronizer (
466+ policy = student_policy ,
467+ generation = student_generation ,
468+ generation_backend = backend ,
469+ colocated = colocated_inference ,
470+ train_cluster = train_cluster if not colocated_inference else None ,
471+ inference_cluster = inference_cluster if not colocated_inference else None ,
474472 )
475- futures_inference = student_generation .init_collective (
476- ip , port , world_size , train_world_size = train_world_size
477- ) # type: ignore
478- # wait for all futures to complete
479- ray .get (futures_train + futures_inference )
473+ weight_sync .init_communicator ()
480474
481475 loss_fn = DistillationLossFn (loss_config )
482476
@@ -488,6 +482,7 @@ def setup(
488482 student_policy ,
489483 teacher_policy ,
490484 student_generation ,
485+ weight_sync ,
491486 dataloader ,
492487 val_dataloader ,
493488 loss_fn ,
@@ -517,6 +512,7 @@ def distillation_train(
517512 checkpointer : CheckpointManager ,
518513 distillation_save_state : DistillationSaveState ,
519514 master_config : MasterConfig ,
515+ weight_sync : Optional [WeightSynchronizer ] = None ,
520516) -> None :
521517 """Run Distillation training algorithm."""
522518 timer = Timer ()
@@ -526,12 +522,9 @@ def distillation_train(
526522 )
527523 timeout .start_iterations ()
528524
529- NEED_REFIT = True
530525 # If student_generation is None, use the student_policy as the generation interface (megatron framework backend)
531526 if student_generation is None :
532527 student_generation = student_policy # type: ignore
533- NEED_REFIT = False
534- POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
535528 assert student_generation is not None # for mypy type check
536529
537530 # common config/state items
@@ -558,11 +551,8 @@ def distillation_train(
558551 # Run validation at the start if configured
559552 if val_at_start and total_steps == 0 :
560553 print ("\n 🔍 Running initial validation..." , flush = True )
561- if NEED_REFIT and POLICY_GENERATION_STALE :
562- refit_policy_generation (
563- student_policy , student_generation , colocated_inference
564- )
565- POLICY_GENERATION_STALE = False
554+ if weight_sync is not None and weight_sync .is_stale :
555+ weight_sync .sync_weights ()
566556 else :
567557 student_generation .prepare_for_generation ()
568558 val_metrics , validation_timings = validate (
@@ -613,14 +603,8 @@ def distillation_train(
613603 flush = True ,
614604 )
615605 with timer .time ("prepare_for_generation" ):
616- if NEED_REFIT and POLICY_GENERATION_STALE :
617- refit_policy_generation (
618- student_policy ,
619- student_generation ,
620- colocated_inference ,
621- timer = timer ,
622- )
623- POLICY_GENERATION_STALE = False
606+ if weight_sync is not None and weight_sync .is_stale :
607+ weight_sync .sync_weights (timer = timer )
624608 else :
625609 student_generation .prepare_for_generation ()
626610
@@ -714,7 +698,8 @@ def distillation_train(
714698 with timer .time ("training_prep" ):
715699 teacher_policy .offload_after_refit ()
716700 student_policy .prepare_for_training () # set model train and reload optim to GPU
717- POLICY_GENERATION_STALE = True
701+ if weight_sync is not None :
702+ weight_sync .mark_stale ()
718703
719704 print ("▶ Training policy..." , flush = True )
720705 with timer .time ("policy_training" ):
@@ -733,11 +718,8 @@ def distillation_train(
733718 if (val_period > 0 and (total_steps + 1 ) % val_period == 0 ) or (
734719 val_at_end and is_last_step
735720 ):
736- if NEED_REFIT and POLICY_GENERATION_STALE :
737- refit_policy_generation (
738- student_policy , student_generation , colocated_inference
739- )
740- POLICY_GENERATION_STALE = False
721+ if weight_sync is not None and weight_sync .is_stale :
722+ weight_sync .sync_weights ()
741723 else :
742724 student_generation .prepare_for_generation ()
743725 val_metrics , validation_timings = validate (
0 commit comments