From b2eb6e36186f76153a7b6a4d5790bd76ef4d3816 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Wed, 11 Mar 2026 16:51:16 -0700 Subject: [PATCH] Add unit test for prompt filtering --- tests/post_training/unit/train_rl_test.py | 75 +++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index af66d52a98..e5a4aa8ec2 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -16,6 +16,7 @@ import unittest from unittest import mock +import grain import pytest from types import SimpleNamespace import jax @@ -288,6 +289,80 @@ def test_get_rollout_kwargs_errors(self): with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"): train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8) + @pytest.mark.cpu_only + def test_prompt_filtering(self): + """Test that prompts longer than max_prefill_predict_length are filtered out.""" + # Setup mocks + mock_tokenizer = mock.MagicMock() + + # Define tokenizer side effect + def tokenize_side_effect(text): + if text == "short": + return [0] * 5 + else: + return [0] * 15 + + mock_tokenizer.tokenize.side_effect = tokenize_side_effect + + # Define dataset mock data + train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}] + test_data = [{"prompts": "short"}, {"prompts": "long"}] + train_map_ds = grain.MapDataset.source(train_data) + test_map_ds = grain.MapDataset.source(test_data) + + def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None): + if split == "train": + return train_map_ds + else: + return test_map_ds + + # Configs + trainer_config = SimpleNamespace( + debug=SimpleNamespace(rl=False), + tokenizer_path="dummy_path", + dataset_name="dummy_dataset", + train_split="train", + eval_split="eval", + hf_train_files=None, + hf_eval_files=None, + max_prefill_predict_length=10, + batch_size=2, + num_batches=2, + train_fraction=1.0, + num_epoch=1, + num_test_batches=1, + ) + + # Patch everything! + with ( + mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect), + mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"), + mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True), + ): + train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer) + + # Check filtered train dataset + elements = list(train_dataset) + # dataset_size = 4. Indices [0,1,2,3] are [short, long, short, long]. + # Filtered results: [short, short]. + # batch(2) will return 1 batch of 2 elements. + self.assertEqual(len(elements), 1) + batch = elements[0] + self.assertEqual(len(batch["prompts"]), 2) + for prompt in batch["prompts"]: + self.assertEqual(prompt, "short") + + # Check filtered test dataset + test_elements = list(test_dataset) + # test_data indices [0,1] are [short, long]. + # num_test_batches=1, batch_size=2 -> test dataset_size = 2. + # Filtering results: [short]. + # batch(2) will return 1 batch of 1 element. + self.assertEqual(len(test_elements), 1) + test_batch = test_elements[0] + self.assertEqual(len(test_batch["prompts"]), 1) + self.assertEqual(test_batch["prompts"][0], "short") + if __name__ == "__main__": unittest.main()