-
Notifications
You must be signed in to change notification settings - Fork 101
Expand file tree
/
Copy pathcross_encoder_model.py
More file actions
81 lines (64 loc) · 2.3 KB
/
cross_encoder_model.py
File metadata and controls
81 lines (64 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional
@dataclass
class RerankResult:
"""
Represents a single reranked document result with metadata and score.
Attributes:
text (str): The document text content.
score (float): The relevance score from the cross encoder.
corpus_id (int): The original index in the input document list.
metadata (Optional[dict]): Original document metadata (source, etc.).
"""
text: str
score: float
corpus_id: int
metadata: Optional[dict] = None
class CrossEncoderModel(ABC):
"""
Abstract base class for cross encoder models.
This class defines a blueprint for implementing cross encoder models with a consistent interface for
loading and retrieving the model.
Attributes:
model_name (str): The name of the model.
model (Any): The loaded model instance.
"""
def __init__(self, model_name: str) -> None:
"""
Initializes an CrossEncoderModel instance.
Args:
model_name (str): The name of the model to be loaded.
"""
self.model_name: str = model_name
self.model: Any = self.load()
@abstractmethod
def load(self) -> Any:
"""
Abstract method to load the cross encoder model.
This method must be implemented by any concrete subclass to define the loading process
for the specific model.
Returns:
Any: The loaded model instance.
"""
pass
def get_model(self) -> CrossEncoderModel:
"""
Retrieves the loaded cross encoder model.
Returns:
Any: The loaded model instance.
"""
return self.model
@abstractmethod
def predict(self, query: str, documents: List[str], top_k: int) -> List[RerankResult]:
"""
Re-ranks the given documents against the query and returns the top_k most relevant.
Args:
query (str): The input query.
documents (List[str]): The list of document texts to rank.
top_k (int): The number of top results to return.
Returns:
List[RerankResult]: The top_k re-ranked results with scores and corpus IDs.
"""
pass