|
3 | 3 | import math |
4 | 4 | import os |
5 | 5 | import statistics |
6 | | -from typing import Any, Callable, Dict, List, Optional |
| 6 | +from typing import Any, Callable, Dict, List, Literal, Optional |
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 |
|
|
29 | 29 | aggregate, |
30 | 30 | create_dynamically_parameterized_wrapper, |
31 | 31 | execute_function, |
| 32 | + log_eval_status_and_rows, |
32 | 33 | ) |
33 | 34 | from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci |
34 | 35 |
|
@@ -76,7 +77,7 @@ def evaluation_test( # noqa: C901 |
76 | 77 | aggregation_method: How to aggregate scores across rows. |
77 | 78 | threshold_of_success: If set, fail the test if the aggregated score is |
78 | 79 | below this threshold. |
79 | | - num_runs: Number of times to repeat the evaluation. |
| 80 | + num_runs: Number of times to repeat the rollout and evaluations. |
80 | 81 | max_dataset_rows: Limit dataset to the first N rows. |
81 | 82 | mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema |
82 | 83 | max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel. |
@@ -250,6 +251,11 @@ def wrapper_body(**kwargs): |
250 | 251 | eval_metadata = None |
251 | 252 | all_results: List[EvaluationRow] = [] |
252 | 253 |
|
| 254 | + def _log_eval_error( |
| 255 | + status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool |
| 256 | + ) -> None: |
| 257 | + log_eval_status_and_rows(eval_metadata, rows, status, passed, default_logger) |
| 258 | + |
253 | 259 | try: |
254 | 260 | # Handle dataset loading |
255 | 261 | data: List[EvaluationRow] = [] |
@@ -542,25 +548,11 @@ def _extract_effort_tag(params: dict) -> str | None: |
542 | 548 | agg_score >= threshold_of_success |
543 | 549 | ), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}" |
544 | 550 |
|
| 551 | + except AssertionError: |
| 552 | + _log_eval_error("finished", data if "data" in locals() else None, passed=False) |
| 553 | + raise |
545 | 554 | except Exception: |
546 | | - # Update eval metadata status to error and log it |
547 | | - if eval_metadata is not None: |
548 | | - eval_metadata.status = "error" |
549 | | - eval_metadata.passed = False |
550 | | - |
551 | | - # Create a minimal result row to log the error if we don't have any results yet |
552 | | - if not data: |
553 | | - error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None) |
554 | | - default_logger.log(error_row) |
555 | | - else: |
556 | | - # Update existing results with error status |
557 | | - for r in data: |
558 | | - if r.eval_metadata is not None: |
559 | | - r.eval_metadata.status = "error" |
560 | | - r.eval_metadata.passed = False |
561 | | - default_logger.log(r) |
562 | | - |
563 | | - # Re-raise the exception to maintain pytest behavior |
| 555 | + _log_eval_error("error", data if "data" in locals() else None, passed=False) |
564 | 556 | raise |
565 | 557 |
|
566 | 558 | return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names) |
|
0 commit comments