Skip to content

Commit b28fa2b

Browse files
author
Dylan Huang
committed
assertion error means finished
1 parent ffcb08d commit b28fa2b

2 files changed

Lines changed: 45 additions & 21 deletions

File tree

eval_protocol/pytest/evaluation_test.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
import os
55
import statistics
6-
from typing import Any, Callable, Dict, List, Optional
6+
from typing import Any, Callable, Dict, List, Literal, Optional
77

88
import pytest
99

@@ -29,6 +29,7 @@
2929
aggregate,
3030
create_dynamically_parameterized_wrapper,
3131
execute_function,
32+
log_eval_status_and_rows,
3233
)
3334
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
3435

@@ -76,7 +77,7 @@ def evaluation_test( # noqa: C901
7677
aggregation_method: How to aggregate scores across rows.
7778
threshold_of_success: If set, fail the test if the aggregated score is
7879
below this threshold.
79-
num_runs: Number of times to repeat the evaluation.
80+
num_runs: Number of times to repeat the rollout and evaluations.
8081
max_dataset_rows: Limit dataset to the first N rows.
8182
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
8283
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
@@ -250,6 +251,11 @@ def wrapper_body(**kwargs):
250251
eval_metadata = None
251252
all_results: List[EvaluationRow] = []
252253

254+
def _log_eval_error(
255+
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
256+
) -> None:
257+
log_eval_status_and_rows(eval_metadata, rows, status, passed, default_logger)
258+
253259
try:
254260
# Handle dataset loading
255261
data: List[EvaluationRow] = []
@@ -542,25 +548,11 @@ def _extract_effort_tag(params: dict) -> str | None:
542548
agg_score >= threshold_of_success
543549
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
544550

551+
except AssertionError:
552+
_log_eval_error("finished", data if "data" in locals() else None, passed=False)
553+
raise
545554
except Exception:
546-
# Update eval metadata status to error and log it
547-
if eval_metadata is not None:
548-
eval_metadata.status = "error"
549-
eval_metadata.passed = False
550-
551-
# Create a minimal result row to log the error if we don't have any results yet
552-
if not data:
553-
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
554-
default_logger.log(error_row)
555-
else:
556-
# Update existing results with error status
557-
for r in data:
558-
if r.eval_metadata is not None:
559-
r.eval_metadata.status = "error"
560-
r.eval_metadata.passed = False
561-
default_logger.log(r)
562-
563-
# Re-raise the exception to maintain pytest behavior
555+
_log_eval_error("error", data if "data" in locals() else None, passed=False)
564556
raise
565557

566558
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)

eval_protocol/pytest/utils.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
22
import inspect
3-
from typing import Any, Callable, List, Literal
3+
from typing import Any, Callable, List, Literal, Optional
4+
5+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
6+
from eval_protocol.models import EvalMetadata, EvaluationRow
47

58

69
def execute_function(func: Callable, **kwargs) -> Any:
@@ -92,3 +95,32 @@ def wrapper(**kwargs):
9295
wrapper.__signature__ = inspect.Signature(parameters)
9396

9497
return wrapper
98+
99+
100+
def log_eval_status_and_rows(
101+
eval_metadata: Optional[EvalMetadata],
102+
rows: Optional[List[EvaluationRow]] | None,
103+
status: Literal["finished", "error"],
104+
passed: bool,
105+
logger: DatasetLogger,
106+
) -> None:
107+
"""Update eval status and emit rows to the given logger.
108+
109+
If no rows are provided, emits a minimal placeholder row so downstream
110+
consumers still observe a terminal status.
111+
"""
112+
if eval_metadata is None:
113+
return
114+
115+
eval_metadata.status = status
116+
eval_metadata.passed = passed
117+
118+
rows_to_log: List[EvaluationRow] = rows or []
119+
if not rows_to_log:
120+
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
121+
logger.log(error_row)
122+
else:
123+
for r in rows_to_log:
124+
if r.eval_metadata is not None:
125+
r.eval_metadata.status = status
126+
logger.log(r)

0 commit comments

Comments
 (0)