Skip to content

Commit 3d1d1d2

Browse files
authored
fix: Subclass sklearn.model_selection._RepeatedSplits and BaseShuffleSplit from BaseCrossValidator (#349)
1 parent 5ef1657 commit 3d1d1d2

File tree

15 files changed

+33
-44
lines changed

15 files changed

+33
-44
lines changed

stubs/sklearn/calibration.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ from .base import (
2222
)
2323
from .isotonic import IsotonicRegression
2424
from .model_selection import BaseCrossValidator, check_cv as check_cv, cross_val_predict as cross_val_predict
25-
from .model_selection._split import BaseShuffleSplit
2625
from .preprocessing import LabelEncoder as LabelEncoder, label_binarize as label_binarize
2726
from .svm import LinearSVC as LinearSVC
2827
from .utils import check_matplotlib_support as check_matplotlib_support, column_or_1d as column_or_1d, indexable as indexable
@@ -51,7 +50,7 @@ class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator)
5150
estimator: None | BaseEstimator = None,
5251
*,
5352
method: Literal["sigmoid", "isotonic"] = "sigmoid",
54-
cv: int | BaseCrossValidator | Iterable | None | str | BaseShuffleSplit = None,
53+
cv: int | BaseCrossValidator | Iterable | None | str = None,
5554
n_jobs: None | Int = None,
5655
ensemble: bool = True,
5756
base_estimator: str | BaseEstimator = "deprecated",

stubs/sklearn/covariance/_graph_lasso.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ from .._typing import ArrayLike, Float, Int, MatrixLike
88
from ..exceptions import ConvergenceWarning as ConvergenceWarning
99
from ..linear_model import lars_path_gram as lars_path_gram
1010
from ..model_selection import BaseCrossValidator, check_cv as check_cv, cross_val_score as cross_val_score
11-
from ..model_selection._split import BaseShuffleSplit
1211
from ..utils._param_validation import Interval as Interval, StrOptions as StrOptions
1312
from ..utils.parallel import Parallel as Parallel, delayed as delayed
1413
from ..utils.validation import check_random_state as check_random_state, check_scalar as check_scalar
@@ -117,7 +116,7 @@ class GraphicalLassoCV(BaseGraphicalLasso):
117116
*,
118117
alphas: ArrayLike | int = 4,
119118
n_refinements: Int = 4,
120-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
119+
cv: int | BaseCrossValidator | Iterable | None = None,
121120
tol: Float = 1e-4,
122121
enet_tol: Float = 1e-4,
123122
max_iter: Int = 100,

stubs/sklearn/ensemble/_stacking.pyi

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ from ..exceptions import NotFittedError as NotFittedError
1919
from ..linear_model._logistic import LogisticRegression
2020
from ..linear_model._ridge import RidgeCV
2121
from ..model_selection import BaseCrossValidator, check_cv as check_cv, cross_val_predict as cross_val_predict
22-
from ..model_selection._split import BaseShuffleSplit
2322
from ..pipeline import Pipeline
2423
from ..preprocessing import LabelEncoder as LabelEncoder
2524
from ..utils import Bunch
@@ -78,7 +77,7 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
7877
estimators: Sequence[tuple[str, BaseEstimator]],
7978
final_estimator: None | BaseEstimator | LogisticRegression = None,
8079
*,
81-
cv: int | BaseCrossValidator | Iterable | None | str | BaseShuffleSplit = None,
80+
cv: int | BaseCrossValidator | Iterable | None | str = None,
8281
stack_method: Literal["auto", "predict_proba", "decision_function", "predict"] = "auto",
8382
n_jobs: None | Int = None,
8483
passthrough: bool = False,
@@ -108,7 +107,7 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
108107
estimators: Sequence[tuple[str, BaseEstimator]] | list[tuple[str, Pipeline]],
109108
final_estimator: None | BaseEstimator | RidgeCV = None,
110109
*,
111-
cv: int | BaseCrossValidator | Iterable | None | str | BaseShuffleSplit = None,
110+
cv: int | BaseCrossValidator | Iterable | None | str = None,
112111
n_jobs: None | Int = None,
113112
passthrough: bool = False,
114113
verbose: Int = 0,

