Skip to content

Commit 11c31f7

Browse files
committed
address Eric's suggestions
1 parent f38ce7d commit 11c31f7

6 files changed

Lines changed: 36 additions & 66 deletions

File tree

mne/decoding/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ..parallel import parallel_func
2626
from ..utils import _pl, logger, pinv, verbose, warn
27-
from .ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged
27+
from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged
2828
from .transformer import MNETransformerMixin
2929

3030

@@ -40,9 +40,9 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
4040
The number of spatial filters to decompose M/EEG signals.
4141
cov_callable : callable
4242
Function used to estimate covariances and reference matrix (C_ref) from the
43-
data.
44-
cov_params : dict
45-
Parameters passed to cov_callable.
43+
data. It should accept only X and y as arguments and return covs, C_ref, info,
44+
rank and additional kwargs passed further to mod_ged_callable.
45+
C_ref, info, rank can be None, while kwargs can be empty dict.
4646
mod_ged_callable : callable
4747
Function used to modify (e.g. sort or normalize) generalized
4848
eigenvalues and eigenvectors.
@@ -91,7 +91,6 @@ def __init__(
9191
self,
9292
n_components,
9393
cov_callable,
94-
cov_params,
9594
mod_ged_callable,
9695
*,
9796
mod_params=None,
@@ -101,7 +100,6 @@ def __init__(
101100
):
102101
self.n_components = n_components
103102
self.cov_callable = cov_callable
104-
self.cov_params = cov_params
105103
self.mod_ged_callable = mod_ged_callable
106104
self.mod_params = mod_params
107105
self.dec_type = dec_type
@@ -117,7 +115,7 @@ def fit(self, X, y=None):
117115
return_y=True,
118116
atleast_3d=False if self.restr_type == "ssd" else True,
119117
)
120-
covs, C_ref, info, rank, kwargs = self.cov_callable(X, y, **self.cov_params)
118+
covs, C_ref, info, rank, kwargs = self.cov_callable(X, y)
121119
covs = np.stack(covs)
122120
self._validate_covariances(covs)
123121
self._validate_covariances([C_ref])

mne/decoding/csp.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright the MNE-Python contributors.
44

55
import copy as cp
6+
from functools import partial
67

