1616"""
1717RL Evaluation Module.
1818"""
19- from math_verify import parse
19+ import collections
20+ import json
21+
2022from tqdm .auto import tqdm
2123from tunix .rl .rollout .base_rollout import RolloutConfig
2224
@@ -86,85 +88,91 @@ def generate_responses(
8688 return multiple_call_responses
8789
8890
89- def score_responses (tmvp_config , question , responses , answer ):
90- """
91- Score a set of responses for a single question.
91+ def _score_single (extracted , response , answers , tmvp_config , match_format ):
92+ """Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format)."""
93+ has_correct_format = match_format .search (response ) is not None
94+ try :
95+ is_correct , is_partially_correct = utils_rl .check_correctness (extracted , answers , tmvp_config )
96+ if tmvp_config .debug .rl :
97+ max_logging .log (f"Result has_correct_format: { has_correct_format } " )
98+ max_logging .log (f"Result is_correct: { is_correct } " )
99+ max_logging .log (f"Result is_partially_correct: { is_partially_correct } " )
100+ except Exception as e : # pylint: disable=broad-exception-caught
101+ is_correct , is_partially_correct = False , False
102+ if tmvp_config .debug .rl :
103+ max_logging .log (f"Evaluation Exception: { e } — SKIPPED" )
104+ return is_correct , is_partially_correct , has_correct_format
105+
106+
107+ def score_responses (tmvp_config , question , responses , answers ):
108+ """Score a set of responses for a single question.
92109
93110 Args:
94111 tmvp_config: Configuration object
95112 question: The evaluation question
96113 responses: List of generated responses for this question
97- answer: The correct answer
114+ answers: List of correct answers
98115
99116 Returns:
100117 Tuple of (is_correct, is_partially_correct, has_correct_format)
101118 """
102- match_format = utils_rl .get_match_format_regex (tmvp_config )
103- answer_fallback = utils_rl .get_answer_fallback_regex (tmvp_config )
104-
105119 if tmvp_config .debug .rl :
106120 max_logging .log ("========================================" )
107121 max_logging .log (f"Evaluation Question: { question } " )
108- max_logging .log (f"Evaluation Answer: { answer } " )
122+ max_logging .log (f"Evaluation Answer: { answers } " )
109123 max_logging .log (f"Evaluation Responses: { responses } " )
110124 max_logging .log ("========================================" )
111125
112- is_correct = False
113- is_partially_correct = False
114- has_correct_format = False
115-
116- for response in responses :
117- # Extract answer: prefer the full format match; fall back to the last
118- # <answer>...</answer> tag if full format match is not found, so result
119- # scoring is decoupled from format.
120- full_match = match_format .search (response )
121- if full_match is not None :
122- extracted_response = full_match .group (1 )
123- else :
124- # Find the *last* occurrence of the answer tag (most likely the final answer).
125- fallback_matches = answer_fallback .findall (response )
126- extracted_response = fallback_matches [- 1 ].strip () if fallback_matches else "-1000000"
126+ eval_mode = getattr (tmvp_config , "eval_mode" , "pass" )
127+ match_format = utils_rl .get_match_format_regex (tmvp_config )
128+ extracted_responses = [utils_rl .extract_answer (r , tmvp_config ) for r in responses ]
129+
130+ if not extracted_responses :
131+ return False , False , False
132+
133+ if eval_mode == "maj" :
134+ # extract the single-most frequent response
135+ counter = collections .Counter (extracted_responses )
136+ majority = counter .most_common (1 )[0 ][0 ]
127137 if tmvp_config .debug .rl :
128- used = "full format" if full_match is not None else "answer-tag fallback"
129- max_logging .log (f"Evaluation extracted_response ({ used } ): { extracted_response } " )
130-
131- # Check exact correctness
132- try :
133- # Fix LaTeX escaping issues for both ground truth and extracted answer
134- norm_answer = utils_rl .fix_latex_escaping (answer )
135- norm_extracted = utils_rl .fix_latex_escaping (extracted_response )
136- # Normalize Normalize for certain datasets and parse
137- if "DAPO" in tmvp_config .dataset_name or "OpenMathInstruct" in tmvp_config .dataset_name :
138- norm_extracted = utils_rl .normalize_final_answer (norm_extracted ).strip ()
139- norm_answer = utils_rl .normalize_final_answer (answer ).strip ()
140- is_correct = utils_rl .math_verify_func ([utils_rl .boxed (norm_answer )], [utils_rl .boxed (norm_extracted )])[0 ] > 0.1
141- if tmvp_config .debug .rl :
142- # is_correct is a tuple, if first value is 1.0 means it's a match;
143- # 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
144- max_logging .log (f"Result is_correct: { is_correct } " )
145-
146- val_extracted = parse (utils_rl .boxed (norm_extracted ))
147- val_answer = parse (utils_rl .boxed (norm_answer ))
148-
149- # Check partial correctness if values can be extracted (within 10%)
150- if val_extracted and val_answer :
151- ratio = (val_extracted [0 ] + utils_rl .EPSILON ) / (val_answer [0 ] + utils_rl .EPSILON )
152- is_partially_correct = 0.9 <= ratio <= 1.1
153-
154- except Exception as e :
155- if tmvp_config .debug .rl :
156- max_logging .log (f"Evaluation Exception: { e } " )
157- max_logging .log ("SKIPPED" )
158-
159- # Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160- if full_match is not None :
161- has_correct_format = True
162-
163- # Early exit if all criteria are met
164- if is_correct and is_partially_correct and has_correct_format :
165- break
138+ max_logging .log (f"Majority Response: { majority } (Count: { counter [majority ]} )" )
166139
167- return is_correct , is_partially_correct , has_correct_format
140+ # Check the format for the majority response
141+ has_correct_format = any (
142+ match_format .search (responses [idx ]) is not None
143+ for idx , response in enumerate (extracted_responses )
144+ if response == majority
145+ )
146+ is_correct , is_partially_correct , _ = _score_single (majority , responses [0 ], answers , tmvp_config , match_format )
147+ return is_correct , is_partially_correct , has_correct_format
148+
149+ if eval_mode == "pass" :
150+ result = False , False , False
151+ for extracted , response in zip (extracted_responses , responses ):
152+ result = _score_single (extracted , response , answers , tmvp_config , match_format )
153+ # Early exit if all criteria are met
154+ if all (result ):
155+ return result
156+ return result
157+
158+ if eval_mode == "pass_at_1" :
159+ # Estimate pass@1: fraction of N samples that are correct per problem.
160+ # Returns floats in [0, 1] instead of booleans.
161+ scores = [
162+ _score_single (extracted_response , response , answers , tmvp_config , match_format )
163+ for extracted_response , response in zip (extracted_responses , responses )
164+ ]
165+ n_samples = len (scores )
166+ frac_correct = sum (s [0 ] for s in scores ) / n_samples
167+ frac_partial = sum (s [1 ] for s in scores ) / n_samples
168+ frac_format = sum (s [2 ] for s in scores ) / n_samples
169+ if tmvp_config .debug .rl :
170+ max_logging .log (f"{ frac_correct * n_samples :.0f} /{ n_samples } correct" )
171+ max_logging .log (f"{ frac_partial * n_samples :.0f} /{ n_samples } partial" )
172+ max_logging .log (f"{ frac_format * n_samples :.0f} /{ n_samples } format" )
173+ return frac_correct , frac_partial , frac_format
174+
175+ raise ValueError (f"Unknown eval_mode: { eval_mode !r} " )
168176
169177
170178def evaluate (
@@ -210,28 +218,29 @@ def evaluate(
210218
211219 # Score each question-answer pair
212220 for question , responses , answer in zip (questions , multiple_call_responses , answers ):
221+ # decode the json-encoded list of acceptable answers
222+ answer = list (dict .fromkeys (json .loads (answer )))
213223 is_correct , is_partially_correct , has_correct_format = score_responses (
214224 tmvp_config = tmvp_config ,
215225 question = question ,
216226 responses = responses ,
217- answer = answer ,
227+ answers = answer ,
218228 )
219229
220- # Update counters
221- if is_correct :
222- corr += 1
223- if corr_lst and make_lst :
230+ # Update counters. For "pass" and "maj" modes, scores are booleans
231+ # (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1]
232+ # representing the fraction of samples correct. Using += works for both:
233+ # bool is a subtype of int in Python, so True += is the same as += 1.
234+ corr += is_correct
235+ partially_corr += is_partially_correct
236+ corr_format += has_correct_format
237+
238+ if make_lst :
239+ if corr_lst and is_correct :
224240 response_lst .append ((question , answer , responses ))
225- else :
226- if not corr_lst and make_lst :
241+ elif not corr_lst and not is_correct :
227242 response_lst .append ((question , answer , responses ))
228243
229- if is_partially_correct :
230- partially_corr += 1
231-
232- if has_correct_format :
233- corr_format += 1
234-
235244 total += 1
236245
237246 # Print progress every 10 items
@@ -243,8 +252,8 @@ def evaluate(
243252
244253 # Prepare return values
245254 to_return = (
246- corr ,
247- total ,
255+ corr * num_passes ,
256+ total * num_passes ,
248257 corr / total * 100 if total > 0 else 0 ,
249258 partially_corr / total * 100 if total > 0 else 0 ,
250259 corr_format / total * 100 if total > 0 else 0 ,
0 commit comments