Skip to content

Commit 3d489f1

Browse files
Copilotthinkall
andauthored
Add validation and clear error messages for custom_metric parameter (#1500)
* Initial plan * Add validation and documentation for custom_metric parameter Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Refactor validation into reusable method and improve error handling Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> * Apply pre-commit formatting fixes Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
1 parent c64eeb5 commit 3d489f1

2 files changed

Lines changed: 64 additions & 1 deletion

File tree

flaml/automl/automl.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def custom_metric(
156156
"pred_time": pred_time,
157157
}
158158
```
159+
**Note:** When passing a custom metric function, pass the function itself
160+
(e.g., `metric=custom_metric`), not the result of calling it
161+
(e.g., `metric=custom_metric(...)`). FLAML will call your function
162+
internally during the training process.
159163
task: A string of the task type, e.g.,
160164
'classification', 'regression', 'ts_forecast', 'rank',
161165
'seq-classification', 'seq-regression', 'summarization',
@@ -370,6 +374,8 @@ def custom_metric(
370374
settings["n_splits"] = settings.get("n_splits", N_SPLITS)
371375
settings["auto_augment"] = settings.get("auto_augment", True)
372376
settings["metric"] = settings.get("metric", "auto")
377+
# Validate that custom metric is callable if not a string
378+
self._validate_metric_parameter(settings["metric"], allow_auto=True)
373379
settings["estimator_list"] = settings.get("estimator_list", "auto")
374380
settings["log_file_name"] = settings.get("log_file_name", "")
375381
settings["max_iter"] = settings.get("max_iter") # no budget by default
@@ -462,6 +468,28 @@ def __setstate__(self, state):
462468
except Exception:
463469
mi.mlflow_client = None
464470

471+
@staticmethod
472+
def _validate_metric_parameter(metric, allow_auto=True):
473+
"""Validate that the metric parameter is either a string or a callable function.
474+
475+
Args:
476+
metric: The metric parameter to validate.
477+
allow_auto: Whether to allow "auto" as a valid string value.
478+
479+
Raises:
480+
ValueError: If metric is not a string or callable function.
481+
"""
482+
if allow_auto and metric == "auto":
483+
return
484+
if not isinstance(metric, str) and not callable(metric):
485+
raise ValueError(
486+
f"The 'metric' parameter must be either a string or a callable function, "
487+
f"but got {type(metric).__name__}. "
488+
f"If you defined a custom_metric function, make sure to pass the function itself "
489+
f"(e.g., metric=custom_metric) and not the result of calling it "
490+
f"(e.g., metric=custom_metric(...))."
491+
)
492+
465493
def get_params(self, deep: bool = False) -> dict:
466494
return self._settings.copy()
467495

@@ -1810,6 +1838,10 @@ def custom_metric(
18101838
"pred_time": pred_time,
18111839
}
18121840
```
1841+
**Note:** When passing a custom metric function, pass the function itself
1842+
(e.g., `metric=custom_metric`), not the result of calling it
1843+
(e.g., `metric=custom_metric(...)`). FLAML will call your function
1844+
internally during the training process.
18131845
task: A string of the task type, e.g.,
18141846
'classification', 'regression', 'ts_forecast_regression',
18151847
'ts_forecast_classification', 'rank', 'seq-classification',
@@ -2095,7 +2127,7 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
20952127
split_ratio = split_ratio or self._settings.get("split_ratio")
20962128
n_splits = n_splits or self._settings.get("n_splits")
20972129
auto_augment = self._settings.get("auto_augment") if auto_augment is None else auto_augment
2098-
metric = metric or self._settings.get("metric")
2130+
metric = self._settings.get("metric") if metric is None else metric
20992131
estimator_list = estimator_list or self._settings.get("estimator_list")
21002132
log_file_name = self._settings.get("log_file_name") if log_file_name is None else log_file_name
21012133
max_iter = self._settings.get("max_iter") if max_iter is None else max_iter
@@ -2334,6 +2366,9 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
23342366
and (self._min_sample_size * SAMPLE_MULTIPLY_FACTOR < self._state.data_size[0])
23352367
)
23362368

2369+
# Validate metric parameter before processing
2370+
self._validate_metric_parameter(metric, allow_auto=True)
2371+
23372372
metric = task.default_metric(metric)
23382373
self._state.metric = metric
23392374

test/automl/test_multiclass.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,34 @@ def test_custom_metric(self):
278278
except ImportError:
279279
pass
280280

281+
def test_invalid_custom_metric(self):
282+
"""Test that proper error is raised when custom_metric is called instead of passed."""
283+
from sklearn.datasets import load_iris
284+
285+
X_train, y_train = load_iris(return_X_y=True)
286+
287+
# Test with non-callable metric in __init__
288+
with self.assertRaises(ValueError) as context:
289+
automl = AutoML(metric=123) # passing an int instead of function
290+
self.assertIn("must be either a string or a callable function", str(context.exception))
291+
self.assertIn("but got int", str(context.exception))
292+
293+
# Test with non-callable metric in fit
294+
automl = AutoML()
295+
with self.assertRaises(ValueError) as context:
296+
automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1)
297+
self.assertIn("must be either a string or a callable function", str(context.exception))
298+
self.assertIn("but got list", str(context.exception))
299+
300+
# Test with tuple (simulating result of calling a function that returns tuple)
301+
with self.assertRaises(ValueError) as context:
302+
automl = AutoML()
303+
automl.fit(
304+
X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1
305+
)
306+
self.assertIn("must be either a string or a callable function", str(context.exception))
307+
self.assertIn("but got tuple", str(context.exception))
308+
281309
def test_classification(self, as_frame=False):
282310
automl_experiment = AutoML()
283311
automl_settings = {

0 commit comments

Comments
 (0)