|
| 1 | +# Authors: The MNE-Python contributors. |
| 2 | +# License: BSD-3-Clause |
| 3 | +# Copyright the MNE-Python contributors. |
| 4 | + |
| 5 | +from collections.abc import Mapping |
| 6 | +from functools import partial |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from .._fiff.meas_info import Info |
| 11 | +from ..cov import Covariance |
| 12 | +from ..decoding._covs_ged import _xdawn_estimate |
| 13 | +from ..decoding._mod_ged import _xdawn_mod |
| 14 | +from ..decoding.base import _GEDTransformer |
| 15 | +from ..utils import _check_option, _validate_type |
| 16 | + |
| 17 | + |
| 18 | +class XdawnTransformer(_GEDTransformer): |
| 19 | + """Implementation of the Xdawn Algorithm compatible with scikit-learn. |
| 20 | +
|
| 21 | + Xdawn is a spatial filtering method designed to improve the signal |
| 22 | + to signal + noise ratio (SSNR) of the event related responses. Xdawn was |
| 23 | + originally designed for P300 evoked potential by enhancing the target |
| 24 | + response with respect to the non-target response. This implementation is a |
| 25 | + generalization to any type of event related response. |
| 26 | +
|
| 27 | + .. note:: _XdawnTransformer does not correct for epochs overlap. To correct |
| 28 | + overlaps see ``Xdawn``. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + n_components : int (default 2) |
| 33 | + The number of components to decompose the signals. |
| 34 | + reg : float | str | None (default None) |
| 35 | + If not None (same as ``'empirical'``, default), allow |
| 36 | + regularization for covariance estimation. |
| 37 | + If float, shrinkage is used (0 <= shrinkage <= 1). |
| 38 | + For str options, ``reg`` will be passed to ``method`` to |
| 39 | + :func:`mne.compute_covariance`. |
| 40 | + signal_cov : None | Covariance | array, shape (n_channels, n_channels) |
| 41 | + The signal covariance used for whitening of the data. |
| 42 | + if None, the covariance is estimated from the epochs signal. |
| 43 | + method_params : dict | None |
| 44 | + Parameters to pass to :func:`mne.compute_covariance`. |
| 45 | +
|
| 46 | + .. versionadded:: 0.16 |
| 47 | + restr_type : "restricting" | "whitening" | None |
| 48 | + Restricting transformation for covariance matrices before performing |
| 49 | + generalized eigendecomposition. |
| 50 | + If "restricting" only restriction to the principal subspace of signal_cov |
| 51 | + will be performed. |
| 52 | + If "whitening", covariance matrices will be additionally rescaled according |
| 53 | + to the whitening for the signal_cov. |
| 54 | + If None, no restriction will be applied. Defaults to None. |
| 55 | +
|
| 56 | + .. versionadded:: 1.10 |
| 57 | + info : mne.Info | None |
| 58 | + The mne.Info object with information about the sensors and methods of |
| 59 | + measurement used for covariance estimation and generalized |
| 60 | + eigendecomposition. |
| 61 | + If None, one channel type and no projections will be assumed and if |
| 62 | + rank is dict, it will be sum of ranks per channel type. |
| 63 | + Defaults to None. |
| 64 | +
|
| 65 | + .. versionadded:: 1.10 |
| 66 | + %(rank)s |
| 67 | + Defaults to "full". |
| 68 | +
|
| 69 | + .. versionadded:: 1.10 |
| 70 | +
|
| 71 | +
|
| 72 | + Attributes |
| 73 | + ---------- |
| 74 | + classes_ : array, shape (n_classes) |
| 75 | + The event indices of the classes. |
| 76 | + filters_ : array, shape (n_channels, n_channels) |
| 77 | + The Xdawn components used to decompose the data for each event type. |
| 78 | + patterns_ : array, shape (n_channels, n_channels) |
| 79 | + The Xdawn patterns used to restore the signals for each event type. |
| 80 | + """ |
| 81 | + |
| 82 | + def __init__( |
| 83 | + self, |
| 84 | + n_components=2, |
| 85 | + reg=None, |
| 86 | + signal_cov=None, |
| 87 | + method_params=None, |
| 88 | + restr_type=None, |
| 89 | + info=None, |
| 90 | + rank="full", |
| 91 | + ): |
| 92 | + """Init.""" |
| 93 | + self.n_components = n_components |
| 94 | + self.signal_cov = signal_cov |
| 95 | + self.reg = reg |
| 96 | + self.method_params = method_params |
| 97 | + self.restr_type = restr_type |
| 98 | + self.info = info |
| 99 | + self.rank = rank |
| 100 | + |
| 101 | + cov_callable = partial( |
| 102 | + _xdawn_estimate, |
| 103 | + reg=reg, |
| 104 | + cov_method_params=method_params, |
| 105 | + R=signal_cov, |
| 106 | + info=info, |
| 107 | + rank=rank, |
| 108 | + ) |
| 109 | + super().__init__( |
| 110 | + n_components=n_components, |
| 111 | + cov_callable=cov_callable, |
| 112 | + mod_ged_callable=_xdawn_mod, |
| 113 | + dec_type="multi", |
| 114 | + restr_type=restr_type, |
| 115 | + ) |
| 116 | + |
| 117 | + def _validate_params(self, X): |
| 118 | + _validate_type(self.n_components, int, "n_components") |
| 119 | + |
| 120 | + # reg is validated in _regularized_covariance |
| 121 | + |
| 122 | + if self.signal_cov is not None: |
| 123 | + if isinstance(self.signal_cov, Covariance): |
| 124 | + self.signal_cov = self.signal_cov.data |
| 125 | + elif not isinstance(self.signal_cov, np.ndarray): |
| 126 | + raise ValueError("signal_cov should be mne.Covariance or np.ndarray") |
| 127 | + if not np.array_equal(self.signal_cov.shape, np.tile(X.shape[1], 2)): |
| 128 | + raise ValueError( |
| 129 | + "signal_cov data should be of shape (n_channels, n_channels)" |
| 130 | + ) |
| 131 | + _validate_type(self.method_params, (Mapping, None), "method_params") |
| 132 | + _check_option( |
| 133 | + "restr_type", |
| 134 | + self.restr_type, |
| 135 | + ("restricting", "whitening", None), |
| 136 | + ) |
| 137 | + _validate_type(self.info, (Info, None), "info") |
| 138 | + |
| 139 | + def fit(self, X, y=None): |
| 140 | + """Fit Xdawn spatial filters. |
| 141 | +
|
| 142 | + Parameters |
| 143 | + ---------- |
| 144 | + X : array, shape (n_epochs, n_channels, n_samples) |
| 145 | + The target data. |
| 146 | + y : array, shape (n_epochs,) | None |
| 147 | + The target labels. If None, Xdawn fit on the average evoked. |
| 148 | +
|
| 149 | + Returns |
| 150 | + ------- |
| 151 | + self : Xdawn instance |
| 152 | + The Xdawn instance. |
| 153 | + """ |
| 154 | + from ..preprocessing.xdawn import _fit_xdawn |
| 155 | + |
| 156 | + X, y = self._check_Xy(X, y) |
| 157 | + self._validate_params(X) |
| 158 | + # Main function |
| 159 | + self.classes_ = np.unique(y) |
| 160 | + self.filters_, self.patterns_, _ = _fit_xdawn( |
| 161 | + X, |
| 162 | + y, |
| 163 | + n_components=self.n_components, |
| 164 | + reg=self.reg, |
| 165 | + signal_cov=self.signal_cov, |
| 166 | + method_params=self.method_params, |
| 167 | + ) |
| 168 | + old_filters = self.filters_ |
| 169 | + old_patterns = self.patterns_ |
| 170 | + super().fit(X, y) |
| 171 | + |
| 172 | + # Hack for assert_allclose in transform |
| 173 | + self.new_filters_ = self.filters_.copy() |
| 174 | + # Xdawn performs separate GED for each class. |
| 175 | + # filters_ returned by _fit_xdawn are subset per |
| 176 | + # n_components and then appended and are of shape |
| 177 | + # (n_classes*n_components, n_chs). |
| 178 | + # GEDTransformer creates new dimension per class without subsetting |
| 179 | + # for easier analysis and visualisations. |
| 180 | + # So it needs to be performed post-hoc to conform with Xdawn. |
| 181 | + # The shape returned by GED here is (n_classes, n_evecs, n_chs) |
| 182 | + # Need to transform and subset into (n_classes*n_components, n_chs) |
| 183 | + self.filters_ = self.filters_[:, : self.n_components, :].reshape( |
| 184 | + -1, self.filters_.shape[2] |
| 185 | + ) |
| 186 | + self.patterns_ = self.patterns_[:, : self.n_components, :].reshape( |
| 187 | + -1, self.patterns_.shape[2] |
| 188 | + ) |
| 189 | + np.testing.assert_allclose(old_filters, self.filters_) |
| 190 | + np.testing.assert_allclose(old_patterns, self.patterns_) |
| 191 | + |
| 192 | + return self |
| 193 | + |
| 194 | + def transform(self, X): |
| 195 | + """Transform data with spatial filters. |
| 196 | +
|
| 197 | + Parameters |
| 198 | + ---------- |
| 199 | + X : array, shape (n_epochs, n_channels, n_samples) |
| 200 | + The target data. |
| 201 | +
|
| 202 | + Returns |
| 203 | + ------- |
| 204 | + X : array, shape (n_epochs, n_components * n_classes, n_samples) |
| 205 | + The transformed data. |
| 206 | + """ |
| 207 | + X, _ = self._check_Xy(X) |
| 208 | + orig_X = X.copy() |
| 209 | + |
| 210 | + # Check size |
| 211 | + if self.filters_.shape[1] != X.shape[1]: |
| 212 | + raise ValueError( |
| 213 | + f"X must have {self.filters_.shape[1]} channels, got {X.shape[1]} " |
| 214 | + "instead." |
| 215 | + ) |
| 216 | + |
| 217 | + # Transform |
| 218 | + X = np.dot(self.filters_, X) |
| 219 | + X = X.transpose((1, 0, 2)) |
| 220 | + ged_X = super().transform(orig_X) |
| 221 | + np.testing.assert_allclose(X, ged_X) |
| 222 | + return X |
| 223 | + |
| 224 | + def inverse_transform(self, X): |
| 225 | + """Remove selected components from the signal. |
| 226 | +
|
| 227 | + Given the unmixing matrix, transform data, zero out components, |
| 228 | + and inverse transform the data. This procedure will reconstruct |
| 229 | + the signals from which the dynamics described by the excluded |
| 230 | + components is subtracted. |
| 231 | +
|
| 232 | + Parameters |
| 233 | + ---------- |
| 234 | + X : array, shape (n_epochs, n_components * n_classes, n_times) |
| 235 | + The transformed data. |
| 236 | +
|
| 237 | + Returns |
| 238 | + ------- |
| 239 | + X : array, shape (n_epochs, n_channels * n_classes, n_times) |
| 240 | + The inverse transform data. |
| 241 | + """ |
| 242 | + # Check size |
| 243 | + X, _ = self._check_Xy(X) |
| 244 | + n_epochs, n_comp, n_times = X.shape |
| 245 | + if n_comp != (self.n_components * len(self.classes_)): |
| 246 | + raise ValueError( |
| 247 | + f"X must have {self.n_components * len(self.classes_)} components, " |
| 248 | + f"got {n_comp} instead." |
| 249 | + ) |
| 250 | + |
| 251 | + # Transform |
| 252 | + return np.dot(self.patterns_.T, X).transpose(1, 0, 2) |
| 253 | + |
| 254 | + def _check_Xy(self, X, y=None): |
| 255 | + """Check X and y types and dimensions.""" |
| 256 | + # Check data |
| 257 | + if not isinstance(X, np.ndarray) or X.ndim != 3: |
| 258 | + raise ValueError( |
| 259 | + "X must be an array of shape (n_epochs, n_channels, n_samples)." |
| 260 | + ) |
| 261 | + if y is None: |
| 262 | + y = np.ones(len(X)) |
| 263 | + y = np.asarray(y) |
| 264 | + if len(X) != len(y): |
| 265 | + raise ValueError("X and y must have the same length") |
| 266 | + return X, y |
0 commit comments