78
import numpy as np
89
from scipy.linalg import eigh
@@ -126,21 +127,19 @@ def __init__(
126127
self.cov_method_params = cov_method_params
127128
self.component_order = component_order
128129

129-
cov_params = dict(
130+
cov_callable = partial(
131+
_csp_estimate,
130132
reg=reg,
131133
cov_method_params=cov_method_params,
132134
cov_est=cov_est,
133135
rank=rank,
134136
norm_trace=norm_trace,
135137
)
136-
137138
super().__init__(
138139
n_components=n_components,
139-
cov_callable=_csp_estimate,
140-
cov_params=cov_params,
140+
cov_callable=cov_callable,
141141
mod_ged_callable=_csp_mod,
142142
mod_params=dict(evecs_order=component_order),
143-
dec_type="single",
144143
restr_type="restricting",
145144
R_func=sum,
146145
)
@@ -899,20 +898,16 @@ def __init__(
899898
cov_method_params=cov_method_params,
900899
)
901900

902-
cov_params = dict(
901+
cov_callable = partial(
902+
_spoc_estimate,
903903
reg=reg,
904904
cov_method_params=cov_method_params,
905905
rank=rank,
906906
)
907-
908907
super(CSP, self).__init__(
909-
n_components,
910-
_spoc_estimate,
911-
cov_params,
912-
_spoc_mod,
913-
dec_type="single",
914-
restr_type=None,
915-
R_func=None,
908+
n_components=n_components,
909+
cov_callable=cov_callable,
910+
mod_ged_callable=_spoc_mod,
916911
)
917912

918913
# Covariance estimation have to be done on the single epoch level,

mne/decoding/ssd.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# License: BSD-3-Clause
33
# Copyright the MNE-Python contributors.
44

5+
from functools import partial
6+
57
import numpy as np
68
from scipy.linalg import eigh
79
from sklearn.utils.validation import check_is_fitted
@@ -119,7 +121,8 @@ def __init__(
119121
self.cov_method_params = cov_method_params
120122
self.rank = rank
121123

122-
cov_params = dict(
124+
cov_callable = partial(
125+
_ssd_estimate,
123126
reg=reg,
124127
cov_method_params=cov_method_params,
125128
info=info,
@@ -128,17 +131,11 @@ def __init__(
128131
filt_params_noise=filt_params_noise,
129132
rank=rank,
130133
)
131-
132-
mod_params = dict()
133134
super().__init__(
134-
n_components,
135-
_ssd_estimate,
136-
cov_params,
137-
_ssd_mod,
138-
mod_params,
139-
dec_type="single",
135+
n_components=n_components,
136+
cov_callable=cov_callable,
137+
mod_ged_callable=_ssd_mod,
140138
restr_type="ssd",
141-
R_func=None,
142139
)
143140

144141
def _validate_params(self, X):

mne/decoding/tests/test_ged.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# License: BSD-3-Clause
33
# Copyright the MNE-Python contributors.
44

5-
import functools
5+
from functools import partial
66
from pathlib import Path
77

88
import numpy as np
@@ -18,15 +18,15 @@
1818
from mne import Epochs, compute_rank, create_info, pick_types, read_events
1919
from mne._fiff.proj import make_eeg_average_ref_proj
2020
from mne.cov import Covariance, _regularized_covariance
21-
from mne.decoding.base import _GEDTransformer
22-
from mne.decoding.ged import (
21+
from mne.decoding._ged import (
2322
_get_restr_mat,
2423
_handle_restr_mat,
2524
_is_cov_pos_def,
2625
_is_cov_symm_pos_semidef,
2726
_smart_ajd,
2827
_smart_ged,
2928
)
29+
from mne.decoding.base import _GEDTransformer
3030
from mne.io import read_raw
3131

3232
data_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
@@ -120,20 +120,16 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs):
120120

121121
param_grid = dict(
122122
n_components=[4],
123-
cov_callable=[_mock_cov_callable],
124-
cov_params=[
125-
dict(cov_method_params=dict(reg="empirical")),
126-
],
123+
cov_callable=[partial(_mock_cov_callable, cov_method_params=dict(reg="empirical"))],
127124
mod_ged_callable=[_mock_mod_ged_callable],
128-
mod_params=[dict()],
129125
dec_type=["single", "multi"],
130126
# XXX: Not covering "ssd" here because test_ssd.py works with 2D data.
131127
# Need to fix its tests first.
132128
restr_type=[
133129
"restricting",
134130
"whitening",
135131
],
136-
R_func=[functools.partial(np.sum, axis=0)],
132+
R_func=[partial(np.sum, axis=0)],
137133
)
138134

139135
ged_estimators = [_GEDTransformer(**p) for p in ParameterGrid(param_grid)]
@@ -185,11 +181,8 @@ def test_ged_binary_cov():
185181
ged = _GEDTransformer(
186182
n_components=4,
187183
cov_callable=_mock_cov_callable,
188-
cov_params=dict(),
189184
mod_ged_callable=_mock_mod_ged_callable,
190-
dec_type="single",
191185
restr_type="restricting",
192-
R_func=None,
193186
)
194187
ged.fit(X, y)
195188
desired_evals = ged.evals_
@@ -212,11 +205,9 @@ def test_ged_binary_cov():
212205
ged = _GEDTransformer(
213206
n_components=4,
214207
cov_callable=_mock_cov_callable,
215-
cov_params=dict(),
216208
mod_ged_callable=_mock_mod_ged_callable,
217209
dec_type="multi",
218210
restr_type="restricting",
219-
R_func=None,
220211
)
221212
ged.fit(X, y)
222213
desired_evals = ged.evals_
@@ -241,11 +232,8 @@ def test_ged_multicov():
241232
ged = _GEDTransformer(
242233
n_components=4,
243234
cov_callable=_mock_cov_callable,
244-
cov_params=dict(),
245235
mod_ged_callable=_mock_mod_ged_callable,
246-
dec_type="single",
247236
restr_type="restricting",
248-
R_func=None,
249237
)
250238
ged.fit(X, y)
251239
desired_filters = ged.filters_
@@ -267,11 +255,9 @@ def test_ged_multicov():
267255
ged = _GEDTransformer(
268256
n_components=4,
269257
cov_callable=_mock_cov_callable,
270-
cov_params=dict(),
271258
mod_ged_callable=_mock_mod_ged_callable,
272259
dec_type="multi",
273260
restr_type="restricting",
274-
R_func=None,
275261
)
276262
ged.fit(X, y)
277263
desired_evals = ged.evals_
@@ -292,12 +278,11 @@ def test_ged_multicov():
292278

293279
ged = _GEDTransformer(
294280
n_components=4,
295-
cov_callable=_mock_cov_callable,
296-
cov_params=dict(cov_method_params=dict(reg="oas"), compute_C_ref=False),
281+
cov_callable=partial(
282+
_mock_cov_callable, cov_method_params=dict(reg="oas"), compute_C_ref=False
283+
),
297284
mod_ged_callable=_mock_mod_ged_callable,
298-
dec_type="single",
299285
restr_type="restricting",
300-
R_func=None,
301286
)
302287
ged.fit(X, y)
303288
desired_filters = ged.filters_
@@ -310,11 +295,7 @@ def test_ged_invalid_cov():
310295
ged = _GEDTransformer(
311296
n_components=1,
312297
cov_callable=_mock_cov_callable,
313-
cov_params=dict(),
314298
mod_ged_callable=_mock_mod_ged_callable,
315-
dec_type="single",
316-
restr_type=None,
317-
R_func=None,
318299
)
319300
asymm_cov = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
320301
with pytest.raises(ValueError, match="not symmetric"):

mne/preprocessing/xdawn.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Copyright the MNE-Python contributors.
44

55
from collections.abc import Mapping
6+
from functools import partial
67

78
import numpy as np
89
from scipy import linalg
@@ -263,16 +264,14 @@ def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None
263264
self.reg = reg
264265
self.method_params = method_params
265266

266-
cov_params = dict(reg=reg, cov_method_params=method_params, R=signal_cov)
267-
267+
cov_callable = partial(
268+
_xdawn_estimate, reg=reg, cov_method_params=method_params, R=signal_cov
269+
)
268270
super().__init__(
269-
n_components,
270-
_xdawn_estimate,
271-
cov_params,
272-
_xdawn_mod,
271+
n_components=n_components,
272+
cov_callable=cov_callable,
273+
mod_ged_callable=_xdawn_mod,
273274
dec_type="multi",
274-
restr_type=None,
275-
R_func=None,
276275
)
277276

278277
def _validate_params(self, X):

0 commit comments

Comments
 (0)