|
16 | 16 | """ |
17 | 17 | RL Evaluation Module. |
18 | 18 | """ |
19 | | -from math_verify import parse |
| 19 | +import collections |
| 20 | +import json |
| 21 | +import re |
| 22 | +from typing import Any, Optional |
| 23 | + |
20 | 24 | from tqdm.auto import tqdm |
21 | 25 | from tunix.rl.rollout.base_rollout import RolloutConfig |
22 | 26 |
|
@@ -86,85 +90,97 @@ def generate_responses( |
86 | 90 | return multiple_call_responses |
87 | 91 |
|
88 | 92 |
|
89 | | -def score_responses(tmvp_config, question, responses, answer): |
90 | | - """ |
91 | | - Score a set of responses for a single question. |
| 93 | +def _score_single( |
| 94 | + extracted_response: str, |
| 95 | + raw_response: str, |
| 96 | + answers: list[str], |
| 97 | + tmvp_config: Any, |
| 98 | + match_format: re.Pattern[str], |
| 99 | +) -> tuple[bool, bool, bool]: |
| 100 | + """Score one (extracted answer, raw response) pair. Returns (is_correct, is_partially_correct, has_correct_format).""" |
| 101 | + has_correct_format = match_format.search(raw_response) is not None |
| 102 | + try: |
| 103 | + is_correct, is_partially_correct = utils_rl.check_correctness(extracted_response, answers, tmvp_config) |
| 104 | + if tmvp_config.debug.rl: |
| 105 | + max_logging.log(f"Result has_correct_format: {has_correct_format}") |
| 106 | + max_logging.log(f"Result is_correct: {is_correct}") |
| 107 | + max_logging.log(f"Result is_partially_correct: {is_partially_correct}") |
| 108 | + except Exception as e: # pylint: disable=broad-exception-caught |
| 109 | + is_correct, is_partially_correct = False, False |
| 110 | + if tmvp_config.debug.rl: |
| 111 | + max_logging.log(f"Evaluation Exception: {e} — SKIPPED") |
| 112 | + return is_correct, is_partially_correct, has_correct_format |
| 113 | + |
| 114 | + |
| 115 | +def score_responses(tmvp_config, question, responses, answers): |
| 116 | + """Score a set of responses for a single question. |
92 | 117 |
|
93 | 118 | Args: |
94 | 119 | tmvp_config: Configuration object |
95 | 120 | question: The evaluation question |
96 | 121 | responses: List of generated responses for this question |
97 | | - answer: The correct answer |
| 122 | + answers: List of correct answers |
98 | 123 |
|
99 | 124 | Returns: |
100 | 125 | Tuple of (is_correct, is_partially_correct, has_correct_format) |
101 | 126 | """ |
102 | | - match_format = utils_rl.get_match_format_regex(tmvp_config) |
103 | | - answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config) |
104 | | - |
105 | 127 | if tmvp_config.debug.rl: |
106 | 128 | max_logging.log("========================================") |
107 | 129 | max_logging.log(f"Evaluation Question: {question}") |
108 | | - max_logging.log(f"Evaluation Answer: {answer}") |
| 130 | + max_logging.log(f"Evaluation Answer: {answers}") |
109 | 131 | max_logging.log(f"Evaluation Responses: {responses}") |
110 | 132 | max_logging.log("========================================") |
111 | 133 |
|
112 | | - is_correct = False |
113 | | - is_partially_correct = False |
114 | | - has_correct_format = False |
115 | | - |
116 | | - for response in responses: |
117 | | - # Extract answer: prefer the full format match; fall back to the last |
118 | | - # <answer>...</answer> tag if full format match is not found, so result |
119 | | - # scoring is decoupled from format. |
120 | | - full_match = match_format.search(response) |
121 | | - if full_match is not None: |
122 | | - extracted_response = full_match.group(1) |
123 | | - else: |
124 | | - # Find the *last* occurrence of the answer tag (most likely the final answer). |
125 | | - fallback_matches = answer_fallback.findall(response) |
126 | | - extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000" |
| 134 | + eval_mode = getattr(tmvp_config, "eval_mode", "pass") |
| 135 | + match_format = utils_rl.get_match_format_regex(tmvp_config) |
| 136 | + extracted_responses = [utils_rl.extract_answer(r, tmvp_config) for r in responses] |
| 137 | + |
| 138 | + if not extracted_responses: |
| 139 | + return False, False, False |
| 140 | + |
| 141 | + if eval_mode == "maj": |
| 142 | + # extract the single-most frequent response |
| 143 | + counter = collections.Counter(extracted_responses) |
| 144 | + majority = counter.most_common(1)[0][0] |
127 | 145 | if tmvp_config.debug.rl: |
128 | | - used = "full format" if full_match is not None else "answer-tag fallback" |
129 | | - max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}") |
130 | | - |
131 | | - # Check exact correctness |
132 | | - try: |
133 | | - # Fix LaTeX escaping issues for both ground truth and extracted answer |
134 | | - norm_answer = utils_rl.fix_latex_escaping(answer) |
135 | | - norm_extracted = utils_rl.fix_latex_escaping(extracted_response) |
136 | | - # Normalize Normalize for certain datasets and parse |
137 | | - if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name: |
138 | | - norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip() |
139 | | - norm_answer = utils_rl.normalize_final_answer(answer).strip() |
140 | | - is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1 |
141 | | - if tmvp_config.debug.rl: |
142 | | - # is_correct is a tuple, if first value is 1.0 means it's a match; |
143 | | - # 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}'])) |
144 | | - max_logging.log(f"Result is_correct: {is_correct}") |
145 | | - |
146 | | - val_extracted = parse(utils_rl.boxed(norm_extracted)) |
147 | | - val_answer = parse(utils_rl.boxed(norm_answer)) |
148 | | - |
149 | | - # Check partial correctness if values can be extracted (within 10%) |
150 | | - if val_extracted and val_answer: |
151 | | - ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON) |
152 | | - is_partially_correct = 0.9 <= ratio <= 1.1 |
153 | | - |
154 | | - except Exception as e: |
155 | | - if tmvp_config.debug.rl: |
156 | | - max_logging.log(f"Evaluation Exception: {e}") |
157 | | - max_logging.log("SKIPPED") |
158 | | - |
159 | | - # Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure) |
160 | | - if full_match is not None: |
161 | | - has_correct_format = True |
162 | | - |
163 | | - # Early exit if all criteria are met |
164 | | - if is_correct and is_partially_correct and has_correct_format: |
165 | | - break |
| 146 | + max_logging.log(f"Majority Response: {majority} (Count: {counter[majority]})") |
166 | 147 |
|
167 | | - return is_correct, is_partially_correct, has_correct_format |
| 148 | + # Check the format for the majority response |
| 149 | + has_correct_format = any( |
| 150 | + match_format.search(responses[idx]) is not None |
| 151 | + for idx, response in enumerate(extracted_responses) |
| 152 | + if response == majority |
| 153 | + ) |
| 154 | + is_correct, is_partially_correct, _ = _score_single(majority, responses[0], answers, tmvp_config, match_format) |
| 155 | + return is_correct, is_partially_correct, has_correct_format |
| 156 | + |
| 157 | + if eval_mode == "pass": |
| 158 | + result = False, False, False |
| 159 | + for extracted, response in zip(extracted_responses, responses): |
| 160 | + result = _score_single(extracted, response, answers, tmvp_config, match_format) |
| 161 | + # Early exit if all criteria are met |
| 162 | + if all(result): |
| 163 | + return result |
| 164 | + return result |
| 165 | + |
| 166 | + if eval_mode == "pass_at_1": |
| 167 | + # Estimate pass@1: fraction of N samples that are correct per problem. |
| 168 | + # Returns floats in [0, 1] instead of booleans. |
| 169 | + scores = [ |
| 170 | + _score_single(extracted_response, response, answers, tmvp_config, match_format) |
| 171 | + for extracted_response, response in zip(extracted_responses, responses) |
| 172 | + ] |
| 173 | + n_samples = len(scores) |
| 174 | + frac_correct = sum(s[0] for s in scores) / n_samples |
| 175 | + frac_partial = sum(s[1] for s in scores) / n_samples |
| 176 | + frac_format = sum(s[2] for s in scores) / n_samples |
| 177 | + if tmvp_config.debug.rl: |
| 178 | + max_logging.log(f"{frac_correct*n_samples:.0f}/{n_samples} correct") |
| 179 | + max_logging.log(f"{frac_partial*n_samples:.0f}/{n_samples} partial") |
| 180 | + max_logging.log(f"{frac_format*n_samples:.0f}/{n_samples} format") |
| 181 | + return frac_correct, frac_partial, frac_format |
| 182 | + |
| 183 | + raise ValueError(f"Unknown eval_mode: {eval_mode!r}") |
168 | 184 |
|
169 | 185 |
|
170 | 186 | def evaluate( |
@@ -210,28 +226,29 @@ def evaluate( |
210 | 226 |
|
211 | 227 | # Score each question-answer pair |
212 | 228 | for question, responses, answer in zip(questions, multiple_call_responses, answers): |
| 229 | + # decode the json-encoded list of acceptable answers |
| 230 | + answer = list(dict.fromkeys(json.loads(answer))) |
213 | 231 | is_correct, is_partially_correct, has_correct_format = score_responses( |
214 | 232 | tmvp_config=tmvp_config, |
215 | 233 | question=question, |
216 | 234 | responses=responses, |
217 | | - answer=answer, |
| 235 | + answers=answer, |
218 | 236 | ) |
219 | 237 |
|
220 | | - # Update counters |
221 | | - if is_correct: |
222 | | - corr += 1 |
223 | | - if corr_lst and make_lst: |
| 238 | + # Update counters. For "pass" and "maj" modes, scores are booleans |
| 239 | + # (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1] |
| 240 | + # representing the fraction of samples correct. Using += works for both: |
| 241 | + # bool is a subtype of int in Python, so True += is the same as += 1. |
| 242 | + corr += is_correct |
| 243 | + partially_corr += is_partially_correct |
| 244 | + corr_format += has_correct_format |
| 245 | + |
| 246 | + if make_lst: |
| 247 | + if corr_lst and is_correct: |
224 | 248 | response_lst.append((question, answer, responses)) |
225 | | - else: |
226 | | - if not corr_lst and make_lst: |
| 249 | + elif not corr_lst and not is_correct: |
227 | 250 | response_lst.append((question, answer, responses)) |
228 | 251 |
|
229 | | - if is_partially_correct: |
230 | | - partially_corr += 1 |
231 | | - |
232 | | - if has_correct_format: |
233 | | - corr_format += 1 |
234 | | - |
235 | 252 | total += 1 |
236 | 253 |
|
237 | 254 | # Print progress every 10 items |
|
0 commit comments