Skip to content

Commit 6513acd

Browse files
committed
Add unit test for prompt filtering and refactor rl_train to extract get_datasets
1 parent 6b6dbc2 commit 6513acd

2 files changed

Lines changed: 128 additions & 45 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -282,37 +282,11 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
282282
return rollout_kwargs
283283

284284

285-
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
286-
"""
287-
Run RL training with the provided configuration.
288-
289-
Args:
290-
trainer_config: MaxText configuration for the trainer.
291-
sampler_config: MaxText configuration for the sampler.
292-
trainer_devices: JAX devices for the trainer.
293-
sampler_devices: JAX devices for the sampler.
294-
"""
295-
if not trainer_config.debug.rl:
296-
# Apply filter to suppress noisy logs
297-
noise_filter = max_logging.NoisyLogFilter()
298-
logging.getLogger().addFilter(noise_filter)
299-
absl_logging.get_absl_logger().addFilter(noise_filter)
300-
301-
max_logging.log("Starting RL Training")
302-
max_logging.log(f"Ensuring TensorBoard log directory exists: {trainer_config.tensorboard_dir}")
303-
if not epath.Path(trainer_config.tensorboard_dir).exists():
304-
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
305-
306-
if not epath.Path(trainer_config.checkpoint_dir).exists():
307-
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)
308-
309-
# Number of training steps.
310-
max_train_steps = int(
311-
trainer_config.num_batches
312-
* trainer_config.rl.num_iterations
313-
* trainer_config.train_fraction
314-
* trainer_config.num_epoch
315-
)
285+
def get_datasets(
286+
model_tokenizer,
287+
trainer_config,
288+
) -> tuple[grain.IterDataset, grain.IterDataset]:
289+
"""Handles loading, templating, filtering, and batching of train/test datasets."""
316290
# ====== Data ======
317291
# Setup data directories
318292
home = os.path.expanduser("~") + "/"
@@ -323,9 +297,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
323297
if not os.path.exists(test_data_dir):
324298
os.makedirs(test_data_dir)
325299

326-
# Create model tokenizer
327-
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
328-
329300
# Load datasets
330301
if trainer_config.dataset_name == "huggingface:nvidia/OpenMathInstruct-2":
331302
import datasets # pylint: disable=import-outside-toplevel
@@ -334,7 +305,6 @@ def prepare_openinstructmath2_dataset(
334305
split: str = "train_1M",
335306
seed: int = 42,
336307
test_size: float = 0.05,
337-
output_key: str = "expected_answer",
338308
):
339309
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
340310
max_logging.log(
@@ -422,16 +392,54 @@ def _filter_long_prompts(x):
422392

423393
if trainer_config.debug.rl:
424394
# Let's see how one batch of the dataset looks like!
425-
if trainer_config.debug.rl:
426-
for i, ele in enumerate(train_dataset):
427-
if i >= 5:
428-
break
429-
pprint(ele)
430-
if trainer_config.debug.rl:
431-
for i, ele in enumerate(test_dataset):
432-
if i >= 5:
433-
break
434-
pprint(ele)
395+
for i, ele in enumerate(train_dataset):
396+
if i >= 5:
397+
break
398+
pprint(ele)
399+
for i, ele in enumerate(test_dataset):
400+
if i >= 5:
401+
break
402+
pprint(ele)
403+
404+
return train_dataset, test_dataset
405+
406+
407+
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
408+
"""
409+
Run RL training with the provided configuration.
410+
411+
Args:
412+
trainer_config: MaxText configuration for the trainer.
413+
sampler_config: MaxText configuration for the sampler.
414+
trainer_devices: JAX devices for the trainer.
415+
sampler_devices: JAX devices for the sampler.
416+
"""
417+
if not trainer_config.debug.rl:
418+
# Apply filter to suppress noisy logs
419+
noise_filter = max_logging.NoisyLogFilter()
420+
logging.getLogger().addFilter(noise_filter)
421+
absl_logging.get_absl_logger().addFilter(noise_filter)
422+
423+
max_logging.log("Starting RL Training")
424+
max_logging.log(f"Ensuring TensorBoard log directory exists: {trainer_config.tensorboard_dir}")
425+
if not epath.Path(trainer_config.tensorboard_dir).exists():
426+
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
427+
428+
if not epath.Path(trainer_config.checkpoint_dir).exists():
429+
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)
430+
431+
# Number of training steps.
432+
max_train_steps = int(
433+
trainer_config.num_batches
434+
* trainer_config.rl.num_iterations
435+
* trainer_config.train_fraction
436+
* trainer_config.num_epoch
437+
)
438+
# ====== Data ======
439+
# Create model tokenizer
440+
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
441+
442+
train_dataset, test_dataset = get_datasets(model_tokenizer, trainer_config)
435443

