Skip to content

Commit 74177d2

Browse files
committed
move predict and classes_ back to the wrapped attrs
1 parent 9f2354f commit 74177d2

2 files changed

Lines changed: 6 additions & 25 deletions

File tree

mne/decoding/base.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,12 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
378378
_model_attr_wrap = (
379379
"transform",
380380
"fit_transform",
381+
"predict",
381382
"predict_proba",
382383
"predict_log_proba",
383384
"decision_function",
384385
"score",
386+
"classes_",
385387
)
386388

387389
def __init__(self, model=None):
@@ -472,22 +474,6 @@ def fit(self, X, y, **fit_params):
472474

473475
return self
474476

475-
def predict(self, X):
476-
"""Predict class labels for X using fitted linear model.
477-
478-
Parameters
479-
----------
480-
X : array, shape (n_samples, n_features)
481-
The data matrix for which we want to get the predictions.
482-
483-
Returns
484-
-------
485-
y_pred : array, shape (n_samples,)
486-
Vector containing the class labels for each sample.
487-
"""
488-
check_is_fitted(self)
489-
return self.model_.predict(X)
490-
491477
@property
492478
def filters_(self):
493479
check_is_fitted(self)
@@ -506,15 +492,6 @@ def filters_(self):
506492
filters = filters[0]
507493
return filters
508494

509-
@property
510-
def classes_(self):
511-
check_is_fitted(self)
512-
if is_regressor(self.model_):
513-
raise AttributeError("Regressors don't have the 'classes_' attribute")
514-
elif hasattr(self.model_, "classes_"):
515-
return self.model_.classes_
516-
return None
517-
518495
# XXX Remove this property after 'model' warning cycle
519496
@property
520497
def model(self):

mne/decoding/tests/test_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,10 @@ def test_linearmodel():
416416
wrong_y = rng.rand(n, n_features, 99)
417417
clf.fit(X, wrong_y)
418418

419+
clf = LinearModel(StandardScaler())
420+
with pytest.raises(ValueError, match="classifier or regressor"):
421+
clf.fit(X, Y)
422+
419423

420424
def test_cross_val_multiscore():
421425
"""Test cross_val_multiscore for computing scores on decoding over time."""

0 commit comments

Comments
 (0)