|
1 | 1 | # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines |
2 | 2 | """Scikit-Learn Wrapper interface for XGBoost.""" |
| 3 | +import collections |
3 | 4 | import copy |
4 | 5 | import json |
5 | 6 | import os |
@@ -432,7 +433,7 @@ def task(i: int) -> float: |
432 | 433 | - ``one_output_per_tree``: One model for each target. |
433 | 434 | - ``multi_output_tree``: Use multi-target trees. |
434 | 435 |
|
435 | | - eval_metric : {Optional[Union[str, List[str], Callable]]} |
| 436 | + eval_metric : {Optional[Union[str, List[Union[str, Callable]], Callable]]} |
436 | 437 |
|
437 | 438 | .. versionadded:: 1.6.0 |
438 | 439 |
|
@@ -763,7 +764,7 @@ def __init__( |
763 | 764 | max_cat_to_onehot: Optional[int] = None, |
764 | 765 | max_cat_threshold: Optional[int] = None, |
765 | 766 | multi_strategy: Optional[str] = None, |
766 | | - eval_metric: Optional[Union[str, List[str], Callable]] = None, |
| 767 | + eval_metric: Optional[Union[str, List[Union[str, Callable]], Callable]] = None, |
767 | 768 | early_stopping_rounds: Optional[int] = None, |
768 | 769 | callbacks: Optional[List[TrainingCallback]] = None, |
769 | 770 | **kwargs: Any, |
@@ -1103,14 +1104,42 @@ def _duplicated(parameter: str) -> None: |
1103 | 1104 |
|
1104 | 1105 | # - configure callable evaluation metric |
1105 | 1106 | metric: Optional[Metric] = None |
| 1107 | + |
| 1108 | + def custom_metric(m: Callable) -> Metric: |
| 1109 | + if self._get_type() == "ranker": |
| 1110 | + wrapped = ltr_metric_decorator(m, self.n_jobs) |
| 1111 | + else: |
| 1112 | + wrapped = _metric_decorator(m) |
| 1113 | + return wrapped |
| 1114 | + |
| 1115 | + def invalid_type(m: Any) -> None: |
| 1116 | + msg = f"Invalid type for the `eval_metric`: {type(m)}" |
| 1117 | + raise TypeError(msg) |
| 1118 | + |
1106 | 1119 | if self.eval_metric is not None: |
1107 | 1120 | if callable(self.eval_metric): |
1108 | | - if self._get_type() == "ranker": |
1109 | | - metric = ltr_metric_decorator(self.eval_metric, self.n_jobs) |
1110 | | - else: |
1111 | | - metric = _metric_decorator(self.eval_metric) |
1112 | | - else: |
| 1121 | + metric = custom_metric(self.eval_metric) |
| 1122 | + elif isinstance(self.eval_metric, str): |
1113 | 1123 | params.update({"eval_metric": self.eval_metric}) |
| 1124 | + else: |
| 1125 | + # A sequence of metrics |
| 1126 | + if not isinstance(self.eval_metric, collections.abc.Sequence): |
| 1127 | + invalid_type(self.eval_metric) |
| 1128 | + # Could be a list of strings or callables |
| 1129 | + builtin_metrics: List[str] = [] |
| 1130 | + for m in self.eval_metric: |
| 1131 | + if callable(m): |
| 1132 | + if metric is not None: |
| 1133 | + raise NotImplementedError( |
| 1134 | + "Using multiple custom metrics is not yet supported." |
| 1135 | + ) |
| 1136 | + metric = custom_metric(m) |
| 1137 | + elif isinstance(m, str): |
| 1138 | + builtin_metrics.append(m) |
| 1139 | + else: |
| 1140 | + invalid_type(m) |
| 1141 | + if builtin_metrics: |
| 1142 | + params.update({"eval_metric": builtin_metrics}) |
1114 | 1143 |
|
1115 | 1144 | if feature_weights is not None: |
1116 | 1145 | _deprecated("feature_weights") |
|
0 commit comments