Skip to content

Commit f38ce7d

Browse files
Genusterlarsoner
andauthored
review suggestions
Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent 029691b commit f38ce7d

5 files changed

Lines changed: 8 additions & 18 deletions

File tree

mne/decoding/_covs_ged.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, in
3030
log_rank=log_rank,
3131
log_ch_type="data",
3232
)
33-
weight = x_class.shape[0]
3433

35-
return cov, weight
34+
return cov, n_channels # the weight here is just the number of channels
3635

3736

3837
def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):

mne/decoding/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
4646
mod_ged_callable : callable
4747
Function used to modify (e.g. sort or normalize) generalized
4848
eigenvalues and eigenvectors.
49-
mod_params : dict
49+
mod_params : dict | None
5050
Parameters passed to mod_ged_callable.
5151
dec_type : "single" | "multi"
5252
When "single" and cov_callable returns > 2 covariances,
@@ -93,7 +93,8 @@ def __init__(
9393
cov_callable,
9494
cov_params,
9595
mod_ged_callable,
96-
mod_params,
96+
*,
97+
mod_params=None,
9798
dec_type="single",
9899
restr_type=None,
99100
R_func=None,
@@ -120,6 +121,7 @@ def fit(self, X, y=None):
120121
covs = np.stack(covs)
121122
self._validate_covariances(covs)
122123
self._validate_covariances([C_ref])
124+
mod_params = self.mod_params if self.mod_params is not None else dict()
123125
if self.dec_type == "single":
124126
if len(covs) > 2:
125127
sample_weights = kwargs["sample_weights"]
@@ -133,7 +135,7 @@ def fit(self, X, y=None):
133135
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
134136

135137
evals, evecs = self.mod_ged_callable(
136-
evals, evecs, covs, **self.mod_params, **kwargs
138+
evals, evecs, covs, **mod_params, **kwargs
137139
)
138140
self.evals_ = evals
139141
self.filters_ = evecs.T
@@ -153,7 +155,7 @@ def fit(self, X, y=None):
153155
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
154156

155157
evals, evecs = self.mod_ged_callable(
156-
evals, evecs, covs, **self.mod_params, **kwargs
158+
evals, evecs, covs, **mod_params, **kwargs
157159
)
158160
all_evals.append(evals)
159161
all_evecs.append(evecs.T)

mne/decoding/csp.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,12 @@ def __init__(
134134
norm_trace=norm_trace,
135135
)
136136

137-
mod_params = dict(evecs_order=component_order)
138137
super().__init__(
139138
n_components=n_components,
140139
cov_callable=_csp_estimate,
141140
cov_params=cov_params,
142141
mod_ged_callable=_csp_mod,
143-
mod_params=mod_params,
142+
mod_params=dict(evecs_order=component_order),
144143
dec_type="single",
145144
restr_type="restricting",
146145
R_func=sum,
@@ -906,13 +905,11 @@ def __init__(
906905
rank=rank,
907906
)
908907

909-
mod_params = dict()
910908
super(CSP, self).__init__(
911909
n_components,
912910
_spoc_estimate,
913911
cov_params,
914912
_spoc_mod,
915-
mod_params,
916913
dec_type="single",
917914
restr_type=None,
918915
R_func=None,

mne/decoding/tests/test_ged.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def test_ged_binary_cov():
187187
cov_callable=_mock_cov_callable,
188188
cov_params=dict(),
189189
mod_ged_callable=_mock_mod_ged_callable,
190-
mod_params=dict(),
191190
dec_type="single",
192191
restr_type="restricting",
193192
R_func=None,
@@ -215,7 +214,6 @@ def test_ged_binary_cov():
215214
cov_callable=_mock_cov_callable,
216215
cov_params=dict(),
217216
mod_ged_callable=_mock_mod_ged_callable,
218-
mod_params=dict(),
219217
dec_type="multi",
220218
restr_type="restricting",
221219
R_func=None,
@@ -245,7 +243,6 @@ def test_ged_multicov():
245243
cov_callable=_mock_cov_callable,
246244
cov_params=dict(),
247245
mod_ged_callable=_mock_mod_ged_callable,
248-
mod_params=dict(),
249246
dec_type="single",
250247
restr_type="restricting",
251248
R_func=None,
@@ -272,7 +269,6 @@ def test_ged_multicov():
272269
cov_callable=_mock_cov_callable,
273270
cov_params=dict(),
274271
mod_ged_callable=_mock_mod_ged_callable,
275-
mod_params=dict(),
276272
dec_type="multi",
277273
restr_type="restricting",
278274
R_func=None,
@@ -299,7 +295,6 @@ def test_ged_multicov():
299295
cov_callable=_mock_cov_callable,
300296
cov_params=dict(cov_method_params=dict(reg="oas"), compute_C_ref=False),
301297
mod_ged_callable=_mock_mod_ged_callable,
302-
mod_params=dict(),
303298
dec_type="single",
304299
restr_type="restricting",
305300
R_func=None,
@@ -317,7 +312,6 @@ def test_ged_invalid_cov():
317312
cov_callable=_mock_cov_callable,
318313
cov_params=dict(),
319314
mod_ged_callable=_mock_mod_ged_callable,
320-
mod_params=dict(),
321315
dec_type="single",
322316
restr_type=None,
323317
R_func=None,

mne/preprocessing/xdawn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,11 @@ def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None
265265

266266
cov_params = dict(reg=reg, cov_method_params=method_params, R=signal_cov)
267267

268-
mod_params = dict()
269268
super().__init__(
270269
n_components,
271270
_xdawn_estimate,
272271
cov_params,
273272
_xdawn_mod,
274-
mod_params,
275273
dec_type="multi",
276274
restr_type=None,
277275
R_func=None,

0 commit comments

Comments
 (0)