Skip to content

Commit f76ae69

Browse files
committed
temp
1 parent 6a80012 commit f76ae69

2 files changed

Lines changed: 83 additions & 64 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -121,29 +121,18 @@ def get_dataset(
121121
if dataset_name is None:
122122
raise ValueError("dataset_name must be provided")
123123

124-
if dataset_name.startswith("huggingface:"):
125-
import datasets # pylint: disable=import-outside-toplevel
126-
127-
if data_files is None:
128-
hf_dataset_name = dataset_name.replace("huggingface:", "")
129-
data = datasets.load_dataset(hf_dataset_name, split=split, cache_dir=data_dir)
130-
if tmvp_config.debug.rl:
131-
max_logging.log(f"Loaded Hugging Face dataset {hf_dataset_name} with split {split}. Size: {len(data)}")
132-
else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
133-
data = datasets.load_dataset(
134-
"parquet",
135-
data_files={tmvp_config.train_split: data_files},
136-
split=split,
137-
cache_dir=data_dir,
138-
)
139-
else:
140-
builder_kwargs = {"file_format": tfds.core.FileFormat.ARRAY_RECORD}
141-
data = tfds.data_source(
142-
dataset_name,
124+
import datasets # pylint: disable=import-outside-toplevel
125+
126+
if data_files is None:
127+
data = datasets.load_dataset(dataset_name, split=split, cache_dir=data_dir)
128+
if tmvp_config.debug.rl:
129+
max_logging.log(f"Loaded Hugging Face dataset {dataset_name} with split {split}. Size: {len(data)}")
130+
else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
131+
data = datasets.load_dataset(
132+
"parquet",
133+
data_files={tmvp_config.train_split: data_files},
143134
split=split,
144-
data_dir=data_dir,
145-
builder_kwargs=builder_kwargs,
146-
download=True,
135+
cache_dir=data_dir,
147136
)
148137

149138
template_config = load_template_from_file(tmvp_config.chat_template_path)
@@ -300,6 +289,37 @@ def get_max_train_steps(trainer_config):
300289
)
301290

302291

292+
def prepare_train_and_eval_dataset(
293+
trainer_config,
294+
seed: int = 42,
295+
test_size: float = 0.05,
296+
):
297+
"""Load and split the dataset into train and validation sets using HF's train_test_split."""
298+
import datasets # pylint: disable=import-outside-toplevel
299+
300+
max_logging.log(
301+
"WARNING: For reproducible experiments, preprocess the dataset once and "
302+
"define your own HfDataset subclass that directly uses the preprocessed datasets."
303+
)
304+
305+
original_ds = datasets.load_dataset(
306+
"parquet",
307+
data_files={trainer_config.train_split: trainer_config.hf_train_files},
308+
split=trainer_config.train_split,
309+
)
310+
311+
if "OpenMathReasoning" in trainer_config.dataset_name:
312+
original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted")
313+
314+
# Split into train and validation sets using HF's train_test_split
315+
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)
316+
317+
return {
318+
"train": split_ds["train"],
319+
"validation": split_ds["test"],
320+
}
321+
322+
303323
def prepare_datasets(trainer_config, model_tokenizer):
304324
"""Setup and return train and test datasets."""
305325
home = os.path.expanduser("~") + "/"
@@ -311,39 +331,16 @@ def prepare_datasets(trainer_config, model_tokenizer):
311331
os.makedirs(test_data_dir)
312332

313333
# Prepare train and test data from training data for certain datasets
314-
if trainer_config.dataset_name in ["nvidia/OpenMathInstruct-2", "nvidia/OpenMathReasoning", "open-r1/OpenR1-Math-220k", "bethgelab/CuratedThoughts"]:
334+
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
335+
if trainer_config.dataset_name in [
336+
"nvidia/OpenMathInstruct-2",
337+
"nvidia/OpenMathReasoning",
338+
"open-r1/OpenR1-Math-220k",
339+
"bethgelab/CuratedThoughts",
340+
] and (not eval_dataset_name or eval_dataset_name == trainer_config.dataset_name):
315341
import datasets # pylint: disable=import-outside-toplevel
316342

317-
def prepare_train_and_eval_dataset(
318-
seed: int = 42,
319-
test_size: float = 0.05,
320-
):
321-
"""Load and split the dataset into train and validation sets using HF's train_test_split."""
322-
max_logging.log(
323-
"WARNING: For reproducible experiments, preprocess the dataset once and "
324-
"define your own HfDataset subclass that directly uses the preprocessed datasets."
325-
)
326-
327-
# Load the original dataset
328-
original_ds = datasets.load_dataset(
329-
"parquet",
330-
data_files={trainer_config.train_split: trainer_config.hf_train_files},
331-
split=trainer_config.train_split,
332-
)
333-
334-
if "OpenMathReasoning" in trainer_config.dataset_name:
335-
original_ds = original_ds.filter(lambda x: x.get("problem_type") == "has_answer_extracted")
336-
337-
338-
# Split into train and validation sets using HF's train_test_split
339-
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)
340-
341-
return {
342-
"train": split_ds["train"],
343-
"validation": split_ds["test"],
344-
}
345-
346-
splits = prepare_train_and_eval_dataset()
343+
splits = prepare_train_and_eval_dataset(trainer_config)
347344
template_config = load_template_from_file(trainer_config.chat_template_path)
348345

349346
train_dataset = (

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@
4242
LatexExtractionConfig(),
4343
)
4444

