Skip to content

Commit efb761d

Browse files
committed
fix: replace empty api.py with working FastAPI endpoint
1 parent b6c5364 commit efb761d

1 file changed

Lines changed: 188 additions & 0 deletions

File tree

api.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""
2+
Oasis Security – Crime Predictor API
3+
FastAPI + MLflow production-ready endpoint
4+
"""
5+
6+
from contextlib import asynccontextmanager
7+
from typing import Dict, Optional
8+
9+
import mlflow
10+
import mlflow.lightgbm
11+
import numpy as np
12+
import os
13+
import pandas as pd
14+
import uvicorn
15+
16+
from fastapi import FastAPI, HTTPException
17+
from pydantic import BaseModel
18+
19+
# ---------------------------------------------------------------------------
20+
# Config MLflow
21+
# ---------------------------------------------------------------------------
22+
MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
23+
mlflow.set_tracking_uri(MLFLOW_URI)
24+
mlflow.set_experiment("crime_predictor_prod")
25+
26+
# ---------------------------------------------------------------------------
27+
# Lifespan : chargement modèle au démarrage
28+
# ---------------------------------------------------------------------------
29+
predictor = None
30+
31+
32+
@asynccontextmanager
33+
async def lifespan(app: FastAPI):
34+
global predictor
35+
print("🚀 Chargement modèle...")
36+
try:
37+
model_uri = "models:/crime_predictor_prod/Production"
38+
predictor = mlflow.lightgbm.load_model(model_uri)
39+
print("✅ Modèle chargé depuis MLflow Registry")
40+
except Exception:
41+
# Fallback : modèle local sérialisé
42+
from models.crime_predictor.src.model import CrimeRatePredictor
43+
predictor = CrimeRatePredictor()
44+
predictor.load("models/crime_predictor/artifacts/crime_predictor.pkl")
45+
print("✅ Modèle chargé depuis fichier local")
46+
yield
47+
print("🛑 API shutdown")
48+
49+
50+
# ---------------------------------------------------------------------------
51+
# App
52+
# ---------------------------------------------------------------------------
53+
app = FastAPI(
54+
title="Oasis Security – Crime Predictor API",
55+
version="2.0.0",
56+
description="Prédiction du taux de délinquance par région (pour 100 000 habitants)",
57+
lifespan=lifespan,
58+
)
59+
60+
61+
# ---------------------------------------------------------------------------
62+
# Schémas
63+
# ---------------------------------------------------------------------------
64+
class PredictionRequest(BaseModel):
65+
year: int = 2030
66+
indicateur: str
67+
region: str
68+
lag1: Optional[float] = 250.0
69+
lag2: Optional[float] = 245.0
70+
71+
model_config = {"json_schema_extra": {
72+
"example": {
73+
"year": 2030,
74+
"indicateur": "Coups et blessures volontaires",
75+
"region": "R11",
76+
"lag1": 280.5,
77+
"lag2": 275.0,
78+
}
79+
}}
80+
81+
82+
# ---------------------------------------------------------------------------
83+
# Endpoints
84+
# ---------------------------------------------------------------------------
85+
@app.get("/health", tags=["Monitoring"])
86+
async def health():
87+
"""Vérifie que l'API et le modèle sont opérationnels."""
88+
return {
89+
"status": "healthy",
90+
"model_loaded": predictor is not None,
91+
"model_version": "v2.0",
92+
"mlflow_uri": MLFLOW_URI,
93+
}
94+
95+
96+
@app.post("/predict", response_model=Dict, tags=["Prédiction"])
97+
async def predict(request: PredictionRequest):
98+
"""
99+
Prédit le taux de délinquance pour un indicateur et une région donnés.
100+
101+
- **year** : année cible (ex. 2030)
102+
- **indicateur** : catégorie de crime (ex. "Coups et blessures volontaires")
103+
- **region** : code région INSEE (ex. "R11" pour Île-de-France)
104+
- **lag1 / lag2** : taux des 2 années précédentes (optionnel, valeurs par défaut utilisées si absent)
105+
"""
106+
if predictor is None:
107+
raise HTTPException(status_code=503, detail="Modèle non chargé")
108+
109+
with mlflow.start_run(nested=True) as run:
110+
try:
111+
lag1 = request.lag1 or 250.0
112+
lag2 = request.lag2 or 245.0
113+
114+
features = pd.DataFrame([{
115+
"year_sin": np.sin(2 * np.pi * request.year / 10),
116+
"year_cos": np.cos(2 * np.pi * request.year / 10),
117+
"year_trend": (request.year - 2016) / 9,
118+
"lag1": lag1,
119+
"lag2": lag2,
120+
"roll_mean_3": (lag1 + lag2 + 240.0) / 3,
121+
"region_mean": 250.0,
122+
"ind_code": hash(request.indicateur) % 100,
123+
"reg_code": int(request.region.replace("R", "")),
124+
}])
125+
126+
pred = float(predictor.predict(features)[0])
127+
128+
# Observabilité MLflow
129+
mlflow.log_params({
130+
"indicateur": request.indicateur,
131+
"region": request.region,
132+
"year": request.year,
133+
})
134+
mlflow.log_metric("prediction", pred)
135+
136+
niveau = (
137+
"🚨 Risque élevé" if pred > 400 else
138+
"⚠️ Risque modéré" if pred > 300 else
139+
"✅ Risque faible"
140+
)
141+
142+
return {
143+
"prediction": round(pred, 2),
144+
"unit": "taux / 100 000 habitants",
145+
"year": request.year,
146+
"indicateur": request.indicateur,
147+
"region": request.region,
148+
"interpretation": niveau,
149+
"mlflow_run_id": run.info.run_id,
150+
}
151+
152+
except Exception as e:
153+
mlflow.log_metric("error", 1)
154+
raise HTTPException(status_code=500, detail=str(e))
155+
156+
157+
@app.get("/leaderboard", tags=["Analyse"])
158+
async def leaderboard():
159+
"""
160+
Retourne le top 5 des combinaisons région/indicateur
161+
avec les prédictions 2030 les plus élevées (risques prioritaires).
162+
"""
163+
try:
164+
client = mlflow.MlflowClient()
165+
runs = client.search_runs(
166+
experiment_ids=["0"],
167+
order_by=["metrics.prediction DESC"],
168+
max_results=50,
169+
)
170+
return {
171+
"top_risks": [
172+
{
173+
"indicateur": r.data.params.get("indicateur", "N/A"),
174+
"region": r.data.params.get("region", "N/A"),
175+
"pred_2030": r.data.metrics.get("prediction", 0),
176+
}
177+
for r in runs
178+
][:5]
179+
}
180+
except Exception as e:
181+
raise HTTPException(status_code=500, detail=str(e))
182+
183+
184+
# ---------------------------------------------------------------------------
185+
# Lancement direct
186+
# ---------------------------------------------------------------------------
187+
if __name__ == "__main__":
188+
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=False)

0 commit comments

Comments
 (0)