Skip to content

Commit a26bff2

Browse files
authored
feat: Add run_async method to Ragas Integration (#3244)
1 parent 8aaf30a commit a26bff2

3 files changed

Lines changed: 353 additions & 8 deletions

File tree

integrations/ragas/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ignore_missing_imports = true
8282

8383
[tool.ruff]
8484
line-length = 120
85-
exclude = ["example", "tests"]
85+
exclude = ["example"]
8686

8787
[tool.ruff.lint]
8888
select = [
@@ -151,7 +151,7 @@ ban-relative-imports = "all"
151151

152152
[tool.ruff.lint.per-file-ignores]
153153
# Tests can use magic values, assertions, and relative imports
154-
"tests/**/*" = ["D", "PLR2004", "S101", "TID252", "ANN"]
154+
"tests/**/*" = ["D", "PLR2004", "S101", "TID252", "ANN", "ARG"]
155155

156156
[tool.coverage.run]
157157
source = ["haystack_integrations"]

integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from asyncio import Semaphore, gather
23
from typing import Any, Union, cast, get_args, get_origin
34

45
from haystack import Document, component, default_from_dict, default_to_dict
@@ -50,17 +51,20 @@ class RagasEvaluator:
5051
```
5152
"""
5253

53-
def __init__(self, ragas_metrics: list[SimpleBaseMetric]) -> None:
54+
def __init__(self, ragas_metrics: list[SimpleBaseMetric], concurrency_limit: int = 4) -> None:
5455
"""
5556
Constructs a new Ragas evaluator.
5657
5758
:param ragas_metrics: A list of modern Ragas metrics from `ragas.metrics.collections`.
5859
Each metric must be fully configured (including its LLM) at construction time.
5960
Available metrics can be found in the
6061
[Ragas documentation](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/).
62+
:param concurrency_limit: The maximum number of metric evaluations that should be allowed to run concurrently.
63+
This parameter is only used in the `run_async` method.
6164
"""
6265
self._validate_inputs(ragas_metrics)
6366
self.metrics = ragas_metrics
67+
self.concurrency_limit = concurrency_limit
6468

6569
@staticmethod
6670
def _validate_inputs(metrics: list[SimpleBaseMetric]) -> None:
@@ -148,6 +152,57 @@ def run(
148152

149153
return {"result": results}
150154

155+
@component.output_types(result=dict[str, dict[str, MetricResult]])
156+
async def run_async(
157+
self,
158+
query: str | None = None,
159+
response: list[ChatMessage] | str | None = None,
160+
documents: list[Document | str] | None = None,
161+
reference_contexts: list[str] | None = None,
162+
multi_responses: list[str] | None = None,
163+
reference: str | None = None,
164+
rubrics: dict[str, str] | None = None,
165+
) -> dict[str, dict[str, MetricResult]]:
166+
"""
167+
Asynchronously evaluates the provided inputs against each metric and returns the results.
168+
169+
:param query: The input query from the user.
170+
:param response: A list of ChatMessage responses (typically from a language model or agent).
171+
:param documents: A list of Haystack Document or strings that were retrieved for the query.
172+
:param reference_contexts: A list of reference contexts that should have been retrieved for the query.
173+
:param multi_responses: List of multiple responses generated for the query.
174+
:param reference: A string reference answer for the query.
175+
:param rubrics: A dictionary of evaluation rubric, where keys represent the score
176+
and the values represent the corresponding evaluation criteria.
177+
:return: A dictionary with key `result` mapping metric names to their `MetricResult`.
178+
"""
179+
processed_docs = self._process_documents(documents)
180+
processed_response = self._process_response(response)
181+
182+
try:
183+
sample = SingleTurnSample(
184+
user_input=query,
185+
retrieved_contexts=processed_docs,
186+
reference_contexts=reference_contexts,
187+
response=processed_response,
188+
multi_responses=multi_responses,
189+
reference=reference,
190+
rubrics=rubrics,
191+
)
192+
except ValidationError as e:
193+
self._handle_conversion_error(e)
194+
195+
sem = Semaphore(max(1, self.concurrency_limit))
196+
197+
async def _runner(metric: SimpleBaseMetric) -> tuple[str, MetricResult]:
198+
async with sem:
199+
return metric.name, await self._score_metric_async(metric, sample)
200+
201+
pairs = await gather(*[_runner(m) for m in self.metrics])
202+
results: dict[str, MetricResult] = dict(pairs)
203+
204+
return {"result": results}
205+
151206
def _score_metric(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> MetricResult:
152207
"""
153208
Score a metric by inspecting its ascore() signature and passing only matching sample fields.
@@ -168,6 +223,26 @@ def _score_metric(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> M
168223
kwargs = {k: v for k, v in sample_dict.items() if k in valid_params and v is not None}
169224
return metric.score(**kwargs)
170225

226+
async def _score_metric_async(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> MetricResult:
227+
"""
228+
Score a metric by inspecting its ascore() signature and passing only matching sample fields.
229+
230+
:param metric: A SimpleBaseMetric instance to score.
231+
:param sample: The SingleTurnSample holding all available input fields.
232+
:return: MetricResult from the metric.
233+
"""
234+
sig = inspect.signature(metric.ascore)
235+
excluded = {"self", "callbacks"}
236+
valid_params = {
237+
name
238+
for name, param in sig.parameters.items()
239+
if name not in excluded
240+
and param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
241+
}
242+
sample_dict = sample.model_dump()
243+
kwargs = {k: v for k, v in sample_dict.items() if k in valid_params and v is not None}
244+
return await metric.ascore(**kwargs)
245+
171246
def _process_documents(self, documents: list[Document | str] | None) -> list[str] | None:
172247
"""
173248
Process and validate input documents.

0 commit comments

Comments
 (0)