diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index c37db48c0a..782983cefa 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -229,7 +229,16 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs): value. """ match_format = get_match_format_regex(tmvp_config) - extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions] + answer_fallback = get_answer_fallback_regex(tmvp_config) + + extracted_responses = [] + for c in completions: + full_match = match_format.search(c) + if full_match is not None: + extracted_responses.append(full_match.group(1)) + else: + fallback_matches = answer_fallback.findall(c) + extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None) scores = [] for guess, true_answer in zip(extracted_responses, answer): @@ -408,7 +417,16 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): # Extract full answer content from solution tags (not just first number) match_format = get_match_format_regex(tmvp_config) - extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions] + answer_fallback = get_answer_fallback_regex(tmvp_config) + + extracted_responses = [] + for c in completions: + full_match = match_format.search(c) + if full_match is not None: + extracted_responses.append(full_match.group(1)) + else: + fallback_matches = answer_fallback.findall(c) + extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None) scores = [] if tmvp_config.debug.rl: diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py index c8aaa7dd83..60bb01a51d 100644 --- a/tests/post_training/unit/rl_utils_test.py +++ b/tests/post_training/unit/rl_utils_test.py @@ -236,12 +236,12 @@ def test_extraction_fails_no_tags(self): @pytest.mark.cpu_only def test_extraction_fails_answer_tags_only(self): - """ tag alone (no block) is not matched by the regex, score 0.""" + """ tag alone (no block) is matched by the regex as a fallback, score 1.5.""" scores = self._check( completions=["42"], answer=["42"], ) - self.assertEqual(scores[0], 0) + self.assertEqual(scores[0], 1.5) @pytest.mark.cpu_only def test_extraction_fails_reasoning_tags_only(self):