Skip to content

Commit 3986c99

Browse files
committed
fix multiplication order in original SSD
1 parent 89fb141 commit 3986c99

3 files changed

Lines changed: 7 additions & 23 deletions

File tree

mne/decoding/base.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,8 @@ def fit(self, X, y=None):
129129
else:
130130
S = covs[0]
131131
R = covs[1]
132-
if self.restr_type == "ssd":
133-
mult_order = "ssd"
134-
else:
135-
mult_order = None
136132
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
137-
evals, evecs = _smart_ged(
138-
S, R, restr_mat, R_func=self.R_func, mult_order=mult_order
139-
)
133+
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
140134

141135
evals, evecs = self.mod_ged_callable(
142136
evals, evecs, covs, **self.mod_params, **kwargs
@@ -151,18 +145,12 @@ def fit(self, X, y=None):
151145
elif self.dec_type == "multi":
152146
self.classes_ = np.unique(y)
153147
R = covs[-1]
154-
if self.restr_type == "ssd":
155-
mult_order = "ssd"
156-
else:
157-
mult_order = None
158148
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
159149
all_evals, all_evecs, all_patterns = list(), list(), list()
160150
for i in range(len(self.classes_)):
161151
S = covs[i]
162152

163-
evals, evecs = _smart_ged(
164-
S, R, restr_mat, R_func=self.R_func, mult_order=mult_order
165-
)
153+
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
166154

167155
evals, evecs = self.mod_ged_callable(
168156
evals, evecs, covs, **self.mod_params, **kwargs

mne/decoding/ged.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _handle_restr_mat(C_ref, restr_type, info, rank):
3434
return restr_mat
3535

3636

37-
def _smart_ged(S, R, restr_mat=None, R_func=None, mult_order=None):
37+
def _smart_ged(S, R, restr_mat=None, R_func=None):
3838
"""Perform smart generalized eigenvalue decomposition (GED) of S and R.
3939
4040
If restr_mat is provided S and R will be restricted to the principal subspace
@@ -47,12 +47,8 @@ def _smart_ged(S, R, restr_mat=None, R_func=None, mult_order=None):
4747
evals, evecs = scipy.linalg.eigh(S, R)
4848
return evals, evecs
4949

50-
if mult_order == "ssd":
51-
S_restr = restr_mat @ (S @ restr_mat.T)
52-
R_restr = restr_mat @ (R @ restr_mat.T)
53-
else:
54-
S_restr = restr_mat @ S @ restr_mat.T
55-
R_restr = restr_mat @ R @ restr_mat.T
50+
S_restr = restr_mat @ S @ restr_mat.T
51+
R_restr = restr_mat @ R @ restr_mat.T
5652
if R_func is not None:
5753
R_restr = R_func([S_restr, R_restr])
5854
evals, evecs_restr = scipy.linalg.eigh(S_restr, R_restr)

mne/decoding/ssd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,6 @@ def _dimensionality_reduction(cov_signal, cov_noise, info, rank):
460460
logger.info("Preserving covariance rank (%i)", rank)
461461

462462
# project covariance matrices to rank subspace
463-
cov_signal = np.matmul(rank_proj.T, np.matmul(cov_signal, rank_proj))
464-
cov_noise = np.matmul(rank_proj.T, np.matmul(cov_noise, rank_proj))
463+
cov_signal = rank_proj.T @ cov_signal @ rank_proj
464+
cov_noise = rank_proj.T @ cov_noise @ rank_proj
465465
return cov_signal, cov_noise, rank_proj

0 commit comments

Comments
 (0)