Skip to content

Commit 726c500

Browse files
committed
move mne.preprocessing._XdawnTransformer to decoding and make it public
1 parent 5266372 commit 726c500

3 files changed

Lines changed: 282 additions & 270 deletions

File tree

mne/decoding/xdawn.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# Authors: The MNE-Python contributors.
2+
# License: BSD-3-Clause
3+
# Copyright the MNE-Python contributors.
4+
5+
from collections.abc import Mapping
6+
from functools import partial
7+
8+
import numpy as np
9+
10+
from .._fiff.meas_info import Info
11+
from ..cov import Covariance
12+
from ..decoding._covs_ged import _xdawn_estimate
13+
from ..decoding._mod_ged import _xdawn_mod
14+
from ..decoding.base import _GEDTransformer
15+
from ..utils import _check_option, _validate_type
16+
17+
18+
class XdawnTransformer(_GEDTransformer):
19+
"""Implementation of the Xdawn Algorithm compatible with scikit-learn.
20+
21+
Xdawn is a spatial filtering method designed to improve the signal
22+
to signal + noise ratio (SSNR) of the event related responses. Xdawn was
23+
originally designed for P300 evoked potential by enhancing the target
24+
response with respect to the non-target response. This implementation is a
25+
generalization to any type of event related response.
26+
27+
.. note:: _XdawnTransformer does not correct for epochs overlap. To correct
28+
overlaps see ``Xdawn``.
29+
30+
Parameters
31+
----------
32+
n_components : int (default 2)
33+
The number of components to decompose the signals.
34+
reg : float | str | None (default None)
35+
If not None (same as ``'empirical'``, default), allow
36+
regularization for covariance estimation.
37+
If float, shrinkage is used (0 <= shrinkage <= 1).
38+
For str options, ``reg`` will be passed to ``method`` to
39+
:func:`mne.compute_covariance`.
40+
signal_cov : None | Covariance | array, shape (n_channels, n_channels)
41+
The signal covariance used for whitening of the data.
42+
if None, the covariance is estimated from the epochs signal.
43+
method_params : dict | None
44+
Parameters to pass to :func:`mne.compute_covariance`.
45+
46+
.. versionadded:: 0.16
47+
restr_type : "restricting" | "whitening" | None
48+
Restricting transformation for covariance matrices before performing
49+
generalized eigendecomposition.
50+
If "restricting" only restriction to the principal subspace of signal_cov
51+
will be performed.
52+
If "whitening", covariance matrices will be additionally rescaled according
53+
to the whitening for the signal_cov.
54+
If None, no restriction will be applied. Defaults to None.
55+
56+
.. versionadded:: 1.10
57+
info : mne.Info | None
58+
The mne.Info object with information about the sensors and methods of
59+
measurement used for covariance estimation and generalized
60+
eigendecomposition.
61+
If None, one channel type and no projections will be assumed and if
62+
rank is dict, it will be sum of ranks per channel type.
63+
Defaults to None.
64+
65+
.. versionadded:: 1.10
66+
%(rank)s
67+
Defaults to "full".
68+
69+
.. versionadded:: 1.10
70+
71+
72+
Attributes
73+
----------
74+
classes_ : array, shape (n_classes)
75+
The event indices of the classes.
76+
filters_ : array, shape (n_channels, n_channels)
77+
The Xdawn components used to decompose the data for each event type.
78+
patterns_ : array, shape (n_channels, n_channels)
79+
The Xdawn patterns used to restore the signals for each event type.
80+
"""
81+
82+
def __init__(
83+
self,
84+
n_components=2,
85+
reg=None,
86+
signal_cov=None,
87+
method_params=None,
88+
restr_type=None,
89+
info=None,
90+
rank="full",
91+
):
92+
"""Init."""
93+
self.n_components = n_components
94+
self.signal_cov = signal_cov
95+
self.reg = reg
96+
self.method_params = method_params
97+
self.restr_type = restr_type
98+
self.info = info
99+
self.rank = rank
100+
101+
cov_callable = partial(
102+
_xdawn_estimate,
103+
reg=reg,
104+
cov_method_params=method_params,
105+
R=signal_cov,
106+
info=info,
107+
rank=rank,
108+
)
109+
super().__init__(
110+
n_components=n_components,
111+
cov_callable=cov_callable,
112+
mod_ged_callable=_xdawn_mod,
113+
dec_type="multi",
114+
restr_type=restr_type,
115+
)
116+
117+
def _validate_params(self, X):
118+
_validate_type(self.n_components, int, "n_components")
119+
120+
# reg is validated in _regularized_covariance
121+
122+
if self.signal_cov is not None:
123+
if isinstance(self.signal_cov, Covariance):
124+
self.signal_cov = self.signal_cov.data
125+
elif not isinstance(self.signal_cov, np.ndarray):
126+
raise ValueError("signal_cov should be mne.Covariance or np.ndarray")
127+
if not np.array_equal(self.signal_cov.shape, np.tile(X.shape[1], 2)):
128+
raise ValueError(
129+
"signal_cov data should be of shape (n_channels, n_channels)"
130+
)
131+
_validate_type(self.method_params, (Mapping, None), "method_params")
132+
_check_option(
133+
"restr_type",
134+
self.restr_type,
135+
("restricting", "whitening", None),
136+
)
137+
_validate_type(self.info, (Info, None), "info")
138+
139+
def fit(self, X, y=None):
140+
"""Fit Xdawn spatial filters.
141+
142+
Parameters
143+
----------
144+
X : array, shape (n_epochs, n_channels, n_samples)
145+
The target data.
146+
y : array, shape (n_epochs,) | None
147+
The target labels. If None, Xdawn fit on the average evoked.
148+
149+
Returns
150+
-------
151+
self : Xdawn instance
152+
The Xdawn instance.
153+
"""
154+
from ..preprocessing.xdawn import _fit_xdawn
155+
156+
X, y = self._check_Xy(X, y)
157+
self._validate_params(X)
158+
# Main function
159+
self.classes_ = np.unique(y)
160+
self.filters_, self.patterns_, _ = _fit_xdawn(
161+
X,
162+
y,
163+
n_components=self.n_components,
164+
reg=self.reg,
165+
signal_cov=self.signal_cov,
166+
method_params=self.method_params,
167+
)
168+
old_filters = self.filters_
169+
old_patterns = self.patterns_
170+
super().fit(X, y)
171+
172+
# Hack for assert_allclose in transform
173+
self.new_filters_ = self.filters_.copy()
174+
# Xdawn performs separate GED for each class.
175+
# filters_ returned by _fit_xdawn are subset per
176+
# n_components and then appended and are of shape
177+
# (n_classes*n_components, n_chs).
178+
# GEDTransformer creates new dimension per class without subsetting
179+
# for easier analysis and visualisations.
180+
# So it needs to be performed post-hoc to conform with Xdawn.
181+
# The shape returned by GED here is (n_classes, n_evecs, n_chs)
182+
# Need to transform and subset into (n_classes*n_components, n_chs)
183+
self.filters_ = self.filters_[:, : self.n_components, :].reshape(
184+
-1, self.filters_.shape[2]
185+
)
186+
self.patterns_ = self.patterns_[:, : self.n_components, :].reshape(
187+
-1, self.patterns_.shape[2]
188+
)
189+
np.testing.assert_allclose(old_filters, self.filters_)
190+
np.testing.assert_allclose(old_patterns, self.patterns_)
191+
192+
return self
193+
194+
def transform(self, X):
195+
"""Transform data with spatial filters.
196+
197+
Parameters
198+
----------
199+
X : array, shape (n_epochs, n_channels, n_samples)
200+
The target data.
201+
202+
Returns
203+
-------
204+
X : array, shape (n_epochs, n_components * n_classes, n_samples)
205+
The transformed data.
206+
"""
207+
X, _ = self._check_Xy(X)
208+
orig_X = X.copy()
209+
210+
# Check size
211+
if self.filters_.shape[1] != X.shape[1]:
212+
raise ValueError(
213+
f"X must have {self.filters_.shape[1]} channels, got {X.shape[1]} "
214+
"instead."
215+
)
216+
217+
# Transform
218+
X = np.dot(self.filters_, X)
219+
X = X.transpose((1, 0, 2))
220+
ged_X = super().transform(orig_X)
221+
np.testing.assert_allclose(X, ged_X)
222+
return X
223+
224+
def inverse_transform(self, X):
225+
"""Remove selected components from the signal.
226+
227+
Given the unmixing matrix, transform data, zero out components,
228+
and inverse transform the data. This procedure will reconstruct
229+
the signals from which the dynamics described by the excluded
230+
components is subtracted.
231+
232+
Parameters
233+
----------
234+
X : array, shape (n_epochs, n_components * n_classes, n_times)
235+
The transformed data.
236+
237+
Returns
238+
-------
239+
X : array, shape (n_epochs, n_channels * n_classes, n_times)
240+
The inverse transform data.
241+
"""
242+
# Check size
243+
X, _ = self._check_Xy(X)
244+
n_epochs, n_comp, n_times = X.shape
245+
if n_comp != (self.n_components * len(self.classes_)):
246+
raise ValueError(
247+
f"X must have {self.n_components * len(self.classes_)} components, "
248+
f"got {n_comp} instead."
249+
)
250+
251+
# Transform
252+
return np.dot(self.patterns_.T, X).transpose(1, 0, 2)
253+
254+
def _check_Xy(self, X, y=None):
255+
"""Check X and y types and dimensions."""
256+
# Check data
257+
if not isinstance(X, np.ndarray) or X.ndim != 3:
258+
raise ValueError(
259+
"X must be an array of shape (n_epochs, n_channels, n_samples)."
260+
)
261+
if y is None:
262+
y = np.ones(len(X))
263+
y = np.asarray(y)
264+
if len(X) != len(y):
265+
raise ValueError("X and y must have the same length")
266+
return X, y

