@@ -35,6 +35,60 @@ def _require_helm_dependencies() -> None:
3535)
3636
3737
38+ def _score_from_stat (stat ) -> float | None :
39+ value = getattr (stat , 'mean' , None )
40+ if value is None :
41+ count = getattr (stat , 'count' , None )
42+ total = getattr (stat , 'sum' , None )
43+ if count :
44+ value = total / count
45+ if value is None :
46+ return None
47+ try :
48+ return float (value )
49+ except (TypeError , ValueError ):
50+ return None
51+
52+
53+ # Metric names whose per-instance score is a correctness signal in [0, 1]
54+ # where ``score > 0`` reasonably maps to ``is_correct=True``. Anything not
55+ # in this allowlist (token counts, runtime, finish-reason flags, logprobs,
56+ # etc.) gets ``is_correct=False`` because we have no correctness claim
57+ # from a bookkeeping/resource metric. Keep this list tight and named after
58+ # the actual HELM stat names — broaden only for verified correctness
59+ # semantics.
60+ _BINARY_CORRECTNESS_METRIC_NAMES : frozenset [str ] = frozenset ({
61+ 'exact_match' ,
62+ 'quasi_exact_match' ,
63+ 'prefix_exact_match' ,
64+ 'quasi_prefix_exact_match' ,
65+ 'exact_match@5' ,
66+ 'quasi_exact_match@5' ,
67+ 'prefix_exact_match@5' ,
68+ 'quasi_prefix_exact_match@5' ,
69+ 'ifeval_strict_accuracy' ,
70+ 'chain_of_thought_correctness' ,
71+ 'math_equiv' ,
72+ 'math_equiv_chain_of_thought' ,
73+ })
74+
75+
76+ def _is_correct_for_metric (metric_name : str | None , score : float ) -> bool :
77+ """Decide ``is_correct`` honestly per metric name.
78+
79+ For correctness metrics in the allowlist, the HELM convention is that
80+ score==1.0 means correct and 0.0 means wrong, so any positive score
81+ rounds up to "correct". For anything else (bookkeeping / resource
82+ stats, or graded metrics like rouge_l/bleu where >0 is not a correctness
83+ signal) we deliberately do not claim correctness.
84+ """
85+ if metric_name is None :
86+ return False
87+ if metric_name in _BINARY_CORRECTNESS_METRIC_NAMES :
88+ return score > 0
89+ return False
90+
91+
3892class HELMInstanceLevelDataAdapter :
3993 def __init__ (
4094 self ,
@@ -97,27 +151,23 @@ def convert_instance_level_logs(
97151 reasoning_traces = extract_all_reasonings (state )
98152 if isinstance (reasoning_traces , str ):
99153 reasoning_traces = [reasoning_traces ]
154+ if reasoning_traces is None :
155+ reasoning_traces = []
156+ reasoning_traces = [
157+ trace for trace in reasoning_traces if isinstance (trace , str )
158+ ]
100159
101- is_correct = False
102- score = 0.0
103- if inst_stats :
104- em_stat = next (
105- (
106- s
107- for s in inst_stats .stats
108- if s .name .name == 'exact_match'
109- ),
110- None ,
160+ metric_stats = list (inst_stats .stats ) if inst_stats else []
161+ if not metric_stats :
162+ correct_completions = sum (
163+ 1 for c in completions if c .strip () in correct_refs
111164 )
112- if em_stat :
113- score = em_stat .mean
114- is_correct = em_stat .mean > 0
115- else : # TODO check for more specific tasks
116- correct_completions = sum (
117- 1 for c in completions if c .strip () in correct_refs
118- )
119- score = correct_completions / len (completions )
120- is_correct = score > 0
165+ fallback_score = (
166+ correct_completions / len (completions )
167+ if completions
168+ else 0.0
169+ )
170+ metric_stats = [None ]
121171
122172 token_usage = None
123173 if inst_stats :
@@ -155,56 +205,72 @@ def convert_instance_level_logs(
155205 total_tokens = int (p_tokens + c_tokens ),
156206 )
157207
158- instance_level_logs .append (
159- InstanceLevelEvaluationLog (
160- schema_version = SCHEMA_VERSION ,
161- evaluation_id = self .evaluation_id ,
162- model_id = model_id ,
163- evaluation_name = evaluation_name ,
164- sample_id = str (state .instance .id ),
165- sample_hash = sha256_string (
166- state .request .prompt + (correct_refs [0 ] if correct_refs else '' )
167- ), # TODO use all references
168- interaction_type = InteractionType .single_turn ,
169- input = Input (
170- raw = state .request .prompt ,
171- reference = correct_refs if correct_refs else [],
172- choices = (
173- list (state .output_mapping .values ())
174- if state .output_mapping
175- else [
176- ref .output .text
177- for ref in state .instance .references
178- ]
208+ for stat in metric_stats :
209+ if stat is None :
210+ metric_name = None
211+ score = fallback_score
212+ # Fallback path: ``score`` here is an exact-match
213+ # proxy from completion-vs-reference matching, so
214+ # the correctness claim is honest in the same sense
215+ # as the legacy single-row behavior.
216+ is_correct = score > 0
217+ else :
218+ metric_name = getattr (getattr (stat , 'name' , None ), 'name' , None )
219+ score = _score_from_stat (stat )
220+ if score is None :
221+ continue
222+ is_correct = _is_correct_for_metric (metric_name , score )
223+ instance_level_logs .append (
224+ InstanceLevelEvaluationLog (
225+ schema_version = SCHEMA_VERSION ,
226+ evaluation_id = self .evaluation_id ,
227+ model_id = model_id ,
228+ evaluation_name = evaluation_name ,
229+ evaluation_result_id = metric_name ,
230+ sample_id = str (state .instance .id ),
231+ sample_hash = sha256_string (
232+ state .request .prompt + (correct_refs [0 ] if correct_refs else '' )
233+ ), # TODO use all references
234+ interaction_type = InteractionType .single_turn ,
235+ input = Input (
236+ raw = state .request .prompt ,
237+ reference = correct_refs if correct_refs else [],
238+ choices = (
239+ list (state .output_mapping .values ())
240+ if state .output_mapping
241+ else [
242+ ref .output .text
243+ for ref in state .instance .references
244+ ]
245+ ),
179246 ),
180- ),
181- output = Output (
182- raw = completions , reasoning_trace = reasoning_traces
183- ),
184- answer_attribution = [
185- AnswerAttributionItem (
186- turn_idx = 0 ,
187- source = 'output.raw' ,
188- extracted_value = state . result . completions [
189- 0
190- ]. text . strip ()
191- if state . result and state . result . completions
192- else ' ' ,
193- extraction_method = 'exact_match' ,
194- is_terminal = True ,
195- )
196- ],
197- evaluation = Evaluation (
198- score = float ( score ), is_correct = is_correct
199- ) ,
200- token_usage = token_usage ,
201- performance = Performance (
202- generation_time_ms = state .result .request_time * 1000
203- if state . result . request_time
204- else None
205- ),
247+ output = Output (
248+ raw = completions , reasoning_trace = reasoning_traces
249+ ),
250+ answer_attribution = [
251+ AnswerAttributionItem (
252+ turn_idx = 0 ,
253+ source = 'output.raw' ,
254+ extracted_value = state . result . completions [
255+ 0
256+ ]. text . strip ()
257+ if state . result and state . result . completions
258+ else '' ,
259+ extraction_method = 'exact_match ' ,
260+ is_terminal = True ,
261+ )
262+ ],
263+ evaluation = Evaluation (
264+ score = float ( score ), is_correct = is_correct
265+ ),
266+ token_usage = token_usage ,
267+ performance = Performance (
268+ generation_time_ms = state . result . request_time * 1000
269+ if state .result .request_time
270+ else None
271+ ),
272+ )
206273 )
207- )
208274
209275 self ._save_json (instance_level_logs )
210276 return self .path , len (instance_level_logs )
0 commit comments