stubs/sklearn/feature_selection/_rfe.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ from ..base import BaseEstimator, MetaEstimatorMixin, clone as clone, is_classif
99
from ..linear_model._logistic import LogisticRegression
1010
from ..metrics import check_scoring as check_scoring
1111
from ..model_selection import BaseCrossValidator, check_cv as check_cv
12-
from ..model_selection._split import BaseShuffleSplit
1312
from ..utils._param_validation import HasMethods as HasMethods, Interval as Interval
1413
from ..utils.metaestimators import available_if as available_if
1514
from ..utils.parallel import Parallel as Parallel, delayed as delayed
@@ -73,7 +72,7 @@ class RFECV(RFE):
7372
*,
7473
step: float = 1,
7574
min_features_to_select: Int = 1,
76-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
75+
cv: int | BaseCrossValidator | Iterable | None = None,
7776
scoring: None | str | Callable = None,
7877
verbose: Int = 0,
7978
n_jobs: None | int = None,

stubs/sklearn/feature_selection/_sequential.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ from .._typing import ArrayLike, Float, Int, MatrixLike
77
from ..base import BaseEstimator, MetaEstimatorMixin, clone as clone
88
from ..metrics import get_scorer_names as get_scorer_names
99
from ..model_selection import BaseCrossValidator, cross_val_score as cross_val_score
10-
from ..model_selection._split import BaseShuffleSplit
1110
from ..utils._param_validation import HasMethods as HasMethods, Hidden as Hidden, Interval as Interval, StrOptions as StrOptions
1211
from ..utils.validation import check_is_fitted as check_is_fitted
1312
from ._base import SelectorMixin
@@ -34,7 +33,7 @@ class SequentialFeatureSelector(SelectorMixin, MetaEstimatorMixin, BaseEstimator
3433
tol: None | Float = None,
3534
direction: Literal["forward", "backward"] = "forward",
3635
scoring: None | str | Callable = None,
37-
cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator = 5,
36+
cv: Iterable | int | BaseCrossValidator = 5,
3837
n_jobs: None | Int = None,
3938
) -> None: ...
4039
def fit(

stubs/sklearn/linear_model/_coordinate_descent.pyi

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ from scipy.sparse._coo import coo_matrix
1313
from .._typing import ArrayLike, Float, Int, MatrixLike
1414
from ..base import MultiOutputMixin, RegressorMixin
1515
from ..model_selection import BaseCrossValidator, check_cv as check_cv
16-
from ..model_selection._split import BaseShuffleSplit
1716
from ..utils import check_array as check_array, check_scalar as check_scalar
1817
from ..utils._param_validation import Interval as Interval, StrOptions as StrOptions
1918
from ..utils.extmath import safe_sparse_dot as safe_sparse_dot
@@ -206,7 +205,7 @@ class LassoCV(RegressorMixin, LinearModelCV):
206205
max_iter: Int = 1000,
207206
tol: Float = 1e-4,
208207
copy_X: bool = True,
209-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
208+
cv: int | BaseCrossValidator | Iterable | None = None,
210209
verbose: int | bool = False,
211210
n_jobs: None | Int = None,
212211
positive: bool = False,
@@ -241,7 +240,7 @@ class ElasticNetCV(RegressorMixin, LinearModelCV):
241240
precompute: Literal["auto"] | MatrixLike | bool = "auto",
242241
max_iter: Int = 1000,
243242
tol: Float = 1e-4,
244-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
243+
cv: int | BaseCrossValidator | Iterable | None = None,
245244
copy_X: bool = True,
246245
verbose: int | bool = 0,
247246
n_jobs: None | Int = None,
@@ -333,7 +332,7 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV):
333332
fit_intercept: bool = True,
334333
max_iter: Int = 1000,
335334
tol: Float = 1e-4,
336-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
335+
cv: int | BaseCrossValidator | Iterable | None = None,
337336
copy_X: bool = True,
338337
verbose: int | bool = 0,
339338
n_jobs: None | Int = None,
@@ -370,7 +369,7 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV):
370369
max_iter: Int = 1000,
371370
tol: Float = 1e-4,
372371
copy_X: bool = True,
373-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
372+
cv: int | BaseCrossValidator | Iterable | None = None,
374373
verbose: int | bool = False,
375374
n_jobs: None | Int = None,
376375
random_state: RandomState | None | Int = None,

