|
16 | 16 | """ |
17 | 17 | RL Evaluation Module. |
18 | 18 | """ |
19 | | -from math_verify import parse |
| 19 | +import collections |
| 20 | +import json |
| 21 | + |
20 | 22 | from tqdm.auto import tqdm |
21 | 23 | from tunix.rl.rollout.base_rollout import RolloutConfig |
22 | 24 |
|
@@ -86,85 +88,128 @@ def generate_responses( |
86 | 88 | return multiple_call_responses |
87 | 89 |
|
88 | 90 |
|
89 | | -def score_responses(tmvp_config, question, responses, answer): |
| 91 | +def score_responses(tmvp_config, question, responses, answers): |
90 | 92 | """ |
91 | 93 | Score a set of responses for a single question. |
92 | 94 |
|
93 | 95 | Args: |
94 | 96 | tmvp_config: Configuration object |
95 | 97 | question: The evaluation question |
96 | 98 | responses: List of generated responses for this question |
97 | | - answer: The correct answer |
| 99 | + answers: List of acceptable answers for this question |
98 | 100 |
|
99 | 101 | Returns: |
100 | 102 | Tuple of (is_correct, is_partially_correct, has_correct_format) |
101 | 103 | """ |
102 | | - match_format = utils_rl.get_match_format_regex(tmvp_config) |
103 | | - answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config) |
104 | | - |
105 | 104 | if tmvp_config.debug.rl: |
106 | 105 | max_logging.log("========================================") |
107 | 106 | max_logging.log(f"Evaluation Question: {question}") |
108 | | - max_logging.log(f"Evaluation Answer: {answer}") |
| 107 | + max_logging.log(f"Evaluation Answer: {answers}") |
109 | 108 | max_logging.log(f"Evaluation Responses: {responses}") |
110 | 109 | max_logging.log("========================================") |
111 | 110 |
|
112 | | - is_correct = False |
113 | | - is_partially_correct = False |
114 | | - has_correct_format = False |
| 111 | + eval_mode = getattr(tmvp_config, "eval_mode", "pass") |
| 112 | + match_format = utils_rl.get_match_format_regex(tmvp_config) |
115 | 113 |
|
| 114 | + extracted_responses = [] |
116 | 115 | for response in responses: |
117 | | - # Extract answer: prefer the full format match; fall back to the last |
118 | | - # <answer>...</answer> tag if full format match is not found, so result |
119 | | - # scoring is decoupled from format. |
120 | | - full_match = match_format.search(response) |
121 | | - if full_match is not None: |
122 | | - extracted_response = full_match.group(1) |
123 | | - else: |
124 | | - # Find the *last* occurrence of the answer tag (most likely the final answer). |
125 | | - fallback_matches = answer_fallback.findall(response) |
126 | | - extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000" |
| 116 | + extracted_response = utils_rl.extract_answer(response, tmvp_config) |
| 117 | + extracted_responses.append(extracted_response) |
| 118 | + |
| 119 | + if not extracted_responses: |
| 120 | + return False, False, False |
| 121 | + |
| 122 | + if eval_mode == "maj": |
| 123 | + # extract the single-most frequent response |
| 124 | + counter = collections.Counter(extracted_responses) |
| 125 | + majority_response = counter.most_common(1)[0][0] |
127 | 126 | if tmvp_config.debug.rl: |
128 | | - used = "full format" if full_match is not None else "answer-tag fallback" |
129 | | - max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}") |
| 127 | + max_logging.log(f"Majority Response: {majority_response} (Count: {counter[majority_response]})") |
| 128 | + |
| 129 | + # Check the format for the majority_response |
| 130 | + has_correct_format = False |
| 131 | + for idx, extracted_response in enumerate(extracted_responses): |
| 132 | + if extracted_response == majority_response: |
| 133 | + if match_format.search(responses[idx]) is not None: |
| 134 | + has_correct_format = True |
| 135 | + break |
130 | 136 |
|
131 | | - # Check exact correctness |
132 | 137 | try: |
133 | | - # Fix LaTeX escaping issues for both ground truth and extracted answer |
134 | | - norm_answer = utils_rl.fix_latex_escaping(answer) |
135 | | - norm_extracted = utils_rl.fix_latex_escaping(extracted_response) |
136 | | - # Normalize Normalize for certain datasets and parse |
137 | | - if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name: |
138 | | - norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip() |
139 | | - norm_answer = utils_rl.normalize_final_answer(answer).strip() |
140 | | - is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1 |
| 138 | + is_correct, is_partially_correct = utils_rl.check_correctness(majority_response, answers, tmvp_config) |
141 | 139 | if tmvp_config.debug.rl: |
142 | | - # is_correct is a tuple, if first value is 1.0 means it's a match; |
143 | | - # 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}'])) |
| 140 | + max_logging.log(f"Result has_correct_format: {has_correct_format}") |
144 | 141 | max_logging.log(f"Result is_correct: {is_correct}") |
145 | | - |
146 | | - val_extracted = parse(utils_rl.boxed(norm_extracted)) |
147 | | - val_answer = parse(utils_rl.boxed(norm_answer)) |
148 | | - |
149 | | - # Check partial correctness if values can be extracted (within 10%) |
150 | | - if val_extracted and val_answer: |
151 | | - ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON) |
152 | | - is_partially_correct = 0.9 <= ratio <= 1.1 |
153 | | - |
| 142 | + max_logging.log(f"Result is_partially_correct: {is_partially_correct}") |
154 | 143 | except Exception as e: |
| 144 | + is_correct, is_partially_correct = False, False |
155 | 145 | if tmvp_config.debug.rl: |
156 | | - max_logging.log(f"Evaluation Exception: {e}") |
| 146 | + max_logging.log(f"Evaluation Exception on majority answer: {e}") |
157 | 147 | max_logging.log("SKIPPED") |
| 148 | + return is_correct, is_partially_correct, has_correct_format |
158 | 149 |
|
159 | | - # Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure) |
160 | | - if full_match is not None: |
161 | | - has_correct_format = True |
| 150 | + if eval_mode == "pass": |
| 151 | + for idx, response in enumerate(responses): |
| 152 | + is_correct, is_partially_correct, has_correct_format = False, False, False |
| 153 | + if match_format.search(response) is not None: |
| 154 | + has_correct_format = True |
| 155 | + |
| 156 | + # Check exact and partial correctness (within 10%) |
| 157 | + try: |
| 158 | + is_correct, is_partially_correct = utils_rl.check_correctness(extracted_responses[idx], answers, tmvp_config) |
| 159 | + if tmvp_config.debug.rl: |
| 160 | + max_logging.log(f"Result is_correct: {is_correct}") |
| 161 | + max_logging.log(f"Result is_partially_correct: {is_partially_correct}") |
| 162 | + except Exception as e: |
| 163 | + if tmvp_config.debug.rl: |
| 164 | + max_logging.log(f"Evaluation Exception: {e}") |
| 165 | + max_logging.log("SKIPPED") |
| 166 | + |
| 167 | + # Early exit if all criteria are met |
| 168 | + if is_correct and is_partially_correct and has_correct_format: |
| 169 | + return is_correct, is_partially_correct, has_correct_format |
| 170 | + return is_correct, is_partially_correct, has_correct_format |
| 171 | + |
| 172 | + if eval_mode == "pass_at_1": |
| 173 | + # Estimate pass@1: fraction of N samples that are correct per problem. |
| 174 | + # Returns floats in [0, 1] instead of booleans. |
| 175 | + n_samples = len(responses) |
| 176 | + n_correct = 0 |
| 177 | + n_partially_correct = 0 |
| 178 | + n_correct_format = 0 |
| 179 | + |
| 180 | + for idx, response in enumerate(responses): |
| 181 | + if match_format.search(response) is not None: |
| 182 | + n_correct_format += 1 |
| 183 | + |
| 184 | + try: |
| 185 | + sample_correct, sample_partial = utils_rl.check_correctness(extracted_responses[idx], answers, tmvp_config) |
| 186 | + if sample_correct: |
| 187 | + n_correct += 1 |
| 188 | + if sample_partial: |
| 189 | + n_partially_correct += 1 |
| 190 | + if tmvp_config.debug.rl: |
| 191 | + max_logging.log(f"Sample {idx}: correct={sample_correct}, partial={sample_partial}") |
| 192 | + except Exception as e: |
| 193 | + if tmvp_config.debug.rl: |
| 194 | + max_logging.log(f"Evaluation Exception on sample {idx}: {e}") |
| 195 | + max_logging.log("SKIPPED") |
| 196 | + |
| 197 | + frac_correct = n_correct / n_samples |
| 198 | + frac_partially_correct = n_partially_correct / n_samples |
| 199 | + frac_correct_format = n_correct_format / n_samples |
| 200 | + |
| 201 | + if tmvp_config.debug.rl: |
| 202 | + max_logging.log( |
| 203 | + f"pass@1: {n_correct}/{n_samples} correct, " |
| 204 | + f"{n_partially_correct}/{n_samples} partial, " |
| 205 | + f"{n_correct_format}/{n_samples} format" |
| 206 | + ) |
162 | 207 |
|
163 | | - # Early exit if all criteria are met |
164 | | - if is_correct and is_partially_correct and has_correct_format: |
165 | | - break |
| 208 | + return frac_correct, frac_partially_correct, frac_correct_format |
166 | 209 |
|
167 | | - return is_correct, is_partially_correct, has_correct_format |
| 210 | + if tmvp_config.debug.rl: |
| 211 | + max_logging.log(f"Unknown eval mode: {eval_mode}") |
| 212 | + return False, False, False |
168 | 213 |
|
169 | 214 |
|
170 | 215 | def evaluate( |
@@ -210,28 +255,29 @@ def evaluate( |
210 | 255 |
|
211 | 256 | # Score each question-answer pair |
212 | 257 | for question, responses, answer in zip(questions, multiple_call_responses, answers): |
| 258 | + # decode the json-encoded list of acceptable answers |
| 259 | + answer = list(dict.fromkeys(json.loads(answer))) |
213 | 260 | is_correct, is_partially_correct, has_correct_format = score_responses( |
214 | 261 | tmvp_config=tmvp_config, |
215 | 262 | question=question, |
216 | 263 | responses=responses, |
217 | | - answer=answer, |
| 264 | + answers=answer, |
218 | 265 | ) |
219 | 266 |
|
220 | | - # Update counters |
221 | | - if is_correct: |
222 | | - corr += 1 |
223 | | - if corr_lst and make_lst: |
| 267 | + # Update counters. For "pass" and "maj" modes, scores are booleans |
| 268 | + # (True=1, False=0). For "pass_at_1" mode, scores are floats in [0, 1] |
| 269 | + # representing the fraction of samples correct. Using += works for both: |
| 270 | + # bool is a subtype of int in Python, so True += is the same as += 1. |
| 271 | + corr += is_correct |
| 272 | + partially_corr += is_partially_correct |
| 273 | + corr_format += has_correct_format |
| 274 | + |
| 275 | + if make_lst: |
| 276 | + if corr_lst and is_correct: |
224 | 277 | response_lst.append((question, answer, responses)) |
225 | | - else: |
226 | | - if not corr_lst and make_lst: |
| 278 | + elif not corr_lst and not is_correct: |
227 | 279 | response_lst.append((question, answer, responses)) |
228 | 280 |
|
229 | | - if is_partially_correct: |
230 | | - partially_corr += 1 |
231 | | - |
232 | | - if has_correct_format: |
233 | | - corr_format += 1 |
234 | | - |
235 | 281 | total += 1 |
236 | 282 |
|
237 | 283 | # Print progress every 10 items |
|
0 commit comments