11
22# third party imports
3+ from typing import Callable , Optional
4+
35import numpy as np
46import pandas as pd
57from scipy import stats
@@ -144,7 +146,8 @@ def score_model(self, X: pd.DataFrame) -> np.ndarray:
144146 return self .logit .predict_proba (X [self .predictors ])[:, 1 ]
145147
146148 def evaluate (self , X : pd .DataFrame , y : pd .Series ,
147- split : str = None ) -> float :
149+ split : str = None ,
150+ metric : Optional [Callable ]= None ) -> float :
148151 """Evaluate the model on a given data set (X, y). The optional split
149152 parameter is to indicate that the data set belongs to
150153 (train, selection, validation), so that the computation on these sets
@@ -158,18 +161,27 @@ def evaluate(self, X: pd.DataFrame, y: pd.Series,
158161 Dataset containing the target of each observation.
159162 split : str, optional
160163 Split name of the dataset (e.g. "train", "selection", or "validation").
164+ metric: Callable (function), optional
165+ Function that computes an evaluation metric to evaluate the model's
166+ performances, instead of the default metric (AUC).
167+ The function should require y_true and y_pred arguments.
168+ Metric functions from sklearn can be used, for example, see
169+ https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics.
161170
162171 Returns
163172 -------
164173 float
165- The performance score of the model (AUC).
174+ The performance score of the model (AUC by default ).
166175 """
167176
168177 if (split is None ) or (split not in self ._eval_metrics_by_split ):
169178
170179 y_pred = self .score_model (X )
171180
172- performance = roc_auc_score (y_true = y , y_score = y_pred )
181+ if metric is None :
182+ performance = roc_auc_score (y_true = y , y_score = y_pred )
183+ else :
184+ performance = metric (y_true = y , y_pred = y_pred )
173185
174186 if split is None :
175187 return performance
@@ -357,7 +369,8 @@ def score_model(self, X: pd.DataFrame) -> np.ndarray:
357369 return self .linear .predict (X [self .predictors ])
358370
359371 def evaluate (self , X : pd .DataFrame , y : pd .Series ,
360- split : str = None ) -> float :
372+ split : str = None ,
373+ metric : Optional [Callable ]= None ) -> float :
361374 """Evaluate the model on a given data set (X, y). The optional split
362375 parameter is to indicate that the data set belongs to
363376 (train, selection, validation), so that the computation on these sets
@@ -371,18 +384,26 @@ def evaluate(self, X: pd.DataFrame, y: pd.Series,
371384 Dataset containing the target of each observation.
372385 split : str, optional
373386 Split name of the dataset (e.g. "train", "selection", or "validation").
387+ metric: Callable (function), optional
388+ Function that computes an evaluation metric to evaluate the model's
389+ performances, instead of the default metric (RMSE).
390+ The function should require y_true and y_pred arguments.
391+ Metric functions from sklearn can be used, for example, see
392+ https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics.
374393
375394 Returns
376395 -------
377396 float
378- The performance score of the model (RMSE).
397+ The performance score of the model (RMSE by default ).
379398 """
380399
381400 if (split is None ) or (split not in self ._eval_metrics_by_split ):
382401
383402 y_pred = self .score_model (X )
384-
385- performance = sqrt (mean_squared_error (y_true = y , y_pred = y_pred ))
403+ if metric is None :
404+ performance = sqrt (mean_squared_error (y_true = y , y_pred = y_pred ))
405+ else :
406+ performance = metric (y_true = y , y_pred = y_pred )
386407
387408 if split is None :
388409 return performance
0 commit comments