@@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self):
5858 # Following the pattern in distillation_checkpointing_test.py for mocking jax objects
5959 with (
6060 mock .patch .object (jax , "devices" , return_value = mock_devices ),
61- mock .patch ("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" , return_value = mock_config ),
61+ mock .patch (
62+ "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" ,
63+ return_value = mock_config ,
64+ ),
6265 ):
6366 trainer_config , sampler_config , trainer_devices , sampler_devices = model_creation_utils .setup_configs_and_devices (
6467 ["dummy" , "dummy" ]
@@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
8790
8891 with (
8992 mock .patch .object (jax , "devices" , return_value = mock_devices ),
90- mock .patch ("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" , return_value = mock_config ),
93+ mock .patch (
94+ "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" ,
95+ return_value = mock_config ,
96+ ),
9197 ):
9298 _ , _ , trainer_devices , sampler_devices = model_creation_utils .setup_configs_and_devices (["dummy" , "dummy" ])
9399
@@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self):
189195 "tensor_parallel_size" : 2 ,
190196 "expert_parallel_size" : 4 ,
191197 }
192- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 16 ), expected_result )
198+ self .assertEqual (
199+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 16 ),
200+ expected_result ,
201+ )
193202
194203 @pytest .mark .cpu_only
195204 def test_get_rollout_kwargs_auto_tp (self ):
@@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self):
204213 "tensor_parallel_size" : 2 ,
205214 "expert_parallel_size" : 1 ,
206215 }
207- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ), expected_result )
216+ self .assertEqual (
217+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ),
218+ expected_result ,
219+ )
208220
209221 @pytest .mark .cpu_only
210222 def test_get_rollout_kwargs_fixed_tp_dp (self ):
@@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self):
219231 "tensor_parallel_size" : 2 ,
220232 "expert_parallel_size" : 1 ,
221233 }
222- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ), expected_result )
234+ self .assertEqual (
235+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ),
236+ expected_result ,
237+ )
223238
224239 @pytest .mark .cpu_only
225240 def test_get_rollout_kwargs_auto_ep (self ):
@@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self):
235250 "tensor_parallel_size" : 2 ,
236251 "expert_parallel_size" : 2 ,
237252 }
238- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 ), expected_result )
253+ self .assertEqual (
254+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 ),
255+ expected_result ,
256+ )
239257
240258 @pytest .mark .cpu_only
241259 def test_get_rollout_kwargs_errors (self ):
@@ -307,7 +325,10 @@ def tokenize_side_effect(text):
307325 {"question" : "short" , "answer" : "a3" },
308326 {"question" : "long" , "answer" : "a4" },
309327 ]
310- test_data = [{"question" : "short" , "answer" : "a5" }, {"question" : "long" , "answer" : "a6" }]
328+ test_data = [
329+ {"question" : "short" , "answer" : "a5" },
330+ {"question" : "long" , "answer" : "a6" },
331+ ]
311332 train_map_ds = grain .MapDataset .source (train_data )
312333 test_map_ds = grain .MapDataset .source (test_data )
313334
@@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
346367 )
347368
348369 with (
349- mock .patch ("maxtext.trainers.post_train.rl.train_rl.get_dataset" , side_effect = get_dataset_side_effect ),
350- mock .patch ("maxtext.trainers.post_train.rl.utils_rl.process_data" , side_effect = get_filtered_data_side_effect ),
370+ mock .patch (
371+ "maxtext.trainers.post_train.rl.train_rl.get_dataset" ,
372+ side_effect = get_dataset_side_effect ,
373+ ),
374+ mock .patch (
375+ "maxtext.trainers.post_train.rl.utils_rl.process_data" ,
376+ side_effect = get_filtered_data_side_effect ,
377+ ),
351378 ):
352379 train_dataset , test_dataset = train_rl .prepare_datasets (trainer_config , mock_tokenizer )
353380
@@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
378405 def test_prepare_datasets_with_split (self , mock_load ):
379406 mock_ds = mock .MagicMock ()
380407 mock_split_result = {
381- "train" : [{"question" : "q1" , "answer" : "a1" }, {"question" : "q2" , "answer" : "a2" }],
408+ "train" : [
409+ {"question" : "q1" , "answer" : "a1" },
410+ {"question" : "q2" , "answer" : "a2" },
411+ ],
382412 "test" : [{"question" : "q3" , "answer" : "a3" }],
383413 }
384414 mock_ds .train_test_split .return_value = mock_split_result
@@ -480,7 +510,7 @@ def test_rl_train_invalid_vocab_tiling(self, mock_setup):
480510 mock_setup .return_value = (mock_config , mock_config , [], [])
481511
482512 with self .assertRaisesRegex (ValueError , "Vocab Tiling is not supported with RL" ):
483- train_rl ._rl_train_impl ([], {})
513+ train_rl ._rl_train_impl ([], {}) # pylint: disable=protected-access
484514
485515
486516class TokenizerChatTemplateTest (unittest .TestCase ):
0 commit comments