diff --git a/tests/unit/train_rl_test.py b/tests/unit/train_rl_test.py index 400b114f52..f3b178ac4f 100644 --- a/tests/unit/train_rl_test.py +++ b/tests/unit/train_rl_test.py @@ -93,6 +93,116 @@ def test_setup_configs_and_devices_pathways_fractional_split(self): self.assertEqual(trainer_devices, mock_devices[:2]) self.assertEqual(sampler_devices, mock_devices[2:]) + @pytest.mark.cpu_only + def test_get_rollout_kwargs_no_dp(self): + """Test case 1: sampler_config.rollout_data_parallelism=-1 -> verify result is calculated.""" + # num_sampler_devices=16, tp=2, ep=4 -> dp should be 16 // (2 * 4) = 2 + sampler_config = SimpleNamespace( + rollout_data_parallelism=-1, + rollout_tensor_parallelism=2, + rollout_expert_parallelism=4, + ) + expected_result = { + "data_parallel_size": 2, + "tensor_parallel_size": 2, + "expert_parallel_size": 4, + } + self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result) + + @pytest.mark.cpu_only + def test_get_rollout_kwargs_auto_tp(self): + """Test case 2: dp=2, tp=-1, num_sampler_devices=4.""" + sampler_config = SimpleNamespace( + rollout_data_parallelism=2, + rollout_tensor_parallelism=-1, + rollout_expert_parallelism=1, + ) + expected_result = { + "data_parallel_size": 2, + "tensor_parallel_size": 2, + "expert_parallel_size": 1, + } + self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result) + + @pytest.mark.cpu_only + def test_get_rollout_kwargs_fixed_tp_dp(self): + """Test case 3: dp=2, tp=2, num_sampler_devices=4.""" + sampler_config = SimpleNamespace( + rollout_data_parallelism=2, + rollout_tensor_parallelism=2, + rollout_expert_parallelism=1, + ) + expected_result = { + "data_parallel_size": 2, + "tensor_parallel_size": 2, + "expert_parallel_size": 1, + } + self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result) + + @pytest.mark.cpu_only + def test_get_rollout_kwargs_auto_ep(self): + """Test case 4: ep=-1 -> verify result is calculated.""" + # num_sampler_devices=8, tp=2, dp=2 -> ep should be 8 // (2 * 2) = 2 + sampler_config = SimpleNamespace( + rollout_data_parallelism=2, + rollout_tensor_parallelism=2, + rollout_expert_parallelism=-1, + ) + expected_result = { + "data_parallel_size": 2, + "tensor_parallel_size": 2, + "expert_parallel_size": 2, + } + self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result) + + @pytest.mark.cpu_only + def test_get_rollout_kwargs_errors(self): + """Test various error cases for get_rollout_kwargs_for_parallelism.""" + # More than one -1 + sampler_config = SimpleNamespace( + rollout_data_parallelism=-1, + rollout_tensor_parallelism=-1, + rollout_expert_parallelism=1, + ) + with self.assertRaisesRegex(ValueError, "At most one of .* can be -1"): + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4) + + # num_devices % (tp * ep) != 0 when dp == -1 + sampler_config = SimpleNamespace( + rollout_data_parallelism=-1, + rollout_tensor_parallelism=3, + rollout_expert_parallelism=1, + ) + with self.assertRaisesRegex(ValueError, "must be divisible by"): + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4) + + # num_devices % (tp * dp) != 0 when ep == -1 + sampler_config = SimpleNamespace( + rollout_data_parallelism=2, + rollout_tensor_parallelism=3, + rollout_expert_parallelism=-1, + ) + with self.assertRaisesRegex(ValueError, "must be divisible by"): + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8) + + # num_devices % (dp * ep) != 0 when tp == -1 + sampler_config = SimpleNamespace( + rollout_data_parallelism=3, + rollout_tensor_parallelism=-1, + rollout_expert_parallelism=2, + ) + with self.assertRaisesRegex(ValueError, "must be divisible by"): + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8) + + # tp * dp * ep != num_sampler_devices when all are positive + sampler_config = SimpleNamespace( + rollout_data_parallelism=2, + rollout_tensor_parallelism=2, + rollout_expert_parallelism=1, + ) + with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"): + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8) + if __name__ == "__main__": unittest.main()