Skip to content

Commit 5266372

Browse files
committed
use mne's pinv in SSD and Xdawn instead of np.linalg.pinv
1 parent 969a73e commit 5266372

3 files changed

Lines changed: 5 additions & 7 deletions

File tree

mne/decoding/base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,7 @@ def fit(self, X, y=None):
143143
evals, evecs = mod_ged_callable(evals, evecs, covs, **kwargs)
144144
self.evals_ = evals
145145
self.filters_ = evecs.T
146-
if self.restr_type == "ssd":
147-
self.patterns_ = np.linalg.pinv(evecs)
148-
else:
149-
self.patterns_ = pinv(evecs)
146+
self.patterns_ = pinv(evecs)
150147

151148
elif self.dec_type == "multi":
152149
self.classes_ = np.unique(y)
@@ -161,7 +158,7 @@ def fit(self, X, y=None):
161158
evals, evecs = mod_ged_callable(evals, evecs, covs, **kwargs)
162159
all_evals.append(evals)
163160
all_evecs.append(evecs.T)
164-
all_patterns.append(np.linalg.pinv(evecs))
161+
all_patterns.append(pinv(evecs))
165162
self.evals_ = np.array(all_evals)
166163
self.filters_ = np.array(all_evecs)
167164
self.patterns_ = np.array(all_patterns)

mne/decoding/ssd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_verbose_safe_false,
2222
fill_doc,
2323
logger,
24+
pinv,
2425
)
2526
from ._covs_ged import _ssd_estimate
2627
from ._mod_ged import _ssd_mod
@@ -271,7 +272,7 @@ def fit(self, X, y=None):
271272
self.eigvals_ = eigvals_[ix]
272273
# project back to sensor space
273274
self.filters_ = np.matmul(rank_proj, eigvects_[:, ix])
274-
self.patterns_ = np.linalg.pinv(self.filters_)
275+
self.patterns_ = pinv(self.filters_)
275276
# Need to unify with Xdawn and CSP as they store it as (n_components, n_chs)
276277
self.filters_ = self.filters_.T
277278

mne/preprocessing/xdawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _fit_xdawn(
208208
)
209209
evecs = evecs[:, np.argsort(evals)[::-1]] # sort eigenvectors
210210
evecs /= np.apply_along_axis(np.linalg.norm, 0, evecs)
211-
_patterns = np.linalg.pinv(evecs.T)
211+
_patterns = pinv(evecs.T)
212212
filters.append(evecs[:, :n_components].T)
213213
patterns.append(_patterns[:, :n_components].T)
214214

0 commit comments

Comments
 (0)