Skip to content

Commit a2a6eff

Browse files
committed
port parsing fallback from evaluation to rewards
1 parent c4b5e64 commit a2a6eff

1 file changed

Lines changed: 20 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,16 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
229229
value.
230230
"""
231231
match_format = get_match_format_regex(tmvp_config)
232-
extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions]
232+
answer_fallback = get_answer_fallback_regex(tmvp_config)
233+
234+
extracted_responses = []
235+
for c in completions:
236+
full_match = match_format.search(c)
237+
if full_match is not None:
238+
extracted_responses.append(full_match.group(1))
239+
else:
240+
fallback_matches = answer_fallback.findall(c)
241+
extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None)
233242

234243
scores = []
235244
for guess, true_answer in zip(extracted_responses, answer):
@@ -408,7 +417,16 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
408417

409418
# Extract full answer content from solution tags (not just first number)
410419
match_format = get_match_format_regex(tmvp_config)
411-
extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions]
420+
answer_fallback = get_answer_fallback_regex(tmvp_config)
421+
422+
extracted_responses = []
423+
for c in completions:
424+
full_match = match_format.search(c)
425+
if full_match is not None:
426+
extracted_responses.append(full_match.group(1))
427+
else:
428+
fallback_matches = answer_fallback.findall(c)
429+
extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None)
412430

413431
scores = []
414432
if tmvp_config.debug.rl:

0 commit comments

Comments
 (0)