Skip to content

Commit 4353b80

Browse files
committed
Fix HELM EEE instance metric rows
1 parent 4f15d50 commit 4353b80

3 files changed

Lines changed: 376 additions & 76 deletions

File tree

every_eval_ever/converters/helm/instance_level_adapter.py

Lines changed: 133 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,60 @@ def _require_helm_dependencies() -> None:
3535
)
3636

3737

38+
def _score_from_stat(stat) -> float | None:
39+
value = getattr(stat, 'mean', None)
40+
if value is None:
41+
count = getattr(stat, 'count', None)
42+
total = getattr(stat, 'sum', None)
43+
if count:
44+
value = total / count
45+
if value is None:
46+
return None
47+
try:
48+
return float(value)
49+
except (TypeError, ValueError):
50+
return None
51+
52+
53+
# Metric names whose per-instance score is a correctness signal in [0, 1]
54+
# where ``score > 0`` reasonably maps to ``is_correct=True``. Anything not
55+
# in this allowlist (token counts, runtime, finish-reason flags, logprobs,
56+
# etc.) gets ``is_correct=False`` because we have no correctness claim
57+
# from a bookkeeping/resource metric. Keep this list tight and named after
58+
# the actual HELM stat names — broaden only for verified correctness
59+
# semantics.
60+
_BINARY_CORRECTNESS_METRIC_NAMES: frozenset[str] = frozenset({
61+
'exact_match',
62+
'quasi_exact_match',
63+
'prefix_exact_match',
64+
'quasi_prefix_exact_match',
65+
'exact_match@5',
66+
'quasi_exact_match@5',
67+
'prefix_exact_match@5',
68+
'quasi_prefix_exact_match@5',
69+
'ifeval_strict_accuracy',
70+
'chain_of_thought_correctness',
71+
'math_equiv',
72+
'math_equiv_chain_of_thought',
73+
})
74+
75+
76+
def _is_correct_for_metric(metric_name: str | None, score: float) -> bool:
77+
"""Decide ``is_correct`` honestly per metric name.
78+
79+
For correctness metrics in the allowlist, the HELM convention is that
80+
score==1.0 means correct and 0.0 means wrong, so any positive score
81+
rounds up to "correct". For anything else (bookkeeping / resource
82+
stats, or graded metrics like rouge_l/bleu where >0 is not a correctness
83+
signal) we deliberately do not claim correctness.
84+
"""
85+
if metric_name is None:
86+
return False
87+
if metric_name in _BINARY_CORRECTNESS_METRIC_NAMES:
88+
return score > 0
89+
return False
90+
91+
3892
class HELMInstanceLevelDataAdapter:
3993
def __init__(
4094
self,
@@ -97,27 +151,23 @@ def convert_instance_level_logs(
97151
reasoning_traces = extract_all_reasonings(state)
98152
if isinstance(reasoning_traces, str):
99153
reasoning_traces = [reasoning_traces]
154+
if reasoning_traces is None:
155+
reasoning_traces = []
156+
reasoning_traces = [
157+
trace for trace in reasoning_traces if isinstance(trace, str)
158+
]
100159

101-
is_correct = False
102-
score = 0.0
103-
if inst_stats:
104-
em_stat = next(
105-
(
106-
s
107-
for s in inst_stats.stats
108-
if s.name.name == 'exact_match'
109-
),
110-
None,
160+
metric_stats = list(inst_stats.stats) if inst_stats else []
161+
if not metric_stats:
162+
correct_completions = sum(
163+
1 for c in completions if c.strip() in correct_refs
111164
)
112-
if em_stat:
113-
score = em_stat.mean
114-
is_correct = em_stat.mean > 0
115-
else: # TODO check for more specific tasks
116-
correct_completions = sum(
117-
1 for c in completions if c.strip() in correct_refs
118-
)
119-
score = correct_completions / len(completions)
120-
is_correct = score > 0
165+
fallback_score = (
166+
correct_completions / len(completions)
167+
if completions
168+
else 0.0
169+
)
170+
metric_stats = [None]
121171

122172
token_usage = None
123173
if inst_stats:
@@ -155,56 +205,72 @@ def convert_instance_level_logs(
155205
total_tokens=int(p_tokens + c_tokens),
156206
)
157207

158-
instance_level_logs.append(
159-
InstanceLevelEvaluationLog(
160-
schema_version=SCHEMA_VERSION,
161-
evaluation_id=self.evaluation_id,
162-
model_id=model_id,
163-
evaluation_name=evaluation_name,
164-
sample_id=str(state.instance.id),
165-
sample_hash=sha256_string(
166-
state.request.prompt + (correct_refs[0] if correct_refs else '')
167-
), # TODO use all references
168-
interaction_type=InteractionType.single_turn,
169-
input=Input(
170-
raw=state.request.prompt,
171-
reference=correct_refs if correct_refs else [],
172-
choices=(
173-
list(state.output_mapping.values())
174-
if state.output_mapping
175-
else [
176-
ref.output.text
177-
for ref in state.instance.references
178-
]
208+
for stat in metric_stats:
209+
if stat is None:
210+
metric_name = None
211+
score = fallback_score
212+
# Fallback path: ``score`` here is an exact-match
213+
# proxy from completion-vs-reference matching, so
214+
# the correctness claim is honest in the same sense
215+
# as the legacy single-row behavior.
216+
is_correct = score > 0
217+
else:
218+
metric_name = getattr(getattr(stat, 'name', None), 'name', None)
219+
score = _score_from_stat(stat)
220+
if score is None:
221+
continue
222+
is_correct = _is_correct_for_metric(metric_name, score)
223+
instance_level_logs.append(
224+
InstanceLevelEvaluationLog(
225+
schema_version=SCHEMA_VERSION,
226+
evaluation_id=self.evaluation_id,
227+
model_id=model_id,
228+
evaluation_name=evaluation_name,
229+
evaluation_result_id=metric_name,
230+
sample_id=str(state.instance.id),
231+
sample_hash=sha256_string(
232+
state.request.prompt + (correct_refs[0] if correct_refs else '')
233+
), # TODO use all references
234+
interaction_type=InteractionType.single_turn,
235+
input=Input(
236+
raw=state.request.prompt,
237+
reference=correct_refs if correct_refs else [],
238+
choices=(
239+
list(state.output_mapping.values())
240+
if state.output_mapping
241+
else [
242+
ref.output.text
243+
for ref in state.instance.references
244+
]
245+
),
179246
),
180-
),
181-
output=Output(
182-
raw=completions, reasoning_trace=reasoning_traces
183-
),
184-
answer_attribution=[
185-
AnswerAttributionItem(
186-
turn_idx=0,
187-
source='output.raw',
188-
extracted_value=state.result.completions[
189-
0
190-
].text.strip()
191-
if state.result and state.result.completions
192-
else '',
193-
extraction_method='exact_match',
194-
is_terminal=True,
195-
)
196-
],
197-
evaluation=Evaluation(
198-
score=float(score), is_correct=is_correct
199-
),
200-
token_usage=token_usage,
201-
performance=Performance(
202-
generation_time_ms=state.result.request_time * 1000
203-
if state.result.request_time
204-
else None
205-
),
247+
output=Output(
248+
raw=completions, reasoning_trace=reasoning_traces
249+
),
250+
answer_attribution=[
251+
AnswerAttributionItem(
252+
turn_idx=0,
253+
source='output.raw',
254+
extracted_value=state.result.completions[
255+
0
256+
].text.strip()
257+
if state.result and state.result.completions
258+
else '',
259+
extraction_method='exact_match',
260+
is_terminal=True,
261+
)
262+
],
263+
evaluation=Evaluation(
264+
score=float(score), is_correct=is_correct
265+
),
266+
token_usage=token_usage,
267+
performance=Performance(
268+
generation_time_ms=state.result.request_time * 1000
269+
if state.result.request_time
270+
else None
271+
),
272+
)
206273
)
207-
)
208274

209275
self._save_json(instance_level_logs)
210276
return self.path, len(instance_level_logs)

tests/test_helm_adapter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def test_mmlu_eval():
7676

7777
assert converted_eval.detailed_evaluation_results is not None
7878
assert converted_eval.detailed_evaluation_results.format is not None
79-
assert converted_eval.detailed_evaluation_results.total_rows == 10
79+
# Per-(sample, metric) emission: each of the 10 samples produces one
80+
# row per non-empty stat, so total_rows is much larger than the
81+
# legacy "one row per sample" count.
82+
assert converted_eval.detailed_evaluation_results.total_rows >= 10
8083

8184

8285
def test_hellswag_eval():
@@ -117,7 +120,8 @@ def test_hellswag_eval():
117120

118121
assert converted_eval.detailed_evaluation_results is not None
119122
assert converted_eval.detailed_evaluation_results.format is not None
120-
assert converted_eval.detailed_evaluation_results.total_rows == 10
123+
# Per-(sample, metric): >= sample count, not equal to it.
124+
assert converted_eval.detailed_evaluation_results.total_rows >= 10
121125

122126

123127
def test_narrativeqa_eval():
@@ -154,7 +158,8 @@ def test_narrativeqa_eval():
154158

155159
assert converted_eval.detailed_evaluation_results is not None
156160
assert converted_eval.detailed_evaluation_results.format is not None
157-
assert converted_eval.detailed_evaluation_results.total_rows == 5
161+
# Per-(sample, metric): >= sample count, not equal to it.
162+
assert converted_eval.detailed_evaluation_results.total_rows >= 5
158163

159164

160165
def test_missing_model_deployment_falls_back_to_model():

0 commit comments

Comments
 (0)