@@ -93,6 +93,116 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
9393 self .assertEqual (trainer_devices , mock_devices [:2 ])
9494 self .assertEqual (sampler_devices , mock_devices [2 :])
9595
96+ @pytest .mark .cpu_only
97+ def test_get_rollout_kwargs_no_dp (self ):
98+ """Test case 1: sampler_config.rollout_data_parallelism=-1 -> verify result is calculated."""
99+ # num_sampler_devices=16, tp=2, ep=4 -> dp should be 16 // (2 * 4) = 2
100+ sampler_config = SimpleNamespace (
101+ rollout_data_parallelism = - 1 ,
102+ rollout_tensor_parallelism = 2 ,
103+ rollout_expert_parallelism = 4 ,
104+ )
105+ expected_result = {
106+ "data_parallel_size" : 2 ,
107+ "tensor_parallel_size" : 2 ,
108+ "expert_parallel_size" : 4 ,
109+ }
110+ self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 16 ), expected_result )
111+
112+ @pytest .mark .cpu_only
113+ def test_get_rollout_kwargs_auto_tp (self ):
114+ """Test case 2: dp=2, tp=-1, num_sampler_devices=4."""
115+ sampler_config = SimpleNamespace (
116+ rollout_data_parallelism = 2 ,
117+ rollout_tensor_parallelism = - 1 ,
118+ rollout_expert_parallelism = 1 ,
119+ )
120+ expected_result = {
121+ "data_parallel_size" : 2 ,
122+ "tensor_parallel_size" : 2 ,
123+ "expert_parallel_size" : 1 ,
124+ }
125+ self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ), expected_result )
126+
127+ @pytest .mark .cpu_only
128+ def test_get_rollout_kwargs_fixed_tp_dp (self ):
129+ """Test case 3: dp=2, tp=2, num_sampler_devices=4."""
130+ sampler_config = SimpleNamespace (
131+ rollout_data_parallelism = 2 ,
132+ rollout_tensor_parallelism = 2 ,
133+ rollout_expert_parallelism = 1 ,
134+ )
135+ expected_result = {
136+ "data_parallel_size" : 2 ,
137+ "tensor_parallel_size" : 2 ,
138+ "expert_parallel_size" : 1 ,
139+ }
140+ self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ), expected_result )
141+
142+ @pytest .mark .cpu_only
143+ def test_get_rollout_kwargs_auto_ep (self ):
144+ """Test case 4: ep=-1 -> verify result is calculated."""
145+ # num_sampler_devices=8, tp=2, dp=2 -> ep should be 8 // (2 * 2) = 2
146+ sampler_config = SimpleNamespace (
147+ rollout_data_parallelism = 2 ,
148+ rollout_tensor_parallelism = 2 ,
149+ rollout_expert_parallelism = - 1 ,
150+ )
151+ expected_result = {
152+ "data_parallel_size" : 2 ,
153+ "tensor_parallel_size" : 2 ,
154+ "expert_parallel_size" : 2 ,
155+ }
156+ self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 ), expected_result )
157+
158+ @pytest .mark .cpu_only
159+ def test_get_rollout_kwargs_errors (self ):
160+ """Test various error cases for get_rollout_kwargs_for_parallelism."""
161+ # More than one -1
162+ sampler_config = SimpleNamespace (
163+ rollout_data_parallelism = - 1 ,
164+ rollout_tensor_parallelism = - 1 ,
165+ rollout_expert_parallelism = 1 ,
166+ )
167+ with self .assertRaisesRegex (ValueError , "At most one of .* can be -1" ):
168+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 )
169+
170+ # num_devices % (tp * ep) != 0 when dp == -1
171+ sampler_config = SimpleNamespace (
172+ rollout_data_parallelism = - 1 ,
173+ rollout_tensor_parallelism = 3 ,
174+ rollout_expert_parallelism = 1 ,
175+ )
176+ with self .assertRaisesRegex (ValueError , "must be divisible by" ):
177+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 )
178+
179+ # num_devices % (tp * dp) != 0 when ep == -1
180+ sampler_config = SimpleNamespace (
181+ rollout_data_parallelism = 2 ,
182+ rollout_tensor_parallelism = 3 ,
183+ rollout_expert_parallelism = - 1 ,
184+ )
185+ with self .assertRaisesRegex (ValueError , "must be divisible by" ):
186+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 )
187+
188+ # num_devices % (dp * ep) != 0 when tp == -1
189+ sampler_config = SimpleNamespace (
190+ rollout_data_parallelism = 3 ,
191+ rollout_tensor_parallelism = - 1 ,
192+ rollout_expert_parallelism = 2 ,
193+ )
194+ with self .assertRaisesRegex (ValueError , "must be divisible by" ):
195+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 )
196+
197+ # tp * dp * ep != num_sampler_devices when all are positive
198+ sampler_config = SimpleNamespace (
199+ rollout_data_parallelism = 2 ,
200+ rollout_tensor_parallelism = 2 ,
201+ rollout_expert_parallelism = 1 ,
202+ )
203+ with self .assertRaisesRegex (ValueError , r"!= len\(sampler_devices\)" ):
204+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 )
205+
96206
97207if __name__ == "__main__" :
98208 unittest .main ()
0 commit comments