@@ -156,6 +156,10 @@ def custom_metric(
156156 "pred_time": pred_time,
157157 }
158158 ```
159+ **Note:** When passing a custom metric function, pass the function itself
160+ (e.g., `metric=custom_metric`), not the result of calling it
161+ (e.g., `metric=custom_metric(...)`). FLAML will call your function
162+ internally during the training process.
159163 task: A string of the task type, e.g.,
160164 'classification', 'regression', 'ts_forecast', 'rank',
161165 'seq-classification', 'seq-regression', 'summarization',
@@ -370,6 +374,8 @@ def custom_metric(
370374 settings ["n_splits" ] = settings .get ("n_splits" , N_SPLITS )
371375 settings ["auto_augment" ] = settings .get ("auto_augment" , True )
372376 settings ["metric" ] = settings .get ("metric" , "auto" )
377+ # Validate that custom metric is callable if not a string
378+ self ._validate_metric_parameter (settings ["metric" ], allow_auto = True )
373379 settings ["estimator_list" ] = settings .get ("estimator_list" , "auto" )
374380 settings ["log_file_name" ] = settings .get ("log_file_name" , "" )
375381 settings ["max_iter" ] = settings .get ("max_iter" ) # no budget by default
@@ -462,6 +468,28 @@ def __setstate__(self, state):
462468 except Exception :
463469 mi .mlflow_client = None
464470
471+ @staticmethod
472+ def _validate_metric_parameter (metric , allow_auto = True ):
473+ """Validate that the metric parameter is either a string or a callable function.
474+
475+ Args:
476+ metric: The metric parameter to validate.
477+ allow_auto: Whether to allow "auto" as a valid string value.
478+
479+ Raises:
480+ ValueError: If metric is not a string or callable function.
481+ """
482+ if allow_auto and metric == "auto" :
483+ return
484+ if not isinstance (metric , str ) and not callable (metric ):
485+ raise ValueError (
486+ f"The 'metric' parameter must be either a string or a callable function, "
487+ f"but got { type (metric ).__name__ } . "
488+ f"If you defined a custom_metric function, make sure to pass the function itself "
489+ f"(e.g., metric=custom_metric) and not the result of calling it "
490+ f"(e.g., metric=custom_metric(...))."
491+ )
492+
465493 def get_params (self , deep : bool = False ) -> dict :
466494 return self ._settings .copy ()
467495
@@ -1810,6 +1838,10 @@ def custom_metric(
18101838 "pred_time": pred_time,
18111839 }
18121840 ```
1841+ **Note:** When passing a custom metric function, pass the function itself
1842+ (e.g., `metric=custom_metric`), not the result of calling it
1843+ (e.g., `metric=custom_metric(...)`). FLAML will call your function
1844+ internally during the training process.
18131845 task: A string of the task type, e.g.,
18141846 'classification', 'regression', 'ts_forecast_regression',
18151847 'ts_forecast_classification', 'rank', 'seq-classification',
@@ -2095,7 +2127,7 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
20952127 split_ratio = split_ratio or self ._settings .get ("split_ratio" )
20962128 n_splits = n_splits or self ._settings .get ("n_splits" )
20972129 auto_augment = self ._settings .get ("auto_augment" ) if auto_augment is None else auto_augment
2098- metric = metric or self ._settings .get ("metric" )
2130+ metric = self ._settings .get ("metric" ) if metric is None else metric
20992131 estimator_list = estimator_list or self ._settings .get ("estimator_list" )
21002132 log_file_name = self ._settings .get ("log_file_name" ) if log_file_name is None else log_file_name
21012133 max_iter = self ._settings .get ("max_iter" ) if max_iter is None else max_iter
@@ -2334,6 +2366,9 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
23342366 and (self ._min_sample_size * SAMPLE_MULTIPLY_FACTOR < self ._state .data_size [0 ])
23352367 )
23362368
2369+ # Validate metric parameter before processing
2370+ self ._validate_metric_parameter (metric , allow_auto = True )
2371+
23372372 metric = task .default_metric (metric )
23382373 self ._state .metric = metric
23392374
0 commit comments