Skip to content

Commit bcf2238

Browse files
committed
fix(evaluate): export EvaluationModuleError and wrap _compute failures
Fixes #758
1 parent a7dd338 commit bcf2238

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

src/evaluate/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,15 @@
4545
from .info import ComparisonInfo, EvaluationModuleInfo, MeasurementInfo, MetricInfo
4646
from .inspect import inspect_evaluation_module, list_evaluation_modules
4747
from .loading import load
48-
from .module import CombinedEvaluations, Comparison, EvaluationModule, Measurement, Metric, combine
48+
from .module import (
49+
CombinedEvaluations,
50+
Comparison,
51+
EvaluationModule,
52+
EvaluationModuleError,
53+
Measurement,
54+
Metric,
55+
combine,
56+
)
4957
from .saving import save
5058
from .utils import *
5159
from .utils import gradio, logging

src/evaluate/module.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
# Lint as: python3
16-
""" EvaluationModule base class."""
16+
"""EvaluationModule base class."""
17+
1718
import collections
1819
import itertools
1920
import os
@@ -41,6 +42,10 @@
4142
logger = get_logger(__name__)
4243

4344

45+
class EvaluationModuleError(Exception):
46+
"""Raised when an EvaluationModule's compute step fails."""
47+
48+
4449
class FileFreeLock(BaseFileLock):
4550
"""Thread lock until a file **cannot** be locked"""
4651

@@ -464,7 +469,12 @@ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[di
464469

465470
inputs = {input_name: self.data[input_name][:] for input_name in self._feature_names()}
466471
with temp_seed(self.seed):
467-
output = self._compute(**inputs, **compute_kwargs)
472+
try:
473+
output = self._compute(**inputs, **compute_kwargs)
474+
except EvaluationModuleError:
475+
raise
476+
except Exception as e:
477+
raise EvaluationModuleError(f"Metric '{self.name}' computation failed: {e}") from e
468478

469479
if self.buf_writer is not None:
470480
self.buf_writer = None

0 commit comments

Comments
 (0)