|
7 | 7 | from shutil import copyfileobj |
8 | 8 | from typing import List, Iterable, TYPE_CHECKING |
9 | 9 | from concurrent.futures import ThreadPoolExecutor |
| 10 | +from rerankers import Reranker, Document as ReRankDocument |
10 | 11 | from fastapi import ( |
11 | 12 | APIRouter, |
12 | 13 | Request, |
|
43 | 44 | QueryRequestBody, |
44 | 45 | DocumentResponse, |
45 | 46 | QueryMultipleBody, |
| 47 | + QueryMultipleDocs, |
46 | 48 | ) |
47 | 49 | from app.services.vector_store.async_pg_vector import AsyncPgVector |
48 | 50 | from app.utils.document_loader import ( |
|
54 | 56 | from app.utils.health import is_health_ok |
55 | 57 |
|
56 | 58 | router = APIRouter() |
57 | | - |
| 59 | +reranker_instance = Reranker( |
| 60 | + model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "ms-marco-MiniLM-L-12-v2"), |
| 61 | + model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "flashrank"), |
| 62 | +) |
58 | 63 |
|
59 | 64 | def calculate_num_batches(total: int, batch_size: int) -> int: |
60 | 65 | """Calculate the number of batches needed to process total items.""" |
@@ -1002,6 +1007,43 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody |
1002 | 1007 | ) |
1003 | 1008 | raise HTTPException(status_code=500, detail=str(e)) |
1004 | 1009 |
|
| 1010 | +@router.post("/rerank") |
| 1011 | +async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs): |
| 1012 | + """ |
| 1013 | + Rerank documents based on relevance to a query using a reranking model. |
| 1014 | +
|
| 1015 | + Args: |
| 1016 | + request: The FastAPI request object |
| 1017 | + body: Contains query string, list of documents, and optional k value |
| 1018 | +
|
| 1019 | + Returns: |
| 1020 | + List of ranked documents with their scores |
| 1021 | + """ |
| 1022 | + |
| 1023 | + try: |
| 1024 | + if not body.docs: |
| 1025 | + raise HTTPException(status_code=400, detail="docs list cannot be empty") |
| 1026 | + docs = [] |
| 1027 | + for i, d in enumerate(body.docs): |
| 1028 | + docs.append(ReRankDocument(text=d, doc_id=i)) |
| 1029 | + |
| 1030 | + top_k = body.k |
| 1031 | + |
| 1032 | + results = reranker_instance.rank(query=body.query, docs=docs) |
| 1033 | + items = results.top_k(top_k) if top_k else results |
| 1034 | + |
| 1035 | + return [ |
| 1036 | + {"text": getattr(r.document, "text", None), "score": r.score} for r in items |
| 1037 | + ] |
| 1038 | + except Exception as e: |
| 1039 | + logger.error( |
| 1040 | + "Error in reranking documents | Query: %s | Error: %s | Traceback: %s", |
| 1041 | + body.query, |
| 1042 | + str(e), |
| 1043 | + traceback.format_exc(), |
| 1044 | + ) |
| 1045 | + raise HTTPException(status_code=500, detail=str(e)) |
| 1046 | + |
1005 | 1047 |
|
1006 | 1048 | @router.post("/text") |
1007 | 1049 | async def extract_text_from_file( |
|
0 commit comments