Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ragas.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ jobs:
name: coverage-comment-ragas
path: python-coverage-comment-action-ragas.txt

# No integration tests yet — add integration-cov-append-retry + combined coverage step when needed
- name: Run integration tests
run: hatch run test:integration-cov-append-retry

- name: Run unit tests with lowest direct dependencies
if: github.event_name != 'push'
Expand Down
9 changes: 8 additions & 1 deletion integrations/ragas/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.22.0", "ragas>=0.2.6,<0.3.0"]
dependencies = ["haystack-ai>=2.22.0", "ragas>=0.4.3"]

[project.urls]
Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ragas"
Expand Down Expand Up @@ -164,3 +164,10 @@ parallel = false
omit = ["*/tests/*", "*/__init__.py"]
show_missing = true
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]

[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"integration: integration tests",
]
log_cli = true
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import re
import inspect
from typing import Any, Union, cast, get_args, get_origin

from haystack import Document, component
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage
from pydantic import ValidationError

from ragas import evaluate
from ragas.dataset_schema import (
EvaluationDataset,
EvaluationResult,
SingleTurnSample,
)
from ragas.embeddings import BaseRagasEmbeddings
from ragas.llms import BaseRagasLLM
from ragas.metrics import Metric
from haystack_integrations.components.evaluators.ragas.utils import _deserialize_metric, _serialize_metric
from ragas.dataset_schema import SingleTurnSample
from ragas.metrics.base import SimpleBaseMetric
from ragas.metrics.result import MetricResult


@component
Expand All @@ -23,19 +18,21 @@ class RagasEvaluator:

