|
16 | 16 |
|
17 | 17 | import unittest |
18 | 18 | from unittest import mock |
| 19 | +import grain |
19 | 20 | import pytest |
20 | 21 | from types import SimpleNamespace |
21 | 22 | import jax |
@@ -286,6 +287,80 @@ def test_get_rollout_kwargs_errors(self): |
286 | 287 | with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"): |
287 | 288 | train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8) |
288 | 289 |
|
| 290 | + @pytest.mark.cpu_only |
| 291 | + def test_prompt_filtering(self): |
| 292 | + """Test that prompts longer than max_prefill_predict_length are filtered out.""" |
| 293 | + # Setup mocks |
| 294 | + mock_tokenizer = mock.MagicMock() |
| 295 | + |
| 296 | + # Define tokenizer side effect |
| 297 | + def tokenize_side_effect(text): |
| 298 | + if text == "short": |
| 299 | + return [0] * 5 |
| 300 | + else: |
| 301 | + return [0] * 15 |
| 302 | + |
| 303 | + mock_tokenizer.tokenize.side_effect = tokenize_side_effect |
| 304 | + |
| 305 | + # Define dataset mock data |
| 306 | + train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}] |
| 307 | + test_data = [{"prompts": "short"}, {"prompts": "long"}] |
| 308 | + train_map_ds = grain.MapDataset.source(train_data) |
| 309 | + test_map_ds = grain.MapDataset.source(test_data) |
| 310 | + |
| 311 | + def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None): |
| 312 | + if split == "train": |
| 313 | + return train_map_ds |
| 314 | + else: |
| 315 | + return test_map_ds |
| 316 | + |
| 317 | + # Configs |
| 318 | + trainer_config = SimpleNamespace( |
| 319 | + debug=SimpleNamespace(rl=False), |
| 320 | + tokenizer_path="dummy_path", |
| 321 | + dataset_name="dummy_dataset", |
| 322 | + train_split="train", |
| 323 | + eval_split="eval", |
| 324 | + hf_train_files=None, |
| 325 | + hf_eval_files=None, |
| 326 | + max_prefill_predict_length=10, |
| 327 | + batch_size=2, |
| 328 | + num_batches=2, |
| 329 | + train_fraction=1.0, |
| 330 | + num_epoch=1, |
| 331 | + num_test_batches=1, |
| 332 | + ) |
| 333 | + |
| 334 | + # Patch everything! |
| 335 | + with ( |
| 336 | + mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect), |
| 337 | + mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"), |
| 338 | + mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True), |
| 339 | + ): |
| 340 | + train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer) |
| 341 | + |
| 342 | + # Check filtered train dataset |
| 343 | + elements = list(train_dataset) |
| 344 | + # dataset_size = 4. Indices [0,1,2,3] are [short, long, short, long]. |
| 345 | + # Filtered results: [short, short]. |
| 346 | + # batch(2) will return 1 batch of 2 elements. |
| 347 | + self.assertEqual(len(elements), 1) |
| 348 | + batch = elements[0] |
| 349 | + self.assertEqual(len(batch["prompts"]), 2) |
| 350 | + for prompt in batch["prompts"]: |
| 351 | + self.assertEqual(prompt, "short") |
| 352 | + |
| 353 | + # Check filtered test dataset |
| 354 | + test_elements = list(test_dataset) |
| 355 | + # test_data indices [0,1] are [short, long]. |
| 356 | + # num_test_batches=1, batch_size=2 -> test dataset_size = 2. |
| 357 | + # Filtering results: [short]. |
| 358 | + # batch(2) will return 1 batch of 1 element. |
| 359 | + self.assertEqual(len(test_elements), 1) |
| 360 | + test_batch = test_elements[0] |
| 361 | + self.assertEqual(len(test_batch["prompts"]), 1) |
| 362 | + self.assertEqual(test_batch["prompts"][0], "short") |
| 363 | + |
289 | 364 |
|
290 | 365 | if __name__ == "__main__": |
291 | 366 | unittest.main() |
0 commit comments