@@ -110,6 +110,13 @@ def __init__(self, estimator, X, y, scoring=None, cv=None):
110110 self ._scoring = make_scorer (scoring )
111111 self .scorer_ = self ._scoring
112112
113+ # Set the sign of the scoring function
114+ if hasattr (self ._scoring , "_score" ):
115+ score_func = self ._scoring ._score_func
116+ _sign = _guess_sign_of_sklmetric (score_func )
117+ _sign_str = "higher" if _sign == 1 else "lower"
118+ self .set_tags (** {"property:higher_or_lower_is_better" : _sign_str })
119+
113120 def _paramnames (self ):
114121 """Return the parameter names of the search.
115122
@@ -235,3 +242,80 @@ def _get_score_params(self):
235242 score_params_regress = {"C" : 1.0 , "kernel" : "linear" }
236243 score_params_defaults = {"C" : 1.0 , "kernel" : "linear" }
237244 return [score_params_classif , score_params_regress , score_params_defaults ]
245+
246+
247+ def _guess_sign_of_sklmetric (scorer ):
248+ """Guess the sign of a sklearn metric scorer.
249+
250+ Parameters
251+ ----------
252+ scorer : callable
253+ The sklearn metric scorer to guess the sign for.
254+
255+ Returns
256+ -------
257+ int
258+ 1 if higher scores are better, -1 if lower scores are better.
259+ """
260+ HIGHER_IS_BETTER = {
261+ # Classification
262+ "accuracy_score" : True ,
263+ "auc" : True ,
264+ "average_precision_score" : True ,
265+ "balanced_accuracy_score" : True ,
266+ "brier_score_loss" : False ,
267+ "class_likelihood_ratios" : False ,
268+ "cohen_kappa_score" : True ,
269+ "d2_log_loss_score" : True ,
270+ "dcg_score" : True ,
271+ "f1_score" : True ,
272+ "fbeta_score" : True ,
273+ "hamming_loss" : False ,
274+ "hinge_loss" : False ,
275+ "jaccard_score" : True ,
276+ "log_loss" : False ,
277+ "matthews_corrcoef" : True ,
278+ "ndcg_score" : True ,
279+ "precision_score" : True ,
280+ "recall_score" : True ,
281+ "roc_auc_score" : True ,
282+ "top_k_accuracy_score" : True ,
283+ "zero_one_loss" : False ,
284+
285+ # Regression
286+ "d2_absolute_error_score" : True ,
287+ "d2_pinball_score" : True ,
288+ "d2_tweedie_score" : True ,
289+ "explained_variance_score" : True ,
290+ "max_error" : False ,
291+ "mean_absolute_error" : False ,
292+ "mean_absolute_percentage_error" : False ,
293+ "mean_gamma_deviance" : False ,
294+ "mean_pinball_loss" : False ,
295+ "mean_poisson_deviance" : False ,
296+ "mean_squared_error" : False ,
297+ "mean_squared_log_error" : False ,
298+ "mean_tweedie_deviance" : False ,
299+ "median_absolute_error" : False ,
300+ "r2_score" : True ,
301+ "root_mean_squared_error" : False ,
302+ "root_mean_squared_log_error" : False ,
303+ }
304+
305+ scorer_name = getattr (scorer , "__name__" , None )
306+
307+ if hasattr (scorer , "greater_is_better" ):
308+ return 1 if scorer .greater_is_better else - 1
309+ elif scorer_name in HIGHER_IS_BETTER :
310+ return 1 if HIGHER_IS_BETTER [scorer_name ] else - 1
311+ elif scorer_name .endswith ("_score" ):
312+ # If the scorer name ends with "_score", we assume higher is better
313+ return 1
314+ elif scorer_name .endswith ("_loss" ) or scorer_name .endswith ("_deviance" ):
315+ # If the scorer name ends with "_loss", we assume lower is better
316+ return - 1
317+ elif scorer_name .endswith ("_error" ):
318+ return - 1
319+ else :
320+ # If we cannot determine the sign, we assume lower is better
321+ return - 1
0 commit comments