From bcf223811cacd8acbfe3b6d7b74a4027b2f73689 Mon Sep 17 00:00:00 2001 From: xodn348 Date: Fri, 22 May 2026 05:19:48 +0000 Subject: [PATCH] fix(evaluate): export EvaluationModuleError and wrap _compute failures Fixes #758 --- src/evaluate/__init__.py | 10 +++++++++- src/evaluate/module.py | 14 ++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/evaluate/__init__.py b/src/evaluate/__init__.py index a8c25bd9..d52146d6 100644 --- a/src/evaluate/__init__.py +++ b/src/evaluate/__init__.py @@ -45,7 +45,15 @@ from .info import ComparisonInfo, EvaluationModuleInfo, MeasurementInfo, MetricInfo from .inspect import inspect_evaluation_module, list_evaluation_modules from .loading import load -from .module import CombinedEvaluations, Comparison, EvaluationModule, Measurement, Metric, combine +from .module import ( + CombinedEvaluations, + Comparison, + EvaluationModule, + EvaluationModuleError, + Measurement, + Metric, + combine, +) from .saving import save from .utils import * from .utils import gradio, logging diff --git a/src/evaluate/module.py b/src/evaluate/module.py index ca38b9b1..461e85a4 100644 --- a/src/evaluate/module.py +++ b/src/evaluate/module.py @@ -13,7 +13,8 @@ # limitations under the License. # Lint as: python3 -""" EvaluationModule base class.""" +"""EvaluationModule base class.""" + import collections import itertools import os @@ -41,6 +42,10 @@ logger = get_logger(__name__) +class EvaluationModuleError(Exception): + """Raised when an EvaluationModule's compute step fails.""" + + class FileFreeLock(BaseFileLock): """Thread lock until a file **cannot** be locked""" @@ -464,7 +469,12 @@ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[di inputs = {input_name: self.data[input_name][:] for input_name in self._feature_names()} with temp_seed(self.seed): - output = self._compute(**inputs, **compute_kwargs) + try: + output = self._compute(**inputs, **compute_kwargs) + except EvaluationModuleError: + raise + except Exception as e: + raise EvaluationModuleError(f"Metric '{self.name}' computation failed: {e}") from e if self.buf_writer is not None: self.buf_writer = None