mne/preprocessing/tests/test_xdawn.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
pytest.importorskip("sklearn")
2424

25-
from mne.preprocessing.xdawn import Xdawn, _XdawnTransformer # noqa: E402
25+
from mne.decoding.xdawn import XdawnTransformer
26+
from mne.preprocessing.xdawn import Xdawn # noqa: E402
2627

2728
base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
2829
raw_fname = base_dir / "test_raw.fif"
@@ -236,7 +237,7 @@ def test_xdawn_regularization():
236237

237238

238239
def test_XdawnTransformer():
239-
"""Test _XdawnTransformer."""
240+
"""Test XdawnTransformer."""
240241
pytest.importorskip("sklearn")
241242
# Get data
242243
raw, events, picks = _get_data()
@@ -255,37 +256,37 @@ def test_XdawnTransformer():
255256
X = epochs._data
256257
y = epochs.events[:, -1]
257258
# Fit
258-
xdt = _XdawnTransformer()
259+
xdt = XdawnTransformer()
259260
xdt.fit(X, y)
260261
pytest.raises(ValueError, xdt.fit, X, y[1:])
261262
pytest.raises(ValueError, xdt.fit, "foo")
262263

263264
# Provide covariance object
264265
signal_cov = compute_raw_covariance(raw, picks=picks)
265-
xdt = _XdawnTransformer(signal_cov=signal_cov)
266+
xdt = XdawnTransformer(signal_cov=signal_cov)
266267
xdt.fit(X, y)
267268
# Provide ndarray
268269
signal_cov = np.eye(len(picks))
269-
xdt = _XdawnTransformer(signal_cov=signal_cov)
270+
xdt = XdawnTransformer(signal_cov=signal_cov)
270271
xdt.fit(X, y)
271272
# Provide ndarray of bad shape
272273
signal_cov = np.eye(len(picks) - 1)
273-
xdt = _XdawnTransformer(signal_cov=signal_cov)
274+
xdt = XdawnTransformer(signal_cov=signal_cov)
274275
pytest.raises(ValueError, xdt.fit, X, y)
275276
# Provide another type
276277
signal_cov = 42
277-
xdt = _XdawnTransformer(signal_cov=signal_cov)
278+
xdt = XdawnTransformer(signal_cov=signal_cov)
278279
pytest.raises(ValueError, xdt.fit, X, y)
279280