See the [Ragas framework](https://docs.ragas.io/) for more details.

This component supports the modern Ragas metrics API (`ragas.metrics.collections`).
Each metric must be a `SimpleBaseMetric` instance with its LLM configured at construction time.

Usage example:
```python
from haystack.components.generators import OpenAIGenerator
from openai import AsyncOpenAI
from ragas.llms import llm_factory
from ragas.metrics.collections import Faithfulness
from haystack_integrations.components.evaluators.ragas import RagasEvaluator
from ragas.metrics import ContextPrecision
from ragas.llms import HaystackLLMWrapper

llm = OpenAIGenerator(model="gpt-4o-mini")
evaluator_llm = HaystackLLMWrapper(llm)
client = AsyncOpenAI()
llm = llm_factory("gpt-4o-mini", client=client)
Comment thread
anakin87 marked this conversation as resolved.

evaluator = RagasEvaluator(
ragas_metrics=[ContextPrecision()],
evaluator_llm=evaluator_llm
ragas_metrics=[Faithfulness(llm=llm)],
)
output = evaluator.run(
query="Which is the most popular global sport?",
Expand All @@ -53,52 +50,59 @@ class RagasEvaluator:
```
"""

def __init__(
self,
ragas_metrics: list[Metric],
evaluator_llm: BaseRagasLLM | None = None,
evaluator_embedding: BaseRagasEmbeddings | None = None,
) -> None:
def __init__(self, ragas_metrics: list[SimpleBaseMetric]) -> None:
"""
Constructs a new Ragas evaluator.

:param ragas_metrics: A list of evaluation metrics from the [Ragas](https://docs.ragas.io/) library.
:param evaluator_llm: A language model used by metrics that require LLMs for evaluation.
:param evaluator_embedding: An embedding model used by metrics that require embeddings for evaluation.
:param ragas_metrics: A list of modern Ragas metrics from `ragas.metrics.collections`.
Each metric must be fully configured (including its LLM) at construction time.
Comment thread
sjrl marked this conversation as resolved.
Available metrics can be found in the
[Ragas documentation](https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/).
"""
self._validate_inputs(ragas_metrics, evaluator_llm, evaluator_embedding)
self._validate_inputs(ragas_metrics)
self.metrics = ragas_metrics
self.llm = evaluator_llm
self.embedding = evaluator_embedding

def _validate_inputs(
self,
metrics: list[Metric],
llm: BaseRagasLLM | None,
embedding: BaseRagasEmbeddings | None,
) -> None:
@staticmethod
def _validate_inputs(metrics: list[SimpleBaseMetric]) -> None:
"""
Validate input parameters.

:param metrics: List of Ragas metrics to validate
:param llm: Language model to validate
:param embedding: Embedding model to validate

:param metrics: List of Ragas metrics to validate.
:return: None.
"""
if not all(isinstance(metric, Metric) for metric in metrics):
error_message = "All items in ragas_metrics must be instances of Metric class."
if not all(isinstance(metric, SimpleBaseMetric) for metric in metrics):
error_message = "All items in ragas_metrics must be instances of SimpleBaseMetric."
raise TypeError(error_message)

if llm is not None and not isinstance(llm, BaseRagasLLM):
error_message = f"Expected evaluator_llm to be BaseRagasLLM, got {type(llm).__name__}"
raise TypeError(error_message)
def to_dict(self) -> dict[str, Any]:
"""
Serialize this component to a dictionary.

if embedding is not None and not isinstance(embedding, BaseRagasEmbeddings):
error_message = f"Expected evaluator_embedding to be BaseRagasEmbeddings, got {type(embedding).__name__}"
raise TypeError(error_message)
:returns:
Dictionary with serialized data.
"""
return default_to_dict(self, ragas_metrics=[_serialize_metric(m) for m in self.metrics])

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "RagasEvaluator":
"""
Deserialize this component from a dictionary.

Metrics are reconstructed from their stored class path and LLM/embedding
configuration. Only the `openai` provider is supported for automatic
deserialization; the API key is read from the `OPENAI_API_KEY` environment
variable at load time.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
metrics_data = data.get("init_parameters", {}).get("ragas_metrics", [])
data["init_parameters"]["ragas_metrics"] = [_deserialize_metric(m) for m in metrics_data]
return default_from_dict(cls, data)

@component.output_types(result=EvaluationResult)
@component.output_types(result=dict[str, dict[str, MetricResult]])
def run(
self,
query: str | None = None,
Expand All @@ -108,9 +112,9 @@ def run(
multi_responses: list[str] | None = None,
reference: str | None = None,
rubrics: dict[str, str] | None = None,
) -> dict[str, Any]:
) -> dict[str, dict[str, MetricResult]]:
"""
Evaluates the provided query against the documents and returns the evaluation result.
Evaluates the provided inputs against each metric and returns the results.

:param query: The input query from the user.
:param response: A list of ChatMessage responses (typically from a language model or agent).
Expand All @@ -120,7 +124,7 @@ def run(
:param reference: A string reference answer for the query.
:param rubrics: A dictionary of evaluation rubric, where keys represent the score
and the values represent the corresponding evaluation criteria.
:return: A dictionary containing the evaluation result.
:return: A dictionary with key `result` mapping metric names to their `MetricResult`.
"""
processed_docs = self._process_documents(documents)
processed_response = self._process_response(response)
Expand All @@ -135,30 +139,41 @@ def run(
reference=reference,
rubrics=rubrics,
)

except (ValueError, ValidationError) as e:
except ValidationError as e:
self._handle_conversion_error(e)

dataset = EvaluationDataset([sample])
results: dict[str, MetricResult] = {}
for metric in self.metrics:
results[metric.name] = self._score_metric(metric, sample)

try:
result = evaluate(
dataset=dataset,
metrics=self.metrics,
llm=self.llm,
embeddings=self.embedding,
)
except (ValueError, ValidationError) as e:
self._handle_evaluation_error(e)
return {"result": results}

return {"result": result}
def _score_metric(self, metric: SimpleBaseMetric, sample: SingleTurnSample) -> MetricResult:
"""
Score a metric by inspecting its ascore() signature and passing only matching sample fields.

:param metric: A SimpleBaseMetric instance to score.
:param sample: The SingleTurnSample holding all available input fields.
:return: MetricResult from the metric.
"""
sig = inspect.signature(metric.ascore)
excluded = {"self", "callbacks"}
valid_params = {
name
for name, param in sig.parameters.items()
if name not in excluded
and param.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
sample_dict = sample.model_dump()
kwargs = {k: v for k, v in sample_dict.items() if k in valid_params and v is not None}
return metric.score(**kwargs)

def _process_documents(self, documents: list[Document | str] | None) -> list[str] | None:
"""
Process and validate input documents.

:param documents: List of Documents or strings to process
:return: List of document contents as strings or None
:param documents: List of Documents or strings to process.
:return: List of document contents as strings or None.
"""
if documents is None:
return None
Expand All @@ -178,10 +193,10 @@ def _process_response(self, response: list[ChatMessage] | str | None) -> str | N
"""
Process response into expected format.

:param response: Response to process
:return: None or Processed response string
:param response: Response to process.
:return: None or processed response string.
"""
if isinstance(response, list): # Check if response is a list
if isinstance(response, list):
if all(isinstance(item, ChatMessage) and item.text for item in response):
return response[0].text
return None
Expand All @@ -191,15 +206,12 @@ def _process_response(self, response: list[ChatMessage] | str | None) -> str | N

def _handle_conversion_error(self, error: Exception) -> None:
"""
Handle evaluation errors with improved messages.
Re-raise pydantic validation errors from SingleTurnSample with Haystack-friendly field names.

:params error: Original error
:params error: Original error.
"""
if isinstance(error, ValidationError):
field_mapping = {
"user_input": "query",
"retrieved_contexts": "documents",
}
field_mapping = {"user_input": "query", "retrieved_contexts": "documents"}
for err in error.errors():
# loc is a tuple of strings and ints but according to pydantic docs, the first element is a string
# https://docs.pydantic.dev/latest/errors/errors/
Expand All @@ -217,26 +229,6 @@ def _handle_conversion_error(self, error: Exception) -> None:
)
raise ValueError(error_message)

def _handle_evaluation_error(self, error: Exception) -> None:
error_message = str(error)
columns_match = re.search(r"additional columns \[(.*?)\]", error_message)
field_mapping = {
"user_input": "query",
"retrieved_contexts": "documents",
}
if columns_match:
columns_str = columns_match.group(1)
columns = [col.strip().strip("'") for col in columns_str.split(",")]

mapped_columns = [field_mapping.get(col, col) for col in columns]
updated_columns_str = "[" + ", ".join(f"'{col}'" for col in mapped_columns) + "]"

# Update the list of columns in the error message
updated_error_message = error_message.replace(
columns_match.group(0), f"additional columns {updated_columns_str}"
)
raise ValueError(updated_error_message)

def _get_expected_type_description(self, expected_type: Any) -> str:
"""Helper method to get a description of the expected type."""
if get_origin(expected_type) is Union:
Expand All @@ -252,21 +244,20 @@ def _get_expected_type_description(self, expected_type: Any) -> str:
value_type_name = getattr(value_type, "__name__", str(value_type))
return f"a dictionary with keys of type {key_type_name} and values of type {value_type_name}"
else:
# Handle non-generic types or unknown types gracefully
return getattr(expected_type, "__name__", str(expected_type))

def _get_example_input(self, field: str) -> str:
"""
Helper method to get an example input based on the field.

:param field: Arguement used to make SingleTurnSample.
:param field: Argument used to make SingleTurnSample.
:returns: Example usage for the field.
"""
examples = {
"query": "A string query like 'Question?'",
"documents": "[Document(content='Example content')]",
"reference_contexts": "['Example string 1', 'Example string 2']",
"response": "ChatMessage(_content='Hi', _role='assistant')",
"response": "ChatMessage.from_assistant('Hi')",
"multi_responses": "['Response 1', 'Response 2']",
"reference": "'A reference string'",
"rubrics": "{'score1': 'high_similarity'}",
Expand Down
Loading
Loading