Skip to content

Commit 969a73e

Browse files
committed
fix SSD's filters_ shape inconsistency
1 parent 87a2466 commit 969a73e

1 file changed

Lines changed: 6 additions & 7 deletions

File tree

mne/decoding/ssd.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ class SSD(_GEDTransformer):
9595
9696
Attributes
9797
----------
98-
filters_ : array, shape (n_channels, n_components)
98+
filters_ : array, shape (n_channels or less, n_channels)
9999
The spatial filters to be multiplied with the signal.
100-
patterns_ : array, shape (n_components, n_channels)
100+
patterns_ : array, shape (n_channels or less, n_channels)
101101
The patterns for reconstructing the signal from the filtered data.
102102
103103
References
@@ -272,13 +272,12 @@ def fit(self, X, y=None):
272272
# project back to sensor space
273273
self.filters_ = np.matmul(rank_proj, eigvects_[:, ix])
274274
self.patterns_ = np.linalg.pinv(self.filters_)
275+
# Need to unify with Xdawn and CSP as they store it as (n_components, n_chs)
276+
self.filters_ = self.filters_.T
275277

276278
old_filters = self.filters_
277279
old_patterns = self.patterns_
278280
super().fit(X, y)
279-
# SSD, as opposed to CSP and Xdawn stores filters as (n_chs, n_components)
280-
# So need to transpose into (n_components, n_chs)
281-
self.filters_ = self.filters_.T
282281

283282
np.testing.assert_allclose(self.eigvals_, self.evals_)
284283
np.testing.assert_allclose(old_filters, self.filters_)
@@ -287,7 +286,7 @@ def fit(self, X, y=None):
287286
# We assume that ordering by spectral ratio is more important
288287
# than the initial ordering. This ordering should be also learned when
289288
# fitting.
290-
X_ssd = self.filters_.T @ X[..., self.picks_, :]
289+
X_ssd = self.filters_ @ X[..., self.picks_, :]
291290
sorter_spec = slice(None)
292291
if self.sort_by_spectral_ratio:
293292
_, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd)
@@ -315,7 +314,7 @@ def transform(self, X):
315314
if self.return_filtered:
316315
X_aux = X[..., self.picks_, :]
317316
X = filter_data(X_aux, self.sfreq_, **self.filt_params_signal)
318-
X_ssd = self.filters_.T @ X[..., self.picks_, :]
317+
X_ssd = self.filters_ @ X[..., self.picks_, :]
319318
X_ssd = X_ssd[..., self.sorter_spec_, :][..., : self.n_components, :]
320319
return X_ssd
321320

0 commit comments

Comments
 (0)