Skip to content

Commit 055aa15

Browse files
committed
MLflow implementation
1 parent 63bf925 commit 055aa15

4 files changed

Lines changed: 1048 additions & 7 deletions

File tree

network_security/components/model_trainer.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
from pathlib import Path
33

4+
import mlflow
45
from sklearn.ensemble import (
56
AdaBoostClassifier,
67
GradientBoostingClassifier,
@@ -40,6 +41,17 @@ def __init__(
4041
except Exception as e:
4142
raise NetworkSecurityException(e, sys)
4243

44+
def track_mlflow(self, best_model, classificationmetric):
45+
with mlflow.start_run():
46+
f1_score = classificationmetric.f1_score
47+
precision_score = classificationmetric.precision_score
48+
recall_score = classificationmetric.recall_score
49+
50+
mlflow.log_metric("f1_score", f1_score)
51+
mlflow.log_metric("precision", precision_score)
52+
mlflow.log_metric("recall_score", recall_score)
53+
mlflow.sklearn.log_model(best_model, "model")
54+
4355
def train_model(
4456
self,
4557
X_train: object,
@@ -57,20 +69,20 @@ def train_model(
5769
params = {
5870
"Decision Tree": {
5971
"criterion": ["gini", "entropy", "log_loss"],
60-
"splitter": ["best", "random"],
61-
"max_features": ["sqrt", "log2"],
72+
# "splitter": ["best", "random"],
73+
# "max_features": ["sqrt", "log2"],
6274
},
6375
"Random Forest": {
64-
"criterion": ["gini", "entropy", "log_loss"],
65-
"max_features": ["sqrt", "log2", None],
76+
# "criterion": ["gini", "entropy", "log_loss"],
77+
# "max_features": ["sqrt", "log2", None],
6678
"n_estimators": [8, 16, 32, 128, 256],
6779
},
6880
"Gradient Boosting": {
69-
"loss": ["log_loss", "exponential"],
81+
# "loss": ["log_loss", "exponential"],
7082
"learning_rate": [0.1, 0.01, 0.05, 0.001],
7183
"subsample": [0.6, 0.7, 0.75, 0.85, 0.9],
72-
"criterion": ["squared_error", "friedman_mse"],
73-
"max_features": ["auto", "sqrt", "log2"],
84+
# "criterion": ["squared_error", "friedman_mse"],
85+
# "max_features": ["auto", "sqrt", "log2"],
7486
"n_estimators": [8, 16, 32, 64, 128, 256],
7587
},
7688
"Logistic Regression": {},
@@ -102,13 +114,17 @@ def train_model(
102114
y_true=y_train,
103115
y_pred=y_train_pred,
104116
)
117+
## Track the training experiements with mlflow
118+
self.track_mlflow(best_model, classification_train_metric)
105119

106120
y_test_pred = best_model.predict(X_test)
107121
classification_test_metric = get_classification_score(
108122
y_true=y_test,
109123
y_pred=y_test_pred,
110124
)
111125

126+
self.track_mlflow(best_model, classification_test_metric)
127+
112128
preprocessor = load_object(
113129
file_path=self.data_transformation_artifact.transformed_object_file_path,
114130
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ requires-python = ">=3.12"
77
dependencies = [
88
"certifi>=2025.6.15",
99
"dill>=0.4.0",
10+
"mlflow>=3.1.0",
1011
"numpy>=2.3.0",
1112
"pandas>=2.3.0",
1213
"pyaml>=25.5.0",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ pymongo[srv]==3.12
77
scikit-learn
88
dill
99
pyaml
10+
mlflow
1011

1112
# -e .

0 commit comments

Comments
 (0)