Skip to content

Commit bd6e94f

Browse files
committed
Add unit test for prompt filtering
1 parent b842fe3 commit bd6e94f

1 file changed

Lines changed: 75 additions & 0 deletions

File tree

tests/unit/train_rl_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import unittest
1818
from unittest import mock
19+
import grain
1920
import pytest
2021
from types import SimpleNamespace
2122
import jax
@@ -286,6 +287,80 @@ def test_get_rollout_kwargs_errors(self):
286287
with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"):
287288
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)
288289

290+
@pytest.mark.cpu_only
291+
def test_prompt_filtering(self):
292+
"""Test that prompts longer than max_prefill_predict_length are filtered out."""
293+
# Setup mocks
294+
mock_tokenizer = mock.MagicMock()
295+
296+
# Define tokenizer side effect
297+
def tokenize_side_effect(text):
298+
if text == "short":
299+
return [0] * 5
300+
else:
301+
return [0] * 15
302+
303+
mock_tokenizer.tokenize.side_effect = tokenize_side_effect
304+
305+
# Define dataset mock data
306+
train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}]
307+
test_data = [{"prompts": "short"}, {"prompts": "long"}]
308+
train_map_ds = grain.MapDataset.source(train_data)
309+
test_map_ds = grain.MapDataset.source(test_data)
310+
311+
def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None):
312+
if split == "train":
313+
return train_map_ds
314+
else:
315+
return test_map_ds
316+
317+
# Configs
318+
trainer_config = SimpleNamespace(
319+
debug=SimpleNamespace(rl=False),
320+
tokenizer_path="dummy_path",
321+
dataset_name="dummy_dataset",
322+
train_split="train",
323+
eval_split="eval",
324+
hf_train_files=None,
325+
hf_eval_files=None,
326+
max_prefill_predict_length=10,
327+
batch_size=2,
328+
num_batches=2,
329+
train_fraction=1.0,
330+
num_epoch=1,
331+
num_test_batches=1,
332+
)
333+
334+
# Patch everything!
335+
with (
336+
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
337+
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"),
338+
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True),
339+
):
340+
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)
341+
342+
# Check filtered train dataset
343+
elements = list(train_dataset)
344+
# dataset_size = 4. Indices [0,1,2,3] are [short, long, short, long].
345+
# Filtered results: [short, short].
346+
# batch(2) will return 1 batch of 2 elements.
347+
self.assertEqual(len(elements), 1)
348+
batch = elements[0]
349+
self.assertEqual(len(batch["prompts"]), 2)
350+
for prompt in batch["prompts"]:
351+
self.assertEqual(prompt, "short")
352+
353+
# Check filtered test dataset
354+
test_elements = list(test_dataset)
355+
# test_data indices [0,1] are [short, long].
356+
# num_test_batches=1, batch_size=2 -> test dataset_size = 2.
357+
# Filtering results: [short].
358+
# batch(2) will return 1 batch of 1 element.
359+
self.assertEqual(len(test_elements), 1)
360+
test_batch = test_elements[0]
361+
self.assertEqual(len(test_batch["prompts"]), 1)
362+
self.assertEqual(test_batch["prompts"][0], "short")
363+
289364

290365
if __name__ == "__main__":
291366
unittest.main()

0 commit comments

Comments
 (0)