Skip to content

Commit b2eb6e3

Browse files
committed
Add unit test for prompt filtering
1 parent 93e2feb commit b2eb6e3

1 file changed

Lines changed: 75 additions & 0 deletions

File tree

tests/post_training/unit/train_rl_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import unittest
1818
from unittest import mock
19+
import grain
1920
import pytest
2021
from types import SimpleNamespace
2122
import jax
@@ -288,6 +289,80 @@ def test_get_rollout_kwargs_errors(self):
288289
with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"):
289290
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)
290291

292+
@pytest.mark.cpu_only
293+
def test_prompt_filtering(self):
294+
"""Test that prompts longer than max_prefill_predict_length are filtered out."""
295+
# Setup mocks
296+
mock_tokenizer = mock.MagicMock()
297+
298+
# Define tokenizer side effect
299+
def tokenize_side_effect(text):
300+
if text == "short":
301+
return [0] * 5
302+
else:
303+
return [0] * 15
304+
305+
mock_tokenizer.tokenize.side_effect = tokenize_side_effect
306+
307+
# Define dataset mock data
308+
train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}]
309+
test_data = [{"prompts": "short"}, {"prompts": "long"}]
310+
train_map_ds = grain.MapDataset.source(train_data)
311+
test_map_ds = grain.MapDataset.source(test_data)
312+
313+
def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None):
314+
if split == "train":
315+
return train_map_ds
316+
else:
317+
return test_map_ds
318+
319+
# Configs
320+
trainer_config = SimpleNamespace(
321+
debug=SimpleNamespace(rl=False),
322+
tokenizer_path="dummy_path",
323+
dataset_name="dummy_dataset",
324+
train_split="train",
325+
eval_split="eval",
326+
hf_train_files=None,
327+
hf_eval_files=None,
328+
max_prefill_predict_length=10,
329+
batch_size=2,
330+
num_batches=2,
331+
train_fraction=1.0,
332+
num_epoch=1,
333+
num_test_batches=1,
334+
)
335+
336+
# Patch everything!
337+
with (
338+
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
339+
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"),
340+
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True),
341+
):
342+
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)
343+
344+
# Check filtered train dataset
345+
elements = list(train_dataset)
346+
# dataset_size = 4. Indices [0,1,2,3] are [short, long, short, long].
347+
# Filtered results: [short, short].
348+
# batch(2) will return 1 batch of 2 elements.
349+
self.assertEqual(len(elements), 1)
350+
batch = elements[0]
351+
self.assertEqual(len(batch["prompts"]), 2)
352+
for prompt in batch["prompts"]:
353+
self.assertEqual(prompt, "short")
354+
355+
# Check filtered test dataset
356+
test_elements = list(test_dataset)
357+
# test_data indices [0,1] are [short, long].
358+
# num_test_batches=1, batch_size=2 -> test dataset_size = 2.
359+
# Filtering results: [short].
360+
# batch(2) will return 1 batch of 1 element.
361+
self.assertEqual(len(test_elements), 1)
362+
test_batch = test_elements[0]
363+
self.assertEqual(len(test_batch["prompts"]), 1)
364+
self.assertEqual(test_batch["prompts"][0], "short")
365+
291366

292367
if __name__ == "__main__":
293368
unittest.main()

0 commit comments

Comments
 (0)