Skip to content

Commit 222ed16

Browse files
committed
add fit_transform test
1 parent 74177d2 commit 222ed16

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

mne/decoding/tests/test_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_classifier,
2929
is_regressor,
3030
)
31+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
3132
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
3233
from sklearn.model_selection import (
3334
GridSearchCV,
@@ -380,6 +381,10 @@ def test_linearmodel():
380381
wrong_X = rng.rand(n, n_features, 99)
381382
clf.fit(wrong_X, y)
382383

384+
# check fit_transform call
385+
clf = LinearModel(LinearDiscriminantAnalysis())
386+
_ = clf.fit_transform(X, y)
387+
383388
# check categorical target fit in standard linear model with GridSearchCV
384389
parameters = {"kernel": ["linear"], "C": [1, 10]}
385390
clf = LinearModel(
@@ -416,6 +421,7 @@ def test_linearmodel():
416421
wrong_y = rng.rand(n, n_features, 99)
417422
clf.fit(X, wrong_y)
418423

424+
# check that model has to be a predictor
419425
clf = LinearModel(StandardScaler())
420426
with pytest.raises(ValueError, match="classifier or regressor"):
421427
clf.fit(X, Y)

0 commit comments

Comments
 (0)