|
16 | 16 | """ |
17 | 17 | RL Evaluation Module. |
18 | 18 | """ |
19 | | -from math_verify import parse |
| 19 | +import collections |
| 20 | +import json |
| 21 | + |
20 | 22 | from tqdm.auto import tqdm |
21 | 23 | from tunix.rl.rollout.base_rollout import RolloutConfig |
22 | 24 |
|
@@ -86,85 +88,87 @@ def generate_responses( |
86 | 88 | return multiple_call_responses |
87 | 89 |
|
88 | 90 |
|
89 | | -def score_responses(tmvp_config, question, responses, answer): |
90 | | - """ |
91 | | - Score a set of responses for a single question. |
| 91 | +def _score_single(extracted, response, answers, tmvp_config, match_format): |
| 92 | + """Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format).""" |
| 93 | + has_correct_format = match_format.search(response) is not None |
| 94 | + try: |
| 95 | + is_correct, is_partially_correct = utils_rl.check_correctness(extracted, answers, tmvp_config) |
| 96 | + if tmvp_config.debug.rl: |
| 97 | + max_logging.log(f"Result has_correct_format: {has_correct_format}") |
| 98 | + max_logging.log(f"Result is_correct: {is_correct}") |
| 99 | + max_logging.log(f"Result is_partially_correct: {is_partially_correct}") |
| 100 | + except Exception as e: # pylint: disable=broad-exception-caught |
| 101 | + is_correct, is_partially_correct = False, False |
| 102 | + if tmvp_config.debug.rl: |
| 103 | + max_logging.log(f"Evaluation Exception: {e} — SKIPPED") |
| 104 | + return is_correct, is_partially_correct, has_correct_format |
92 | 105 |
|
93 | | - Args: |
94 | | - tmvp_config: Configuration object |
95 | | - question: The evaluation question |
96 | | - responses: List of generated responses for this question |
97 | | - answer: The correct answer |
| 106 | + |
| 107 | +def score_responses(tmvp_config, question, responses, answers): |
| 108 | + """Score a set of responses for a single question. |
98 | 109 |
|
99 | 110 | Returns: |
100 | | - Tuple of (is_correct, is_partially_correct, has_correct_format) |
| 111 | + Tuple of (is_correct, is_partially_correct, has_correct_format). |
101 | 112 | """ |
102 | | - match_format = utils_rl.get_match_format_regex(tmvp_config) |
103 | | - answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config) |
104 | | - |
105 | 113 | if tmvp_config.debug.rl: |
106 | 114 | max_logging.log("========================================") |
107 | 115 | max_logging.log(f"Evaluation Question: {question}") |
108 | | - max_logging.log(f"Evaluation Answer: {answer}") |
| 116 | + max_logging.log(f"Evaluation Answer: {answers}") |
109 | 117 | max_logging.log(f"Evaluation Responses: {responses}") |
110 | 118 | max_logging.log("========================================") |
111 | 119 |
|
112 | | - is_correct = False |
113 | | - is_partially_correct = False |
114 | | - has_correct_format = False |
115 | | - |
116 | | - for response in responses: |
117 | | - # Extract answer: prefer the full format match; fall back to the last |
118 | | - # <answer>...</answer> tag if full format match is not found, so result |
119 | | - # scoring is decoupled from format. |
120 | | - full_match = match_format.search(response) |
121 | | - if full_match is not None: |
122 | | - extracted_response = full_match.group(1) |
123 | | - else: |
124 | | - # Find the *last* occurrence of the answer tag (most likely the final answer). |
125 | | - fallback_matches = answer_fallback.findall(response) |
126 | | - extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000" |
| 120 | + eval_mode = getattr(tmvp_config, "eval_mode", "pass") |
| 121 | + match_format = utils_rl.get_match_format_regex(tmvp_config) |
| 122 | + extracted_responses = [utils_rl.extract_answer(r, tmvp_config) for r in responses] |
| 123 | + |
| 124 | + if not extracted_responses: |
| 125 | + return False, False, False |
| 126 | + |
| 127 | + if eval_mode == "maj": |
| 128 | + # extract the single-most frequent response |
| 129 | + counter = collections.Counter(extracted_responses) |
| 130 | + majority = counter.most_common(1)[0][0] |
127 | 131 | if tmvp_config.debug.rl: |
128 | | - used = "full format" if full_match is not None else "answer-tag fallback" |
129 | | - max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}") |
130 | | - |
131 | | - # Check exact correctness |
132 | | - try: |
133 | | - # Fix LaTeX escaping issues for both ground truth and extracted answer |
134 | | - norm_answer = utils_rl.fix_latex_escaping(answer) |
135 | | - norm_extracted = utils_rl.fix_latex_escaping(extracted_response) |
136 | | - # Normalize Normalize for certain datasets and parse |
137 | | - if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name: |
138 | | - norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip() |
139 | | - norm_answer = utils_rl.normalize_final_answer(answer).strip() |
140 | | - is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1 |
141 | | - if tmvp_config.debug.rl: |
142 | | - # is_correct is a tuple, if first value is 1.0 means it's a match; |
143 | | - # 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}'])) |
144 | | - max_logging.log(f"Result is_correct: {is_correct}") |
145 | | - |
146 | | - val_extracted = parse(utils_rl.boxed(norm_extracted)) |
147 | | - val_answer = parse(utils_rl.boxed(norm_answer)) |
148 | | - |
149 | | - # Check partial correctness if values can be extracted (within 10%) |
150 | | - if val_extracted and val_answer: |
151 | | - ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON) |
152 | | - is_partially_correct = 0.9 <= ratio <= 1.1 |
153 | | - |
154 | | - except Exception as e: |
155 | | - if tmvp_config.debug.rl: |
156 | | - max_logging.log(f"Evaluation Exception: {e}") |
157 | | - max_logging.log("SKIPPED") |
158 | | - |
159 | | - # Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure) |
160 | | - if full_match is not None: |
161 | | - has_correct_format = True |
162 | | - |
163 | | - # Early exit if all criteria are met |
164 | | - if is_correct and is_partially_correct and has_correct_format: |
165 | | - break |
| 132 | + max_logging.log(f"Majority Response: {majority} (Count: {counter[majority]})") |
166 | 133 |
|
167 | | - return is_correct, is_partially_correct, has_correct_format |
| 134 | + # Check the format for the majority response |
| 135 | + has_correct_format = any( |
| 136 | + match_format.search(responses[idx]) is not None |
| 137 | + for idx, response in enumerate(extracted_responses) |
| 138 | + if response == majority |
| 139 | + ) |
| 140 | + is_correct, is_partially_correct, _ = _score_single(majority, responses[0], answers, tmvp_config, match_format) |
| 141 | + return is_correct, is_partially_correct, has_correct_format |
| 142 | + |
| 143 | + if eval_mode == "pass": |
| 144 | + result = False, False, False |
| 145 | + for extracted, response in zip(extracted_responses, responses): |
| 146 | + result = _score_single(extracted, response, answers, tmvp_config, match_format) |
| 147 | + # Early exit if all criteria are met |
| 148 | + if all(result): |
| 149 | + return result |
| 150 | + return result |
| 151 | + |
| 152 | + if eval_mode == "pass_at_1": |
| 153 | + # Estimate pass@1: fraction of N samples that are correct per problem. |
| 154 | + # Returns floats in [0, 1] instead of booleans. |
| 155 | + scores = [ |
| 156 | + _score_single(extracted_response, response, answers, tmvp_config, match_format) |
| 157 | + for extracted_response, response in zip(extracted_responses, responses) |
| 158 | + ] |
| 159 | + n_samples = len(scores) |
| 160 | + frac_correct = sum(s[0] for s in scores) / n_samples |
| 161 | + frac_partial = sum(s[1] for s in scores) / n_samples |
| 162 | + frac_format = sum(s[2] for s in scores) / n_samples |
| 163 | + if tmvp_config.debug.rl: |
| 164 | + max_logging.log(f"Result has_correct_format: {frac_format*n_samples:.0f}/{n_samples}") |
| 165 | + max_logging.log(f"Result is_correct: {frac_correct*n_samples:.0f}/{n_samples}") |
| 166 | + max_logging.log(f"Result is_partially_correct: {frac_partial*n_samples:.0f}/{n_samples}") |
| 167 | + return frac_correct, frac_partial, frac_format |
| 168 | + |
| 169 | + if tmvp_config.debug.rl: |
| 170 | + max_logging.log(f"Unknown eval mode: {eval_mode}") |
| 171 | + raise ValueError(f"Unknown eval_mode: {eval_mode!r}") |
168 | 172 |
|
169 | 173 |
|
170 | 174 | def evaluate( |
@@ -210,28 +214,29 @@ def evaluate( |
210 | 214 |
|
211 | 215 | # Score each question-answer pair |
212 | 216 | for question, responses, answer in zip(questions, multiple_call_responses, answers): |
| 217 | + # decode the json-encoded list of acceptable answers |
| 218 | + answer = list(dict.fromkeys(json.loads(answer))) |
213 | 219 | is_correct, is_partially_correct, has_correct_format = score_responses( |
214 | 220 | tmvp_config=tmvp_config, |
215 | 221 | question=question, |
216 | 222 | responses=responses, |
217 | | - answer=answer, |
| 223 | + answers=answer, |
218 | 224 | ) |
219 | 225 |
|
220 | | - # Update counters |
221 | | - if is_correct: |
222 | | - corr += 1 |
223 | | - if corr_lst and make_lst: |
| 226 | + # Update counters. For "pass" and "maj" modes, scores are booleans |
| 227 | + # (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1] |
| 228 | + # representing the fraction of samples correct. Using += works for both: |
| 229 | + # bool is a subtype of int in Python, so True += is the same as += 1. |
| 230 | + corr += is_correct |
| 231 | + partially_corr += is_partially_correct |
| 232 | + corr_format += has_correct_format |
| 233 | + |
| 234 | + if make_lst: |
| 235 | + if corr_lst and is_correct: |
224 | 236 | response_lst.append((question, answer, responses)) |
225 | | - else: |
226 | | - if not corr_lst and make_lst: |
| 237 | + elif not corr_lst and not is_correct: |
227 | 238 | response_lst.append((question, answer, responses)) |
228 | 239 |
|
229 | | - if is_partially_correct: |
230 | | - partially_corr += 1 |
231 | | - |
232 | | - if has_correct_format: |
233 | | - corr_format += 1 |
234 | | - |
235 | 240 | total += 1 |
236 | 241 |
|
237 | 242 | # Print progress every 10 items |
|
0 commit comments