@@ -537,6 +537,163 @@ def test_post_process_train_step(self):
537537 values_list = mock_buffer .additional_metrics ["distill/kl_div" ][0 ]
538538 self .assertEqual (values_list [0 ], 0.5 )
539539
540+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator" )
541+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer" )
542+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator" )
543+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model" )
544+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer" )
545+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh" )
546+ @mock .patch ("maxtext.configs.pyconfig.initialize" )
547+ def test_main_offline_mode_skips_teacher_loading (
548+ self ,
549+ mock_pyconfig_init ,
550+ mock_create_mesh ,
551+ mock_build_tokenizer ,
552+ mock_get_model ,
553+ mock_create_iterator ,
554+ mock_trainer_cls ,
555+ mock_offline_iter_cls ,
556+ ):
557+ """Verifies offline mode (offline_data_dir is set) skips teacher model loading."""
558+ # 1. Configs
559+ mock_global = mock .Mock ()
560+ mock_global .student_overrides = {}
561+ mock_global .teacher_overrides = {} # No checkpoint needed
562+ mock_global .offline_data_dir = "gs://bucket/data" # Triggers offline mode
563+
564+ mock_student_cfg = mock .Mock ()
565+ mock_student_cfg .vocab_size = 32000
566+ mock_student_cfg .mesh_axes = ("data" ,)
567+ mock_student_cfg .dataset_type = "grain"
568+
569+ # Add dummy numbers for optimizer math
570+ mock_student_cfg .learning_rate = 1e-4
571+ mock_student_cfg .warmup_steps_fraction = 0.1
572+ mock_student_cfg .learning_rate_final_fraction = 0.1
573+ mock_student_cfg .steps = 100
574+ mock_student_cfg .checkpoint_period = 10
575+ mock_student_cfg .gradient_clipping_threshold = 0.0
576+ mock_student_cfg .eval_interval = - 1
577+
578+ # Add dummy numbers for strategy math/logic
579+ mock_student_cfg .distill_temperature = 1.0
580+ mock_student_cfg .distill_alpha = 0.5
581+ mock_student_cfg .distill_beta = 0.0
582+ mock_student_cfg .distill_layer_indices = None
583+ mock_student_cfg .use_sft = False
584+ mock_student_cfg .enable_dropout = False
585+
586+ # Add dummy variables for Checkpointer and Logger
587+ mock_student_cfg .max_num_checkpoints_to_keep = 1
588+ mock_student_cfg .async_checkpointing = False
589+ mock_student_cfg .profiler = "none"
590+ mock_student_cfg .tensorboard_dir = ""
591+ mock_student_cfg .checkpoint_dir = ""
592+ mock_student_cfg .log_period = 10
593+ mock_student_cfg .save_checkpoint_on_completion = False
594+ mock_student_cfg .logical_axis_rules = []
595+
596+ mock_teacher_cfg = mock .Mock ()
597+ mock_teacher_cfg .vocab_size = 32000
598+ mock_pyconfig_init .side_effect = [mock_global , mock_student_cfg , mock_teacher_cfg ]
599+
600+ # 2. Model Loading
601+ mock_student_model = mock .Mock ()
602+ mock_get_model .return_value = mock_student_model
603+
604+ # 3. Tokenizer & Data Iterator
605+ mock_build_tokenizer .return_value = mock .Mock (pad_id = 0 )
606+ mock_create_iterator .return_value = (None , None )
607+
608+ train_distill .main (["train_distill.py" , "config.yml" ])
609+
610+ # 4. Assertions
611+ # checking to ensure get_maxtext_model is only called once for student and not for teacher
612+ mock_get_model .assert_called_once_with (mock_student_cfg , mock .ANY )
613+
614+ trainer_init_kwargs = mock_trainer_cls .call_args .kwargs
615+ model_bundle = trainer_init_kwargs ["model" ]
616+ # check that student model is set but teacher model is None since offline mode should skip loading teacher
617+ self .assertIs (model_bundle .student_model , mock_student_model )
618+ self .assertIsNone (model_bundle .teacher_model )
619+
620+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer" )
621+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator" )
622+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model" )
623+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer" )
624+ @mock .patch ("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh" )
625+ @mock .patch ("maxtext.configs.pyconfig.initialize" )
626+ def test_main_online_mode_loads_teacher (
627+ self ,
628+ mock_pyconfig_init ,
629+ mock_create_mesh ,
630+ mock_build_tokenizer ,
631+ mock_get_model ,
632+ mock_create_iterator ,
633+ mock_trainer_cls ,
634+ ):
635+ """Verifies online mode (offline_data_dir is None) loads both student and teacher models."""
636+ mock_global = mock .Mock ()
637+ mock_global .student_overrides = {}
638+ mock_global .teacher_overrides = {"load_parameters_path" : "gs://ckpt" }
639+ mock_global .offline_data_dir = None # Triggers online mode
640+
641+ mock_student_cfg = mock .Mock ()
642+ mock_student_cfg .vocab_size = 32000
643+ mock_student_cfg .mesh_axes = ("data" ,)
644+ mock_student_cfg .dataset_type = "grain"
645+
646+ # Add dummy numbers for optimizer math
647+ mock_student_cfg .learning_rate = 1e-4
648+ mock_student_cfg .warmup_steps_fraction = 0.1
649+ mock_student_cfg .learning_rate_final_fraction = 0.1
650+ mock_student_cfg .steps = 100
651+ mock_student_cfg .checkpoint_period = 10
652+ mock_student_cfg .gradient_clipping_threshold = 0.0
653+ mock_student_cfg .eval_interval = - 1
654+
655+ # Add dummy numbers for strategy math/logic
656+ mock_student_cfg .distill_temperature = 1.0
657+ mock_student_cfg .distill_alpha = 0.5
658+ mock_student_cfg .distill_beta = 0.0
659+ mock_student_cfg .distill_layer_indices = None
660+ mock_student_cfg .use_sft = False
661+ mock_student_cfg .enable_dropout = False
662+
663+ # Add dummy variables for Checkpointer and Logger
664+ mock_student_cfg .max_num_checkpoints_to_keep = 1
665+ mock_student_cfg .async_checkpointing = False
666+ mock_student_cfg .profiler = "none"
667+ mock_student_cfg .tensorboard_dir = ""
668+ mock_student_cfg .checkpoint_dir = ""
669+ mock_student_cfg .log_period = 10
670+ mock_student_cfg .save_checkpoint_on_completion = False
671+ mock_student_cfg .logical_axis_rules = []
672+
673+ mock_teacher_cfg = mock .Mock ()
674+ mock_teacher_cfg .vocab_size = 32000
675+ mock_pyconfig_init .side_effect = [mock_global , mock_student_cfg , mock_teacher_cfg ]
676+
677+ mock_student_model = mock .Mock ()
678+ mock_teacher_model = mock .Mock ()
679+ mock_get_model .side_effect = [mock_student_model , mock_teacher_model ]
680+
681+ mock_build_tokenizer .return_value = mock .Mock (pad_id = 0 )
682+ mock_create_iterator .return_value = (mock .Mock (), mock .Mock ())
683+
684+ train_distill .main (["train_distill.py" , "config.yml" ])
685+
686+ # checking to ensure get_maxtext_model is called for both student and teacher since online mode should load both
687+ self .assertEqual (mock_get_model .call_count , 2 )
688+ mock_get_model .assert_any_call (mock_student_cfg , mock .ANY )
689+ mock_get_model .assert_any_call (mock_teacher_cfg , mock .ANY )
690+
691+ trainer_init_kwargs = mock_trainer_cls .call_args .kwargs
692+ model_bundle = trainer_init_kwargs ["model" ]
693+ # check that both student and teacher models are set since online mode should load both
694+ self .assertIs (model_bundle .student_model , mock_student_model )
695+ self .assertIs (model_bundle .teacher_model , mock_teacher_model )
696+
540697
541698if __name__ == "__main__" :
542699 absltest .main ()
0 commit comments