Skip to content

Commit 2ccc1d6

Browse files
committed
training pipeline api
1 parent 168cf91 commit 2ccc1d6

4 files changed

Lines changed: 629 additions & 0 deletions

File tree

app.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
import sys
3+
from urllib.parse import quote_plus
4+
5+
import certifi
6+
import pymongo
7+
from dotenv import load_dotenv
8+
from fastapi import FastAPI
9+
from fastapi.middleware.cors import CORSMiddleware
10+
from fastapi.responses import Response
11+
from fastapi.templating import Jinja2Templates
12+
from pymongo.mongo_client import MongoClient
13+
from pymongo.server_api import ServerApi
14+
from starlette.responses import RedirectResponse
15+
from uvicorn import run as app_run
16+
17+
from network_security.constant.training_pipeline import (
18+
DATA_INGESTION_COLLECTION_NAME,
19+
DATA_INGESTION_DATABASE_NAME,
20+
)
21+
from network_security.exception.exception import NetworkSecurityException
22+
from network_security.logging.logger import logging
23+
from network_security.pipeline.training_pipeline import TrainingPipeline
24+
from network_security.utils.main_utils.utils import load_object
25+
from network_security.utils.ml_utils.model.estimator import NetworkModel
26+
27+
ca = certifi.where()
28+
29+
30+
load_dotenv()
31+
username = os.getenv("MONGO_DB_USERNAME")
32+
password = os.getenv("MONGO_DB_PASSWORD")
33+
34+
username = quote_plus(username)
35+
password = quote_plus(password)
36+
37+
mongo_db_url: str = f"mongodb+srv://{username}:{password}@cluster0.l5ee6dv.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
38+
39+
client = MongoClient(mongo_db_url, server_api=ServerApi("1"))
40+
41+
client = pymongo.MongoClient(mongo_db_url, tlsCAFile=ca)
42+
43+
44+
database = client[DATA_INGESTION_DATABASE_NAME]
45+
collection = database[DATA_INGESTION_COLLECTION_NAME]
46+
47+
app = FastAPI()
48+
origins = ["*"]
49+
50+
app.add_middleware(
51+
CORSMiddleware,
52+
allow_origins=origins,
53+
allow_credentials=True,
54+
allow_methods=["*"],
55+
allow_headers=["*"],
56+
)
57+
58+
59+
templates = Jinja2Templates(directory="./templates")
60+
61+
62+
@app.get("/", tags=["authentication"])
63+
async def index() -> RedirectResponse:
64+
return RedirectResponse(url="/docs")
65+
66+
67+
@app.get("/train")
68+
async def train_route() -> Response:
69+
try:
70+
train_pipeline = TrainingPipeline()
71+
train_pipeline.run_pipeline()
72+
return Response("Training is successful")
73+
except Exception as e:
74+
raise NetworkSecurityException(e, sys)
75+
76+
77+
if __name__ == "__main__":
78+
app_run(app, host="0.0.0.0", port=8000)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
88
"certifi>=2025.6.15",
9+
"dagshub>=0.5.10",
910
"dill>=0.4.0",
11+
"fastapi>=0.115.13",
1012
"mlflow>=3.1.0",
1113
"numpy>=2.3.0",
1214
"pandas>=2.3.0",
@@ -15,4 +17,5 @@ dependencies = [
1517
"python-dotenv>=1.1.0",
1618
"scikit-learn>=1.7.0",
1719
"setuptools>=80.9.0",
20+
"uvicorn>=0.34.3",
1821
]

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,9 @@ scikit-learn
88
dill
99
pyaml
1010
mlflow
11+
dagshub
12+
fastapi
13+
uvicorn
14+
1115

1216
# -e .

0 commit comments

Comments
 (0)