Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions tests/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,116 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
self.assertEqual(trainer_devices, mock_devices[:2])
self.assertEqual(sampler_devices, mock_devices[2:])

@pytest.mark.cpu_only
def test_get_rollout_kwargs_no_dp(self):
"""Test case 1: sampler_config.rollout_data_parallelism=-1 -> verify result is calculated."""
# num_sampler_devices=16, tp=2, ep=4 -> dp should be 16 // (2 * 4) = 2
sampler_config = SimpleNamespace(
rollout_data_parallelism=-1,
rollout_tensor_parallelism=2,
rollout_expert_parallelism=4,
)
expected_result = {
"data_parallel_size": 2,
"tensor_parallel_size": 2,
"expert_parallel_size": 4,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_auto_tp(self):
"""Test case 2: dp=2, tp=-1, num_sampler_devices=4."""
sampler_config = SimpleNamespace(
rollout_data_parallelism=2,
rollout_tensor_parallelism=-1,
rollout_expert_parallelism=1,
)
expected_result = {
"data_parallel_size": 2,
"tensor_parallel_size": 2,
"expert_parallel_size": 1,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_fixed_tp_dp(self):
"""Test case 3: dp=2, tp=2, num_sampler_devices=4."""
sampler_config = SimpleNamespace(
rollout_data_parallelism=2,
rollout_tensor_parallelism=2,
rollout_expert_parallelism=1,
)
expected_result = {
"data_parallel_size": 2,
"tensor_parallel_size": 2,
"expert_parallel_size": 1,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_auto_ep(self):
"""Test case 4: ep=-1 -> verify result is calculated."""
# num_sampler_devices=8, tp=2, dp=2 -> ep should be 8 // (2 * 2) = 2
sampler_config = SimpleNamespace(
rollout_data_parallelism=2,
rollout_tensor_parallelism=2,
rollout_expert_parallelism=-1,
)
expected_result = {
"data_parallel_size": 2,
"tensor_parallel_size": 2,
"expert_parallel_size": 2,
}
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result)

@pytest.mark.cpu_only
def test_get_rollout_kwargs_errors(self):
"""Test various error cases for get_rollout_kwargs_for_parallelism."""
# More than one -1
sampler_config = SimpleNamespace(
rollout_data_parallelism=-1,
rollout_tensor_parallelism=-1,
rollout_expert_parallelism=1,
)
with self.assertRaisesRegex(ValueError, "At most one of .* can be -1"):
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4)

# num_devices % (tp * ep) != 0 when dp == -1
sampler_config = SimpleNamespace(
rollout_data_parallelism=-1,
rollout_tensor_parallelism=3,
rollout_expert_parallelism=1,
)
with self.assertRaisesRegex(ValueError, "must be divisible by"):
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4)

# num_devices % (tp * dp) != 0 when ep == -1
sampler_config = SimpleNamespace(
rollout_data_parallelism=2,
rollout_tensor_parallelism=3,
rollout_expert_parallelism=-1,
)
with self.assertRaisesRegex(ValueError, "must be divisible by"):
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)
Comment thread
igorts-git marked this conversation as resolved.

# num_devices % (dp * ep) != 0 when tp == -1
sampler_config = SimpleNamespace(
rollout_data_parallelism=3,
rollout_tensor_parallelism=-1,
rollout_expert_parallelism=2,
)
with self.assertRaisesRegex(ValueError, "must be divisible by"):
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)

# tp * dp * ep != num_sampler_devices when all are positive
sampler_config = SimpleNamespace(
rollout_data_parallelism=2,
rollout_tensor_parallelism=2,
rollout_expert_parallelism=1,
)
with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"):
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)


if __name__ == "__main__":
unittest.main()
Loading