4242 LatexExtractionConfig (),
4343)
4444
45+
4546def math_verify_func (items , timeout = 5 ):
4647 """Verifies a batch of math problems, handling timeouts and exceptions."""
4748 with concurrent .futures .ThreadPoolExecutor () as executor :
48- future_to_index = {executor .submit (verify_math , golds , predictions ): idx for idx , (_ , golds , predictions ) in enumerate (items )}
49+ future_to_index = {
50+ executor .submit (verify_math , golds , predictions ): idx for idx , (_ , golds , predictions ) in enumerate (items )
51+ }
4952 results = [0.0 ] * len (items )
5053 for future in concurrent .futures .as_completed (future_to_index ):
5154 index = future_to_index [future ]
@@ -59,8 +62,12 @@ def math_verify_func(items, timeout=5):
5962def verify_math (golds , predictions ):
6063 """Runs mathematical expression evaluation on ground-truth and predictions."""
6164
62- extracted_predictions = list (itertools .chain .from_iterable (parse (pred , PRED_EXTRACTION_TARGET , parsing_timeout = None ) for pred in predictions ))
63- extracted_golds = list (itertools .chain .from_iterable (parse (gold , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for gold in golds ))
65+ extracted_predictions = list (
66+ itertools .chain .from_iterable (parse (pred , PRED_EXTRACTION_TARGET , parsing_timeout = None ) for pred in predictions )
67+ )
68+ extracted_golds = list (
69+ itertools .chain .from_iterable (parse (gold , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for gold in golds )
70+ )
6471 # If no predictions or golds were extracted, return 0.0
6572 if not extracted_predictions or not extracted_golds :
6673 return 0.0
@@ -72,6 +79,7 @@ def verify_math(golds, predictions):
7279 ]
7380 )
7481
82+
7583def boxed (x ):
7684 """Wraps the input string in a LaTeX boxed command if it's not already wrapped."""
7785 return "\\ boxed{" + x + "}" if not x .startswith ("\\ boxed{" ) else x
@@ -267,7 +275,10 @@ def normalize_final_answer(final_answer: str) -> str:
267275def preprocess_math_string (dataset_name , text ) -> str :
268276 """Fix common formatting issues in text."""
269277 # Normalize for certain datasets and parse
270- if any (name in dataset_name for name in ["DAPO" , "OpenMathInstruct" , "OpenMathReasoning" , "OpenR1-Math-220k" , "CuratedThoughts" ]):
278+ if any (
279+ name in dataset_name
280+ for name in ["DAPO" , "OpenMathInstruct" , "OpenMathReasoning" , "OpenR1-Math-220k" , "CuratedThoughts" , "MATH-500" ]
281+ ):
271282 text = normalize_final_answer (text ).strip ()
272283 # Fix LaTeX escaping issues
273284 text = fix_latex_escaping (text )
@@ -418,7 +429,11 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
418429 # 3. As a fallback, try numeric comparison if both can be parsed as numbers
419430 try :
420431 predictions = parse (norm_guesses [0 ], PRED_EXTRACTION_TARGET , parsing_timeout = None )
421- golds = list (itertools .chain .from_iterable (parse (norm_answer , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
432+ golds = list (
433+ itertools .chain .from_iterable (
434+ parse (norm_answer , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers
435+ )
436+ )
422437 for gold in golds :
423438 for pred in predictions :
424439 try :
@@ -430,9 +445,11 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
430445 else :
431446 scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer )
432447 except :
433- scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer )
448+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer )
434449 except :
435- scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_format ) # Penalize if we can't parse numbers at all
450+ scores [gen_idx ] = max (
451+ scores [gen_idx ], tmvp_config .penalty_incorrect_format
452+ ) # Penalize if we can't parse numbers at all
436453 if tmvp_config .debug .rl :
437454 debug_log_path = epath .Path (tmvp_config .base_output_directory ) / tmvp_config .run_name / "debug_rl_logs"
438455 debug_log_path .mkdir (parents = True , exist_ok = True )
@@ -469,10 +486,11 @@ def extract_hash_answer(text: str) -> str | None:
469486def check_correctness (extracted_response , acceptable_answers , tmvp_config ):
470487 """Handles math verification and partial correctness logic."""
471488 norm_answers = []
472- norm_response = preprocess_math_string (tmvp_config .dataset_name , extracted_response )
489+ dataset_name = tmvp_config .eval_dataset_name if tmvp_config .eval_dataset_name else tmvp_config .dataset_name
490+ norm_response = preprocess_math_string (dataset_name , extracted_response )
473491 # Check exact correctness first
474- for answer in acceptable_answers :
475- norm_answers .append (preprocess_math_string (tmvp_config . dataset_name , answer ))
492+ for answer in acceptable_answers :
493+ norm_answers .append (preprocess_math_string (dataset_name , answer ))
476494 is_correct = verify_math ([boxed (norm_answer ) for norm_answer in norm_answers ], [boxed (norm_response )]) > 0.1
477495 if is_correct :
478496 return True , True # Exact correctness implies partial correctness
@@ -481,7 +499,11 @@ def check_correctness(extracted_response, acceptable_answers, tmvp_config):
481499 is_partially_correct = False
482500 try :
483501 predictions = parse (boxed (norm_response ), PRED_EXTRACTION_TARGET , parsing_timeout = None )
484- golds = list (itertools .chain .from_iterable (parse (boxed (norm_answer ), GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
502+ golds = list (
503+ itertools .chain .from_iterable (
504+ parse (boxed (norm_answer ), GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers
505+ )
506+ )
485507 is_partially_correct = any (
486508 0.9 <= (float (pred ) + EPSILON ) / (float (gold ) + EPSILON ) <= 1.1 for pred in predictions for gold in golds
487509 )
0 commit comments