Skip to content

Commit 17dda16

Browse files
SurbhiJainUSCA9isha
authored andcommitted
Add open-r1/OpenR1-Math-220k dataset to RL
1 parent 04a07ed commit 17dda16

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,12 @@ def prepare_datasets(trainer_config, model_tokenizer):
305305
os.makedirs(test_data_dir)
306306

307307
# Prepare train and test data from training data for certain datasets
308-
if trainer_config.dataset_name in ["nvidia/OpenMathInstruct-2", "nvidia/OpenMathReasoning", "open-r1/OpenR1-Math-220k", "bethgelab/CuratedThoughts"]:
308+
if trainer_config.dataset_name in [
309+
"nvidia/OpenMathInstruct-2",
310+
"nvidia/OpenMathReasoning",
311+
"open-r1/OpenR1-Math-220k",
312+
"bethgelab/CuratedThoughts",
313+
]:
309314
import datasets # pylint: disable=import-outside-toplevel
310315

311316
def prepare_train_and_eval_dataset(
@@ -325,7 +330,7 @@ def prepare_train_and_eval_dataset(
325330
)
326331

327332
if "OpenMathReasoning" in trainer_config.dataset_name:
328-
original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted")
333+
original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted")
329334

330335
# Split into train and validation sets using HF's train_test_split
331336
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)

0 commit comments

Comments
 (0)