436444
# Load reference model
437445
max_logging.log("Creating reference model and also meshes for reference and rollout")

tests/unit/train_rl_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import unittest
1818
from unittest import mock
19+
import grain
1920
import pytest
2021
from types import SimpleNamespace
2122
import jax
@@ -203,6 +204,80 @@ def test_get_rollout_kwargs_errors(self):
203204
with self.assertRaisesRegex(ValueError, r"!= len\(sampler_devices\)"):
204205
train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8)
205206

207+
@pytest.mark.cpu_only
208+
def test_prompt_filtering(self):
209+
"""Test that prompts longer than max_prefill_predict_length are filtered out."""
210+
# Setup mocks
211+
mock_tokenizer = mock.MagicMock()
212+
213+
# Define tokenizer side effect
214+
def tokenize_side_effect(text):
215+
if text == "short":
216+
return [0] * 5
217+
else:
218+
return [0] * 15
219+
220+
mock_tokenizer.tokenize.side_effect = tokenize_side_effect
221+
222+
# Define dataset mock data
223+
train_data = [{"prompts": "short"}, {"prompts": "long"}, {"prompts": "short"}, {"prompts": "long"}]
224+
test_data = [{"prompts": "short"}, {"prompts": "long"}]
225+
train_map_ds = grain.MapDataset.source(train_data)
226+
test_map_ds = grain.MapDataset.source(test_data)
227+
228+
def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files=None, dataset_name=None):
229+
if split == "train":
230+
return train_map_ds
231+
else:
232+
return test_map_ds
233+
234+
# Configs
235+
trainer_config = SimpleNamespace(
236+
debug=SimpleNamespace(rl=False),
237+
tokenizer_path="dummy_path",
238+
dataset_name="dummy_dataset",
239+
train_split="train",
240+
eval_split="eval",
241+
hf_train_files=None,
242+
hf_eval_files=None,
243+
max_prefill_predict_length=10,
244+
batch_size=2,
245+
num_batches=2,
246+
train_fraction=1.0,
247+
num_epoch=1,
248+
num_test_batches=1,
249+
)
250+
251+
# Patch everything!
252+
with (
253+
mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect),
254+
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.makedirs"),
255+
mock.patch("maxtext.trainers.post_train.rl.train_rl.os.path.exists", return_value=True),
256+
):
257+
train_dataset, test_dataset = train_rl.get_datasets(mock_tokenizer, trainer_config)
258+
259+
# Check filtered train dataset
260+
elements = list(train_dataset)
261+
# dataset_size = 4. Indices [0,1,2,3] are [short, long, short, long].
262+
# Filtered results: [short, short].
263+
# batch(2) will return 1 batch of 2 elements.
264+
self.assertEqual(len(elements), 1)
265+
batch = elements[0]
266+
self.assertEqual(len(batch["prompts"]), 2)
267+
for prompt in batch["prompts"]:
268+
self.assertEqual(prompt, "short")
269+
270+
# Check filtered test dataset
271+
test_elements = list(test_dataset)
272+
# test_data indices [0,1] are [short, long].
273+
# num_test_batches=1, batch_size=2 -> test dataset_size = 2.
274+
# Filtering results: [short].
275+
# batch(2) will return 1 batch of 1 element.
276+
self.assertEqual(len(test_elements), 1)
277+
test_batch = test_elements[0]
278+
self.assertEqual(len(test_batch["prompts"]), 1)
279+
self.assertEqual(test_batch["prompts"][0], "short")
280+
206281

207282
if __name__ == "__main__":
208283
unittest.main()

0 commit comments

Comments
 (0)