22# License: BSD-3-Clause
33# Copyright the MNE-Python contributors.
44
5- import functools
5+ from functools import partial
66from pathlib import Path
77
88import numpy as np
1818from mne import Epochs , compute_rank , create_info , pick_types , read_events
1919from mne ._fiff .proj import make_eeg_average_ref_proj
2020from mne .cov import Covariance , _regularized_covariance
21- from mne .decoding .base import _GEDTransformer
22- from mne .decoding .ged import (
21+ from mne .decoding ._ged import (
2322 _get_restr_mat ,
2423 _handle_restr_mat ,
2524 _is_cov_pos_def ,
2625 _is_cov_symm_pos_semidef ,
2726 _smart_ajd ,
2827 _smart_ged ,
2928)
29+ from mne .decoding .base import _GEDTransformer
3030from mne .io import read_raw
3131
3232data_dir = Path (__file__ ).parents [2 ] / "io" / "tests" / "data"
@@ -120,20 +120,16 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs):
120120
121121param_grid = dict (
122122 n_components = [4 ],
123- cov_callable = [_mock_cov_callable ],
124- cov_params = [
125- dict (cov_method_params = dict (reg = "empirical" )),
126- ],
123+ cov_callable = [partial (_mock_cov_callable , cov_method_params = dict (reg = "empirical" ))],
127124 mod_ged_callable = [_mock_mod_ged_callable ],
128- mod_params = [dict ()],
129125 dec_type = ["single" , "multi" ],
130126 # XXX: Not covering "ssd" here because test_ssd.py works with 2D data.
131127 # Need to fix its tests first.
132128 restr_type = [
133129 "restricting" ,
134130 "whitening" ,
135131 ],
136- R_func = [functools . partial (np .sum , axis = 0 )],
132+ R_func = [partial (np .sum , axis = 0 )],
137133)
138134
139135ged_estimators = [_GEDTransformer (** p ) for p in ParameterGrid (param_grid )]
@@ -185,11 +181,8 @@ def test_ged_binary_cov():
185181 ged = _GEDTransformer (
186182 n_components = 4 ,
187183 cov_callable = _mock_cov_callable ,
188- cov_params = dict (),
189184 mod_ged_callable = _mock_mod_ged_callable ,
190- dec_type = "single" ,
191185 restr_type = "restricting" ,
192- R_func = None ,
193186 )
194187 ged .fit (X , y )
195188 desired_evals = ged .evals_
@@ -212,11 +205,9 @@ def test_ged_binary_cov():
212205 ged = _GEDTransformer (
213206 n_components = 4 ,
214207 cov_callable = _mock_cov_callable ,
215- cov_params = dict (),
216208 mod_ged_callable = _mock_mod_ged_callable ,
217209 dec_type = "multi" ,
218210 restr_type = "restricting" ,
219- R_func = None ,
220211 )
221212 ged .fit (X , y )
222213 desired_evals = ged .evals_
@@ -241,11 +232,8 @@ def test_ged_multicov():
241232 ged = _GEDTransformer (
242233 n_components = 4 ,
243234 cov_callable = _mock_cov_callable ,
244- cov_params = dict (),
245235 mod_ged_callable = _mock_mod_ged_callable ,
246- dec_type = "single" ,
247236 restr_type = "restricting" ,
248- R_func = None ,
249237 )
250238 ged .fit (X , y )
251239 desired_filters = ged .filters_
@@ -267,11 +255,9 @@ def test_ged_multicov():
267255 ged = _GEDTransformer (
268256 n_components = 4 ,
269257 cov_callable = _mock_cov_callable ,
270- cov_params = dict (),
271258 mod_ged_callable = _mock_mod_ged_callable ,
272259 dec_type = "multi" ,
273260 restr_type = "restricting" ,
274- R_func = None ,
275261 )
276262 ged .fit (X , y )
277263 desired_evals = ged .evals_
@@ -292,12 +278,11 @@ def test_ged_multicov():
292278
293279 ged = _GEDTransformer (
294280 n_components = 4 ,
295- cov_callable = _mock_cov_callable ,
296- cov_params = dict (cov_method_params = dict (reg = "oas" ), compute_C_ref = False ),
281+ cov_callable = partial (
282+ _mock_cov_callable , cov_method_params = dict (reg = "oas" ), compute_C_ref = False
283+ ),
297284 mod_ged_callable = _mock_mod_ged_callable ,
298- dec_type = "single" ,
299285 restr_type = "restricting" ,
300- R_func = None ,
301286 )
302287 ged .fit (X , y )
303288 desired_filters = ged .filters_
@@ -310,11 +295,7 @@ def test_ged_invalid_cov():
310295 ged = _GEDTransformer (
311296 n_components = 1 ,
312297 cov_callable = _mock_cov_callable ,
313- cov_params = dict (),
314298 mod_ged_callable = _mock_mod_ged_callable ,
315- dec_type = "single" ,
316- restr_type = None ,
317- R_func = None ,
318299 )
319300 asymm_cov = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]])
320301 with pytest .raises (ValueError , match = "not symmetric" ):
0 commit comments