1616"""
1717RL Evaluation Module.
1818"""
19- from math_verify import parse
19+ import collections
20+ import json
21+
2022from tqdm .auto import tqdm
2123from tunix .rl .rollout .base_rollout import RolloutConfig
2224
@@ -86,85 +88,82 @@ def generate_responses(
8688 return multiple_call_responses
8789
8890
89- def score_responses (tmvp_config , question , responses , answer ):
90- """
91- Score a set of responses for a single question.
91+ def _score_single (extracted , response , answers , tmvp_config , match_format ):
92+ """Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format)."""
93+ has_correct_format = match_format .search (response ) is not None
94+ try :
95+ is_correct , is_partially_correct = utils_rl .check_correctness (extracted , answers , tmvp_config )
96+ if tmvp_config .debug .rl :
97+ max_logging .log (f"Result has_correct_format: { has_correct_format } " )
98+ max_logging .log (f"Result is_correct: { is_correct } " )
99+ max_logging .log (f"Result is_partially_correct: { is_partially_correct } " )
100+ except Exception as e : # pylint: disable=broad-exception-caught
101+ is_correct , is_partially_correct = False , False
102+ if tmvp_config .debug .rl :
103+ max_logging .log (f"Evaluation Exception: { e } — SKIPPED" )
104+ return is_correct , is_partially_correct , has_correct_format
92105
93- Args:
94- tmvp_config: Configuration object
95- question: The evaluation question
96- responses: List of generated responses for this question
97- answer: The correct answer
106+
107+ def score_responses (tmvp_config , question , responses , answers ):
108+ """Score a set of responses for a single question.
98109
99110 Returns:
100- Tuple of (is_correct, is_partially_correct, has_correct_format)
111+ Tuple of (is_correct, is_partially_correct, has_correct_format).
101112 """
102- match_format = utils_rl .get_match_format_regex (tmvp_config )
103- answer_fallback = utils_rl .get_answer_fallback_regex (tmvp_config )
104-
105113 if tmvp_config .debug .rl :
106114 max_logging .log ("========================================" )
107115 max_logging .log (f"Evaluation Question: { question } " )
108- max_logging .log (f"Evaluation Answer: { answer } " )
116+ max_logging .log (f"Evaluation Answer: { answers } " )
109117 max_logging .log (f"Evaluation Responses: { responses } " )
110118 max_logging .log ("========================================" )
111119
112- is_correct = False
113- is_partially_correct = False
114- has_correct_format = False
115-
116- for response in responses :
117- # Extract answer: prefer the full format match; fall back to the last
118- # <answer>...</answer> tag if full format match is not found, so result
119- # scoring is decoupled from format.
120- full_match = match_format .search (response )
121- if full_match is not None :
122- extracted_response = full_match .group (1 )
123- else :
124- # Find the *last* occurrence of the answer tag (most likely the final answer).
125- fallback_matches = answer_fallback .findall (response )
126- extracted_response = fallback_matches [- 1 ].strip () if fallback_matches else "-1000000"
120+ eval_mode = getattr (tmvp_config , "eval_mode" , "pass" )
121+ match_format = utils_rl .get_match_format_regex (tmvp_config )
122+ extracted_responses = [utils_rl .extract_answer (r , tmvp_config ) for r in responses ]
123+
124+ if not extracted_responses :
125+ return False , False , False
126+
127+ if eval_mode == "maj" :
128+ # extract the single-most frequent response
129+ counter = collections .Counter (extracted_responses )
130+ majority = counter .most_common (1 )[0 ][0 ]
127131 if tmvp_config .debug .rl :
128- used = "full format" if full_match is not None else "answer-tag fallback"
129- max_logging .log (f"Evaluation extracted_response ({ used } ): { extracted_response } " )
130-
131- # Check exact correctness
132- try :
133- # Fix LaTeX escaping issues for both ground truth and extracted answer
134- norm_answer = utils_rl .fix_latex_escaping (answer )
135- norm_extracted = utils_rl .fix_latex_escaping (extracted_response )
136- # Normalize Normalize for certain datasets and parse
137- if "DAPO" in tmvp_config .dataset_name or "OpenMathInstruct" in tmvp_config .dataset_name :
138- norm_extracted = utils_rl .normalize_final_answer (norm_extracted ).strip ()
139- norm_answer = utils_rl .normalize_final_answer (answer ).strip ()
140- is_correct = utils_rl .math_verify_func ([utils_rl .boxed (norm_answer )], [utils_rl .boxed (norm_extracted )])[0 ] > 0.1
141- if tmvp_config .debug .rl :
142- # is_correct is a tuple, if first value is 1.0 means it's a match;
143- # 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
144- max_logging .log (f"Result is_correct: { is_correct } " )
145-
146- val_extracted = parse (utils_rl .boxed (norm_extracted ))
147- val_answer = parse (utils_rl .boxed (norm_answer ))
148-
149- # Check partial correctness if values can be extracted (within 10%)
150- if val_extracted and val_answer :
151- ratio = (val_extracted [0 ] + utils_rl .EPSILON ) / (val_answer [0 ] + utils_rl .EPSILON )
152- is_partially_correct = 0.9 <= ratio <= 1.1
153-
154- except Exception as e :
155- if tmvp_config .debug .rl :
156- max_logging .log (f"Evaluation Exception: { e } " )
157- max_logging .log ("SKIPPED" )
158-
159- # Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160- if full_match is not None :
161- has_correct_format = True
162-
163- # Early exit if all criteria are met
164- if is_correct and is_partially_correct and has_correct_format :
165- break
132+ max_logging .log (f"Majority Response: { majority } (Count: { counter [majority ]} )" )
166133
167- return is_correct , is_partially_correct , has_correct_format
134+ # Check the format for the majority response
135+ has_correct_format = any (
136+ match_format .search (responses [idx ]) is not None
137+ for idx , response in enumerate (extracted_responses )
138+ if response == majority
139+ )
140+ is_correct , is_partially_correct , _ = _score_single (majority , responses [0 ], answers , tmvp_config , match_format )
141+ return is_correct , is_partially_correct , has_correct_format
142+
143+ if eval_mode == "pass" :
144+ result = False , False , False
145+ for extracted , response in zip (extracted_responses , responses ):
146+ result = _score_single (extracted , response , answers , tmvp_config , match_format )
147+ # Early exit if all criteria are met
148+ if all (result ):
149+ return result
150+ return result
151+
152+ if eval_mode == "pass_at_1" :
153+ # Estimate pass@1: fraction of N samples that are correct per problem.
154+ # Returns floats in [0, 1] instead of booleans.
155+ scores = [_score_single (extracted_response , response , answers , tmvp_config , match_format ) for extracted_response , response in zip (extracted_responses , responses )]
156+ n_samples = len (scores )
157+ frac_correct = sum (s [0 ] for s in scores ) / n_samples
158+ frac_partial = sum (s [1 ] for s in scores ) / n_samples
159+ frac_format = sum (s [2 ] for s in scores ) / n_samples
160+ if tmvp_config .debug .rl :
161+ max_logging .log (f"pass@1: { frac_correct * n_samples :.0f} /{ n_samples } correct, { frac_partial * n_samples :.0f} /{ n_samples } partial, { frac_format * n_samples :.0f} /{ n_samples } format" )
162+ return frac_correct , frac_partial , frac_format
163+
164+ if tmvp_config .debug .rl :
165+ max_logging .log (f"Unknown eval mode: { eval_mode } " )
166+ raise ValueError (f"Unknown eval_mode: { eval_mode !r} " )
168167
169168
170169def evaluate (
@@ -210,28 +209,29 @@ def evaluate(
210209
211210 # Score each question-answer pair
212211 for question , responses , answer in zip (questions , multiple_call_responses , answers ):
212+ # decode the json-encoded list of acceptable answers
213+ answer = list (dict .fromkeys (json .loads (answer )))
213214 is_correct , is_partially_correct , has_correct_format = score_responses (
214215 tmvp_config = tmvp_config ,
215216 question = question ,
216217 responses = responses ,
217- answer = answer ,
218+ answers = answer ,
218219 )
219220
220- # Update counters
221- if is_correct :
222- corr += 1
223- if corr_lst and make_lst :
221+ # Update counters. For "pass" and "maj" modes, scores are booleans
222+ # (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
223+ # representing the fraction of samples correct. Using += works for both:
224+ # bool is a subtype of int in Python, so True += is the same as += 1.
225+ corr += is_correct
226+ partially_corr += is_partially_correct
227+ corr_format += has_correct_format
228+
229+ if make_lst :
230+ if corr_lst and is_correct :
224231 response_lst .append ((question , answer , responses ))
225- else :
226- if not corr_lst and make_lst :
232+ elif not corr_lst and not is_correct :
227233 response_lst .append ((question , answer , responses ))
228234
229- if is_partially_correct :
230- partially_corr += 1
231-
232- if has_correct_format :
233- corr_format += 1
234-
235235 total += 1
236236
237237 # Print progress every 10 items
@@ -243,8 +243,8 @@ def evaluate(
243243
244244 # Prepare return values
245245 to_return = (
246- corr ,
247- total ,
246+ corr * num_passes ,
247+ total * num_passes ,
248248 corr / total * 100 if total > 0 else 0 ,
249249 partially_corr / total * 100 if total > 0 else 0 ,
250250 corr_format / total * 100 if total > 0 else 0 ,
0 commit comments