Skip to content

Commit 40dec84

Browse files
committed
more old sklearn fixes
1 parent 175f6c6 commit 40dec84

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

mne/decoding/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
381381
"predict",
382382
"predict_proba",
383383
"predict_log_proba",
384+
"_estimator_type", # remove after sklearn 1.6
384385
"decision_function",
385386
"score",
386387
"classes_",
@@ -428,14 +429,23 @@ def __getattr__(self, attr):
428429
def _fit_transform(self, X, y):
429430
return self.fit(X, y).transform(X)
430431

431-
def _validate_params(self):
432-
is_predictor = is_regressor(self._orig_model) or is_classifier(self._orig_model)
432+
def _validate_params(self, X):
433+
model = self._orig_model
434+
if isinstance(model, MetaEstimatorMixin):
435+
model = model.estimator
436+
is_predictor = is_regressor(model) or is_classifier(model)
433437
if not is_predictor:
434438
raise ValueError(
435439
"Linear model should be a supervised predictor "
436440
"(classifier or regressor)"
437441
)
438442

443+
# For sklearn < 1.6
444+
try:
445+
self._check_n_features(X, reset=True)
446+
except AttributeError:
447+
pass
448+
439449
def fit(self, X, y, **fit_params):
440450
"""Estimate the coefficients of the linear model.
441451
@@ -456,7 +466,7 @@ def fit(self, X, y, **fit_params):
456466
self : instance of LinearModel
457467
Returns the modified instance.
458468
"""
459-
self._validate_params()
469+
self._validate_params(X)
460470
X, y = validate_data(self, X, y, multi_output=True)
461471

462472
# fit the Model

mne/decoding/tests/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_get_coef():
9999
"""Test getting linear coefficients (filters/patterns) from estimators."""
100100
lm_classification = LinearModel(LogisticRegression(solver="liblinear"))
101101
assert hasattr(lm_classification, "__sklearn_tags__")
102-
if check_version("sklearn", "1.4"):
102+
if check_version("sklearn", "1.6"):
103103
print(lm_classification.__sklearn_tags__())
104104
assert is_classifier(lm_classification.model)
105105
assert is_classifier(lm_classification)

0 commit comments

Comments
 (0)