Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/post_training/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest
from unittest import mock
import grain
import pytest
from types import SimpleNamespace
import jax
Expand Down Expand Up @@ -288,6 +289,80 @@ def test_get_rollout_kwargs_errors(self):
with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"):
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)

@pytest.mark.cpu_only
def test_prompt_filtering(self):
"""Test that prompts longer than max_prefill_predict_length are filtered out."""
# Setup mocks
mock_tokenizer = mock.MagicMock()

# Define tokenizer side effect
def tokenize_side_effect(text):
if text == "short":
return [0] * 5
else:
return [0] * 15

mock_tokenizer.tokenize.side_effect = tokenize_side_effect

# Define dataset mock data
train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}]
test_data = [{"prompts": "short"}, {"prompts": "long"}]
train_map_ds = grain.MapDataset.source(train_data)
test_map_ds = grain.MapDataset.source(test_data)

def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None):
if split == "train":
return train_map_ds
else:
return test_map_ds

# Configs
trainer_config = SimpleNamespace(
debug=SimpleNamespace(rl=False),
tokenizer_path="dummy_path",
dataset_name="dummy_dataset",
train_split="train",
eval_split="eval",
hf_train_files=None,
hf_eval_files=None,
max_prefill_predict_length=10,
batch_size=2,
num_batches=2,
train_fraction=1.0,
num_epoch=1,
num_test_batches=1,
)

# Patch everything!
with (
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"),
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True),
):
train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer)

# Check filtered train dataset
elements = list(train_dataset)
# dataset_size = 4. Indices [0,1,2,3] are [short, long, short, long].
# Filtered results: [short, short].
# batch(2) will return 1 batch of 2 elements.
self.assertEqual(len(elements), 1)
batch = elements[0]
self.assertEqual(len(batch["prompts"]), 2)
for prompt in batch["prompts"]:
self.assertEqual(prompt, "short")

# Check filtered test dataset
test_elements = list(test_dataset)
# test_data indices [0,1] are [short, long].
# num_test_batches=1, batch_size=2 -> test dataset_size = 2.
# Filtering results: [short].
# batch(2) will return 1 batch of 1 element.
self.assertEqual(len(test_elements), 1)
test_batch = test_elements[0]
self.assertEqual(len(test_batch["prompts"]), 1)
self.assertEqual(test_batch["prompts"][0], "short")


if __name__ == "__main__":
unittest.main()
Loading