@@ -379,8 +379,8 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
379379
380380 # Extract full answer content from solution tags (not just first number)
381381 extracted_responses = [extract_answer (c , tmvp_config ) for c in completions ]
382- true_answers = [list (dict .fromkeys (json . loads ( acceptable_answers ))) if isinstance ( acceptable_answers , str ) else acceptable_answers for acceptable_answers in answer ]
383-
382+ true_answers = [list (dict .fromkeys (acceptable_answers )) for acceptable_answers in answer ]
383+
384384 if tmvp_config .debug .rl :
385385 max_logging .log ("START ============================" )
386386 max_logging .log (f"Question: { question [0 ]} " )
@@ -417,24 +417,26 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
417417 if math_verify_queue :
418418 # 2. Try math_verify for robust mathematical correctness checking
419419 math_verify_results = math_verify_func (math_verify_queue )
420- for (gen_idx , norm_answers , norm_guess ), score in zip (math_verify_queue , math_verify_results ):
420+ for (gen_idx , norm_answers , norm_guesses ), score in zip (math_verify_queue , math_verify_results ):
421421 if score > 0.1 :
422422 scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_exact_answer )
423- print ("-------- Found a math_verify match -----------" )
424423 else :
425424 # 3. As a fallback, try numeric comparison if both can be parsed as numbers
426425 try :
427- predictions = parse (boxed ( norm_guess ) , PRED_EXTRACTION_TARGET , parsing_timeout = None )
428- golds = list (itertools .chain .from_iterable (parse (boxed ( norm_answer ) , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
426+ predictions = parse (norm_guesses [ 0 ] , PRED_EXTRACTION_TARGET , parsing_timeout = None )
427+ golds = list (itertools .chain .from_iterable (parse (norm_answer , GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
429428 for gold in golds :
430429 for pred in predictions :
431- ratio = (float (pred ) + EPSILON ) / (float (gold ) + EPSILON )
432- if 0.9 <= ratio <= 1.1 :
433- scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_ratio_guess_to_answer_high )
434- elif 0.8 <= ratio <= 1.2 :
435- scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_ratio_guess_to_answer_low )
436- else :
437- scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer ) # Penalize wrong answers
430+ try :
431+ ratio = (float (pred ) + EPSILON ) / (float (gold ) + EPSILON )
432+ if 0.9 <= ratio <= 1.1 :
433+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_ratio_guess_to_answer_high )
434+ elif 0.8 <= ratio <= 1.2 :
435+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .reward_ratio_guess_to_answer_low )
436+ else :
437+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer )
438+ except :
439+ scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_answer )
438440 except :
439441 scores [gen_idx ] = max (scores [gen_idx ], tmvp_config .penalty_incorrect_format ) # Penalize if we can't parse numbers at all
440442 return scores
@@ -462,20 +464,22 @@ def check_correctness(extracted_response, acceptable_answers, tmvp_config):
462464 # Check exact correctness first
463465 for answer in acceptable_answers :
464466 norm_answers .append (preprocess_math_string (tmvp_config .dataset_name , answer ))
465- is_correct = math_verify_func ([( 0 , [ boxed (norm_answer ) for norm_answer in norm_answers ], [boxed (norm_response )])])[ 0 ] > 0.1
467+ is_correct = verify_math ([ boxed (norm_answer ) for norm_answer in norm_answers ], [boxed (norm_response )]) > 0.1
466468 if is_correct :
467469 return True , True # Exact correctness implies partial correctness
468470
469471 # Check partial correctness if values can be extracted (within 10%)
470- predictions = parse (boxed (norm_response ), PRED_EXTRACTION_TARGET , parsing_timeout = None )
471- golds = list (itertools .chain .from_iterable (parse (boxed (norm_answer ), GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
472- is_partially_correct = any (
473- 0.9 <= (float (pred ) + EPSILON ) / (float (gold ) + EPSILON ) <= 1.1 for pred in predictions for gold in golds
474- )
475- if is_partially_correct :
476- return False , True # Not exactly correct, but partially correct
472+ is_partially_correct = False
473+ try :
474+ predictions = parse (boxed (norm_response ), PRED_EXTRACTION_TARGET , parsing_timeout = None )
475+ golds = list (itertools .chain .from_iterable (parse (boxed (norm_answer ), GOLD_EXTRACTION_TARGET , parsing_timeout = None ) for norm_answer in norm_answers ))
476+ is_partially_correct = any (
477+ 0.9 <= (float (pred ) + EPSILON ) / (float (gold ) + EPSILON ) <= 1.1 for pred in predictions for gold in golds
478+ )
479+ except :
480+ pass
477481
478- return False , False # Not correct at all
482+ return False , is_partially_correct
479483
480484
481485def get_optimizer (tmvp_config , max_train_steps ):
0 commit comments