Skip to content

Commit d720dc2

Browse files
committed
add some minor tests and fixes
1 parent 3266d16 commit d720dc2

5 files changed

Lines changed: 54 additions & 5 deletions

File tree

mne/decoding/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def get_spatial_filter(self, info):
215215
check_is_fitted(self, ["filters_", "patterns_", "evals_"])
216216
sp_filter = SpatialFilter(
217217
info,
218-
evecs=self.filters_,
218+
filters=self.filters_,
219219
evals=self.evals_,
220220
patterns=self.patterns_,
221221
patterns_method="pinv",
@@ -463,7 +463,7 @@ def get_spatial_filter(self, info):
463463
check_is_fitted(self, ["filters_", "patterns_"])
464464
sp_filter = SpatialFilter(
465465
info,
466-
evecs=self.filters_.T,
466+
filters=self.filters_,
467467
patterns=self.patterns_,
468468
patterns_method="haufe",
469469
)

mne/decoding/tests/test_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,14 @@ def test_linearmodel():
418418
wrong_y = rng.rand(n, n_features, 99)
419419
clf.fit(X, wrong_y)
420420

421+
# check get_spatial_filter
422+
info = create_info(n_features, 1000.0, "eeg")
423+
sp_filter = clf.get_spatial_filter(info)
424+
assert sp_filter.patterns_method == "haufe"
425+
np.testing.assert_array_equal(sp_filter.filters, clf.filters_)
426+
np.testing.assert_array_equal(sp_filter.patterns, clf.patterns_)
427+
assert sp_filter.evals is None
428+
421429

422430
def test_cross_val_multiscore():
423431
"""Test cross_val_multiscore for computing scores on decoding over time."""

mne/decoding/tests/test_ged.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_sklearn_compliance(estimator, check):
145145
check(estimator)
146146

147147

148-
def _get_X_y(event_id):
148+
def _get_X_y(event_id, return_info=False):
149149
raw = read_raw(raw_fname, preload=False)
150150
events = read_events(event_name)
151151
picks = pick_types(
@@ -166,6 +166,8 @@ def _get_X_y(event_id):
166166
)
167167
X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT"))
168168
y = epochs.events[:, -1]
169+
if return_info:
170+
return X, y, epochs.info
169171
return X, y
170172

171173

@@ -386,3 +388,22 @@ def test__no_op_mod():
386388
assert evals is evals_no_op
387389
assert evecs is evecs_no_op
388390
assert sorter_no_op is None
391+
392+
393+
def test_get_spatial_filter():
394+
"""Test instantiation of spatial filter."""
395+
event_id = dict(aud_l=1, vis_l=3)
396+
X, y, info = _get_X_y(event_id, return_info=True)
397+
398+
ged = _GEDTransformer(
399+
n_components=4,
400+
cov_callable=_mock_cov_callable,
401+
mod_ged_callable=_mock_mod_ged_callable,
402+
restr_type="restricting",
403+
)
404+
ged.fit(X, y)
405+
sp_filter = ged.get_spatial_filter(info)
406+
assert sp_filter.patterns_method == "pinv"
407+
np.testing.assert_array_equal(sp_filter.filters, ged.filters_)
408+
np.testing.assert_array_equal(sp_filter.patterns, ged.patterns_)
409+
np.testing.assert_array_equal(sp_filter.evals, ged.evals_)

mne/viz/decoding/ged.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,12 @@ def __init__(
294294
self.patterns = patterns
295295
self.patterns_method = patterns_method
296296

297-
if n_comps > n_chs:
297+
# In case of multi-target classification in LinearModel
298+
# number of targets can be greater than number of channels.
299+
if patterns_method != "haufe" and n_comps > n_chs:
298300
raise ValueError(
299301
"Number of components can't be greater "
300-
"than number of channels in filters,"
302+
"than number of channels in filters, "
301303
"perhaps the provided matrix is transposed?"
302304
)
303305
if self.filters.shape != self.patterns.shape:

mne/viz/decoding/tests/test_ged.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Authors: The MNE-Python contributors.
2+
# License: BSD-3-Clause
3+
# Copyright the MNE-Python contributors.
4+
5+
import numpy as np
6+
import pytest
7+
8+
from mne import create_info
9+
from mne.viz import SpatialFilter
10+
11+
12+
def test_plot_scree_raises():
13+
"""Tests that plot_scree can't plot without evals."""
14+
info = create_info(2, 1000.0, "eeg")
15+
filters = np.array([[1, 2], [3, 4]])
16+
sp_filter = SpatialFilter(info, filters, evals=None)
17+
with pytest.raises(AttributeError):
18+
sp_filter.plot_scree()

0 commit comments

Comments
 (0)