Skip to content

Commit 9559e31

Browse files
committed
feat: Add ability to evaluate ragas metrics asynchronously
1 parent cac1043 commit 9559e31

3 files changed

Lines changed: 208 additions & 5 deletions

File tree

integrations/ragas/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ignore_missing_imports = true
8282

8383
[tool.ruff]
8484
line-length = 120
85-
exclude = ["example", "tests"]
85+
exclude = ["example"]
8686

8787
[tool.ruff.lint]
8888
select = [
@@ -151,7 +151,7 @@ ban-relative-imports = "all"
151151

152152
[tool.ruff.lint.per-file-ignores]
153153
# Tests can use magic values, assertions, and relative imports
154-
"tests/**/*" = ["D", "PLR2004", "S101", "TID252", "ANN"]
154+
"tests/**/*" = ["D", "PLR2004", "S101", "TID252", "ANN", "ARG"]
155155

156156
[tool.coverage.run]
157157
source = ["haystack_integrations"]

integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from asyncio import Semaphore, gather
23
from typing import Any, Union, cast, get_args, get_origin
34

45
from haystack import Document, component, default_from_dict, default_to_dict
@@ -50,17 +51,20 @@ class RagasEvaluator:
5051
```
5152
"""
5253

53-
def __init__(self, ragas_metrics: list[SimpleBaseMetric]) -> None:
54+
def __init__(self, ragas_metrics: list[SimpleBaseMetric], concurrency_limit: int = 4) -> None:
5455
"""
5556
Constructs a new Ragas evaluator.
5657
5758
:param ragas_metrics: A list of modern Ragas metrics from `ragas.metrics.collections`.
5859
Each metric must be fully configured (including its LLM) at construction time.
5960
Available metrics can be found in the
6061
[Ragas documentation](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/).
62+
:param concurrency_limit:
63+
The maximum number of metric evaluations that should be allowed to run concurrently.
6164
"""
6265
self._validate_inputs(ragas_metrics)
6366
self.metrics = ragas_metrics
67+
self.concurrency_limit = concurrency_limit
6468

6569
@staticmethod
6670
def _validate_inputs(metrics: list[SimpleBaseMetric]) -> None:
@@ -148,6 +152,57 @@ def run(
148152

149153
return {"result": results}
150154

155+
@component.output_types(result=dict[str, dict[str, MetricResult]])
156+
async def run_async(
157+
self,
158+
query: str | None = None,
159+
response: list[ChatMessage] | str | None = None,
160+
documents: list[Document | str] | None = None,
161+
reference_contexts: list[str] | None = None,
162+
multi_responses: list[str] | None = None,
163+
reference: str | None = None,
164+
rubrics: dict[str, str] | None = None,
165+
) -> dict[str, dict[str, MetricResult]]:
166+
"""
167+
Asynchronously evaluates the provided inputs against each metric and returns the results.
168+
169+
:param query: The input query from the user.
170+
:param response: A list of ChatMessage responses (typically from a language model or agent).
171+
:param documents: A list of Haystack Document or strings that were retrieved for the query.
172+
:param reference_contexts: A list of reference contexts that should have been retrieved for the query.
173+
:param multi_responses: List of multiple responses generated for the query.
174+
:param reference: A string reference answer for the query.
175+
:param rubrics: A dictionary of evaluation rubric, where keys represent the score
176+
and the values represent the corresponding evaluation criteria.
177+
:return: A dictionary with key `result` mapping metric names to their `MetricResult`.
178+
"""
179+
processed_docs = self._process_documents(documents)
180+
processed_response = self._process_response(response)
181+
182+
try:
183+
sample = SingleTurnSample(
184+
user_input=query,
185+
retrieved_contexts=processed_docs,
186+
reference_contexts=reference_contexts,
187+
response=processed_response,
188+
multi_responses=multi_responses,
189+
reference=reference,
190+
rubrics=rubrics,
191+
)
192+
except ValidationError as e:
193+
self._handle_conversion_error(e)
194+
195+
sem = Semaphore(max(1, self.concurrency_limit))
196+
197+
async def _runner(metric: SimpleBaseMetric) -> tuple[str, MetricResult]:
198+
async with sem:
199+
return metric.name, await self._score_metric_async(metric, sample)
200+
201+
pairs = await gather(*[_runner(m) for m in self.metrics])
202+
results: dict[str, MetricResult] = dict(pairs)
203+
204+
return {"result": results}
205+
151206
def _score_metric(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> MetricResult:
152207
"""
153208
Score a metric by inspecting its ascore() signature and passing only matching sample fields.
@@ -168,6 +223,26 @@ def _score_metric(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> M
168223
kwargs = {k: v for k, v in sample_dict.items() if k in valid_params and v is not None}
169224
return metric.score(**kwargs)
170225

226+
async def _score_metric_async(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> MetricResult:
227+
"""
228+
Score a metric by inspecting its ascore() signature and passing only matching sample fields.
229+
230+
:param metric: A SimpleBaseMetric instance to score.
231+
:param sample: The SingleTurnSample holding all available input fields.
232+
:return: MetricResult from the metric.
233+
"""
234+
sig = inspect.signature(metric.ascore)
235+
excluded = {"self", "callbacks"}
236+
valid_params = {
237+
name
238+
for name, param in sig.parameters.items()
239+
if name not in excluded
240+
and param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
241+
}
242+
sample_dict = sample.model_dump()
243+
kwargs = {k: v for k, v in sample_dict.items() if k in valid_params and v is not None}
244+
return await metric.ascore(**kwargs)
245+
171246
def _process_documents(self, documents: list[Document | str] | None) -> list[str] | None:
172247
"""
173248
Process and validate input documents.

integrations/ragas/tests/test_evaluator.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import inspect
12
import os
2-
from unittest.mock import MagicMock
3+
from unittest.mock import AsyncMock, MagicMock
34

45
import pytest
56
from haystack import Document, Pipeline
@@ -47,6 +48,20 @@ async def ascore(user_input: str, response: str, retrieved_contexts: list) -> Me
4748
return metric
4849

4950

51+
def make_metric_async(name: str, score: float = 0.8, reason: str = "test reason") -> MagicMock:
52+
"""Create a mock SimpleBaseMetric with a concrete ascore signature for inspect.signature."""
53+
metric = MagicMock(spec=SimpleBaseMetric)
54+
metric.name = name
55+
56+
async def ascore(user_input: str, response: str, retrieved_contexts: list) -> MetricResult:
57+
return MetricResult(value=score, reason=reason)
58+
59+
mock_ascore = AsyncMock(return_value=MetricResult(value=score, reason=reason))
60+
mock_ascore.__signature__ = inspect.signature(ascore)
61+
metric.ascore = mock_ascore
62+
return metric
63+
64+
5065
class TestInit:
5166
def test_init(self, monkeypatch):
5267
monkeypatch.setenv("OPENAI_API_KEY", "test")
@@ -67,7 +82,7 @@ def test_init_with_multiple_metrics(self, monkeypatch):
6782
assert len(evaluator.metrics) == 2
6883

6984
def test_invalid_metrics_raises_type_error(self):
70-
with pytest.raises(TypeError, match="All items in ragas_metrics must be instances of SimpleBaseMetric."):
85+
with pytest.raises(TypeError, match=r"All items in ragas_metrics must be instances of SimpleBaseMetric."):
7186
RagasEvaluator(ragas_metrics=["not_a_metric"])
7287

7388

@@ -167,6 +182,119 @@ def test_run_raises_on_invalid_input_types(self, invalid_input, field_name, erro
167182
assert error_message in str(exc_info.value)
168183

169184

185+
class TestRunAsync:
186+
@pytest.mark.asyncio
187+
async def test_run_async_returns_result_by_metric_name(self) -> None:
188+
metric = make_metric_async("faithfulness", score=0.9)
189+
evaluator = RagasEvaluator(ragas_metrics=[metric])
190+
output = await evaluator.run_async(
191+
query="Which is the most popular global sport?",
192+
response="Football is the most popular sport.",
193+
documents=["Football is undoubtedly the world's most popular sport."],
194+
)
195+
assert "result" in output
196+
assert "faithfulness" in output["result"]
197+
result = output["result"]["faithfulness"]
198+
assert isinstance(result, MetricResult)
199+
assert result.value == 0.9
200+
201+
@pytest.mark.asyncio
202+
async def test_run_async_scores_all_metrics(self) -> None:
203+
metrics = [make_metric_async("faithfulness", 0.9), make_metric_async("answer_relevancy", 0.7)]
204+
evaluator = RagasEvaluator(ragas_metrics=metrics)
205+
output = await evaluator.run_async(query="test?", response="answer", documents=["doc"])
206+
assert set(output["result"].keys()) == {"faithfulness", "answer_relevancy"}
207+
assert output["result"]["faithfulness"].value == 0.9
208+
assert output["result"]["answer_relevancy"].value == 0.7
209+
210+
@pytest.mark.asyncio
211+
async def test_run_async_calls_ascore_on_each_metric(self) -> None:
212+
metric_a = make_metric_async("faithfulness")
213+
metric_b = make_metric_async("answer_relevancy")
214+
evaluator = RagasEvaluator(ragas_metrics=[metric_a, metric_b])
215+
await evaluator.run_async(query="test?", response="answer", documents=["doc"])
216+
metric_a.ascore.assert_called_once()
217+
metric_b.ascore.assert_called_once()
218+
219+
@pytest.mark.asyncio
220+
async def test_score_metric_async_passes_only_matching_params(self) -> None:
221+
"""Metric that only needs user_input + response should not receive retrieved_contexts."""
222+
metric = MagicMock(spec=SimpleBaseMetric)
223+
metric.name = "selective_metric"
224+
225+
async def ascore(user_input: str, response: str) -> MetricResult:
226+
return MetricResult(value=0.5, reason="ok")
227+
228+
metric.ascore = ascore
229+
230+
evaluator = RagasEvaluator(ragas_metrics=[metric])
231+
await evaluator.run_async(query="test?", response="answer", documents=["doc"], reference="ref")
232+
# Only user_input and response should have been passed — not retrieved_contexts or reference
233+
# We wrap ascore to capture kwargs
234+
captured = {}
235+
236+
async def capturing_ascore(user_input: str, response: str) -> MetricResult:
237+
captured.update({"user_input": user_input, "response": response})
238+
return MetricResult(value=0.5, reason="ok")
239+
240+
metric.ascore = capturing_ascore
241+
await evaluator.run_async(query="test?", response="answer", documents=["doc"], reference="ref")
242+
assert set(captured.keys()) == {"user_input", "response"}
243+
244+
@pytest.mark.asyncio
245+
async def test_score_metric_async_omits_none_fields(self) -> None:
246+
metric = make_metric_async("faithfulness")
247+
evaluator = RagasEvaluator(ragas_metrics=[metric])
248+
await evaluator.run_async(query="test?", response="answer") # no documents → retrieved_contexts=None
249+
_, kwargs = metric.ascore.call_args
250+
assert "retrieved_contexts" not in kwargs
251+
252+
@pytest.mark.asyncio
253+
async def test_run_async_accepts_document_objects(self) -> None:
254+
metric = make_metric_async("faithfulness")
255+
evaluator = RagasEvaluator(ragas_metrics=[metric])
256+
await evaluator.run_async(
257+
query="test?",
258+
response="answer",
259+
documents=[Document(content="some content"), Document(content="more content")],
260+
)
261+
_, kwargs = metric.ascore.call_args
262+
assert kwargs["retrieved_contexts"] == ["some content", "more content"]
263+
264+
@pytest.mark.asyncio
265+
async def test_run_async_accepts_string_documents(self):
266+
metric = make_metric_async("faithfulness")
267+
evaluator = RagasEvaluator(ragas_metrics=[metric])
268+
await evaluator.run_async(query="test?", response="answer", documents=["doc one", "doc two"])
269+
_, kwargs = metric.ascore.call_args
270+
assert kwargs["retrieved_contexts"] == ["doc one", "doc two"]
271+
272+
@pytest.mark.asyncio
273+
@pytest.mark.parametrize(
274+
"invalid_input,field_name,error_message",
275+
[
276+
(["Invalid query type"], "query", "'query' field expected"),
277+
([123, ["Invalid document"]], "documents", "'documents' must be a list"),
278+
(["score_1"], "rubrics", "'rubrics' field expected"),
279+
],
280+
)
281+
async def test_run_async_raises_on_invalid_input_types(self, invalid_input, field_name, error_message):
282+
evaluator = RagasEvaluator(ragas_metrics=[make_metric_async("faithfulness")])
283+
query = "Which is the most popular global sport?"
284+
documents = ["Football is the most popular sport."]
285+
response = "Football is the most popular sport in the world"
286+
287+
with pytest.raises(ValueError) as exc_info:
288+
if field_name == "query":
289+
await evaluator.run_async(query=invalid_input, documents=documents, response=response)
290+
elif field_name == "documents":
291+
await evaluator.run_async(query=query, documents=invalid_input, response=response)
292+
elif field_name == "rubrics":
293+
await evaluator.run_async(query=query, rubrics=invalid_input, documents=documents, response=response)
294+
295+
assert error_message in str(exc_info.value)
296+
297+
170298
class TestSerialization:
171299
def test_to_dict(self, monkeypatch):
172300
monkeypatch.setenv("OPENAI_API_KEY", "test")

0 commit comments

Comments
 (0)