Skip to content

Commit 0b1a50a

Browse files
rohan-pandeyyrahulharpal1603
authored andcommitted
feat(backend): add model management API and hardware detection
1 parent 894daca commit 0b1a50a

6 files changed

Lines changed: 571 additions & 18 deletions

File tree

backend/app/routes/models.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import os
2+
import json
3+
import uuid
4+
import asyncio
5+
from typing import Dict, Optional
6+
from dataclasses import dataclass, field
7+
from datetime import datetime, timezone
8+
from fastapi import APIRouter, HTTPException, status
9+
from pydantic import BaseModel
10+
from fastapi.responses import StreamingResponse
11+
from app.models.model_registry import MODEL_REGISTRY, TIER_MODELS, get_model_path
12+
from app.utils.hardware_detect import get_hardware_info
13+
from app.utils.model_downloader import ensure_model
14+
import logging
15+
16+
logger = logging.getLogger(__name__)
17+
18+
router = APIRouter()
19+
20+
REQUIRED_MODELS = ["facenet"]
21+
22+
# Global dict to track download tasks
23+
@dataclass
24+
class DownloadTaskEntry:
25+
queue: asyncio.Queue
26+
task: asyncio.Task
27+
created_at: datetime = field(
28+
default_factory=lambda: datetime.now(timezone.utc)
29+
)
30+
31+
download_tasks: Dict[str, DownloadTaskEntry] = {}
32+
33+
async def _cleanup_stale_tasks(max_age_minutes: int = 10):
34+
while True:
35+
await asyncio.sleep(300) # run every 5 minutes
36+
now = datetime.now(timezone.utc)
37+
stale = [
38+
tid for tid, entry in download_tasks.items()
39+
if (now - entry.created_at).total_seconds() > max_age_minutes * 60
40+
]
41+
for tid in stale:
42+
entry = download_tasks.pop(tid, None)
43+
if entry and not entry.task.done():
44+
entry.task.cancel()
45+
46+
class SetupRequest(BaseModel):
47+
tier: str
48+
49+
@router.get("/status")
50+
def get_model_status():
51+
"""
52+
Returns the installation status of all models in the registry.
53+
"""
54+
status_dict = {}
55+
for key, spec in MODEL_REGISTRY.items():
56+
path = get_model_path(key)
57+
is_installed = os.path.exists(path)
58+
status_dict[key] = {
59+
"name": spec["filename"],
60+
"installed": is_installed,
61+
"feature": spec["feature"],
62+
"tier": spec["tier"],
63+
"size_mb": spec["size_mb"]
64+
}
65+
return {
66+
"success": True,
67+
"data": status_dict
68+
}
69+
70+
@router.get("/hardware")
71+
def get_hardware_recommendation():
72+
"""
73+
Returns hardware specs and the recommended model tier.
74+
"""
75+
try:
76+
hw_info = get_hardware_info()
77+
return {
78+
"success": True,
79+
"data": hw_info
80+
}
81+
except Exception as e:
82+
logger.error(f"Failed to get hardware info: {e}")
83+
raise HTTPException(
84+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
85+
detail=f"Failed to detect hardware: {str(e)}"
86+
)
87+
88+
@router.delete("/{model_key}")
89+
def delete_model(model_key: str):
90+
"""
91+
Deletes a specific model from disk.
92+
"""
93+
if model_key not in MODEL_REGISTRY:
94+
raise HTTPException(
95+
status_code=status.HTTP_404_NOT_FOUND,
96+
detail=f"Model key '{model_key}' not found in registry."
97+
)
98+
99+
path = get_model_path(model_key)
100+
if os.path.exists(path):
101+
try:
102+
os.remove(path)
103+
return {"success": True, "message": f"Model {model_key} deleted successfully."}
104+
except Exception as e:
105+
logger.error(f"Failed to delete model {model_key}: {e}")
106+
raise HTTPException(
107+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
108+
detail=f"Failed to delete model: {str(e)}"
109+
)
110+
else:
111+
return {"success": True, "message": f"Model {model_key} already not present."}
112+
113+
114+
@router.post("/setup")
115+
async def setup_models(request: SetupRequest):
116+
"""
117+
Initializes setup by starting downloads for a specific tier + required models.
118+
Returns a single task_id to track overall progress.
119+
"""
120+
if request.tier not in TIER_MODELS:
121+
raise HTTPException(
122+
status_code=status.HTTP_400_BAD_REQUEST,
123+
detail=f"Invalid tier '{request.tier}'. Valid tiers are: {list(TIER_MODELS.keys())}"
124+
)
125+
126+
models_to_download = (
127+
TIER_MODELS[request.tier]
128+
if request.tier == "required"
129+
else TIER_MODELS[request.tier] + REQUIRED_MODELS
130+
)
131+
132+
task_id = str(uuid.uuid4())
133+
queue = asyncio.Queue()
134+
135+
async def background_setup():
136+
try:
137+
total_models = len(models_to_download)
138+
for idx, model_key in enumerate(models_to_download):
139+
def progress_callback(
140+
percent: float,
141+
downloaded: int,
142+
total: int,
143+
*,
144+
_model_key: str = model_key,
145+
_idx: int = idx,
146+
_total_models: int = total_models,
147+
):
148+
# Send progress update
149+
queue.put_nowait({
150+
"status": "downloading",
151+
"model_key": _model_key,
152+
"model_index": _idx + 1,
153+
"total_models": _total_models,
154+
"percent": percent,
155+
"downloaded": downloaded,
156+
"total": total
157+
})
158+
159+
await ensure_model(model_key, progress_callback=progress_callback)
160+
161+
queue.put_nowait({"status": "complete"})
162+
except Exception as e:
163+
logger.error(f"Error during setup download: {e}")
164+
queue.put_nowait({"status": "error", "message": str(e)})
165+
166+
# Start the setup in the background
167+
task = asyncio.create_task(background_setup())
168+
download_tasks[task_id] = DownloadTaskEntry(queue=queue, task=task)
169+
170+
return {
171+
"success": True,
172+
"task_id": task_id,
173+
"message": f"Setup started for tier '{request.tier}'"
174+
}
175+
176+
177+
@router.post("/download/{model_key}")
178+
async def start_download_model(model_key: str):
179+
"""
180+
Starts download for a specific model by key. Returns a task_id.
181+
"""
182+
if model_key not in MODEL_REGISTRY:
183+
raise HTTPException(
184+
status_code=status.HTTP_404_NOT_FOUND,
185+
detail=f"Model key '{model_key}' not found in registry."
186+
)
187+
188+
task_id = str(uuid.uuid4())
189+
queue = asyncio.Queue()
190+
191+
async def background_download():
192+
try:
193+
def progress_callback(percent: float, downloaded: int, total: int):
194+
queue.put_nowait({
195+
"status": "downloading",
196+
"model_key": model_key,
197+
"percent": percent,
198+
"downloaded": downloaded,
199+
"total": total
200+
})
201+
202+
await ensure_model(model_key, progress_callback=progress_callback)
203+
queue.put_nowait({"status": "complete", "model_key": model_key})
204+
except Exception as e:
205+
logger.error(f"Error downloading model {model_key}: {e}")
206+
queue.put_nowait({"status": "error", "message": str(e)})
207+
208+
# Start the download in the background
209+
task = asyncio.create_task(background_download())
210+
download_tasks[task_id] = DownloadTaskEntry(queue=queue, task=task)
211+
212+
return {
213+
"success": True,
214+
"task_id": task_id,
215+
"message": f"Download started for {model_key}"
216+
}
217+
218+
@router.get("/download/{task_id}/progress")
219+
async def download_progress(task_id: str):
220+
"""
221+
Streams SSE progress for a given download task_id.
222+
"""
223+
entry = download_tasks.get(task_id)
224+
if entry is None:
225+
raise HTTPException(
226+
status_code=status.HTTP_404_NOT_FOUND,
227+
detail="Task ID not found or already completed."
228+
)
229+
230+
async def event_generator():
231+
try:
232+
while True:
233+
try:
234+
msg = await asyncio.wait_for(entry.queue.get(), timeout=60.0)
235+
except asyncio.TimeoutError:
236+
yield "event: heartbeat\ndata: {}\n\n"
237+
continue
238+
239+
if msg["status"] == "complete":
240+
yield f"data: {json.dumps(msg)}\n\n"
241+
break
242+
elif msg["status"] == "error":
243+
yield f"data: {json.dumps(msg)}\n\n"
244+
break
245+
else:
246+
yield f"data: {json.dumps(msg)}\n\n"
247+
except asyncio.CancelledError:
248+
# Client disconnected — cancel the download task too
249+
entry.task.cancel()
250+
raise # always re-raise CancelledError
251+
finally:
252+
# Runs on: normal completion, CancelledError, any other exception
253+
download_tasks.pop(task_id, None)
254+
255+
return StreamingResponse(event_generator(), media_type="text/event-stream")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import shutil
2+
import subprocess
3+
4+
import onnxruntime as ort
5+
import psutil
6+
7+
8+
def detect_physical_gpu() -> list[str]:
9+
"""
10+
Detect locally installed GPU hardware without relying on ONNX Runtime providers.
11+
12+
Returns:
13+
list[str]: Human-readable GPU names detected on the machine.
14+
"""
15+
gpu_names: list[str] = []
16+
17+
nvidia_smi = shutil.which("nvidia-smi")
18+
if not nvidia_smi:
19+
return gpu_names
20+
21+
try:
22+
result = subprocess.run(
23+
[nvidia_smi, "--query-gpu=name", "--format=csv,noheader"],
24+
capture_output=True,
25+
text=True,
26+
check=True,
27+
timeout=5,
28+
)
29+
gpu_names = [
30+
line.strip() for line in result.stdout.splitlines() if line.strip()
31+
]
32+
except (OSError, subprocess.SubprocessError):
33+
gpu_names = []
34+
35+
return gpu_names
36+
37+
38+
def detect_hardware_tier() -> str:
39+
"""
40+
Detect system hardware to recommend the best YOLO/FaceNet model tier.
41+
Returns: 'nano', 'small', or 'medium'
42+
"""
43+
# Check RAM in GB
44+
ram_gb = psutil.virtual_memory().total / (1024**3)
45+
46+
# Check for physical GPU hardware directly; this is separate from runtime providers.
47+
gpu_names = detect_physical_gpu()
48+
49+
if gpu_names or ram_gb >= 8:
50+
return "medium"
51+
elif ram_gb >= 4:
52+
return "small"
53+
else:
54+
return "nano"
55+
56+
57+
def get_hardware_info() -> dict:
58+
"""
59+
Return detailed hardware information.
60+
61+
This includes physical hardware detection for recommendations and
62+
ONNX Runtime provider detection for inference/runtime diagnostics.
63+
"""
64+
gpu_names = detect_physical_gpu()
65+
66+
return {
67+
"ram_gb": round(psutil.virtual_memory().total / (1024**3), 2),
68+
"gpu_detected": bool(gpu_names),
69+
"gpu_names": gpu_names,
70+
"available_providers": ort.get_available_providers(),
71+
"recommended_tier": detect_hardware_tier(),
72+
}

