Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/raglight/cross_encoder/cross_encoder_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List
from dataclasses import dataclass
from typing import Any, List, Optional


@dataclass
class RerankResult:
"""
Represents a single reranked document result with metadata and score.

Attributes:
text (str): The document text content.
score (float): The relevance score from the cross encoder.
corpus_id (int): The original index in the input document list.
metadata (Optional[dict]): Original document metadata (source, etc.).
"""
text: str
score: float
corpus_id: int
metadata: Optional[dict] = None


class CrossEncoderModel(ABC):
Expand Down Expand Up @@ -48,7 +66,7 @@ def get_model(self) -> CrossEncoderModel:
return self.model

@abstractmethod
def predict(self, query: str, documents: List[str], top_k: int) -> List[str]:
def predict(self, query: str, documents: List[str], top_k: int) -> List[RerankResult]:
"""
Re-ranks the given documents against the query and returns the top_k most relevant.

Expand All @@ -58,6 +76,6 @@ def predict(self, query: str, documents: List[str], top_k: int) -> List[str]:
top_k (int): The number of top results to return.

Returns:
List[str]: The top_k re-ranked document texts.
List[RerankResult]: The top_k re-ranked results with scores and corpus IDs.
"""
pass
22 changes: 16 additions & 6 deletions src/raglight/cross_encoder/huggingface_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import List
from typing_extensions import override
from .cross_encoder_model import CrossEncoderModel
from .cross_encoder_model import CrossEncoderModel, RerankResult
from sentence_transformers import CrossEncoder


Expand Down Expand Up @@ -39,22 +39,32 @@ def load(self) -> HuggingfaceCrossEncoderModel:
return CrossEncoder(self.model_name)

@override
def predict(self, query: str, documents: List[str], top_k: int) -> List[str]:
def predict(self, query: str, documents: List[str], top_k: int) -> List[RerankResult]:
"""
Predicts the similarity scores and returns the list of most relevant document texts.
Predicts the similarity scores and returns the list of most relevant document results.

Args:
query (str): The input query.
documents (List[str]): The list of document texts to rank.
top_k (int): The number of top results to return.

Returns:
List[str]: The list of top_k re-ranked document texts.
List[RerankResult]: The list of top_k re-ranked results with scores and corpus IDs.
"""
if not documents:
return []

# rank returns a list of dicts: [{'corpus_id': int, 'score': float, 'text': str}, ...]
results = self.model.rank(
query=query, documents=documents, top_k=top_k, return_documents=True
)

# We extract and return only the text strings
return [res["text"] for res in results]
# Convert to RerankResult objects with all metadata
return [
RerankResult(
text=res["text"],
score=float(res["score"]),
corpus_id=int(res["corpus_id"])
)
for res in results
]
Loading