File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1064,6 +1064,8 @@ def test_main_offline_mode_skips_teacher_loading(
10641064 ):
10651065 """Verifies offline mode (offline_data_dir is set) skips teacher model loading."""
10661066 # 1. Configs
1067+ devices = jax .devices ()[:1 ]
1068+ mock_create_mesh .return_value = np .array (devices )
10671069 mock_global = mock .Mock ()
10681070 mock_global .student_overrides = {}
10691071 mock_global .teacher_overrides = {} # No checkpoint needed
@@ -1172,6 +1174,8 @@ def test_main_online_mode_loads_teacher(
11721174 mock_trainer_cls ,
11731175 ):
11741176 """Verifies online mode (offline_data_dir is None) loads both student and teacher models."""
1177+ devices = jax .devices ()[:1 ]
1178+ mock_create_mesh .return_value = np .array (devices )
11751179 mock_global = mock .Mock ()
11761180 mock_global .student_overrides = {}
11771181 mock_global .teacher_overrides = {"load_parameters_path" : "gs://ckpt" }
You can’t perform that action at this time.
0 commit comments