-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag_server.py
More file actions
77 lines (61 loc) · 2.24 KB
/
rag_server.py
File metadata and controls
77 lines (61 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import chromadb
from sentence_transformers import SentenceTransformer
import warnings
import uvicorn
warnings.filterwarnings("ignore")
app = FastAPI(title="SysReptor RAG Retrieval Server")
DB_DIR = "./chroma_db"
COLLECTION_NAME = "sysreptor_rag"
# Global states
client = None
collection = None
model = None
@app.on_event("startup")
async def startup_event():
global client, collection, model
print("[*] Initializing RAG components...")
client = chromadb.PersistentClient(path=DB_DIR)
try:
collection = client.get_collection(name=COLLECTION_NAME)
except Exception as e:
print(f"[-] Error loading collection: {e}. Make sure to run ingest_data.py first.")
model = SentenceTransformer('all-MiniLM-L6-v2')
print("[+] RAG components initialized and ready.")
class QueryRequest(BaseModel):
query: str
num_results: int = 5
filter_type: str = None # e.g., 'finding', 'executive_summary', 'technical_summary'
@app.post("/query")
async def query_rag(request: QueryRequest):
if collection is None:
raise HTTPException(status_code=500, detail="Chroma DB collection not initialized.")
if not request.query.strip():
raise HTTPException(status_code=400, detail="Query cannot be empty.")
# Convert query to embedding
query_embedding = model.encode([request.query]).tolist()
# Apply metadata filter if requested
where_clause = None
if request.filter_type:
where_clause = {"type": request.filter_type}
# Query Chroma
results = collection.query(
query_embeddings=query_embedding,
n_results=request.num_results,
where=where_clause
)
# Format output
documents = results.get("documents", [[]])[0]
metadatas = results.get("metadatas", [[]])[0]
distances = results.get("distances", [[]])[0]
response_items = []
for doc, meta, dist in zip(documents, metadatas, distances):
response_items.append({
"content": doc,
"metadata": meta,
"distance": dist # Lower distance = more similar
})
return {"results": response_items}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)