@@ -891,7 +891,8 @@ def side_effect(self, *args, **kwargs):
891891 self .assertTrue (any (c == "1" or c .endswith ("1" ) for c in checkpoints ), f"Checkpoint 1 not found in { checkpoints } " )
892892 self .assertTrue (any (c == "2" or c .endswith ("2" ) for c in checkpoints ), f"Checkpoint 2 not found in { checkpoints } " )
893893
894- def test_checkpointing_and_resume (self ):
894+ @mock .patch .object (distillation_utils , "calculate_distillation_tflops_per_device" , return_value = (0.0 , 0.0 , 0.0 ))
895+ def test_checkpointing_and_resume (self , _mock_tflops ):
895896 """Trains a few steps, saves a checkpoint, and resumes from it."""
896897
897898 # 1. Setup minimal dummy model and models bundle
@@ -941,6 +942,8 @@ def __call__(self, input_tokens, **kwargs):
941942 strategy = strategy ,
942943 optimizer = optimizer1 ,
943944 training_config = train_config ,
945+ student_config = mock .Mock (),
946+ teacher_config = mock .Mock (),
944947 )
945948 trainer1 ._lora_enabled = False
946949 trainer1 .is_managed_externally = True
@@ -989,6 +992,8 @@ def __call__(self, input_tokens, **kwargs):
989992 strategy = strategy ,
990993 optimizer = optimizer2 ,
991994 training_config = train_config ,
995+ student_config = mock .Mock (),
996+ teacher_config = mock .Mock (),
992997 )
993998 trainer2 ._lora_enabled = False
994999
@@ -1083,8 +1088,17 @@ def test_main_offline_mode_skips_teacher_loading(
10831088 mock_student_cfg .save_checkpoint_on_completion = False
10841089 mock_student_cfg .logical_axis_rules = []
10851090
1091+ # main() validates that student/teacher share batch shape — set explicit
1092+ # equal scalars on both mocks so the assertion passes.
1093+ mock_student_cfg .per_device_batch_size = 1
1094+ mock_student_cfg .max_target_length = 16
1095+ mock_student_cfg .gradient_accumulation_steps = 1
1096+
10861097 mock_teacher_cfg = mock .Mock ()
10871098 mock_teacher_cfg .vocab_size = 32000
1099+ mock_teacher_cfg .per_device_batch_size = 1
1100+ mock_teacher_cfg .max_target_length = 16
1101+ mock_teacher_cfg .gradient_accumulation_steps = 1
10881102 mock_pyconfig_init .side_effect = [mock_global , mock_student_cfg , mock_teacher_cfg ]
10891103
10901104 # 2. Model Loading
@@ -1181,8 +1195,17 @@ def test_main_online_mode_loads_teacher(
11811195 mock_student_cfg .save_checkpoint_on_completion = False
11821196 mock_student_cfg .logical_axis_rules = []
11831197
1198+ # main() validates that student/teacher share batch shape — set explicit
1199+ # equal scalars on both mocks so the assertion passes.
1200+ mock_student_cfg .per_device_batch_size = 1
1201+ mock_student_cfg .max_target_length = 16
1202+ mock_student_cfg .gradient_accumulation_steps = 1
1203+
11841204 mock_teacher_cfg = mock .Mock ()
11851205 mock_teacher_cfg .vocab_size = 32000
1206+ mock_teacher_cfg .per_device_batch_size = 1
1207+ mock_teacher_cfg .max_target_length = 16
1208+ mock_teacher_cfg .gradient_accumulation_steps = 1
11861209 mock_pyconfig_init .side_effect = [mock_global , mock_student_cfg , mock_teacher_cfg ]
11871210
11881211 mock_student_model = mock .Mock ()
@@ -1206,7 +1229,8 @@ def test_main_online_mode_loads_teacher(
12061229 self .assertIs (model_bundle .student_model , mock_student_model )
12071230 self .assertIs (model_bundle .teacher_model , mock_teacher_model )
12081231
1209- def test_student_freeze_param_filter (self ):
1232+ @mock .patch .object (distillation_utils , "calculate_distillation_tflops_per_device" , return_value = (0.0 , 0.0 , 0.0 ))
1233+ def test_student_freeze_param_filter (self , _mock_tflops ):
12101234 """Verifies that student_freeze_param_filter correctly freezes specified parameters."""
12111235
12121236 # 1. Setup a dummy model with multiple layers
@@ -1260,6 +1284,8 @@ def freeze_filter(path):
12601284 strategy = strategy ,
12611285 optimizer = optax .sgd (0.1 ),
12621286 training_config = train_config ,
1287+ student_config = mock .Mock (),
1288+ teacher_config = mock .Mock (),
12631289 student_freeze_param_filter = freeze_filter ,
12641290 )
12651291 trainer ._lora_enabled = False
0 commit comments