|
6 | 6 | import aiofiles.os |
7 | 7 | from shutil import copyfileobj |
8 | 8 | from typing import List, Iterable |
| 9 | +from rerankers import Reranker, Document as ReRankDocument |
9 | 10 | from fastapi import ( |
10 | 11 | APIRouter, |
11 | 12 | Request, |
|
29 | 30 | QueryRequestBody, |
30 | 31 | DocumentResponse, |
31 | 32 | QueryMultipleBody, |
| 33 | + QueryMultipleDocs, |
32 | 34 | ) |
33 | 35 | from app.services.vector_store.async_pg_vector import AsyncPgVector |
34 | 36 | from app.utils.document_loader import ( |
|
40 | 42 | from app.utils.health import is_health_ok |
41 | 43 |
|
42 | 44 | router = APIRouter() |
43 | | - |
| 45 | +reranker_instance = Reranker( |
| 46 | + model_name=os.getenv("SIMPLE_RERANKER_MODEL_NAME", "mixedbread-ai/mxbai-rerank-large-v1"), |
| 47 | + model_type=os.getenv("SIMPLE_RERANKER_MODEL_TYPE", "cross-encoder"), |
| 48 | +) |
44 | 49 |
|
45 | 50 | def get_user_id(request: Request, entity_id: str = None) -> str: |
46 | 51 | """Extract user ID from request or entity_id.""" |
@@ -702,6 +707,43 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody |
702 | 707 | ) |
703 | 708 | raise HTTPException(status_code=500, detail=str(e)) |
704 | 709 |
|
| 710 | +@router.post("/rerank") |
| 711 | +async def rerank_documents_by_query(request: Request, body: QueryMultipleDocs): |
| 712 | + """ |
| 713 | + Rerank documents based on relevance to a query using a reranking model. |
| 714 | +
|
| 715 | + Args: |
| 716 | + request: The FastAPI request object |
| 717 | + body: Contains query string, list of documents, and optional k value |
| 718 | +
|
| 719 | + Returns: |
| 720 | + List of ranked documents with their scores |
| 721 | + """ |
| 722 | + |
| 723 | + try: |
| 724 | + if not body.docs: |
| 725 | + raise HTTPException(status_code=400, detail="docs list cannot be empty") |
| 726 | + docs = [] |
| 727 | + for i, d in enumerate(body.docs): |
| 728 | + docs.append(ReRankDocument(text=d, doc_id=i)) |
| 729 | + |
| 730 | + top_k = body.k |
| 731 | + |
| 732 | + results = reranker_instance.rank(query=body.query, docs=docs) |
| 733 | + items = results.top_k(top_k) if top_k else results |
| 734 | + |
| 735 | + return [ |
| 736 | + {"text": getattr(r.document, "text", None), "score": r.score} for r in items |
| 737 | + ] |
| 738 | + except Exception as e: |
| 739 | + logger.error( |
| 740 | + "Error in reranking documents | Query: %s | Error: %s | Traceback: %s", |
| 741 | + body.query, |
| 742 | + str(e), |
| 743 | + traceback.format_exc(), |
| 744 | + ) |
| 745 | + raise HTTPException(status_code=500, detail=str(e)) |
| 746 | + |
705 | 747 |
|
706 | 748 | @router.post("/text") |
707 | 749 | async def extract_text_from_file( |
|
0 commit comments