Skip to content

Commit 18c26ad

Browse files
committed
remove conditional preprocess_math_string
1 parent f9e2ea3 commit 18c26ad

5 files changed

Lines changed: 46 additions & 6 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ generation_configs:
147147
num_eval_passes: 1 # Number of generation passes during evaluation
148148
eval_corr_lst: False # If True, only include correct responses in the list during evaluation
149149
eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation
150+
eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K)
150151

151152
# ====== Inference ======
152153
# === Generation during GRPO training ===

src/maxtext/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,10 @@ class RLEvaluation(BaseModel):
17751775
False,
17761776
description="If True, return a list of (question, answer, responses) during evaluation.",
17771777
)
1778+
eval_mode: Literal["pass", "maj"] = Field(
1779+
"pass",
1780+
description="Evaluation mode to use ('pass' for pass@K, 'maj' for maj@K).",
1781+
)
17781782

17791783

17801784
class Reward(BaseModel):

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717
RL Evaluation Module.
1818
"""
19+
import collections
1920
import json
2021

2122
from tqdm.auto import tqdm
@@ -87,7 +88,7 @@ def generate_responses(
8788
return multiple_call_responses
8889

8990

90-
def score_responses(tmvp_config, question, responses, answers):
91+
def score_responses(tmvp_config, question, responses, answers, eval_mode="pass"):
9192
"""
9293
Score a set of responses for a single question.
9394
@@ -96,6 +97,7 @@ def score_responses(tmvp_config, question, responses, answers):
9697
question: The evaluation question
9798
responses: List of generated responses for this question
9899
answers: List of acceptable answers for this question
100+
eval_mode: The evaluation mode to use ("pass" for pass@K, "maj" for maj@K)
99101
100102
Returns:
101103
Tuple of (is_correct, is_partially_correct, has_correct_format)
@@ -112,6 +114,35 @@ def score_responses(tmvp_config, question, responses, answers):
112114
is_partially_correct = False
113115
has_correct_format = False
114116

117+
if eval_mode == "maj":
118+
extracted_answers = []
119+
for response in responses:
120+
match_format = utils_rl.get_match_format_regex(tmvp_config)
121+
if match_format.search(response) is not None:
122+
has_correct_format = True
123+
124+
extracted_response = utils_rl.extract_answer(response, tmvp_config)
125+
extracted_answers.append(extracted_response)
126+
127+
if not extracted_answers:
128+
return False, False, False
129+
130+
counter = collections.Counter(extracted_answers)
131+
majority_answer = counter.most_common(1)[0][0]
132+
133+
try:
134+
is_correct, is_partially_correct = utils_rl.check_correctness(majority_answer, answers, tmvp_config)
135+
if tmvp_config.debug.rl:
136+
max_logging.log(f"Majority Answer: {majority_answer} (Count: {counter[majority_answer]})")
137+
max_logging.log(f"Result is_correct: {is_correct}")
138+
max_logging.log(f"Result is_partially_correct: {is_partially_correct}")
139+
except Exception as e:
140+
if tmvp_config.debug.rl:
141+
max_logging.log(f"Evaluation Exception on majority answer: {e}")
142+
max_logging.log("SKIPPED")
143+
144+
return is_correct, is_partially_correct, has_correct_format
145+
115146
for response in responses:
116147
match_format = utils_rl.get_match_format_regex(tmvp_config)
117148
if match_format.search(response) is not None:
@@ -144,6 +175,7 @@ def evaluate(
144175
num_passes=1,
145176
corr_lst=False,
146177
make_lst=False,
178+
eval_mode=None,
147179
):
148180
"""
149181
Computes accuracy and percentage of outputs matching the format.
@@ -155,10 +187,14 @@ def evaluate(
155187
num_passes: Number of generation passes
156188
corr_lst: If True, only include correct responses in the list
157189
make_lst: If True, return a list of (question, answer, responses)
190+
eval_mode: Override for the evaluation mode ("pass" or "maj").
158191
159192
Returns:
160193
Tuple of statistics and optionally the response list
161194
"""
195+
if eval_mode is None:
196+
eval_mode = getattr(tmvp_config, "eval_mode", "pass")
197+
162198
response_lst = []
163199
corr = 0
164200
partially_corr = 0
@@ -187,6 +223,7 @@ def evaluate(
187223
question=question,
188224
responses=responses,
189225
answers=answer,
226+
eval_mode=eval_mode,
190227
)
191228

192229
# Update counters

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
726726
num_passes=trainer_config.num_eval_passes,
727727
corr_lst=trainer_config.eval_corr_lst,
728728
make_lst=trainer_config.eval_make_lst,
729+
eval_mode=getattr(trainer_config, "eval_mode", "pass"),
729730
)
730731
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
731732

@@ -755,6 +756,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
755756
num_passes=trainer_config.num_eval_passes,
756757
corr_lst=trainer_config.eval_corr_lst,
757758
make_lst=trainer_config.eval_make_lst,
759+
eval_mode=getattr(trainer_config, "eval_mode", "pass"),
758760
)
759761
max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
760762

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,7 @@ def normalize_final_answer(final_answer: str) -> str:
275275
def preprocess_math_string(dataset_name, text) -> str:
276276
"""Fix common formatting issues in text."""
277277
# Normalize for certain datasets and parse
278-
if any(
279-
name in dataset_name
280-
for name in ["DAPO", "OpenMathInstruct", "OpenMathReasoning", "OpenR1-Math-220k", "CuratedThoughts", "MATH-500"]
281-
):
282-
text = normalize_final_answer(text).strip()
278+
text = normalize_final_answer(text).strip()
283279
# Fix LaTeX escaping issues
284280
text = fix_latex_escaping(text)
285281
return text

0 commit comments

Comments
 (0)