@@ -52,6 +52,150 @@ def approx_token_count(text: str) -> int:
5252 return max (1 , len (re .findall (r"\S+" , text )))
5353
5454
55+ def _extract_boxed (text : str ) -> str | None :
56+ """Extract the last \\ boxed{...} from a string, handling nested braces."""
57+ results = []
58+ i = 0
59+ while i < len (text ):
60+ idx = text .find ("\\ boxed{" , i )
61+ if idx == - 1 :
62+ break
63+ start = idx + len ("\\ boxed{" )
64+ depth = 1
65+ j = start
66+ while j < len (text ) and depth > 0 :
67+ if text [j ] == "{" :
68+ depth += 1
69+ elif text [j ] == "}" :
70+ depth -= 1
71+ j += 1
72+ if depth == 0 :
73+ results .append (text [start :j - 1 ].strip ())
74+ i = j
75+ return results [- 1 ] if results else None
76+
77+
78+ def _normalize_math (s : str ) -> str :
79+ """Normalize a math answer string for comparison."""
80+ if s is None :
81+ return ""
82+ s = s .strip ()
83+ if s .startswith ("$" ) and s .endswith ("$" ):
84+ s = s [1 :- 1 ].strip ()
85+ s = re .sub (r"\\text\s*\{([^}]*)\}" , r"\1" , s )
86+ s = re .sub (r"\\mathrm\s*\{([^}]*)\}" , r"\1" , s )
87+ for cmd in [r"\left" , r"\right" , r"\displaystyle" , r"\tfrac" , r"\dfrac" ]:
88+ s = s .replace (cmd , "" )
89+ s = re .sub (r"\s+" , " " , s ).strip ()
90+ s = s .rstrip (".," )
91+ return s
92+
93+
94+ def _math_equiv (pred : str , gold : str ) -> bool :
95+ """Check if two math answers are equivalent."""
96+ if pred is None or gold is None :
97+ return False
98+ p = _normalize_math (pred )
99+ g = _normalize_math (gold )
100+ if p == g :
101+ return True
102+ try :
103+ pf = float (p .replace ("," , "" ))
104+ gf = float (g .replace ("," , "" ))
105+ return abs (pf - gf ) < 1e-6
106+ except (ValueError , TypeError ):
107+ pass
108+ frac_pat = re .compile (r"\\?frac\s*\{([^}]+)\}\s*\{([^}]+)\}" )
109+ for s , other in [(p , g ), (g , p )]:
110+ m = frac_pat .search (s )
111+ if m :
112+ try :
113+ val = float (m .group (1 )) / float (m .group (2 ))
114+ oval = float (other .replace ("," , "" ))
115+ if abs (val - oval ) < 1e-6 :
116+ return True
117+ except (ValueError , ZeroDivisionError ):
118+ pass
119+ return False
120+
121+
122+ def _extract_numeric_answer (text : str ) -> str | None :
123+ """Extract a numeric answer from model output for GSM-style problems."""
124+ think_end = text .rfind ("</think>" )
125+ answer_text = text [think_end + len ("</think>" ):] if think_end >= 0 else text
126+
127+ # #### <number>
128+ m = re .search (r'####\s*([+-]?\d[\d,]*\.?\d*)' , answer_text )
129+ if m :
130+ return m .group (1 ).replace ("," , "" )
131+
132+ # \boxed{<number>}
133+ boxed = _extract_boxed (answer_text )
134+ if boxed :
135+ cleaned = boxed .replace ("," , "" ).strip ()
136+ if re .match (r'^[+-]?\d+\.?\d*$' , cleaned ):
137+ return cleaned
138+
139+ # "the answer is <number>"
140+ m = re .search (
141+ r'(?:answer\s+is|result\s+is|equals?|there\s+are|we\s+get)\s*\$?\s*\\?(?:boxed\{)?([+-]?\d[\d,]*\.?\d*)' ,
142+ answer_text , re .IGNORECASE )
143+ if m :
144+ return m .group (1 ).replace ("," , "" )
145+
146+ # **<number>**
147+ m = re .search (r'\*\*([+-]?\d[\d,]*\.?\d*)\*\*' , answer_text )
148+ if m :
149+ return m .group (1 ).replace ("," , "" )
150+
151+ # Last standalone number
152+ nums = re .findall (r'(?<![.\d])([+-]?\d[\d,]*\.?\d*)(?![.\d])' , answer_text )
153+ if nums :
154+ return nums [- 1 ].replace ("," , "" )
155+
156+ return None
157+
158+
159+ def score_gold_answer (case : dict [str , Any ], text : str ) -> tuple [bool | None , str ]:
160+ """Score model output against gold_answer if present.
161+
162+ Returns (correct_or_None, detail_str). None means no gold_answer to check.
163+ """
164+ gold = case .get ("gold_answer" )
165+ if gold is None :
166+ return None , ""
167+
168+ suite = case .get ("suite" , "" )
169+ think_end = text .rfind ("</think>" )
170+ answer_text = text [think_end + len ("</think>" ):] if think_end >= 0 else text
171+
172+ if suite == "gsm" :
173+ pred = _extract_numeric_answer (text )
174+ if pred is None :
175+ return False , f"no numeric answer found, gold={ gold } "
176+ try :
177+ correct = abs (float (pred ) - float (gold )) < 1e-6
178+ except (ValueError , TypeError ):
179+ correct = pred .strip () == gold .strip ()
180+ return correct , f"pred={ pred } gold={ gold } "
181+ else :
182+ # Math-style: extract \boxed{} and compare
183+ pred = _extract_boxed (answer_text )
184+ if not pred :
185+ pred = _extract_boxed (text )
186+ if not pred :
187+ # Fallback: bold pattern
188+ m = re .search (
189+ r'(?:answer\s+is|result\s+is|equals?)\s*\*\*(.+?)\*\*' ,
190+ answer_text , re .IGNORECASE )
191+ if m :
192+ pred = m .group (1 ).strip ().rstrip ("." )
193+ if not pred :
194+ return False , f"no answer found, gold={ gold } "
195+ correct = _math_equiv (pred , gold )
196+ return correct , f"pred={ pred } gold={ gold } "
197+
198+
55199def expected_pass (case : dict [str , Any ], text : str ) -> tuple [bool , list [str ]]:
56200 failures : list [str ] = []
57201 for needle in case .get ("expect_contains" , []):
@@ -142,6 +286,7 @@ def run_case(
142286 token_source = "approx_words"
143287 prompt_tokens = usage .get ("prompt_tokens" )
144288 pass_expected , failures = expected_pass (case , text )
289+ gold_correct , gold_detail = score_gold_answer (case , text )
145290 runs .append (
146291 {
147292 "elapsed_s" : elapsed ,
@@ -151,23 +296,29 @@ def run_case(
151296 "token_count_source" : token_source ,
152297 "expected_pass" : pass_expected ,
153298 "expected_failures" : failures ,
299+ "gold_correct" : gold_correct ,
300+ "gold_detail" : gold_detail ,
154301 "text" : text ,
155302 "usage" : usage ,
156303 }
157304 )
158305
159306 tok_s_values = [r ["tok_s" ] for r in runs ]
160307 elapsed_values = [r ["elapsed_s" ] for r in runs ]
308+ gold_results = [r ["gold_correct" ] for r in runs if r ["gold_correct" ] is not None ]
161309 return {
162310 "id" : case ["id" ],
163311 "description" : case .get ("description" , "" ),
164312 "expect_contains" : case .get ("expect_contains" , []),
165313 "expect_regex" : case .get ("expect_regex" , []),
314+ "gold_answer" : case .get ("gold_answer" ),
166315 "runs" : runs ,
167316 "mean_tok_s" : statistics .mean (tok_s_values ),
168317 "median_tok_s" : statistics .median (tok_s_values ),
169318 "mean_elapsed_s" : statistics .mean (elapsed_values ),
170319 "expected_pass" : all (r ["expected_pass" ] for r in runs ),
320+ "gold_correct" : all (gold_results ) if gold_results else None ,
321+ "gold_detail" : runs [- 1 ].get ("gold_detail" , "" ),
171322 "text" : runs [- 1 ]["text" ],
172323 "completion_tokens" : runs [- 1 ]["completion_tokens" ],
173324 "prompt_tokens" : runs [- 1 ]["prompt_tokens" ],
@@ -179,20 +330,25 @@ def cmd_run(args: argparse.Namespace) -> int:
179330 cases = load_cases (Path (args .prompts ))
180331 results = []
181332 for case in cases :
182- print (f"[bench] { args .name } : { case ['id' ]} " , flush = True )
183- results .append (
184- run_case (
185- case = case ,
186- base_url = args .url ,
187- api_key = args .api_key ,
188- model = args .model ,
189- max_tokens = args .max_tokens ,
190- temperature = args .temperature ,
191- timeout = args .timeout ,
192- repeats = args .repeats ,
193- )
333+ print (f"[bench] { args .name } : { case ['id' ]} " , end = "" , flush = True )
334+ result = run_case (
335+ case = case ,
336+ base_url = args .url ,
337+ api_key = args .api_key ,
338+ model = args .model ,
339+ max_tokens = args .max_tokens ,
340+ temperature = args .temperature ,
341+ timeout = args .timeout ,
342+ repeats = args .repeats ,
194343 )
195-
344+ results .append (result )
345+ if result ["gold_correct" ] is not None :
346+ mark = "🎯" if result ["gold_correct" ] else "✗"
347+ print (f" { mark } { result ['gold_detail' ]} " , flush = True )
348+ else :
349+ print (flush = True )
350+
351+ scored = [r for r in results if r ["gold_correct" ] is not None ]
196352 report = {
197353 "name" : args .name ,
198354 "url" : args .url ,
@@ -206,13 +362,18 @@ def cmd_run(args: argparse.Namespace) -> int:
206362 "summary" : {
207363 "cases" : len (results ),
208364 "expected_pass" : sum (1 for r in results if r ["expected_pass" ]),
365+ "gold_correct" : sum (1 for r in scored if r ["gold_correct" ]),
366+ "gold_scored" : len (scored ),
209367 "mean_tok_s" : statistics .mean ([r ["mean_tok_s" ] for r in results ]) if results else 0.0 ,
210368 },
211369 }
212370 out = Path (args .json_out )
213371 out .parent .mkdir (parents = True , exist_ok = True )
214372 out .write_text (json .dumps (report , indent = 2 , sort_keys = True ), encoding = "utf-8" )
215373 print (f"[bench] wrote { out } " )
374+ if scored :
375+ print (f"[bench] correctness: { report ['summary' ]['gold_correct' ]} /{ len (scored )} "
376+ f" ({ report ['summary' ]['gold_correct' ]/ len (scored )* 100 :.0f} %)" )
216377 return 0 if report ["summary" ]["expected_pass" ] == len (results ) else 1
217378
218379
0 commit comments