Skip to content

Commit 8755089

Browse files
committed
add option for CSP to select restr_type and provide info
1 parent 95544c5 commit 8755089

3 files changed

Lines changed: 69 additions & 39 deletions

File tree

mne/decoding/_covs_ged.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..utils import _verbose_safe_false, logger
1616

1717

18-
def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):
18+
def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, info, rank):
1919
"""Concatenate epochs before computing the covariance."""
2020
_, n_channels, _ = x_class.shape
2121

@@ -34,7 +34,7 @@ def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, in
3434
return cov, n_channels # the weight here is just the number of channels
3535

3636

37-
def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):
37+
def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, info, rank):
3838
"""Mean of per-epoch covariances."""
3939
name = reg if isinstance(reg, str) else "empirical"
4040
name += " with shrinkage" if isinstance(reg, float) else ""
@@ -62,22 +62,29 @@ def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, inf
6262
return cov, weight
6363

6464

65-
def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
65+
def _handle_info_rank(X, info, rank):
66+
if info is None:
67+
# use mag instead of eeg to avoid the cov EEG projection warning
68+
info = create_info(X.shape[1], 1000.0, "mag")
69+
if isinstance(rank, dict):
70+
rank = dict(mag=sum(rank.values()))
71+
72+
return info, rank
73+
74+
75+
def _csp_estimate(X, y, reg, cov_method_params, cov_est, info, rank, norm_trace):
6676
_, n_channels, _ = X.shape
6777
classes_ = np.unique(y)
6878
if cov_est == "concat":
6979
cov_estimator = _concat_cov
7080
elif cov_est == "epoch":
7181
cov_estimator = _epoch_cov
72-
# Someday we could allow the user to pass this, then we wouldn't need to convert
73-
# but in the meantime they can use a pipeline with a scaler
74-
_info = create_info(n_channels, 1000.0, "mag")
75-
if isinstance(rank, dict):
76-
_rank = {"mag": sum(rank.values())}
77-
else:
78-
_rank = _compute_rank_raw_array(
79-
X.transpose(1, 0, 2).reshape(X.shape[1], -1),
80-
_info,
82+
83+
info, rank = _handle_info_rank(X, info, rank)
84+
if not isinstance(rank, dict):
85+
rank = _compute_rank_raw_array(
86+
np.hstack(X),
87+
info,
8188
rank=rank,
8289
scalings=None,
8390
log_ch_type="data",
@@ -92,8 +99,8 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
9299
log_rank=ci == 0,
93100
reg=reg,
94101
cov_method_params=cov_method_params,
95-
rank=_rank,
96-
info=_info,
102+
info=info,
103+
rank=rank,
97104
)
98105

99106
if norm_trace:
@@ -105,7 +112,7 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
105112
covs = np.stack(covs)
106113
C_ref = covs.mean(0)
107114

108-
return covs, C_ref, _info, _rank, dict(sample_weights=np.array(sample_weights))
115+
return covs, C_ref, info, rank, dict(sample_weights=np.array(sample_weights))
109116

110117

111118
def _xdawn_estimate(
@@ -118,6 +125,7 @@ def _xdawn_estimate(
118125
rank="full",
119126
):
120127
classes = np.unique(y)
128+
info, rank = _handle_info_rank(X, info, rank)
121129

122130
# Retrieve or compute whitening covariance
123131
if R is None:
@@ -143,7 +151,14 @@ def _xdawn_estimate(
143151

144152
covs.append(R)
145153
C_ref = R
146-
rank = rank if isinstance(rank, dict) else None
154+
if not isinstance(rank, dict):
155+
rank = _compute_rank_raw_array(
156+
np.hstack(X),
157+
info,
158+
rank=rank,
159+
scalings=None,
160+
log_ch_type="data",
161+
)
147162
return covs, C_ref, info, rank, dict()
148163

149164

mne/decoding/base.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from sklearn.utils import check_array, check_X_y, indexable
2323
from sklearn.utils.validation import check_is_fitted
2424

25-
from .._fiff.meas_info import create_info
26-
from ..cov import _compute_rank_raw_array
2725
from ..parallel import parallel_func
2826
from ..utils import _pl, logger, pinv, verbose, warn
2927
from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged
@@ -70,7 +68,7 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
7068
preserved for compatibility.
7169
If None, no restriction will be applied. Defaults to None.
7270
R_func : callable | None
73-
If provided GED will be performed on (S, R_func(S,R)).
71+
If provided, GED will be performed on (S, R_func(S,R)).
7472
7573
Attributes
7674
----------
@@ -88,7 +86,10 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
8886
CSP
8987
SPoC
9088
SSD
91-
mne.preprocessing.Xdawn
89+
90+
Notes
91+
-----
92+
.. versionadded:: 1.10
9293
"""
9394

9495
def __init__(
@@ -125,23 +126,6 @@ def fit(self, X, y=None):
125126
self.mod_ged_callable if self.mod_ged_callable is not None else _no_op_mod
126127
)
127128

128-
# If restriction to be done, info and rank should exist.
129-
if self.restr_type is not None and C_ref is not None:
130-
if info is None:
131-
# use mag instead of eeg to avoid the cov EEG projection warning
132-
info = create_info(C_ref.shape[0], 1000.0, "mag")
133-
if isinstance(rank, dict):
134-
rank = dict(mag=sum(rank.values()))
135-
136-
if rank is None:
137-
rank = _compute_rank_raw_array(
138-
np.hstack(X),
139-
info,
140-
rank=None,
141-
scalings=None,
142-
log_ch_type="data",
143-
)
144-
145129
if self.dec_type == "single":
146130
if len(covs) > 2:
147131
weights = (

mne/decoding/csp.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.linalg import eigh
1010
from sklearn.utils.validation import check_is_fitted
1111

12-
from .._fiff.meas_info import create_info
12+
from .._fiff.meas_info import Info, create_info
1313
from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh
1414
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
1515
from ..evoked import EvokedArray
@@ -70,6 +70,26 @@ class CSP(_GEDTransformer):
7070
Parameters to pass to :func:`mne.compute_covariance`.
7171
7272
.. versionadded:: 0.16
73+
74+
restr_type : "restricting" | "whitening" | None
75+
Restricting transformation for covariance matrices before performing
76+
generalized eigendecomposition.
77+
If "restricting" only restriction to the principal subspace of signal_cov
78+
will be performed.
79+
If "whitening", covariance matrices will be additionally rescaled according
80+
to the whitening for the signal_cov.
81+
If None, no restriction will be applied. Defaults to "restricting".
82+
83+
.. versionadded:: 1.10
84+
info : mne.Info | None
85+
The mne.Info object with information about the sensors and methods of
86+
measurement used for covariance estimation and generalized
87+
eigendecomposition.
88+
If None, one channel type and no projections will be assumed and if
89+
rank is dict, it will be sum of ranks per channel type.
90+
Defaults to None.
91+
92+
.. versionadded:: 1.10
7393
%(rank_none)s
7494
7595
.. versionadded:: 0.17
@@ -113,11 +133,14 @@ def __init__(
113133
transform_into="average_power",
114134
norm_trace=False,
115135
cov_method_params=None,
136+
restr_type="restricting",
137+
info=None,
116138
rank=None,
117139
component_order="mutual_info",
118140
):
119141
# Init default CSP
120142
self.n_components = n_components
143+
self.info = info
121144
self.rank = rank
122145
self.reg = reg
123146
self.cov_est = cov_est
@@ -126,12 +149,14 @@ def __init__(
126149
self.norm_trace = norm_trace
127150
self.cov_method_params = cov_method_params
128151
self.component_order = component_order
152+
self.restr_type = restr_type
129153

130154
cov_callable = partial(
131155
_csp_estimate,
132156
reg=reg,
133157
cov_method_params=cov_method_params,
134158
cov_est=cov_est,
159+
info=info,
135160
rank=rank,
136161
norm_trace=norm_trace,
137162
)
@@ -140,7 +165,7 @@ def __init__(
140165
n_components=n_components,
141166
cov_callable=cov_callable,
142167
mod_ged_callable=mod_ged_callable,
143-
restr_type="restricting",
168+
restr_type=restr_type,
144169
R_func=sum,
145170
)
146171

@@ -172,6 +197,12 @@ def _validate_params(self, *, y):
172197
n_classes = len(self.classes_)
173198
if n_classes < 2:
174199
raise ValueError(f"n_classes must be >= 2, but got {n_classes} class")
200+
_check_option(
201+
"restr_type",
202+
self.restr_type,
203+
("restricting", "whitening", None),
204+
)
205+
_validate_type(self.info, (Info, None), "info")
175206

176207
def fit(self, X, y):
177208
"""Estimate the CSP decomposition on epochs.

0 commit comments

Comments
 (0)