|
28 | 28 | ) |
29 | 29 |
|
30 | 30 |
|
31 | | -def _get_mock_devices(num_devices): |
32 | | - mock_devices = [mock.MagicMock() for _ in range(num_devices)] |
33 | | - for i, d in enumerate(mock_devices): |
34 | | - d.id = i |
| 31 | +def _get_mock_devices(devices_per_slice, num_slices=1): |
| 32 | + mock_devices = [] |
| 33 | + for slice_idx in range(num_slices): |
| 34 | + for _ in range(devices_per_slice): |
| 35 | + d = mock.MagicMock() |
| 36 | + d.id = len(mock_devices) |
| 37 | + d.slice_index = slice_idx |
| 38 | + mock_devices.append(d) |
35 | 39 | return mock_devices |
36 | 40 |
|
37 | 41 |
|
@@ -93,6 +97,85 @@ def test_setup_configs_and_devices_pathways_fractional_split(self): |
93 | 97 | self.assertEqual(trainer_devices, mock_devices[:2]) |
94 | 98 | self.assertEqual(sampler_devices, mock_devices[2:]) |
95 | 99 |
|
| 100 | + @pytest.mark.cpu_only |
| 101 | + def test_setup_configs_and_devices_multislice_not_enough_slices(self): |
| 102 | + """Test setup_configs_and_devices raises ValueError when not enough slices.""" |
| 103 | + mock_devices = _get_mock_devices(num_slices=2, devices_per_slice=4) |
| 104 | + mock_config = SimpleNamespace( |
| 105 | + num_trainer_slices=2, |
| 106 | + num_samplers_slices=1, |
| 107 | + ) |
| 108 | + |
| 109 | + def side_effect(argv, **kwargs): |
| 110 | + res = SimpleNamespace(**vars(mock_config)) |
| 111 | + for k, v in kwargs.items(): |
| 112 | + setattr(res, k, v) |
| 113 | + return res |
| 114 | + |
| 115 | + with ( |
| 116 | + mock.patch.object(jax, "devices", return_value=mock_devices), |
| 117 | + mock.patch( |
| 118 | + "maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", |
| 119 | + side_effect=side_effect, |
| 120 | + ), |
| 121 | + ): |
| 122 | + with self.assertRaisesRegex(ValueError, "Not enough slices for trainer and samplers"): |
| 123 | + train_rl.setup_configs_and_devices(["dummy", "dummy"]) |
| 124 | + |
| 125 | + @pytest.mark.cpu_only |
| 126 | + def test_setup_configs_and_devices_multislice_invalid_tp(self): |
| 127 | + """Test setup_configs_and_devices raises ValueError for invalid TP.""" |
| 128 | + mock_devices = _get_mock_devices(num_slices=4, devices_per_slice=8) |
| 129 | + mock_config = SimpleNamespace( |
| 130 | + num_trainer_slices=2, |
| 131 | + num_samplers_slices=2, |
| 132 | + ici_tensor_parallelism=3, # 8 is not divisible by 3 |
| 133 | + ici_fsdp_parallelism=-1, |
| 134 | + ) |
| 135 | + |
| 136 | + def side_effect(argv, **kwargs): |
| 137 | + res = SimpleNamespace(**vars(mock_config)) |
| 138 | + for k, v in kwargs.items(): |
| 139 | + setattr(res, k, v) |
| 140 | + return res |
| 141 | + |
| 142 | + with ( |
| 143 | + mock.patch.object(jax, "devices", return_value=mock_devices), |
| 144 | + mock.patch( |
| 145 | + "maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", |
| 146 | + side_effect=side_effect, |
| 147 | + ), |
| 148 | + ): |
| 149 | + with self.assertRaisesRegex(ValueError, "must be divisible by tensor parallelism"): |
| 150 | + train_rl.setup_configs_and_devices(["dummy", "dummy"]) |
| 151 | + |
| 152 | + @pytest.mark.cpu_only |
| 153 | + def test_setup_configs_and_devices_multislice_invalid_tp_fsdp(self): |
| 154 | + """Test setup_configs_and_devices raises ValueError for inconsistent TP and FSDP.""" |
| 155 | + mock_devices = _get_mock_devices(num_slices=4, devices_per_slice=8) |
| 156 | + mock_config = SimpleNamespace( |
| 157 | + num_trainer_slices=2, |
| 158 | + num_samplers_slices=2, |
| 159 | + ici_tensor_parallelism=4, |
| 160 | + ici_fsdp_parallelism=3, # 4 * 3 != 8 |
| 161 | + ) |
| 162 | + |
| 163 | + def side_effect(argv, **kwargs): |
| 164 | + res = SimpleNamespace(**vars(mock_config)) |
| 165 | + for k, v in kwargs.items(): |
| 166 | + setattr(res, k, v) |
| 167 | + return res |
| 168 | + |
| 169 | + with ( |
| 170 | + mock.patch.object(jax, "devices", return_value=mock_devices), |
| 171 | + mock.patch( |
| 172 | + "maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", |
| 173 | + side_effect=side_effect, |
| 174 | + ), |
| 175 | + ): |
| 176 | + with self.assertRaisesRegex(ValueError, "must equal devices_per_slice"): |
| 177 | + train_rl.setup_configs_and_devices(["dummy", "dummy"]) |
| 178 | + |
96 | 179 | @pytest.mark.cpu_only |
97 | 180 | def test_get_rollout_kwargs_no_dp(self): |
98 | 181 | """Test case 1: sampler_config.rollout_data_parallelism=-1 -> verify result is calculated.""" |
|
0 commit comments