@@ -110,6 +110,13 @@ def __init__(self, estimator, X, y, scoring=None, cv=None):
110110 self ._scoring = make_scorer (scoring )
111111 self .scorer_ = self ._scoring
112112
113+ # Set the sign of the scoring function
114+ if hasattr (self ._scoring , "_score" ):
115+ score_func = self ._scoring ._score_func
116+ _sign = _guess_sign_of_sklmetric (score_func )
117+ _sign_str = "higher" if _sign == 1 else "lower"
118+ self .set_tags (** {"property:higher_or_lower_is_better" : _sign_str })
119+
113120 def _paramnames (self ):
114121 """Return the parameter names of the search.
115122
@@ -120,18 +127,18 @@ def _paramnames(self):
120127 """
121128 return list (self .estimator .get_params ().keys ())
122129
123- def _score (self , params ):
124- """Score the parameters.
130+ def _evaluate (self , params ):
131+ """Evaluate the parameters.
125132
126133 Parameters
127134 ----------
128135 params : dict with string keys
129- Parameters to score .
136+ Parameters to evaluate .
130137
131138 Returns
132139 -------
133140 float
134- The score of the parameters.
141+ The value of the parameters as per evaluation .
135142 dict
136143 Additional metadata about the search.
137144 """
@@ -221,10 +228,11 @@ def get_test_params(cls, parameter_set="default"):
221228
222229 @classmethod
223230 def _get_score_params (self ):
224- """Return settings for testing the score function . Used in tests only.
231+ """Return settings for testing score/evaluate functions . Used in tests only.
225232
226- Returns a list, the i-th element corresponds to self.get_test_params()[i].
227- It should be a valid call for self.score.
233+ Returns a list, the i-th element should be valid arguments for
234+ self.evaluate and self.score, of an instance constructed with
235+ self.get_test_params()[i].
228236
229237 Returns
230238 -------
@@ -235,3 +243,80 @@ def _get_score_params(self):
235243 score_params_regress = {"C" : 1.0 , "kernel" : "linear" }
236244 score_params_defaults = {"C" : 1.0 , "kernel" : "linear" }
237245 return [score_params_classif , score_params_regress , score_params_defaults ]
246+
247+
248+ def _guess_sign_of_sklmetric (scorer ):
249+ """Guess the sign of a sklearn metric scorer.
250+
251+ Parameters
252+ ----------
253+ scorer : callable
254+ The sklearn metric scorer to guess the sign for.
255+
256+ Returns
257+ -------
258+ int
259+ 1 if higher scores are better, -1 if lower scores are better.
260+ """
261+ HIGHER_IS_BETTER = {
262+ # Classification
263+ "accuracy_score" : True ,
264+ "auc" : True ,
265+ "average_precision_score" : True ,
266+ "balanced_accuracy_score" : True ,
267+ "brier_score_loss" : False ,
268+ "class_likelihood_ratios" : False ,
269+ "cohen_kappa_score" : True ,
270+ "d2_log_loss_score" : True ,
271+ "dcg_score" : True ,
272+ "f1_score" : True ,
273+ "fbeta_score" : True ,
274+ "hamming_loss" : False ,
275+ "hinge_loss" : False ,
276+ "jaccard_score" : True ,
277+ "log_loss" : False ,
278+ "matthews_corrcoef" : True ,
279+ "ndcg_score" : True ,
280+ "precision_score" : True ,
281+ "recall_score" : True ,
282+ "roc_auc_score" : True ,
283+ "top_k_accuracy_score" : True ,
284+ "zero_one_loss" : False ,
285+
286+ # Regression
287+ "d2_absolute_error_score" : True ,
288+ "d2_pinball_score" : True ,
289+ "d2_tweedie_score" : True ,
290+ "explained_variance_score" : True ,
291+ "max_error" : False ,
292+ "mean_absolute_error" : False ,
293+ "mean_absolute_percentage_error" : False ,
294+ "mean_gamma_deviance" : False ,
295+ "mean_pinball_loss" : False ,
296+ "mean_poisson_deviance" : False ,
297+ "mean_squared_error" : False ,
298+ "mean_squared_log_error" : False ,
299+ "mean_tweedie_deviance" : False ,
300+ "median_absolute_error" : False ,
301+ "r2_score" : True ,
302+ "root_mean_squared_error" : False ,
303+ "root_mean_squared_log_error" : False ,
304+ }
305+
306+ scorer_name = getattr (scorer , "__name__" , None )
307+
308+ if hasattr (scorer , "greater_is_better" ):
309+ return 1 if scorer .greater_is_better else - 1
310+ elif scorer_name in HIGHER_IS_BETTER :
311+ return 1 if HIGHER_IS_BETTER [scorer_name ] else - 1
312+ elif scorer_name .endswith ("_score" ):
313+ # If the scorer name ends with "_score", we assume higher is better
314+ return 1
315+ elif scorer_name .endswith ("_loss" ) or scorer_name .endswith ("_deviance" ):
316+ # If the scorer name ends with "_loss", we assume lower is better
317+ return - 1
318+ elif scorer_name .endswith ("_error" ):
319+ return - 1
320+ else :
321+ # If we cannot determine the sign, we assume lower is better
322+ return - 1
0 commit comments