Skip to content

Commit a4d56f3

Browse files
committed
feat: ReRanker endpoint using local models.
1 parent 1d6ef08 commit a4d56f3

6 files changed

Lines changed: 94 additions & 1 deletion

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ The following environment variables are required to run the application:
121121
- `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings
122122
- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html)
123123
- `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models.
124+
- `SIMPLE_RERANKER_MODEL_NAME` (Optional) defaults to `ms-marco-MiniLM-L-12-v2`, more options at (https://github.com/AnswerDotAI/rerankers)
125+
- `SIMPLE_RERANKER_MODEL_TYPE` (Optional) defaults to `flashrank`, more options at (https://github.com/AnswerDotAI/rerankers)
126+
124127

125128
Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.
126129

app/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,9 @@ class QueryMultipleBody(BaseModel):
4242
query: str
4343
file_ids: List[str]
4444
k: int = 4
45+
46+
47+
class QueryMultipleDocs(BaseModel):
48+
query: str
49+
docs: List[str]
50+
k: int = 4

app/routes/document_routes.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from shutil import copyfileobj
88
from typing import List, Iterable, TYPE_CHECKING
99
from concurrent.futures import ThreadPoolExecutor
10+
from rerankers import Reranker, Document as ReRankDocument
1011
from fastapi import (
1112
APIRouter,
1213
Request,
@@ -43,6 +44,7 @@
4344
QueryRequestBody,
4445
DocumentResponse,
4546
QueryMultipleBody,
47+
QueryMultipleDocs,
4648
)
4749
from app.services.vector_store.async_pg_vector import AsyncPgVector
4850
from app.utils.document_loader import (
@@ -54,7 +56,10 @@
5456
from app.utils.health import is_health_ok
5557

5658
router = APIRouter()
57-
59+
reranker_instance = Reranker(
60+
model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "ms-marco-MiniLM-L-12-v2"),
61+
model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "flashrank"),
62+
)
5863

5964
def calculate_num_batches(total: int, batch_size: int) -> int:
6065
"""Calculate the number of batches needed to process total items."""
@@ -1002,6 +1007,43 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
10021007
)
10031008
raise HTTPException(status_code=500, detail=str(e))
10041009

1010+
@router.post("/rerank")
1011+
async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
1012+
"""
1013+
Rerank documents based on relevance to a query using a reranking model.
1014+
1015+
Args:
1016+
request: The FastAPI request object
1017+
body: Contains query string, list of documents, and optional k value
1018+
1019+
Returns:
1020+
List of ranked documents with their scores
1021+
"""
1022+
1023+
try:
1024+
if not body.docs:
1025+
raise HTTPException(status_code=400, detail="docs list cannot be empty")
1026+
docs = []
1027+
for i, d in enumerate(body.docs):
1028+
docs.append(ReRankDocument(text=d, doc_id=i))
1029+
1030+
top_k = body.k
1031+
1032+
results = reranker_instance.rank(query=body.query, docs=docs)
1033+
items = results.top_k(top_k) if top_k else results
1034+
1035+
return [
1036+
{"text": getattr(r.document, "text", None), "score": r.score} for r in items
1037+
]
1038+
except Exception as e:
1039+
logger.error(
1040+
"Error in reranking documents | Query: %s | Error: %s | Traceback: %s",
1041+
body.query,
1042+
str(e),
1043+
traceback.format_exc(),
1044+
)
1045+
raise HTTPException(status_code=500, detail=str(e))
1046+
10051047

10061048
@router.post("/text")
10071049
async def extract_text_from_file(

docker-compose.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ services:
1717
- DB_PORT=5432
1818
ports:
1919
- "8000:8000"
20+
runtime: ${DOCKER_RUNTIME:-runc}
2021
volumes:
2122
- ./uploads:/app/uploads
23+
- ~/.cache/huggingface:/root/.cache/huggingface:rw
2224
depends_on:
2325
- db
2426
env_file:

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,7 @@ python-magic==0.4.27
3737
python-pptx==1.0.2
3838
xlrd==2.0.2
3939
pydantic==2.9.2
40+
rerankers[transformers]==0.6.0
41+
rerankers[flashrank]==0.6.0
4042
chardet==5.2.0
4143
tenacity>=9.0.0

tests/test_main.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,41 @@ def test_extract_text_from_file(tmp_path, auth_headers):
275275
assert json_data["file_id"] == "test_text_123"
276276
assert json_data["filename"] == "test_text_extraction.txt"
277277
assert json_data["known_type"] is True # text files are known types
278+
279+
def test_query_rerank(auth_headers):
280+
# Successful reranking with string documents
281+
data = {
282+
"query": "I love you",
283+
"docs": ["I hate you", "I really like you"],
284+
"k": 1
285+
}
286+
response = client.post("/rerank", json=data, headers=auth_headers)
287+
assert response.status_code == 200
288+
json_data = response.json()
289+
assert isinstance(json_data, list)
290+
assert len(json_data) == 1
291+
doc = json_data[0]
292+
assert doc["text"] == "I really like you"
293+
294+
# Handling of the k parameter (top_k filtering)
295+
data = {
296+
"query": "I love you",
297+
"docs": ["I hate you", "I really like you", "I love you too"],
298+
"k": 2
299+
}
300+
response = client.post("/rerank", json=data, headers=auth_headers)
301+
assert response.status_code == 200
302+
json_data = response.json()
303+
assert isinstance(json_data, list)
304+
assert len(json_data) == 2
305+
assert json_data[0]["text"] == "I really like you"
306+
assert json_data[1]["text"] == "I love you too"
307+
308+
# Error handling for invalid inputs
309+
data = {
310+
"query": "I love you",
311+
"docs": [123, 456],
312+
"k": 1
313+
}
314+
response = client.post("/rerank", json=data, headers=auth_headers)
315+
assert response.status_code == 422

0 commit comments

Comments
 (0)