Skip to content

Commit 16cc4f4

Browse files
Merge pull request #4091 from AI-Hypercomputer:bvandermoon-test-fixes
PiperOrigin-RevId: 927552045
2 parents b2153a3 + b5296e0 commit 16cc4f4

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

tests/post_training/unit/train_distill_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,8 @@ def test_main_offline_mode_skips_teacher_loading(
10641064
):
10651065
"""Verifies offline mode (offline_data_dir is set) skips teacher model loading."""
10661066
# 1. Configs
1067+
devices = jax.devices()[:1]
1068+
mock_create_mesh.return_value = np.array(devices)
10671069
mock_global = mock.Mock()
10681070
mock_global.student_overrides = {}
10691071
mock_global.teacher_overrides = {} # No checkpoint needed
@@ -1172,6 +1174,8 @@ def test_main_online_mode_loads_teacher(
11721174
mock_trainer_cls,
11731175
):
11741176
"""Verifies online mode (offline_data_dir is None) loads both student and teacher models."""
1177+
devices = jax.devices()[:1]
1178+
mock_create_mesh.return_value = np.array(devices)
11751179
mock_global = mock.Mock()
11761180
mock_global.student_overrides = {}
11771181
mock_global.teacher_overrides = {"load_parameters_path": "gs://ckpt"}

0 commit comments

Comments
 (0)