Skip to content

Commit 95544c5

Browse files
committed
add feature to perform GED in the principal subspace for xdawn
1 parent 9c7c711 commit 95544c5

4 files changed

Lines changed: 75 additions & 13 deletions

File tree

mne/decoding/_covs_ged.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,6 @@ def _xdawn_estimate(
119119
):
120120
classes = np.unique(y)
121121

122-
# XXX Eventually this could be made to deal with rank deficiency properly
123-
# by exposing this "rank" parameter, but this will require refactoring
124-
# the linalg.eigh call to operate in the lower-dimension
125-
# subspace, then project back out.
126-
127122
# Retrieve or compute whitening covariance
128123
if R is None:
129124
R = _regularized_covariance(
@@ -147,9 +142,8 @@ def _xdawn_estimate(
147142
covs.append(evo_cov)
148143

149144
covs.append(R)
150-
C_ref = None
151-
rank = None
152-
info = None
145+
C_ref = R
146+
rank = rank if isinstance(rank, dict) else None
153147
return covs, C_ref, info, rank, dict()
154148

155149

mne/decoding/_ged.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ def _handle_restr_mat(C_ref, restr_type, info, rank):
1818
if C_ref is None or restr_type is None:
1919
return None
2020
if restr_type == "whitening":
21-
projs = info["projs"]
22-
C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], projs, 0)
21+
C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], info["projs"], 0)
2322
restr_mat = compute_whitener(C_ref_cov, info, rank=rank, pca=True)[0]
2423
elif restr_type == "ssd":
2524
restr_mat = _get_ssd_whitener(C_ref, rank)

mne/decoding/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from sklearn.utils import check_array, check_X_y, indexable
2323
from sklearn.utils.validation import check_is_fitted
2424

25+
from .._fiff.meas_info import create_info
26+
from ..cov import _compute_rank_raw_array
2527
from ..parallel import parallel_func
2628
from ..utils import _pl, logger, pinv, verbose, warn
2729
from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged
@@ -122,6 +124,24 @@ def fit(self, X, y=None):
122124
mod_ged_callable = (
123125
self.mod_ged_callable if self.mod_ged_callable is not None else _no_op_mod
124126
)
127+
128+
# If restriction to be done, info and rank should exist.
129+
if self.restr_type is not None and C_ref is not None:
130+
if info is None:
131+
# use mag instead of eeg to avoid the cov EEG projection warning
132+
info = create_info(C_ref.shape[0], 1000.0, "mag")
133+
if isinstance(rank, dict):
134+
rank = dict(mag=sum(rank.values()))
135+
136+
if rank is None:
137+
rank = _compute_rank_raw_array(
138+
np.hstack(X),
139+
info,
140+
rank=None,
141+
scalings=None,
142+
log_ch_type="data",
143+
)
144+
125145
if self.dec_type == "single":
126146
if len(covs) > 2:
127147
weights = (

mne/preprocessing/xdawn.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
from scipy import linalg
1010

11+
from .._fiff.meas_info import Info
1112
from .._fiff.pick import _pick_data_channels, pick_info
1213
from ..cov import Covariance, _regularized_covariance
1314
from ..decoding._covs_ged import _xdawn_estimate
@@ -246,6 +247,30 @@ class _XdawnTransformer(_GEDTransformer):
246247
Parameters to pass to :func:`mne.compute_covariance`.
247248
248249
.. versionadded:: 0.16
250+
restr_type : "restricting" | "whitening" | None
251+
Restricting transformation for covariance matrices before performing
252+
generalized eigendecomposition.
253+
If "restricting" only restriction to the principal subspace of signal_cov
254+
will be performed.
255+
If "whitening", covariance matrices will be additionally rescaled according
256+
to the whitening for the signal_cov.
257+
If None, no restriction will be applied. Defaults to None.
258+
259+
.. versionadded:: 1.10
260+
info : mne.Info | None
261+
The mne.Info object with information about the sensors and methods of
262+
measurement used for covariance estimation and generalized
263+
eigendecomposition.
264+
If None, one channel type and no projections will be assumed and if
265+
rank is dict, it will be sum of ranks per channel type.
266+
Defaults to None.
267+
268+
.. versionadded:: 1.10
269+
%(rank)s
270+
Defaults to "full".
271+
272+
.. versionadded:: 1.10
273+
249274
250275
Attributes
251276
----------
@@ -257,21 +282,39 @@ class _XdawnTransformer(_GEDTransformer):
257282
The Xdawn patterns used to restore the signals for each event type.
258283
"""
259284

260-
def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None):
285+
def __init__(
286+
self,
287+
n_components=2,
288+
reg=None,
289+
signal_cov=None,
290+
method_params=None,
291+
restr_type=None,
292+
info=None,
293+
rank="full",
294+
):
261295
"""Init."""
262296
self.n_components = n_components
263297
self.signal_cov = signal_cov
264298
self.reg = reg
265299
self.method_params = method_params
300+
self.restr_type = restr_type
301+
self.info = info
302+
self.rank = rank
266303

267304
cov_callable = partial(
268-
_xdawn_estimate, reg=reg, cov_method_params=method_params, R=signal_cov
305+
_xdawn_estimate,
306+
reg=reg,
307+
cov_method_params=method_params,
308+
R=signal_cov,
309+
info=info,
310+
rank=rank,
269311
)
270312
super().__init__(
271313
n_components=n_components,
272314
cov_callable=cov_callable,
273315
mod_ged_callable=_xdawn_mod,
274316
dec_type="multi",
317+
restr_type=restr_type,
275318
)
276319

277320
def _validate_params(self, X):
@@ -288,7 +331,13 @@ def _validate_params(self, X):
288331
raise ValueError(
289332
"signal_cov data should be of shape (n_channels, n_channels)"
290333
)
291-
_validate_type(self.method_params, (Mapping, None))
334+
_validate_type(self.method_params, (Mapping, None), "method_params")
335+
_check_option(
336+
"restr_type",
337+
self.restr_type,
338+
("restricting", "whitening", None),
339+
)
340+
_validate_type(self.info, (Info, None), "info")
292341

293342
def fit(self, X, y=None):
294343
"""Fit Xdawn spatial filters.

0 commit comments

Comments
 (0)