diff --git a/imbens/ensemble/base.py b/imbens/ensemble/base.py index 5e67ae0..f63cb67 100644 --- a/imbens/ensemble/base.py +++ b/imbens/ensemble/base.py @@ -42,8 +42,13 @@ check_random_state, column_or_1d, has_fit_parameter, - validate_data, ) +try: + from sklearn.utils.validation import validate_data + HAS_VALIDATE_DATA = True +except ImportError: + HAS_VALIDATE_DATA = False + TRAINING_LOG_HEAD_TITLES = { "iter": "#Estimators", @@ -516,6 +521,27 @@ def _validate_y(self, y): self.n_classes_ = len(self.classes_) return y + + def _validate_data(self, X="no_validation", y="no_validation", reset=True, validate_separately=False, **check_params): + """Cross-compatible _validate_data for scikit-learn < 1.6.0 and >= 1.6.0""" + if HAS_VALIDATE_DATA: + return validate_data( + self, + X=X, + y=y, + reset=reset, + validate_separately=validate_separately, + **check_params + ) + else: + return super()._validate_data( + X=X, + y=y, + reset=reset, + validate_separately=validate_separately, + **check_params + ) + def _validate_estimator(self, default): """Check the estimator, sampler and the n_estimator attribute. @@ -572,7 +598,7 @@ def fit(self, X, y, *, sample_weight=None, **kwargs): self.random_state = check_random_state(self.random_state) # Convert data (X is required to be 2d and indexable) - X, y = validate_data(self, X, y, **self.check_x_y_args) + X, y = self._validate_data(X=X, y=y, **self.check_x_y_args) if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float64) sample_weight /= sample_weight.sum() @@ -704,9 +730,8 @@ def decision_function(self, X): check_is_fitted(self) # Check data - X = validate_data( - self, - X, + X = self._validate_data( + X=X, accept_sparse=["csr", "csc"], dtype=None, ensure_all_finite=False,