|
11 | 11 | from sklearn.preprocessing import StandardScaler |
12 | 12 | from sklearn.utils.multiclass import check_classification_targets |
13 | 13 | from sklearn.utils.validation import ( |
14 | | - check_X_y, |
15 | | - check_array, |
16 | 14 | check_is_fitted, |
17 | 15 | _check_sample_weight, |
| 16 | + validate_data, |
18 | 17 | ) |
| 18 | +from sklearn.utils._tags import (Tags, ClassifierTags, TargetTags, InputTags) |
19 | 19 | from .Splitter import Splitter, Snode, Siterator |
20 | 20 | from ._version import __version__ |
21 | 21 |
|
22 | 22 |
|
23 | | -class Stree(BaseEstimator, ClassifierMixin): |
| 23 | +class Stree(ClassifierMixin, BaseEstimator): |
24 | 24 | """ |
25 | 25 | Estimator that is based on binary trees of svm nodes |
26 | 26 | can deal with sample_weights in predict, used in boosting sklearn methods |
@@ -179,15 +179,32 @@ def __call__(self) -> None: |
179 | 179 | ensembles""" |
180 | 180 | pass |
181 | 181 |
|
182 | | - def _more_tags(self) -> dict: |
183 | | - """Required by sklearn to supply features of the classifier |
184 | | - make mandatory the labels array |
185 | | -
|
186 | | - :return: the tag required |
187 | | - :rtype: dict |
188 | | - """ |
189 | | - return {"requires_y": True} |
190 | | - |
| 182 | + def __sklearn_tags__(self): |
| 183 | + return Tags( |
| 184 | + estimator_type="classifier", |
| 185 | + target_tags=TargetTags( |
| 186 | + required=True, |
| 187 | + multi_output=False, |
| 188 | + ), |
| 189 | + classifier_tags=ClassifierTags( |
| 190 | + multi_class=True, |
| 191 | + multi_label=False, |
| 192 | + poor_score=False, |
| 193 | + ), |
| 194 | + input_tags=InputTags( |
| 195 | + sparse=False, |
| 196 | + one_d_array=False, |
| 197 | + two_d_array=True, |
| 198 | + three_d_array=False, |
| 199 | + categorical=True, |
| 200 | + string=True, |
| 201 | + pairwise=False, |
| 202 | + ), |
| 203 | + requires_fit=True, |
| 204 | + array_api_support=False, |
| 205 | + non_deterministic=False, |
| 206 | + _skip_test=False, |
| 207 | + ) |
191 | 208 | def fit( |
192 | 209 | self, |
193 | 210 | X: np.ndarray, |
@@ -245,7 +262,7 @@ def fit( |
245 | 262 | if self.kernel not in kernels: |
246 | 263 | raise ValueError(f"Kernel {self.kernel} not in {kernels}") |
247 | 264 | check_classification_targets(y) |
248 | | - X, y = check_X_y(X, y) |
| 265 | + X, y = validate_data(self, X, y) |
249 | 266 | sample_weight = _check_sample_weight( |
250 | 267 | sample_weight, X, dtype=np.float64 |
251 | 268 | ) |
@@ -435,12 +452,7 @@ def check_predict(self, X) -> np.array: |
435 | 452 | """ |
436 | 453 | check_is_fitted(self, ["tree_"]) |
437 | 454 | # Input validation |
438 | | - X = check_array(X) |
439 | | - if X.shape[1] != self.n_features_: |
440 | | - raise ValueError( |
441 | | - f"Expected {self.n_features_} features but got " |
442 | | - f"({X.shape[1]})" |
443 | | - ) |
| 455 | + X = validate_data(self, X, reset=False) |
444 | 456 | return X |
445 | 457 |
|
446 | 458 | def predict_proba(self, X: np.array) -> np.array: |
|
0 commit comments