Skip to content

Commit 66ca97f

Browse files
committed
add test
1 parent 034ce38 commit 66ca97f

1 file changed

Lines changed: 60 additions & 1 deletion

File tree

tests/test_evaluation_postprocess.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22

33
from unittest.mock import Mock, patch
44

5-
from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata, Message
5+
import pytest
6+
7+
from eval_protocol.models import (
8+
EvaluationRow,
9+
EvaluateResult,
10+
EvalMetadata,
11+
EvaluationThreshold,
12+
ExecutionMetadata,
13+
InputMetadata,
14+
Message,
15+
)
616
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
717
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
818

@@ -206,6 +216,55 @@ def test_all_invalid_scores(self):
206216
# Should still call logger.log for all rows
207217
assert mock_logger.log.call_count == 2
208218

219+
@patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads
220+
def test_threshold_all_zero_scores_fail(self):
221+
"""When all scores are 0.0 and threshold.success is 0.01, postprocess should fail."""
222+
all_results = [
223+
[self.create_test_row(0.0), self.create_test_row(0.0)],
224+
]
225+
226+
mock_logger = Mock()
227+
threshold = EvaluationThreshold(success=0.01, standard_error=None)
228+
229+
with pytest.raises(AssertionError) as excinfo:
230+
postprocess(
231+
all_results=all_results,
232+
aggregation_method="mean",
233+
threshold=threshold,
234+
active_logger=mock_logger,
235+
mode="pointwise",
236+
completion_params={"model": "test-model"},
237+
test_func_name="test_threshold_all_zero",
238+
num_runs=1,
239+
experiment_duration_seconds=10.0,
240+
)
241+
242+
# Sanity check on the assertion message
243+
assert "below threshold" in str(excinfo.value)
244+
245+
@patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads
246+
def test_threshold_equal_score_passes(self):
247+
"""When agg_score equals threshold.success (0.01), postprocess should pass."""
248+
all_results = [
249+
[self.create_test_row(0.01)],
250+
]
251+
252+
mock_logger = Mock()
253+
threshold = EvaluationThreshold(success=0.01, standard_error=None)
254+
255+
# Should not raise
256+
postprocess(
257+
all_results=all_results,
258+
aggregation_method="mean",
259+
threshold=threshold,
260+
active_logger=mock_logger,
261+
mode="pointwise",
262+
completion_params={"model": "test-model"},
263+
test_func_name="test_threshold_equal_score",
264+
num_runs=1,
265+
experiment_duration_seconds=10.0,
266+
)
267+
209268

210269
class TestBootstrapEquivalence:
211270
def test_bootstrap_equivalence_pandas_vs_pure_python(self):

0 commit comments

Comments
 (0)