1515# pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught
1616"""RL Utils Module."""
1717import concurrent
18+ import itertools
1819import json
1920import re
2021import optax
9495]
9596
9697
97- def math_verify_func (golds , predictions , timeout = 5 ):
98- """Wrapper around math_verify's verify function to handle timeouts and exceptions gracefully."""
99- with concurrent .futures .ProcessPoolExecutor (max_workers = 1 ) as executor :
100- future = executor .submit (verify_math , golds , predictions )
101- try :
102- return future .result (timeout = timeout )
103- except concurrent .futures .TimeoutError :
104- return 0.0
105- except Exception :
106- return 0.0
98+ def math_verify_func (items , timeout = 5 ):
99+ """Verifies a batch of math problems, handling timeouts and exceptions."""
100+ with concurrent .futures .ThreadPoolExecutor () as executor :
101+ future_to_index = {executor .submit (verify_math , golds , predictions ): idx for idx , (_ , golds , predictions ) in enumerate (items )}
102+ results = [0.0 ] * len (items )
103+ for future in concurrent .futures .as_completed (future_to_index ):
104+ index = future_to_index [future ]
105+ try :
106+ results [index ] = future .result (timeout = timeout )
107+ except (concurrent .futures .TimeoutError , Exception ):
108+ max_logging .log (f"math_verify_func failed for golds: { items [index ][1 ]} and predictions: { items [index ][2 ]} " )
109+ return results
107110
108111
109112def verify_math (golds , predictions ):
110113 """Runs mathematical expression evaluation on ground-truth and predictions."""
111114
112- extracted_predictions = [parse (pred , PRED_EXTRACTION_TARGET , parsing_timeout = None ) for pred in predictions ]
113- extracted_golds = [parse (gold , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for gold in golds ]
115+ extracted_predictions = list (itertools .chain .from_iterable (parse (pred , PRED_EXTRACTION_TARGET , parsing_timeout = None ) for pred in predictions ))
116+ extracted_golds = list (itertools .chain .from_iterable (parse (gold , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for gold in golds ))
117+ # If no predictions or golds were extracted, return 0.0
118+ if not extracted_predictions or not extracted_golds :
119+ return 0.0
114120
115121 return max (
116122 [
@@ -372,7 +378,8 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
372378 question = kargs ["question" ]
373379
374380 # Extract full answer content from solution tags (not just first number)
375- extracted_responses = [extract_answer (c , tmvp_config )[0 ] for c in completions ]
381+ extracted_responses = [extract_answer (c , tmvp_config ) for c in completions ]
382+ true_answers = [list (dict .fromkeys (json .loads (acceptable_answers ))) if isinstance (acceptable_answers , str ) else acceptable_answers for acceptable_answers in answer ]
376383
377384 if tmvp_config .debug .rl :
378385 max_logging .log ("START ============================" )
@@ -382,79 +389,63 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
382389 max_logging .log (f"Extracted: { extracted_responses [0 ]} " )
383390 max_logging .log ("END ==============================" )
384391
385- scores = []
386- for guess , acceptable_answers in zip (extracted_responses , answer ):
392+ scores = [tmvp_config .penalty_incorrect_format ] * len (completions ) # Default to penalty for incorrect format
393+ math_verify_queue = []
394+ for gen_idx , (guess , unique_answers ) in enumerate (zip (extracted_responses , true_answers )):
387395 if guess is None :
388- scores . append ( 0 )
396+ scores [ gen_idx ] = 0
389397 continue
390398
391- norm_guess = preprocess_math_string (tmvp_config .dataset_name , guess )
392-
393- # decode the json-encoded list of acceptable answers
394- answers = json .loads (acceptable_answers ) if isinstance (acceptable_answers , str ) else acceptable_answers
395- unique_answers = list (dict .fromkeys (answers ))
396- max_score = tmvp_config .penalty_incorrect_format
399+ has_exact_match = False
397400 for true_answer in unique_answers :
398- norm_answer = preprocess_math_string (tmvp_config .dataset_name , true_answer )
399-
400401 # 1. Check for exact or whitespace-normalized match first for a quick reward
401402 if guess == true_answer :
402- max_score = max (max_score , tmvp_config .reward_exact_answer )
403- continue
403+ scores [ gen_idx ] = max (scores [ gen_idx ] , tmvp_config .reward_exact_answer )
404+ has_exact_match = True
404405 elif guess .strip () == true_answer .strip ():
405- max_score = max (max_score , tmvp_config .reward_white_space_format_match )
406- continue
406+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_white_space_format_match )
407+ has_exact_match = True
408+
409+ if not has_exact_match :
410+ norm_guess = preprocess_math_string (tmvp_config .dataset_name , guess )
411+ norm_answers = []
412+ for true_answer in unique_answers :
413+ norm_answer = preprocess_math_string (tmvp_config .dataset_name , true_answer )
414+ norm_answers .append (boxed (norm_answer ))
415+ math_verify_queue .append ((gen_idx , norm_answers , [boxed (norm_guess )]))
416+
417+ if math_verify_queue :
418+ # 2. Try math_verify for robust mathematical correctness checking
419+ math_verify_results = math_verify_func (math_verify_queue )
420+ for (gen_idx , norm_answers , norm_guess ), score in zip (math_verify_queue , math_verify_results ):
421+ if score > 0.1 :
422+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_exact_answer )
423+ print ("-------- Found a math_verify match -----------" )
407424 else :
425+ # 3. As a fallback, try numeric comparison if both can be parsed as numbers
408426 try :
409- # 2. Try math_verify for robust comparison
410- if math_verify_func ([boxed (norm_answer )], [boxed (norm_guess )]) > 0.1 :
411- max_score = max (max_score , tmvp_config .reward_exact_answer )
412- continue
413- except (TimeoutException , Exception ):
414- if tmvp_config .debug .rl :
415- max_logging .log (
416- f"math_verify_func failed for gold: { norm_answer } and prediction: "
417- f"{ norm_guess } , falling back to numeric comparison."
418- )
419-
420- # 3. As a fallback, try numeric comparison if both can be parsed as numbers
421- try :
422- predictions = parse (boxed (norm_guess ))
423- golds = parse (boxed (norm_answer ))
424- for gold in golds :
425- for pred in predictions :
426- ratio = (float (pred ) + EPSILON ) / (float (gold ) + EPSILON )
427- if 0.9 <= ratio <= 1.1 :
428- max_score = max (max_score , tmvp_config .reward_ratio_guess_to_answer_high )
429- elif 0.8 <= ratio <= 1.2 :
430- max_score = max (max_score , tmvp_config .reward_ratio_guess_to_answer_low )
431- else :
432- max_score = max (max_score , tmvp_config .penalty_incorrect_answer ) # Penalize wrong answers
433- except :
434- max_score = max (max_score , tmvp_config .penalty_incorrect_format ) # Penalize if we can't parse numbers at all
435-
436- scores .append (max_score )
427+ predictions = parse (boxed (norm_guess ), PRED_EXTRACTION_TARGET , parsing_timeout = None )
428+ golds = list (itertools .chain .from_iterable (parse (boxed (norm_answer ), GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
429+ for gold in golds :
430+ for pred in predictions :
431+ ratio = (float (pred ) + EPSILON ) / (float (gold ) + EPSILON )
432+ if 0.9 <= ratio <= 1.1 :
433+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_ratio_guess_to_answer_high )
434+ elif 0.8 <= ratio <= 1.2 :
435+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_ratio_guess_to_answer_low )
436+ else :
437+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer ) # Penalize wrong answers
438+ except :
439+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_format ) # Penalize if we can't parse numbers at all
437440 return scores
438441
439442
440443def extract_answer (response , tmvp_config ) -> str | None :
441444 """Function to extract the answer from the text based on the tmvp_config format."""
442- match_format = get_match_format_regex (tmvp_config )
443445 answer_fallback = get_answer_fallback_regex (tmvp_config )
444-
445- full_match = match_format .search (response )
446- if full_match is not None :
447- extracted_response = full_match .group (1 )
448- else :
449- # Find the *last* occurrence of the answer tag (most likely the final answer).
450- fallback_matches = answer_fallback .findall (response )
451- extracted_response = fallback_matches [- 1 ].strip () if fallback_matches else "-1000000"
452-
453- match_method = "full format" if full_match is not None else "answer-tag fallback"
454- if tmvp_config .debug .rl :
455- max_logging .log (f"Evaluation extracted_response ({ match_method } ): { extracted_response } " )
456-
457- return extracted_response , match_method
446+ fallback_matches = answer_fallback .findall (response )
447+ extracted_response = fallback_matches [- 1 ].strip () if fallback_matches else "-1000000"
448+ return extracted_response
458449
459450
460451def extract_hash_answer (text : str ) -> str | None :
@@ -466,26 +457,23 @@ def extract_hash_answer(text: str) -> str | None:
466457
467458def check_correctness (extracted_response , acceptable_answers , tmvp_config ):
468459 """Handles math verification and partial correctness logic."""
469- for answer in acceptable_answers :
470- # Check exact correctness first
471- norm_answer = preprocess_math_string (tmvp_config .dataset_name , answer )
472- norm_extracted = preprocess_math_string (tmvp_config .dataset_name , extracted_response )
473- is_correct = math_verify_func ([boxed (norm_answer )], [boxed (norm_extracted )]) > 0.1
474- if is_correct :
475- return True , True # Exact correctness implies partial correctness
460+ norm_answers = []
461+ norm_response = preprocess_math_string (tmvp_config .dataset_name , extracted_response )
462+ # Check exact correctness first
463+ for answer in acceptable_answers :
464+ norm_answers .append (preprocess_math_string (tmvp_config .dataset_name , answer ))
465+ is_correct = math_verify_func ([(0 , [boxed (norm_answer ) for norm_answer in norm_answers ], [boxed (norm_response )])])[0 ] > 0.1
466+ if is_correct :
467+ return True , True # Exact correctness implies partial correctness
476468
477469 # Check partial correctness if values can be extracted (within 10%)
478- for answer in acceptable_answers :
479- norm_answer = preprocess_math_string (tmvp_config .dataset_name , answer )
480- norm_extracted = preprocess_math_string (tmvp_config .dataset_name , extracted_response )
481- val_extracted = parse (boxed (norm_extracted ))
482- val_answer = parse (boxed (norm_answer ))
483- if val_extracted and val_answer :
484- is_partially_correct = any (
485- 0.9 <= (float (pred ) + EPSILON ) / (float (gold ) + EPSILON ) <= 1.1 for pred in val_extracted for gold in val_answer
486- )
487- if is_partially_correct :
488- return False , True # Not exactly correct, but partially correct
470+ predictions = parse (boxed (norm_response ), PRED_EXTRACTION_TARGET , parsing_timeout = None )
471+ golds = list (itertools .chain .from_iterable (parse (boxed (norm_answer ), GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
472+ is_partially_correct = any (
473+ 0.9 <= (float (pred ) + EPSILON ) / (float (gold ) + EPSILON ) <= 1.1 for pred in predictions for gold in golds
474+ )
475+ if is_partially_correct :
476+ return False , True # Not exactly correct, but partially correct
489477
490478 return False , False # Not correct at all
491479
0 commit comments