Skip to content

Commit c3bbec3

Browse files
committed
Add comprehensive unit tests for rollout kwargs calculation
1 parent e3dbd54 commit c3bbec3

1 file changed

Lines changed: 101 additions & 0 deletions

File tree

tests/unit/train_rl_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,107 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
9393
self.assertEqual(trainer_devices, mock_devices[:2])
9494
self.assertEqual(sampler_devices, mock_devices[2:])
9595

96+
@pytest.mark.cpu_only
97+
def test_get_rollout_kwargs_no_dp(self):
98+
"""Test case 1: sampler_config.rollout_data_parallelism=-1 -> verify result is calculated."""
99+
# num_sampler_devices=16, tp=2, ep=4 -> dp should be 16 // (2 * 4) = 2
100+
sampler_config = SimpleNamespace(
101+
rollout_data_parallelism=-1,
102+
rollout_tensor_parallelism=2,
103+
rollout_expert_parallelism=4,
104+
)
105+
expected_result = {
106+
"data_parallel_size": 2,
107+
"tensor_parallel_size": 2,
108+
"expert_parallel_size": 4,
109+
}
110+
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result)
111+
112+
@pytest.mark.cpu_only
113+
def test_get_rollout_kwargs_auto_tp(self):
114+
"""Test case 2: dp=2, tp=-1, num_sampler_devices=4."""
115+
sampler_config = SimpleNamespace(
116+
rollout_data_parallelism=2,
117+
rollout_tensor_parallelism=-1,
118+
rollout_expert_parallelism=1,
119+
)
120+
expected_result = {
121+
"data_parallel_size": 2,
122+
"tensor_parallel_size": 2,
123+
"expert_parallel_size": 1,
124+
}
125+
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
126+
127+
@pytest.mark.cpu_only
128+
def test_get_rollout_kwargs_fixed_tp_dp(self):
129+
"""Test case 3: dp=2, tp=2, num_sampler_devices=4."""
130+
sampler_config = SimpleNamespace(
131+
rollout_data_parallelism=2,
132+
rollout_tensor_parallelism=2,
133+
rollout_expert_parallelism=1,
134+
)
135+
expected_result = {
136+
"data_parallel_size": 2,
137+
"tensor_parallel_size": 2,
138+
"expert_parallel_size": 1,
139+
}
140+
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
141+
142+
@pytest.mark.cpu_only
143+
def test_get_rollout_kwargs_auto_ep(self):
144+
"""Test case 4: ep=-1 -> verify result is calculated."""
145+
# num_sampler_devices=8, tp=2, dp=2 -> ep should be 8 // (2 * 2) = 2
146+
sampler_config = SimpleNamespace(
147+
rollout_data_parallelism=2,
148+
rollout_tensor_parallelism=2,
149+
rollout_expert_parallelism=-1,
150+
)
151+
expected_result = {
152+
"data_parallel_size": 2,
153+
"tensor_parallel_size": 2,
154+
"expert_parallel_size": 2,
155+
}
156+
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result)
157+
158+
@pytest.mark.cpu_only
159+
def test_get_rollout_kwargs_errors(self):
160+
"""Test various error cases for get_rollout_kwargs_for_parallelism."""
161+
# More than one -1
162+
sampler_config = SimpleNamespace(
163+
rollout_data_parallelism=-1,
164+
rollout_tensor_parallelism=-1,
165+
rollout_expert_parallelism=1,
166+
)
167+
with self.assertRaisesRegex(ValueError, "At most one of .* can be -1"):
168+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4)
169+
170+
# num_devices % (tp * ep) != 0 when dp == -1
171+
sampler_config = SimpleNamespace(
172+
rollout_data_parallelism=-1,
173+
rollout_tensor_parallelism=3,
174+
rollout_expert_parallelism=1,
175+
)
176+
with self.assertRaisesRegex(ValueError, "must be divisible by"):
177+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4)
178+
179+
# num_devices % (tp * dp) != 0 when ep == -1
180+
sampler_config = SimpleNamespace(
181+
rollout_data_parallelism=2,
182+
rollout_tensor_parallelism=3,
183+
rollout_expert_parallelism=-1,
184+
)
185+
with self.assertRaisesRegex(ValueError, "must be divisible by"):
186+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)
187+
188+
# tp * dp * ep != num_sampler_devices when all are positive
189+
sampler_config = SimpleNamespace(
190+
rollout_data_parallelism=2,
191+
rollout_tensor_parallelism=2,
192+
rollout_expert_parallelism=1,
193+
)
194+
with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"):
195+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)
196+
96197

97198
if __name__ == "__main__":
98199
unittest.main()

0 commit comments

Comments
 (0)