Skip to content

Commit 5604499

Browse files
Merge pull request #115 from DataScienceUIBK/copilot/add-retrievers-and-rerankers
Add TAS-B, UniCOIL, SPLADE-v2, API embedding retrievers, PRP, TART, RankGemma, RankMistral
2 parents a3c8a7c + fb27893 commit 5604499

12 files changed

Lines changed: 1340 additions & 7 deletions

README.md

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,30 @@ model = Reranking(method='vicuna_reranker', model_name='rank_vicuna_7b_v1')
11541154

11551155
# Zephyr Reranker
11561156
model = Reranking(method='zephyr_reranker', model_name='rank_zephyr_7b_v1_full')
1157+
1158+
# DuoT5 (pairwise T5-based reranker)
1159+
model = Reranking(method='duot5', model_name='duot5-base-msmarco')
1160+
1161+
# RankLLaMA (LLaMA-based passage reranker)
1162+
model = Reranking(method='rankllama', model_name='rankllama-v1-7b-lora-passage')
1163+
1164+
# DeAR (Decoder-only Autoregressive Reranker)
1165+
model = Reranking(method='dear_reranker', model_name='dear-3b-reranker-ce-v1')
1166+
1167+
# TART (Task-Aware Reranker with Instructions)
1168+
model = Reranking(method='tart', model_name='tart-full-flan-t5-xl')
1169+
1170+
# PRP (Pairwise Ranking Prompting) — local LLM
1171+
model = Reranking(method='prp', model_name='llamav3.1-8b')
1172+
1173+
# PRP — API-based LLM
1174+
model = Reranking(method='prp-api', model_name='gpt-4', api_key="gpt-api-key")
1175+
1176+
# RankGemma (Gemma-based listwise reranker)
1177+
model = Reranking(method='rankgemma', model_name='gemma-2-2b')
1178+
1179+
# RankMistral (Mistral-based listwise reranker)
1180+
model = Reranking(method='rankmistral', model_name='mistral-7b')
11571181
```
11581182
---
11591183

@@ -1453,7 +1477,13 @@ print("RAGAS (OpenAI):", {k: v for k, v in scores_openai.items() if k.startswith
14531477
- 🕒 **coCondenser**
14541478
- 🕒 **Spar**
14551479
- 🕒 **Dragon**
1456-
- 🕒 **Hybrid**
1480+
- 🕒 **Hybrid**
1481+
-**TAS-B**
1482+
-**UniCOIL**
1483+
-**SPLADE-v2**
1484+
-**OpenAI Embedding Retriever**
1485+
-**Cohere Embedding Retriever**
1486+
-**Voyage AI Retriever**
14571487
---
14581488

14591489
### **2️⃣ Rerankers**
@@ -1483,9 +1513,18 @@ print("RAGAS (OpenAI):", {k: v for k, v in scores_openai.items() if k.startswith
14831513
-**[LLM2VEC Reranker](https://arxiv.org/abs/2404.05961)**
14841514
-**[ECHO Reranker](https://arxiv.org/abs/2402.10866)**
14851515
-**[Incontext Reranker](https://arxiv.org/abs/2410.02642)**
1516+
-**[DuoT5](https://arxiv.org/abs/2101.05667)**
1517+
-**[RankLLaMA](https://arxiv.org/abs/2310.08319)**
1518+
-**[DeAR](https://arxiv.org/abs/2410.23089)**
14861519
- 🕒 **DynRank**
14871520
- 🕒 **ASRank**
1488-
- 🕒 **RankLlama**
1521+
-**PRP (Pairwise Ranking Prompting)**
1522+
-**RankMistral**
1523+
-**RankGemma**
1524+
- 🕒 **SetRank**
1525+
- 🕒 **Cohere Rerank API**
1526+
-**TART**
1527+
- 🕒 **PolyEncoder**
14891528

14901529
---
14911530

@@ -1513,7 +1552,7 @@ print("RAGAS (OpenAI):", {k: v for k, v in scores_openai.items() if k.startswith
15131552
- 🔥 **Unified Framework**: Combines **retrieval**, **re-ranking**, and **retrieval-augmented generation (RAG)** into a single modular toolkit.
15141553
- 📚 **Rich Dataset Support**: Includes **40+ benchmark datasets** with **pre-retrieved documents** for seamless experimentation.
15151554
- 🧲 **Diverse Retrieval Methods**: Supports **BM25, DPR, ANCE, BPR, ColBERT, BGE, Contriever, SFR, E5, GritLM, M2, Nomic, Instructor, RaDeR, ReasonIR, BGE-Reasoner and ReasonEmbed** for flexible retrieval strategies.
1516-
- 🎯 **Powerful Re-Ranking**: Implements **24 advanced models** with **41 sub-methods** to optimize ranking performance.
1555+
- 🎯 **Powerful Re-Ranking**: Implements **28 advanced models** with **44 sub-methods** to optimize ranking performance.
15171556
- 🏗️ **Prebuilt Indices**: Provides **Wikipedia and MS MARCO** corpora, eliminating indexing overhead and speeding up retrieval.
15181557
- 🔮 **Seamless RAG Integration**: Works with backends like **Hugging Face, OpenAI, vLLM, LiteLLM** inferening models like **GPT, LLAMA, T5, and Fusion-in-Decoder (FiD)** for multiple **retrieval-augmented generation** methods.
15191558
- 🛠 **Extensible & Modular**: Easily integrates **custom datasets, retrievers, ranking models, and RAG pipelines**.

rankify/models/prp_reranker.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
Pairwise Ranking Prompting (PRP) reranker.
3+
4+
References:
5+
- Qin et al. (2023): "Large Language Models are Effective Text Rankers
6+
with Pairwise Ranking Prompting"
7+
https://arxiv.org/abs/2306.17563
8+
"""
9+
10+
import copy
11+
from typing import List, Optional
12+
13+
import torch
14+
from tqdm import tqdm
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from rankify.models.base import BaseRanking
18+
from rankify.dataset.dataset import Document, Context
19+
20+
21+
class PRPReranker(BaseRanking):
22+
"""
23+
Pairwise Ranking Prompting (PRP) reranker.
24+
25+
For each query, compares document pairs by prompting an LLM to decide
26+
which of two passages (A or B) is more relevant. The final ranking is
27+
determined by the number of pairwise wins each document accumulates.
28+
29+
Two comparison modes are available:
30+
31+
* ``'allpairs'`` – compare every ordered pair (O(N²) LLM calls).
32+
* ``'bubblesort'`` – O(N²) comparisons in the worst case; faster for large
33+
sets when documents are partially ordered.
34+
35+
For **local** models (``method='prp'``) the winner is decided by
36+
comparing the next-token logit probability of the tokens ``"A"`` and
37+
``"B"``. For **API** models (``method='prp-api'``) the response text
38+
is parsed for the letter ``"A"`` or ``"B"``.
39+
40+
References:
41+
- Qin et al. (2023): https://arxiv.org/abs/2306.17563
42+
43+
Args:
44+
method (str): ``'prp'`` for local HuggingFace models,
45+
``'prp-api'`` for API-based LLMs.
46+
model_name (str): HuggingFace model ID (local) or API model name.
47+
api_key (str, optional): API key when ``method='prp-api'``.
48+
mode (str): ``'allpairs'`` (default) or ``'bubblesort'``.
49+
max_pairs (int): Maximum pairwise comparisons per query (allpairs only).
50+
Defaults to 100.
51+
api_endpoint (str, optional): OpenAI-compatible API base URL.
52+
device (str, optional): ``'cpu'`` or ``'cuda'``. Auto-detected when
53+
not supplied.
54+
55+
Example:
56+
```python
57+
from rankify.models.reranking import Reranking
58+
59+
model = Reranking(method='prp', model_name='llamav3.1-8b')
60+
model.rank(documents)
61+
```
62+
"""
63+
64+
PROMPT_TEMPLATE = (
65+
"Given the following query: {query}\n\n"
66+
"Document A: {doc_a}\n\n"
67+
"Document B: {doc_b}\n\n"
68+
"Which document is more relevant to the query? Answer with 'A' or 'B'."
69+
)
70+
71+
def __init__(
72+
self,
73+
method: Optional[str] = None,
74+
model_name: Optional[str] = None,
75+
api_key: Optional[str] = None,
76+
mode: str = "allpairs",
77+
max_pairs: int = 100,
78+
api_endpoint: str = "https://api.openai.com/v1",
79+
device: Optional[str] = None,
80+
**kwargs,
81+
):
82+
self.method = method or "prp"
83+
self.model_name = model_name
84+
self.api_key = api_key
85+
self.mode = mode
86+
self.max_pairs = max_pairs
87+
self.api_endpoint = api_endpoint
88+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
89+
90+
self._model = None
91+
self._tokenizer = None
92+
self._token_id_a: Optional[int] = None
93+
self._token_id_b: Optional[int] = None
94+
95+
if self.method == "prp":
96+
self._load_local_model()
97+
98+
# ------------------------------------------------------------------
99+
# Model loading
100+
# ------------------------------------------------------------------
101+
102+
def _load_local_model(self) -> None:
103+
if self.model_name is None:
104+
raise ValueError("model_name must be provided for method='prp'")
105+
print(f"Loading PRP model: {self.model_name}")
106+
self._model = AutoModelForCausalLM.from_pretrained(
107+
self.model_name,
108+
torch_dtype=torch.float16,
109+
device_map="auto",
110+
low_cpu_mem_usage=True,
111+
trust_remote_code=True,
112+
)
113+
self._model.eval()
114+
self._tokenizer = AutoTokenizer.from_pretrained(
115+
self.model_name, use_fast=True, trust_remote_code=True
116+
)
117+
if self._tokenizer.pad_token is None:
118+
self._tokenizer.pad_token = self._tokenizer.eos_token
119+
120+
# Resolve token IDs for "A" and "B" (first token of each)
121+
self._token_id_a = self._tokenizer.encode("A", add_special_tokens=False)[0]
122+
self._token_id_b = self._tokenizer.encode("B", add_special_tokens=False)[0]
123+
124+
# ------------------------------------------------------------------
125+
# Pairwise comparison primitives
126+
# ------------------------------------------------------------------
127+
128+
def _compare_local(self, query: str, doc_a: str, doc_b: str) -> bool:
129+
"""Return True if doc_a is preferred over doc_b (local model)."""
130+
prompt = self.PROMPT_TEMPLATE.format(query=query, doc_a=doc_a, doc_b=doc_b)
131+
inputs = self._tokenizer(
132+
prompt, return_tensors="pt", truncation=True, max_length=2048
133+
).to(self._model.device)
134+
with torch.no_grad():
135+
logits = self._model(**inputs).logits
136+
last_logits = logits[0, -1] # shape: (vocab_size,)
137+
score_a = last_logits[self._token_id_a].item()
138+
score_b = last_logits[self._token_id_b].item()
139+
return score_a >= score_b
140+
141+
def _compare_api(self, query: str, doc_a: str, doc_b: str) -> bool:
142+
"""Return True if doc_a is preferred over doc_b (API model)."""
143+
try:
144+
import openai as _openai
145+
except ImportError:
146+
raise ImportError(
147+
"openai package is required for method='prp-api'. "
148+
"Install with: pip install openai"
149+
)
150+
prompt = self.PROMPT_TEMPLATE.format(query=query, doc_a=doc_a, doc_b=doc_b)
151+
client = _openai.OpenAI(api_key=self.api_key, base_url=self.api_endpoint)
152+
response = client.chat.completions.create(
153+
model=self.model_name,
154+
messages=[{"role": "user", "content": prompt}],
155+
max_tokens=1,
156+
temperature=0.0,
157+
)
158+
answer = response.choices[0].message.content.strip().upper()
159+
return answer.startswith("A")
160+
161+
def _compare(self, query: str, doc_a: str, doc_b: str) -> bool:
162+
"""Dispatch to local or API comparison."""
163+
if self.method == "prp-api":
164+
return self._compare_api(query, doc_a, doc_b)
165+
return self._compare_local(query, doc_a, doc_b)
166+
167+
# ------------------------------------------------------------------
168+
# Ranking modes
169+
# ------------------------------------------------------------------
170+
171+
def _rank_allpairs(self, query: str, contexts: List[Context]) -> List[Context]:
172+
"""All-pairs tournament ranking."""
173+
n = len(contexts)
174+
wins = [0] * n
175+
pairs = [(i, j) for i in range(n) for j in range(n) if i != j]
176+
# Limit comparisons if max_pairs is set
177+
if self.max_pairs and len(pairs) > self.max_pairs:
178+
import random
179+
pairs = random.sample(pairs, self.max_pairs)
180+
for i, j in pairs:
181+
if self._compare(query, contexts[i].text, contexts[j].text):
182+
wins[i] += 1
183+
ranked = sorted(range(n), key=lambda k: wins[k], reverse=True)
184+
return [contexts[k] for k in ranked]
185+
186+
def _rank_bubblesort(self, query: str, contexts: List[Context]) -> List[Context]:
187+
"""Bubble-sort based ranking (O(N log N) comparisons)."""
188+
docs = list(contexts)
189+
n = len(docs)
190+
for i in range(n):
191+
for j in range(0, n - i - 1):
192+
if not self._compare(query, docs[j].text, docs[j + 1].text):
193+
docs[j], docs[j + 1] = docs[j + 1], docs[j]
194+
return docs
195+
196+
# ------------------------------------------------------------------
197+
# BaseRanking interface
198+
# ------------------------------------------------------------------
199+
200+
def rank(self, documents: List[Document]) -> List[Document]:
201+
"""
202+
Rerank documents using pairwise comparisons.
203+
204+
Args:
205+
documents (List[Document]): Documents to rerank.
206+
207+
Returns:
208+
List[Document]: Documents with populated ``reorder_contexts``.
209+
"""
210+
for document in tqdm(documents, desc="PRP reranking"):
211+
if not document.contexts:
212+
document.reorder_contexts = []
213+
continue
214+
ctx_copy = copy.deepcopy(document.contexts)
215+
query = document.question.question
216+
217+
if self.mode == "bubblesort":
218+
ranked = self._rank_bubblesort(query, ctx_copy)
219+
else:
220+
ranked = self._rank_allpairs(query, ctx_copy)
221+
222+
# Assign synthetic scores based on rank position
223+
for rank_pos, ctx in enumerate(ranked):
224+
ctx.score = float(len(ranked) - rank_pos)
225+
226+
document.reorder_contexts = ranked
227+
return documents
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
RankGemma: Gemma-based listwise passage reranker.
3+
4+
Extends RankGPT with Gemma-specific model defaults and configuration.
5+
Uses the same sliding-window permutation ranking approach as RankGPT.
6+
7+
References:
8+
- RankGPT: Sun et al. (2023): "Is ChatGPT Good at Search?"
9+
https://arxiv.org/abs/2304.09542
10+
- Gemma: https://huggingface.co/google/gemma-2-2b-it
11+
"""
12+
13+
from typing import Optional
14+
15+
from rankify.models.rankgpt import RankGPT
16+
17+
18+
class RankGemmaReranker(RankGPT):
19+
"""
20+
Gemma-based listwise passage reranker.
21+
22+
A thin wrapper around :class:`~rankify.models.rankgpt.RankGPT` that
23+
provides Gemma-specific model aliases and sane defaults.
24+
25+
Supported ``model_name`` aliases:
26+
27+
* ``'gemma-2-2b'`` → ``google/gemma-2-2b-it``
28+
* ``'gemma-2-9b'`` → ``google/gemma-2-9b-it``
29+
* ``'gemma-2-27b'`` → ``google/gemma-2-27b-it``
30+
31+
Any raw HuggingFace model ID is also accepted (e.g.
32+
``'google/gemma-2-2b-it'``).
33+
34+
Args:
35+
method (str, optional): Reranking method name.
36+
model_name (str, optional): Model alias or full HuggingFace ID.
37+
api_key (str, optional): Unused for local models; kept for interface
38+
compatibility.
39+
**kwargs: Additional keyword arguments forwarded to RankGPT
40+
(e.g. ``window_size``, ``step``).
41+
42+
Example:
43+
```python
44+
from rankify.models.reranking import Reranking
45+
46+
model = Reranking(method='rankgemma', model_name='gemma-2-2b')
47+
model.rank(documents)
48+
```
49+
"""
50+
51+
DEFAULT_MODELS = {
52+
"gemma-2-2b": "google/gemma-2-2b-it",
53+
"gemma-2-9b": "google/gemma-2-9b-it",
54+
"gemma-2-27b": "google/gemma-2-27b-it",
55+
}
56+
57+
def __init__(
58+
self,
59+
method: Optional[str] = None,
60+
model_name: Optional[str] = None,
61+
api_key: Optional[str] = None,
62+
**kwargs,
63+
):
64+
resolved = self.DEFAULT_MODELS.get(model_name, model_name)
65+
super().__init__(
66+
method=method or "rankgemma",
67+
model_name=resolved,
68+
api_key=api_key,
69+
**kwargs,
70+
)

0 commit comments

Comments
 (0)