stubs/sklearn/linear_model/_least_angle.pyi

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ from .._typing import ArrayLike, Float, Int, MatrixLike
1111
from ..base import MultiOutputMixin, RegressorMixin
1212
from ..exceptions import ConvergenceWarning as ConvergenceWarning
1313
from ..model_selection import BaseCrossValidator, check_cv as check_cv
14-
from ..model_selection._split import BaseShuffleSplit
1514
from ..utils import arrayfuncs as arrayfuncs, as_float_array as as_float_array, check_random_state as check_random_state
1615
from ..utils._param_validation import Hidden as Hidden, Interval as Interval, StrOptions as StrOptions
1716
from ..utils.parallel import Parallel as Parallel, delayed as delayed
@@ -160,7 +159,7 @@ class LarsCV(Lars):
160159
max_iter: Int = 500,
161160
normalize: str | bool = "deprecated",
162161
precompute: Literal["auto"] | ArrayLike | bool = "auto",
163-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
162+
cv: int | BaseCrossValidator | Iterable | None = None,
164163
max_n_alphas: Int = 1000,
165164
n_jobs: None | int = None,
166165
eps: Float = ...,
@@ -193,7 +192,7 @@ class LassoLarsCV(LarsCV):
193192
max_iter: Int = 500,
194193
normalize: str | bool = "deprecated",
195194
precompute: Literal["auto"] | bool = "auto",
196-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
195+
cv: int | BaseCrossValidator | Iterable | None = None,
197196
max_n_alphas: Int = 1000,
198197
n_jobs: None | int = None,
199198
eps: Float = ...,

stubs/sklearn/linear_model/_logistic.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ from .._loss.loss import HalfBinomialLoss as HalfBinomialLoss, HalfMultinomialLo
1010
from .._typing import ArrayLike, Float, Int, MatrixLike
1111
from ..metrics import get_scorer as get_scorer, get_scorer_names as get_scorer_names
1212
from ..model_selection import BaseCrossValidator, check_cv as check_cv
13-
from ..model_selection._split import BaseShuffleSplit
1413
from ..preprocessing import LabelBinarizer as LabelBinarizer, LabelEncoder as LabelEncoder
1514
from ..utils import (
1615
check_array as check_array,
@@ -108,7 +107,7 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima
108107
*,
109108
Cs: Sequence[float] | int = 10,
110109
fit_intercept: bool = True,
111-
cv: int | None | BaseShuffleSplit | BaseCrossValidator = None,
110+
cv: int | None | BaseCrossValidator = None,
112111
dual: bool = False,
113112
penalty: Literal["l1", "l2", "elasticnet"] = "l2",
114113
scoring: None | str | Callable = None,

stubs/sklearn/linear_model/_omp.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ from scipy.linalg.lapack import get_lapack_funcs as get_lapack_funcs
99
from .._typing import ArrayLike, Float, Int, MatrixLike
1010
from ..base import MultiOutputMixin, RegressorMixin
1111
from ..model_selection import BaseCrossValidator, check_cv as check_cv
12-
from ..model_selection._split import BaseShuffleSplit
1312
from ..utils import as_float_array as as_float_array, check_array as check_array
1413
from ..utils._param_validation import Hidden as Hidden, Interval as Interval, StrOptions as StrOptions
1514
from ..utils.parallel import Parallel as Parallel, delayed as delayed
@@ -90,7 +89,7 @@ class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
9089
fit_intercept: bool = True,
9190
normalize: str | bool = "deprecated",
9291
max_iter: None | Int = None,
93-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
92+
cv: int | BaseCrossValidator | Iterable | None = None,
9493
n_jobs: None | Int = None,
9594
verbose: int | bool = False,
9695
) -> None: ...

stubs/sklearn/linear_model/_ridge.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ from ..base import MultiOutputMixin, RegressorMixin, is_classifier as is_classif
1818
from ..exceptions import ConvergenceWarning as ConvergenceWarning
1919
from ..metrics import check_scoring as check_scoring, get_scorer_names as get_scorer_names
2020
from ..model_selection import BaseCrossValidator, GridSearchCV as GridSearchCV
21-
from ..model_selection._split import BaseShuffleSplit
2221
from ..preprocessing import LabelBinarizer as LabelBinarizer
2322
from ..utils import (
2423
check_array as check_array,
@@ -237,7 +236,7 @@ class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
237236
*,
238237
fit_intercept: bool = True,
239238
scoring: None | str | Callable = None,
240-
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
239+
cv: int | BaseCrossValidator | Iterable | None = None,
241240
class_weight: None | Mapping | str = None,
242241
store_cv_values: bool = False,
243242
) -> None: ...

0 commit comments

Comments
 (0)