11import sys
22from pathlib import Path
33
4+ import mlflow
45from sklearn .ensemble import (
56 AdaBoostClassifier ,
67 GradientBoostingClassifier ,
@@ -40,6 +41,17 @@ def __init__(
4041 except Exception as e :
4142 raise NetworkSecurityException (e , sys )
4243
44+ def track_mlflow (self , best_model , classificationmetric ):
45+ with mlflow .start_run ():
46+ f1_score = classificationmetric .f1_score
47+ precision_score = classificationmetric .precision_score
48+ recall_score = classificationmetric .recall_score
49+
50+ mlflow .log_metric ("f1_score" , f1_score )
51+ mlflow .log_metric ("precision" , precision_score )
52+ mlflow .log_metric ("recall_score" , recall_score )
53+ mlflow .sklearn .log_model (best_model , "model" )
54+
4355 def train_model (
4456 self ,
4557 X_train : object ,
@@ -57,20 +69,20 @@ def train_model(
5769 params = {
5870 "Decision Tree" : {
5971 "criterion" : ["gini" , "entropy" , "log_loss" ],
60- "splitter" : ["best" , "random" ],
61- "max_features" : ["sqrt" , "log2" ],
72+ # "splitter": ["best", "random"],
73+ # "max_features": ["sqrt", "log2"],
6274 },
6375 "Random Forest" : {
64- "criterion" : ["gini" , "entropy" , "log_loss" ],
65- "max_features" : ["sqrt" , "log2" , None ],
76+ # "criterion": ["gini", "entropy", "log_loss"],
77+ # "max_features": ["sqrt", "log2", None],
6678 "n_estimators" : [8 , 16 , 32 , 128 , 256 ],
6779 },
6880 "Gradient Boosting" : {
69- "loss" : ["log_loss" , "exponential" ],
81+ # "loss": ["log_loss", "exponential"],
7082 "learning_rate" : [0.1 , 0.01 , 0.05 , 0.001 ],
7183 "subsample" : [0.6 , 0.7 , 0.75 , 0.85 , 0.9 ],
72- "criterion" : ["squared_error" , "friedman_mse" ],
73- "max_features" : ["auto" , "sqrt" , "log2" ],
84+ # "criterion": ["squared_error", "friedman_mse"],
85+ # "max_features": ["auto", "sqrt", "log2"],
7486 "n_estimators" : [8 , 16 , 32 , 64 , 128 , 256 ],
7587 },
7688 "Logistic Regression" : {},
@@ -102,13 +114,17 @@ def train_model(
102114 y_true = y_train ,
103115 y_pred = y_train_pred ,
104116 )
117+ ## Track the training experiements with mlflow
118+ self .track_mlflow (best_model , classification_train_metric )
105119
106120 y_test_pred = best_model .predict (X_test )
107121 classification_test_metric = get_classification_score (
108122 y_true = y_test ,
109123 y_pred = y_test_pred ,
110124 )
111125
126+ self .track_mlflow (best_model , classification_test_metric )
127+
112128 preprocessor = load_object (
113129 file_path = self .data_transformation_artifact .transformed_object_file_path ,
114130 )
0 commit comments