280281
# Fit with y as None
281-
xdt = _XdawnTransformer()
282+
xdt = XdawnTransformer()
282283
xdt.fit(X)
283284

284-
# Compare xdawn and _XdawnTransformer
285+
# Compare xdawn and XdawnTransformer
285286
xd = Xdawn(correct_overlap=False)
286287
xd.fit(epochs)
287288

288-
xdt = _XdawnTransformer()
289+
xdt = XdawnTransformer()
289290
xdt.fit(X, y)
290291
assert_array_almost_equal(
291292
xd.filters_["cond2"][:2, :], xdt.filters_.reshape(2, 2, 8)[0]
@@ -363,15 +364,15 @@ def test_xdawn_decoding_performance():
363364
epochs, mixing_mat = _simulate_erplike_mixed_data(n_epochs=100)
364365
y = epochs.events[:, 2]
365366

366-
# results of Xdawn and _XdawnTransformer should match
367+
# results of Xdawn and XdawnTransformer should match
367368
xdawn_pipe = make_pipeline(
368369
Xdawn(n_components=n_xdawn_comps),
369370
Vectorizer(),
370371
MinMaxScaler(),
371372
LogisticRegression(solver="liblinear"),
372373
)
373374
xdawn_trans_pipe = make_pipeline(
374-
_XdawnTransformer(n_components=n_xdawn_comps),
375+
XdawnTransformer(n_components=n_xdawn_comps),
375376
Vectorizer(),
376377
MinMaxScaler(),
377378
LogisticRegression(solver="liblinear"),

0 commit comments

Comments
 (0)