45+
4546
def math_verify_func(items, timeout=5):
4647
"""Verifies a batch of math problems, handling timeouts and exceptions."""
4748
with concurrent.futures.ThreadPoolExecutor() as executor:
48-
future_to_index = {executor.submit(verify_math, golds, predictions): idx for idx, (_, golds, predictions) in enumerate(items)}
49+
future_to_index = {
50+
executor.submit(verify_math, golds, predictions): idx for idx, (_, golds, predictions) in enumerate(items)
51+
}
4952
results = [0.0] * len(items)
5053
for future in concurrent.futures.as_completed(future_to_index):
5154
index = future_to_index[future]
@@ -59,8 +62,12 @@ def math_verify_func(items, timeout=5):
5962
def verify_math(golds, predictions):
6063
"""Runs mathematical expression evaluation on ground-truth and predictions."""
6164

62-
extracted_predictions = list(itertools.chain.from_iterable(parse(pred, PRED_EXTRACTION_TARGET, parsing_timeout=None) for pred in predictions))
63-
extracted_golds = list(itertools.chain.from_iterable(parse(gold, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for gold in golds))
65+
extracted_predictions = list(
66+
itertools.chain.from_iterable(parse(pred, PRED_EXTRACTION_TARGET, parsing_timeout=None) for pred in predictions)
67+
)
68+
extracted_golds = list(
69+
itertools.chain.from_iterable(parse(gold, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for gold in golds)
70+
)
6471
# If no predictions or golds were extracted, return 0.0
6572
if not extracted_predictions or not extracted_golds:
6673
return 0.0
@@ -72,6 +79,7 @@ def verify_math(golds, predictions):
7279
]
7380
)
7481

82+
7583
def boxed(x):
7684
"""Wraps the input string in a LaTeX boxed command if it's not already wrapped."""
7785
return "\\boxed{" + x + "}" if not x.startswith("\\boxed{") else x
@@ -267,7 +275,10 @@ def normalize_final_answer(final_answer: str) -> str:
267275
def preprocess_math_string(dataset_name, text) -> str:
268276
"""Fix common formatting issues in text."""
269277
# Normalize for certain datasets and parse
270-
if any(name in dataset_name for name in ["DAPO", "OpenMathInstruct", "OpenMathReasoning", "OpenR1-Math-220k", "CuratedThoughts"]):
278+
if any(
279+
name in dataset_name
280+
for name in ["DAPO", "OpenMathInstruct", "OpenMathReasoning", "OpenR1-Math-220k", "CuratedThoughts", "MATH-500"]
281+
):
271282
text = normalize_final_answer(text).strip()
272283
# Fix LaTeX escaping issues
273284
text = fix_latex_escaping(text)
@@ -418,7 +429,11 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
418429
# 3. As a fallback, try numeric comparison if both can be parsed as numbers
419430
try:
420431
predictions = parse(norm_guesses[0], PRED_EXTRACTION_TARGET, parsing_timeout=None)
421-
golds = list(itertools.chain.from_iterable(parse(norm_answer, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
432+
golds = list(
433+
itertools.chain.from_iterable(
434+
parse(norm_answer, GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers
435+
)
436+
)
422437
for gold in golds:
423438
for pred in predictions:
424439
try:
@@ -430,9 +445,11 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
430445
else:
431446
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer)
432447
except:
433-
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer)
448+
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_answer)
434449
except:
435-
scores[gen_idx] = max(scores[gen_idx], tmvp_config.penalty_incorrect_format) # Penalize if we can't parse numbers at all
450+
scores[gen_idx] = max(
451+
scores[gen_idx], tmvp_config.penalty_incorrect_format
452+
) # Penalize if we can't parse numbers at all
436453
if tmvp_config.debug.rl:
437454
debug_log_path = epath.Path(tmvp_config.base_output_directory) / tmvp_config.run_name / "debug_rl_logs"
438455
debug_log_path.mkdir(parents=True, exist_ok=True)
@@ -469,10 +486,11 @@ def extract_hash_answer(text: str) -> str | None:
469486
def check_correctness(extracted_response, acceptable_answers, tmvp_config):
470487
"""Handles math verification and partial correctness logic."""
471488
norm_answers = []
472-
norm_response = preprocess_math_string(tmvp_config.dataset_name, extracted_response)
489+
dataset_name = tmvp_config.eval_dataset_name if tmvp_config.eval_dataset_name else tmvp_config.dataset_name
490+
norm_response = preprocess_math_string(dataset_name, extracted_response)
473491
# Check exact correctness first
474-
for answer in acceptable_answers:
475-
norm_answers.append(preprocess_math_string(tmvp_config.dataset_name, answer))
492+
for answer in acceptable_answers:
493+
norm_answers.append(preprocess_math_string(dataset_name, answer))
476494
is_correct = verify_math([boxed(norm_answer) for norm_answer in norm_answers], [boxed(norm_response)]) > 0.1
477495
if is_correct:
478496
return True, True # Exact correctness implies partial correctness
@@ -481,7 +499,11 @@ def check_correctness(extracted_response, acceptable_answers, tmvp_config):
481499
is_partially_correct = False
482500
try:
483501
predictions = parse(boxed(norm_response), PRED_EXTRACTION_TARGET, parsing_timeout=None)
484-
golds = list(itertools.chain.from_iterable(parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers))
502+
golds = list(
503+
itertools.chain.from_iterable(
504+
parse(boxed(norm_answer), GOLD_EXTRACTION_TARGET, parsing_timeout=None) for norm_answer in norm_answers
505+
)
506+
)
485507
is_partially_correct = any(
486508
0.9 <= (float(pred) + EPSILON) / (float(gold) + EPSILON) <= 1.1 for pred in predictions for gold in golds
487509
)

0 commit comments

Comments
 (0)