|
| 1 | +""" |
| 2 | +Pairwise Ranking Prompting (PRP) reranker. |
| 3 | +
|
| 4 | +References: |
| 5 | + - Qin et al. (2023): "Large Language Models are Effective Text Rankers |
| 6 | + with Pairwise Ranking Prompting" |
| 7 | + https://arxiv.org/abs/2306.17563 |
| 8 | +""" |
| 9 | + |
| 10 | +import copy |
| 11 | +from typing import List, Optional |
| 12 | + |
| 13 | +import torch |
| 14 | +from tqdm import tqdm |
| 15 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 16 | + |
| 17 | +from rankify.models.base import BaseRanking |
| 18 | +from rankify.dataset.dataset import Document, Context |
| 19 | + |
| 20 | + |
| 21 | +class PRPReranker(BaseRanking): |
| 22 | + """ |
| 23 | + Pairwise Ranking Prompting (PRP) reranker. |
| 24 | +
|
| 25 | + For each query, compares document pairs by prompting an LLM to decide |
| 26 | + which of two passages (A or B) is more relevant. The final ranking is |
| 27 | + determined by the number of pairwise wins each document accumulates. |
| 28 | +
|
| 29 | + Two comparison modes are available: |
| 30 | +
|
| 31 | + * ``'allpairs'`` – compare every ordered pair (O(N²) LLM calls). |
| 32 | + * ``'bubblesort'`` – O(N²) comparisons in the worst case; faster for large |
| 33 | + sets when documents are partially ordered. |
| 34 | +
|
| 35 | + For **local** models (``method='prp'``) the winner is decided by |
| 36 | + comparing the next-token logit probability of the tokens ``"A"`` and |
| 37 | + ``"B"``. For **API** models (``method='prp-api'``) the response text |
| 38 | + is parsed for the letter ``"A"`` or ``"B"``. |
| 39 | +
|
| 40 | + References: |
| 41 | + - Qin et al. (2023): https://arxiv.org/abs/2306.17563 |
| 42 | +
|
| 43 | + Args: |
| 44 | + method (str): ``'prp'`` for local HuggingFace models, |
| 45 | + ``'prp-api'`` for API-based LLMs. |
| 46 | + model_name (str): HuggingFace model ID (local) or API model name. |
| 47 | + api_key (str, optional): API key when ``method='prp-api'``. |
| 48 | + mode (str): ``'allpairs'`` (default) or ``'bubblesort'``. |
| 49 | + max_pairs (int): Maximum pairwise comparisons per query (allpairs only). |
| 50 | + Defaults to 100. |
| 51 | + api_endpoint (str, optional): OpenAI-compatible API base URL. |
| 52 | + device (str, optional): ``'cpu'`` or ``'cuda'``. Auto-detected when |
| 53 | + not supplied. |
| 54 | +
|
| 55 | + Example: |
| 56 | + ```python |
| 57 | + from rankify.models.reranking import Reranking |
| 58 | +
|
| 59 | + model = Reranking(method='prp', model_name='llamav3.1-8b') |
| 60 | + model.rank(documents) |
| 61 | + ``` |
| 62 | + """ |
| 63 | + |
| 64 | + PROMPT_TEMPLATE = ( |
| 65 | + "Given the following query: {query}\n\n" |
| 66 | + "Document A: {doc_a}\n\n" |
| 67 | + "Document B: {doc_b}\n\n" |
| 68 | + "Which document is more relevant to the query? Answer with 'A' or 'B'." |
| 69 | + ) |
| 70 | + |
| 71 | + def __init__( |
| 72 | + self, |
| 73 | + method: Optional[str] = None, |
| 74 | + model_name: Optional[str] = None, |
| 75 | + api_key: Optional[str] = None, |
| 76 | + mode: str = "allpairs", |
| 77 | + max_pairs: int = 100, |
| 78 | + api_endpoint: str = "https://api.openai.com/v1", |
| 79 | + device: Optional[str] = None, |
| 80 | + **kwargs, |
| 81 | + ): |
| 82 | + self.method = method or "prp" |
| 83 | + self.model_name = model_name |
| 84 | + self.api_key = api_key |
| 85 | + self.mode = mode |
| 86 | + self.max_pairs = max_pairs |
| 87 | + self.api_endpoint = api_endpoint |
| 88 | + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| 89 | + |
| 90 | + self._model = None |
| 91 | + self._tokenizer = None |
| 92 | + self._token_id_a: Optional[int] = None |
| 93 | + self._token_id_b: Optional[int] = None |
| 94 | + |
| 95 | + if self.method == "prp": |
| 96 | + self._load_local_model() |
| 97 | + |
| 98 | + # ------------------------------------------------------------------ |
| 99 | + # Model loading |
| 100 | + # ------------------------------------------------------------------ |
| 101 | + |
| 102 | + def _load_local_model(self) -> None: |
| 103 | + if self.model_name is None: |
| 104 | + raise ValueError("model_name must be provided for method='prp'") |
| 105 | + print(f"Loading PRP model: {self.model_name}") |
| 106 | + self._model = AutoModelForCausalLM.from_pretrained( |
| 107 | + self.model_name, |
| 108 | + torch_dtype=torch.float16, |
| 109 | + device_map="auto", |
| 110 | + low_cpu_mem_usage=True, |
| 111 | + trust_remote_code=True, |
| 112 | + ) |
| 113 | + self._model.eval() |
| 114 | + self._tokenizer = AutoTokenizer.from_pretrained( |
| 115 | + self.model_name, use_fast=True, trust_remote_code=True |
| 116 | + ) |
| 117 | + if self._tokenizer.pad_token is None: |
| 118 | + self._tokenizer.pad_token = self._tokenizer.eos_token |
| 119 | + |
| 120 | + # Resolve token IDs for "A" and "B" (first token of each) |
| 121 | + self._token_id_a = self._tokenizer.encode("A", add_special_tokens=False)[0] |
| 122 | + self._token_id_b = self._tokenizer.encode("B", add_special_tokens=False)[0] |
| 123 | + |
| 124 | + # ------------------------------------------------------------------ |
| 125 | + # Pairwise comparison primitives |
| 126 | + # ------------------------------------------------------------------ |
| 127 | + |
| 128 | + def _compare_local(self, query: str, doc_a: str, doc_b: str) -> bool: |
| 129 | + """Return True if doc_a is preferred over doc_b (local model).""" |
| 130 | + prompt = self.PROMPT_TEMPLATE.format(query=query, doc_a=doc_a, doc_b=doc_b) |
| 131 | + inputs = self._tokenizer( |
| 132 | + prompt, return_tensors="pt", truncation=True, max_length=2048 |
| 133 | + ).to(self._model.device) |
| 134 | + with torch.no_grad(): |
| 135 | + logits = self._model(**inputs).logits |
| 136 | + last_logits = logits[0, -1] # shape: (vocab_size,) |
| 137 | + score_a = last_logits[self._token_id_a].item() |
| 138 | + score_b = last_logits[self._token_id_b].item() |
| 139 | + return score_a >= score_b |
| 140 | + |
| 141 | + def _compare_api(self, query: str, doc_a: str, doc_b: str) -> bool: |
| 142 | + """Return True if doc_a is preferred over doc_b (API model).""" |
| 143 | + try: |
| 144 | + import openai as _openai |
| 145 | + except ImportError: |
| 146 | + raise ImportError( |
| 147 | + "openai package is required for method='prp-api'. " |
| 148 | + "Install with: pip install openai" |
| 149 | + ) |
| 150 | + prompt = self.PROMPT_TEMPLATE.format(query=query, doc_a=doc_a, doc_b=doc_b) |
| 151 | + client = _openai.OpenAI(api_key=self.api_key, base_url=self.api_endpoint) |
| 152 | + response = client.chat.completions.create( |
| 153 | + model=self.model_name, |
| 154 | + messages=[{"role": "user", "content": prompt}], |
| 155 | + max_tokens=1, |
| 156 | + temperature=0.0, |
| 157 | + ) |
| 158 | + answer = response.choices[0].message.content.strip().upper() |
| 159 | + return answer.startswith("A") |
| 160 | + |
| 161 | + def _compare(self, query: str, doc_a: str, doc_b: str) -> bool: |
| 162 | + """Dispatch to local or API comparison.""" |
| 163 | + if self.method == "prp-api": |
| 164 | + return self._compare_api(query, doc_a, doc_b) |
| 165 | + return self._compare_local(query, doc_a, doc_b) |
| 166 | + |
| 167 | + # ------------------------------------------------------------------ |
| 168 | + # Ranking modes |
| 169 | + # ------------------------------------------------------------------ |
| 170 | + |
| 171 | + def _rank_allpairs(self, query: str, contexts: List[Context]) -> List[Context]: |
| 172 | + """All-pairs tournament ranking.""" |
| 173 | + n = len(contexts) |
| 174 | + wins = [0] * n |
| 175 | + pairs = [(i, j) for i in range(n) for j in range(n) if i != j] |
| 176 | + # Limit comparisons if max_pairs is set |
| 177 | + if self.max_pairs and len(pairs) > self.max_pairs: |
| 178 | + import random |
| 179 | + pairs = random.sample(pairs, self.max_pairs) |
| 180 | + for i, j in pairs: |
| 181 | + if self._compare(query, contexts[i].text, contexts[j].text): |
| 182 | + wins[i] += 1 |
| 183 | + ranked = sorted(range(n), key=lambda k: wins[k], reverse=True) |
| 184 | + return [contexts[k] for k in ranked] |
| 185 | + |
| 186 | + def _rank_bubblesort(self, query: str, contexts: List[Context]) -> List[Context]: |
| 187 | + """Bubble-sort based ranking (O(N log N) comparisons).""" |
| 188 | + docs = list(contexts) |
| 189 | + n = len(docs) |
| 190 | + for i in range(n): |
| 191 | + for j in range(0, n - i - 1): |
| 192 | + if not self._compare(query, docs[j].text, docs[j + 1].text): |
| 193 | + docs[j], docs[j + 1] = docs[j + 1], docs[j] |
| 194 | + return docs |
| 195 | + |
| 196 | + # ------------------------------------------------------------------ |
| 197 | + # BaseRanking interface |
| 198 | + # ------------------------------------------------------------------ |
| 199 | + |
| 200 | + def rank(self, documents: List[Document]) -> List[Document]: |
| 201 | + """ |
| 202 | + Rerank documents using pairwise comparisons. |
| 203 | +
|
| 204 | + Args: |
| 205 | + documents (List[Document]): Documents to rerank. |
| 206 | +
|
| 207 | + Returns: |
| 208 | + List[Document]: Documents with populated ``reorder_contexts``. |
| 209 | + """ |
| 210 | + for document in tqdm(documents, desc="PRP reranking"): |
| 211 | + if not document.contexts: |
| 212 | + document.reorder_contexts = [] |
| 213 | + continue |
| 214 | + ctx_copy = copy.deepcopy(document.contexts) |
| 215 | + query = document.question.question |
| 216 | + |
| 217 | + if self.mode == "bubblesort": |
| 218 | + ranked = self._rank_bubblesort(query, ctx_copy) |
| 219 | + else: |
| 220 | + ranked = self._rank_allpairs(query, ctx_copy) |
| 221 | + |
| 222 | + # Assign synthetic scores based on rank position |
| 223 | + for rank_pos, ctx in enumerate(ranked): |
| 224 | + ctx.score = float(len(ranked) - rank_pos) |
| 225 | + |
| 226 | + document.reorder_contexts = ranked |
| 227 | + return documents |
0 commit comments