Skip to content

Commit d7e658d

Browse files
committed
added unit tests to make sure offline and online distillation is loading the correct models
1 parent 9a7b2e4 commit d7e658d

1 file changed

Lines changed: 157 additions & 0 deletions

File tree

tests/unit/train_distill_test.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,163 @@ def test_post_process_train_step(self):
537537
values_list = mock_buffer.additional_metrics["distill/kl_div"][0]
538538
self.assertEqual(values_list[0], 0.5)
539539

540+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator")
541+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer")
542+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator")
543+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model")
544+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer")
545+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh")
546+
@mock.patch("maxtext.configs.pyconfig.initialize")
547+
def test_main_offline_mode_skips_teacher_loading(
548+
self,
549+
mock_pyconfig_init,
550+
mock_create_mesh,
551+
mock_build_tokenizer,
552+
mock_get_model,
553+
mock_create_iterator,
554+
mock_trainer_cls,
555+
mock_offline_iter_cls,
556+
):
557+
"""Verifies offline mode (offline_data_dir is set) skips teacher model loading."""
558+
# 1. Configs
559+
mock_global = mock.Mock()
560+
mock_global.student_overrides = {}
561+
mock_global.teacher_overrides = {} # No checkpoint needed
562+
mock_global.offline_data_dir = "gs://bucket/data" # Triggers offline mode
563+
564+
mock_student_cfg = mock.Mock()
565+
mock_student_cfg.vocab_size = 32000
566+
mock_student_cfg.mesh_axes = ("data",)
567+
mock_student_cfg.dataset_type = "grain"
568+
569+
# Add dummy numbers for optimizer math
570+
mock_student_cfg.learning_rate = 1e-4
571+
mock_student_cfg.warmup_steps_fraction = 0.1
572+
mock_student_cfg.learning_rate_final_fraction = 0.1
573+
mock_student_cfg.steps = 100
574+
mock_student_cfg.checkpoint_period = 10
575+
mock_student_cfg.gradient_clipping_threshold = 0.0
576+
mock_student_cfg.eval_interval = -1
577+
578+
# Add dummy numbers for strategy math/logic
579+
mock_student_cfg.distill_temperature = 1.0
580+
mock_student_cfg.distill_alpha = 0.5
581+
mock_student_cfg.distill_beta = 0.0
582+
mock_student_cfg.distill_layer_indices = None
583+
mock_student_cfg.use_sft = False
584+
mock_student_cfg.enable_dropout = False
585+
586+
# Add dummy variables for Checkpointer and Logger
587+
mock_student_cfg.max_num_checkpoints_to_keep = 1
588+
mock_student_cfg.async_checkpointing = False
589+
mock_student_cfg.profiler = "none"
590+
mock_student_cfg.tensorboard_dir = ""
591+
mock_student_cfg.checkpoint_dir = ""
592+
mock_student_cfg.log_period = 10
593+
mock_student_cfg.save_checkpoint_on_completion = False
594+
mock_student_cfg.logical_axis_rules = []
595+
596+
mock_teacher_cfg = mock.Mock()
597+
mock_teacher_cfg.vocab_size = 32000
598+
mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg]
599+
600+
# 2. Model Loading
601+
mock_student_model = mock.Mock()
602+
mock_get_model.return_value = mock_student_model
603+
604+
# 3. Tokenizer & Data Iterator
605+
mock_build_tokenizer.return_value = mock.Mock(pad_id=0)
606+
mock_create_iterator.return_value = (None, None)
607+
608+
train_distill.main(["train_distill.py", "config.yml"])
609+
610+
# 4. Assertions
611+
# checking to ensure get_maxtext_model is only called once for student and not for teacher
612+
mock_get_model.assert_called_once_with(mock_student_cfg, mock.ANY)
613+
614+
trainer_init_kwargs = mock_trainer_cls.call_args.kwargs
615+
model_bundle = trainer_init_kwargs["model"]
616+
# check that student model is set but teacher model is None since offline mode should skip loading teacher
617+
self.assertIs(model_bundle.student_model, mock_student_model)
618+
self.assertIsNone(model_bundle.teacher_model)
619+
620+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer")
621+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.input_pipeline_interface.create_data_iterator")
622+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.get_maxtext_model")
623+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer")
624+
@mock.patch("maxtext.trainers.post_train.distillation.train_distill.maxtext_utils.create_device_mesh")
625+
@mock.patch("maxtext.configs.pyconfig.initialize")
626+
def test_main_online_mode_loads_teacher(
627+
self,
628+
mock_pyconfig_init,
629+
mock_create_mesh,
630+
mock_build_tokenizer,
631+
mock_get_model,
632+
mock_create_iterator,
633+
mock_trainer_cls,
634+
):
635+
"""Verifies online mode (offline_data_dir is None) loads both student and teacher models."""
636+
mock_global = mock.Mock()
637+
mock_global.student_overrides = {}
638+
mock_global.teacher_overrides = {"load_parameters_path": "gs://ckpt"}
639+
mock_global.offline_data_dir = None # Triggers online mode
640+
641+
mock_student_cfg = mock.Mock()
642+
mock_student_cfg.vocab_size = 32000
643+
mock_student_cfg.mesh_axes = ("data",)
644+
mock_student_cfg.dataset_type = "grain"
645+
646+
# Add dummy numbers for optimizer math
647+
mock_student_cfg.learning_rate = 1e-4
648+
mock_student_cfg.warmup_steps_fraction = 0.1
649+
mock_student_cfg.learning_rate_final_fraction = 0.1
650+
mock_student_cfg.steps = 100
651+
mock_student_cfg.checkpoint_period = 10
652+
mock_student_cfg.gradient_clipping_threshold = 0.0
653+
mock_student_cfg.eval_interval = -1
654+
655+
# Add dummy numbers for strategy math/logic
656+
mock_student_cfg.distill_temperature = 1.0
657+
mock_student_cfg.distill_alpha = 0.5
658+
mock_student_cfg.distill_beta = 0.0
659+
mock_student_cfg.distill_layer_indices = None
660+
mock_student_cfg.use_sft = False
661+
mock_student_cfg.enable_dropout = False
662+
663+
# Add dummy variables for Checkpointer and Logger
664+
mock_student_cfg.max_num_checkpoints_to_keep = 1
665+
mock_student_cfg.async_checkpointing = False
666+
mock_student_cfg.profiler = "none"
667+
mock_student_cfg.tensorboard_dir = ""
668+
mock_student_cfg.checkpoint_dir = ""
669+
mock_student_cfg.log_period = 10
670+
mock_student_cfg.save_checkpoint_on_completion = False
671+
mock_student_cfg.logical_axis_rules = []
672+
673+
mock_teacher_cfg = mock.Mock()
674+
mock_teacher_cfg.vocab_size = 32000
675+
mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg]
676+
677+
mock_student_model = mock.Mock()
678+
mock_teacher_model = mock.Mock()
679+
mock_get_model.side_effect = [mock_student_model, mock_teacher_model]
680+
681+
mock_build_tokenizer.return_value = mock.Mock(pad_id=0)
682+
mock_create_iterator.return_value = (mock.Mock(), mock.Mock())
683+
684+
train_distill.main(["train_distill.py", "config.yml"])
685+
686+
# checking to ensure get_maxtext_model is called for both student and teacher since online mode should load both
687+
self.assertEqual(mock_get_model.call_count, 2)
688+
mock_get_model.assert_any_call(mock_student_cfg, mock.ANY)
689+
mock_get_model.assert_any_call(mock_teacher_cfg, mock.ANY)
690+
691+
trainer_init_kwargs = mock_trainer_cls.call_args.kwargs
692+
model_bundle = trainer_init_kwargs["model"]
693+
# check that both student and teacher models are set since online mode should load both
694+
self.assertIs(model_bundle.student_model, mock_student_model)
695+
self.assertIs(model_bundle.teacher_model, mock_teacher_model)
696+
540697

541698
if __name__ == "__main__":
542699
absltest.main()

0 commit comments

Comments
 (0)