Skip to content

Commit 792a8ee

Browse files
committed
Fix pylint protected-access on _rl_train_impl test from merged main
1 parent fd9717a commit 792a8ee

1 file changed

Lines changed: 41 additions & 11 deletions

File tree

tests/post_training/unit/train_rl_test.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self):
5858
# Following the pattern in distillation_checkpointing_test.py for mocking jax objects
5959
with (
6060
mock.patch.object(jax, "devices", return_value=mock_devices),
61-
mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
61+
mock.patch(
62+
"maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
63+
return_value=mock_config,
64+
),
6265
):
6366
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
6467
["dummy", "dummy"]
@@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
8790

8891
with (
8992
mock.patch.object(jax, "devices", return_value=mock_devices),
90-
mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config),
93+
mock.patch(
94+
"maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic",
95+
return_value=mock_config,
96+
),
9197
):
9298
_, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"])
9399

@@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self):
189195
"tensor_parallel_size": 2,
190196
"expert_parallel_size": 4,
191197
}
192-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result)
198+
self.assertEqual(
199+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16),
200+
expected_result,
201+
)
193202

194203
@pytest.mark.cpu_only
195204
def test_get_rollout_kwargs_auto_tp(self):
@@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self):
204213
"tensor_parallel_size": 2,
205214
"expert_parallel_size": 1,
206215
}
207-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
216+
self.assertEqual(
217+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
218+
expected_result,
219+
)
208220

209221
@pytest.mark.cpu_only
210222
def test_get_rollout_kwargs_fixed_tp_dp(self):
@@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self):
219231
"tensor_parallel_size": 2,
220232
"expert_parallel_size": 1,
221233
}
222-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result)
234+
self.assertEqual(
235+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4),
236+
expected_result,
237+
)
223238

224239
@pytest.mark.cpu_only
225240
def test_get_rollout_kwargs_auto_ep(self):
@@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self):
235250
"tensor_parallel_size": 2,
236251
"expert_parallel_size": 2,
237252
}
238-
self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result)
253+
self.assertEqual(
254+
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8),
255+
expected_result,
256+
)
239257

240258
@pytest.mark.cpu_only
241259
def test_get_rollout_kwargs_errors(self):
@@ -307,7 +325,10 @@ def tokenize_side_effect(text):
307325
{"question": "short", "answer": "a3"},
308326
{"question": "long", "answer": "a4"},
309327
]
310-
test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}]
328+
test_data = [
329+
{"question": "short", "answer": "a5"},
330+
{"question": "long", "answer": "a6"},
331+
]
311332
train_map_ds = grain.MapDataset.source(train_data)
312333
test_map_ds = grain.MapDataset.source(test_data)
313334

@@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
346367
)
347368

348369
with (
349-
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
350-
mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect),
370+
mock.patch(
371+
"maxtext.trainers.post_train.rl.train_rl.get_dataset",
372+
side_effect=get_dataset_side_effect,
373+
),
374+
mock.patch(
375+
"maxtext.trainers.post_train.rl.utils_rl.process_data",
376+
side_effect=get_filtered_data_side_effect,
377+
),
351378
):
352379
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)
353380

@@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
378405
def test_prepare_datasets_with_split(self, mock_load):
379406
mock_ds = mock.MagicMock()
380407
mock_split_result = {
381-
"train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}],
408+
"train": [
409+
{"question": "q1", "answer": "a1"},
410+
{"question": "q2", "answer": "a2"},
411+
],
382412
"test": [{"question": "q3", "answer": "a3"}],
383413
}
384414
mock_ds.train_test_split.return_value = mock_split_result
@@ -480,7 +510,7 @@ def test_rl_train_invalid_vocab_tiling(self, mock_setup):
480510
mock_setup.return_value = (mock_config, mock_config, [], [])
481511

482512
with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"):
483-
train_rl._rl_train_impl([], {})
513+
train_rl._rl_train_impl([], {}) # pylint: disable=protected-access
484514

485515

486516
class TokenizerChatTemplateTest(unittest.TestCase):

0 commit comments

Comments
 (0)