11
2- # third party imports
32from typing import Callable , Optional
43
4+ # third party imports
55import numpy as np
66import pandas as pd
77from scipy import stats
88from sklearn .metrics import roc_auc_score , mean_squared_error
99from numpy import sqrt
1010from sklearn .linear_model import LogisticRegression , LinearRegression
11+ from sklearn .metrics import roc_curve
1112
1213# custom imports
1314import cobra .utils as utils
15+ from cobra .evaluation import ClassificationEvaluator
1416
1517class LogisticRegressionModel :
1618 """Wrapper around the LogisticRegression class, with additional methods
@@ -148,8 +150,8 @@ def score_model(self, X: pd.DataFrame) -> np.ndarray:
148150 def evaluate (self , X : pd .DataFrame , y : pd .Series ,
149151 split : str = None ,
150152 metric : Optional [Callable ]= None ) -> float :
151- """Evaluate the model on a given data set (X, y). The optional split
152- parameter is to indicate that the data set belongs to
153+ """Evaluate the model on a given dataset (X, y). The optional split
154+ parameter is to indicate that the dataset belongs to
153155 (train, selection, validation), so that the computation on these sets
154156 can be cached!
155157
@@ -164,7 +166,7 @@ def evaluate(self, X: pd.DataFrame, y: pd.Series,
164166 metric: Callable (function), optional
165167 Function that computes an evaluation metric to evaluate the model's
166168 performances, instead of the default metric (AUC).
167- The function should require y_true and y_pred arguments.
169+ The function should require y_true and y_pred (binary output) arguments.
168170 Metric functions from sklearn can be used, for example, see
169171 https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics.
170172
@@ -173,20 +175,25 @@ def evaluate(self, X: pd.DataFrame, y: pd.Series,
173175 float
174176 The performance score of the model (AUC by default).
175177 """
178+ if metric is not None : # decouple from _eval_metrics_by_split attribute
179+ y_pred = self .score_model (X )
176180
177- if (split is None ) or (split not in self ._eval_metrics_by_split ):
181+ fpr , tpr , thresholds = roc_curve (y_true = y , y_score = y_pred )
182+ cutoff = (ClassificationEvaluator ._compute_optimal_cutoff (fpr , tpr , thresholds ))
183+ y_pred_b = np .array ([0 if pred <= cutoff else 1 for pred in y_pred ])
178184
179- y_pred = self . score_model ( X )
185+ performance = metric ( y_true = y , y_pred = y_pred_b )
180186
181- if metric is None :
187+ return performance
188+ else :
189+ if (split is None ) or (split not in self ._eval_metrics_by_split ):
190+ y_pred = self .score_model (X )
182191 performance = roc_auc_score (y_true = y , y_score = y_pred )
183- else :
184- performance = metric (y_true = y , y_pred = y_pred )
185192
186- if split is None :
187- return performance
188- else :
189- self ._eval_metrics_by_split [split ] = performance
193+ if split is None :
194+ return performance
195+ else :
196+ self ._eval_metrics_by_split [split ] = performance
190197
191198 return self ._eval_metrics_by_split [split ]
192199
@@ -371,8 +378,8 @@ def score_model(self, X: pd.DataFrame) -> np.ndarray:
371378 def evaluate (self , X : pd .DataFrame , y : pd .Series ,
372379 split : str = None ,
373380 metric : Optional [Callable ]= None ) -> float :
374- """Evaluate the model on a given data set (X, y). The optional split
375- parameter is to indicate that the data set belongs to
381+ """Evaluate the model on a given dataset (X, y). The optional split
382+ parameter is to indicate that the dataset belongs to
376383 (train, selection, validation), so that the computation on these sets
377384 can be cached!
378385
@@ -396,19 +403,20 @@ def evaluate(self, X: pd.DataFrame, y: pd.Series,
396403 float
397404 The performance score of the model (RMSE by default).
398405 """
399-
400- if (split is None ) or (split not in self ._eval_metrics_by_split ):
401-
406+ if metric is not None : # decouple from _eval_metrics_by_split attribute
402407 y_pred = self .score_model (X )
403- if metric is None :
408+ performance = metric (y_true = y , y_pred = y_pred )
409+
410+ return performance
411+ else :
412+ if (split is None ) or (split not in self ._eval_metrics_by_split ):
413+ y_pred = self .score_model (X )
404414 performance = sqrt (mean_squared_error (y_true = y , y_pred = y_pred ))
405- else :
406- performance = metric (y_true = y , y_pred = y_pred )
407415
408- if split is None :
409- return performance
410- else :
411- self ._eval_metrics_by_split [split ] = performance
416+ if split is None :
417+ return performance
418+ else :
419+ self ._eval_metrics_by_split [split ] = performance
412420
413421 return self ._eval_metrics_by_split [split ]
414422
0 commit comments