backend/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import multiprocessing
66
import os
77
import json
8+
import asyncio
89

910
from app.config.settings import DATABASE_PATH, THUMBNAIL_IMAGES_PATH
1011
from uvicorn import Config, Server
@@ -28,6 +29,7 @@
2829
from app.routes.user_preferences import router as user_preferences_router
2930
from app.routes.memories import router as memories_router
3031
from app.routes.shutdown import router as shutdown_router
32+
from app.routes.models import router as models_router, _cleanup_stale_tasks
3133
from fastapi.openapi.utils import get_openapi
3234
from app.logging.setup_logging import (
3335
configure_uvicorn_logging,
@@ -62,9 +64,14 @@ async def lifespan(app: FastAPI):
6264
# Create ProcessPoolExecutor and attach it to app.state
6365
app.state.executor = ProcessPoolExecutor(max_workers=1)
6466

67+
# Start the SSE model download cleanup task
68+
cleanup_task = asyncio.create_task(_cleanup_stale_tasks())
69+
6570
try:
6671
yield
6772
finally:
73+
cleanup_task.cancel()
74+
await asyncio.gather(cleanup_task, return_exceptions=True)
6875
app.state.executor.shutdown(wait=True)
6976

7077

@@ -142,6 +149,7 @@ async def root():
142149
memories_router
143150
) # Memories router (prefix already defined in router)
144151
app.include_router(shutdown_router, tags=["Shutdown"])
152+
app.include_router(models_router, prefix="/models", tags=["Models"])
145153

146154

147155
# Entry point for running with: python3 main.py

0 commit comments

Comments
 (0)