Skip to content

Commit 4f2f3cb

Browse files
committed
feat: ReRanker endpoint using local models.
1 parent 65c64ed commit 4f2f3cb

6 files changed

Lines changed: 93 additions & 1 deletion

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ The following environment variables are required to run the application:
9292
- `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings
9393
- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html)
9494
- `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models.
95+
- `SIMPLE_RERANKER_MODEL_NAME` (Optional) defaults to `mixedbread-ai/mxbai-rerank-large-v1`, more options at (https://github.com/AnswerDotAI/rerankers)
96+
- `SIMPLE_RERANKER_MODEL_TYPE` (Optional) defaults to `cross-encoder`, more options at (https://github.com/AnswerDotAI/rerankers)
9597

9698
Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.
9799

app/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,9 @@ class QueryMultipleBody(BaseModel):
4242
query: str
4343
file_ids: List[str]
4444
k: int = 4
45+
46+
47+
class QueryMultipleDocs(BaseModel):
48+
query: str
49+
docs: List[str]
50+
k: int = 4

app/routes/document_routes.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import aiofiles.os
77
from shutil import copyfileobj
88
from typing import List, Iterable
9+
from rerankers import Reranker, Document as ReRankDocument
910
from fastapi import (
1011
APIRouter,
1112
Request,
@@ -29,6 +30,7 @@
2930
QueryRequestBody,
3031
DocumentResponse,
3132
QueryMultipleBody,
33+
QueryMultipleDocs,
3234
)
3335
from app.services.vector_store.async_pg_vector import AsyncPgVector
3436
from app.utils.document_loader import (
@@ -40,7 +42,10 @@
4042
from app.utils.health import is_health_ok
4143

4244
router = APIRouter()
43-
45+
reranker_instance = Reranker(
46+
model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "mixedbread-ai/mxbai-rerank-large-v1"),
47+
model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "cross-encoder"),
48+
)
4449

4550
def get_user_id(request: Request, entity_id: str = None) -> str:
4651
"""Extract user ID from request or entity_id."""
@@ -702,6 +707,43 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
702707
)
703708
raise HTTPException(status_code=500, detail=str(e))
704709

710+
@router.post("/rerank")
711+
async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs):
712+
"""
713+
Rerank documents based on relevance to a query using a reranking model.
714+
715+
Args:
716+
request: The FastAPI request object
717+
body: Contains query string, list of documents, and optional k value
718+
719+
Returns:
720+
List of ranked documents with their scores
721+
"""
722+
723+
try:
724+
if not body.docs:
725+
raise HTTPException(status_code=400, detail="docs list cannot be empty")
726+
docs = []
727+
for i, d in enumerate(body.docs):
728+
docs.append(ReRankDocument(text=d, doc_id=i))
729+
730+
top_k = body.k
731+
732+
results = reranker_instance.rank(query=body.query, docs=docs)
733+
items = results.top_k(top_k) if top_k else results
734+
735+
return [
736+
{"text": getattr(r.document, "text", None), "score": r.score} for r in items
737+
]
738+
except Exception as e:
739+
logger.error(
740+
"Error in reranking documents | Query: %s | Error: %s | Traceback: %s",
741+
body.query,
742+
str(e),
743+
traceback.format_exc(),
744+
)
745+
raise HTTPException(status_code=500, detail=str(e))
746+
705747

706748
@router.post("/text")
707749
async def extract_text_from_file(

docker-compose.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ services:
1717
- DB_PORT=5432
1818
ports:
1919
- "8000:8000"
20+
runtime: ${DOCKER_RUNTIME:-runc}
2021
volumes:
2122
- ./uploads:/app/uploads
23+
- ~/.cache/huggingface:/root/.cache/huggingface:rw
2224
depends_on:
2325
- db
2426
env_file:

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,6 @@ python-magic==0.4.27
3737
python-pptx==1.0.2
3838
xlrd==2.0.2
3939
pydantic==2.9.2
40+
rerankers[transformers]==0.6.0
41+
rerankers[flashrank]==0.6.0
4042
chardet==5.2.0

tests/test_main.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,41 @@ def test_extract_text_from_file(tmp_path, auth_headers):
263263
assert json_data["file_id"] == "test_text_123"
264264
assert json_data["filename"] == "test_text_extraction.txt"
265265
assert json_data["known_type"] is True # text files are known types
266+
267+
def test_query_rerank(auth_headers):
268+
# Successful reranking with string documents
269+
data = {
270+
"query": "I love you",
271+
"docs": ["I hate you", "I really like you"],
272+
"k": 1
273+
}
274+
response = client.post("/rerank", json=data, headers=auth_headers)
275+
assert response.status_code == 200
276+
json_data = response.json()
277+
assert isinstance(json_data, list)
278+
assert len(json_data) == 1
279+
doc = json_data[0]
280+
assert doc["text"] == "I really like you"
281+
282+
# Handling of the k parameter (top_k filtering)
283+
data = {
284+
"query": "I love you",
285+
"docs": ["I hate you", "I really like you", "I love you too"],
286+
"k": 2
287+
}
288+
response = client.post("/rerank", json=data, headers=auth_headers)
289+
assert response.status_code == 200
290+
json_data = response.json()
291+
assert isinstance(json_data, list)
292+
assert len(json_data) == 2
293+
assert json_data[0]["text"] == "I really like you"
294+
assert json_data[1]["text"] == "I love you too"
295+
296+
# Error handling for invalid inputs
297+
data = {
298+
"query": "I love you",
299+
"docs": [123, 456],
300+
"k": 1
301+
}
302+
response = client.post("/rerank", json=data, headers=auth_headers)
303+
assert response.status_code == 422

0 commit comments

Comments
 (0)