Skip to content

Commit 2a1c5cb

Browse files
committed
Add big sklearn compliance test
1 parent 0d58c8d commit 2a1c5cb

7 files changed

Lines changed: 241 additions & 63 deletions

File tree

mne/decoding/base.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,23 @@
88
import numbers
99

1010
import numpy as np
11+
import scipy.linalg
1112
from sklearn import model_selection as models
1213
from sklearn.base import ( # noqa: F401
1314
BaseEstimator,
1415
MetaEstimatorMixin,
15-
TransformerMixin,
1616
clone,
1717
is_classifier,
1818
)
1919
from sklearn.linear_model import LogisticRegression
2020
from sklearn.metrics import check_scoring
2121
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
2222
from sklearn.utils import check_array, check_X_y, indexable
23+
from sklearn.utils.validation import check_is_fitted
2324

2425
from ..parallel import parallel_func
2526
from ..utils import _pl, logger, pinv, verbose, warn
26-
from .ged import _get_ssd_rank, _handle_restr_map, _smart_ajd, _smart_ged
27+
from .ged import _handle_restr_map, _smart_ajd, _smart_ged
2728
from .transformer import MNETransformerMixin
2829

2930

@@ -55,7 +56,7 @@ class GEDTransformer(MNETransformerMixin, BaseEstimator):
5556
(except the last) returned by cov_callable is decomposed with the last
5657
covariance. In this case, number of covariances should be number of classes + 1.
5758
Defaults to "single".
58-
restr_map : "restricting" | "whitening" | "ssd" | None
59+
restr_type : "restricting" | "whitening" | "ssd" | None
5960
Restricting transformation for covariance matrices before performing GED.
6061
If "restricting" only restriction to the principal subspace of the C_ref
6162
will be performed.
@@ -94,7 +95,7 @@ def __init__(
9495
mod_ged_callable,
9596
mod_params,
9697
dec_type="single",
97-
restr_map=None,
98+
restr_type=None,
9899
R_func=None,
99100
):
100101
self.n_filters = n_filters
@@ -103,27 +104,35 @@ def __init__(
103104
self.mod_ged_callable = mod_ged_callable
104105
self.mod_params = mod_params
105106
self.dec_type = dec_type
106-
self.restr_map = restr_map
107+
self.restr_type = restr_type
107108
self.R_func = R_func
108109

109110
def fit(self, X, y=None):
110111
"""..."""
112+
X, y = self._check_data(
113+
X,
114+
y=y,
115+
fit=True,
116+
return_y=True,
117+
atleast_3d=False if self.restr_type == "ssd" else True,
118+
)
111119
covs, C_ref, info, rank, kwargs = self.cov_callable(X, y, **self.cov_params)
120+
self._validate_covariances(covs + [C_ref])
112121
if self.dec_type == "single":
113122
if len(covs) > 2:
123+
covs = np.array(covs)
114124
sample_weights = kwargs["sample_weights"]
115-
restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank)
125+
restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank)
116126
evecs = _smart_ajd(covs, restr_map, weights=sample_weights)
117127
evals = None
118128
else:
119129
S = covs[0]
120130
R = covs[1]
121-
if self.restr_map == "ssd":
122-
rank = _get_ssd_rank(S, R, info, rank)
131+
if self.restr_type == "ssd":
123132
mult_order = "ssd"
124133
else:
125134
mult_order = None
126-
restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank)
135+
restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank)
127136
evals, evecs = _smart_ged(
128137
S, R, restr_map, R_func=self.R_func, mult_order=mult_order
129138
)
@@ -133,19 +142,26 @@ def fit(self, X, y=None):
133142
)
134143
self.evals_ = evals
135144
self.filters_ = evecs.T
136-
if self.restr_map == "ssd":
145+
if self.restr_type == "ssd":
137146
self.patterns_ = np.linalg.pinv(evecs)
138147
else:
139148
self.patterns_ = pinv(evecs)
140149

141150
elif self.dec_type == "multi":
142151
self.classes_ = np.unique(y)
143152
R = covs[-1]
144-
restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank)
153+
if self.restr_type == "ssd":
154+
mult_order = "ssd"
155+
else:
156+
mult_order = None
157+
restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank)
145158
all_evals, all_evecs, all_patterns = list(), list(), list()
146159
for i in range(len(self.classes_)):
147160
S = covs[i]
148-
evals, evecs = _smart_ged(S, R, restr_map, R_func=self.R_func)
161+
162+
evals, evecs = _smart_ged(
163+
S, R, restr_map, R_func=self.R_func, mult_order=mult_order
164+
)
149165

150166
evals, evecs = self.mod_ged_callable(
151167
evals, evecs, covs, **self.mod_params, **kwargs
@@ -161,9 +177,48 @@ def fit(self, X, y=None):
161177

162178
def transform(self, X):
163179
"""..."""
164-
X = np.dot(self.filters_, X)
180+
check_is_fitted(self, "filters_")
181+
X = self._check_data(X)
182+
if self.dec_type == "single":
183+
pick_filters = self.filters_[: self.n_filters]
184+
elif self.dec_type == "multi":
185+
pick_filters = np.concatenate(
186+
[
187+
self.filters_[i, : self.n_filters]
188+
for i in range(self.filters_.shape[0])
189+
],
190+
axis=0,
191+
)
192+
X = np.asarray([pick_filters @ epoch for epoch in X])
165193
return X
166194

195+
def _validate_covariances(self, covs):
196+
for cov in covs:
197+
if cov is None:
198+
continue
199+
is_sym = scipy.linalg.issymmetric(cov, rtol=1e-10, atol=1e-11)
200+
if not is_sym:
201+
raise ValueError(
202+
"One of covariances or C_ref is not symmetric, "
203+
"check your cov_callable"
204+
)
205+
if not np.all(np.linalg.eigvals(cov) >= 0):
206+
ValueError(
207+
"One of covariances or C_ref has negative eigenvalues, "
208+
"check your cov_callable"
209+
)
210+
211+
def __sklearn_tags__(self):
212+
"""Tag the transformer."""
213+
tags = super().__sklearn_tags__()
214+
tags.estimator_type = "transformer"
215+
# Can be a transformer where S and R covs are not based on y classes.
216+
tags.target_tags.required = False
217+
tags.target_tags.one_d_labels = True
218+
tags.input_tags.two_d_array = True
219+
tags.input_tags.three_d_array = True
220+
return tags
221+
167222

168223
class LinearModel(MetaEstimatorMixin, BaseEstimator):
169224
"""Compute and store patterns from linear models.

mne/decoding/covs_ged.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from .._fiff.meas_info import Info, create_info
1111
from .._fiff.pick import _picks_to_idx
1212
from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance
13+
from ..defaults import _handle_default
1314
from ..filter import filter_data
15+
from ..rank import compute_rank
1416
from ..utils import _verbose_safe_false, logger, pinv
1517

1618

@@ -293,6 +295,26 @@ def _ssd_estimate(
293295
)
294296
covs = [S, R]
295297
C_ref = S
298+
299+
all_ranks = list()
300+
for cov in covs:
301+
r = list(
302+
compute_rank(
303+
Covariance(
304+
cov,
305+
info.ch_names,
306+
list(),
307+
list(),
308+
0,
309+
verbose=_verbose_safe_false(),
310+
),
311+
rank,
312+
_handle_default("scalings_cov_rank", None),
313+
info,
314+
).values()
315+
)[0]
316+
all_ranks.append(r)
317+
rank = np.min(all_ranks)
296318
return covs, C_ref, info, rank, dict()
297319

298320

mne/decoding/csp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
_csp_mod,
143143
mod_params,
144144
dec_type="single",
145-
restr_map="restricting",
145+
restr_type="restricting",
146146
R_func=sum,
147147
)
148148

@@ -911,7 +911,7 @@ def __init__(
911911
_spoc_mod,
912912
mod_params,
913913
dec_type="single",
914-
restr_map=None,
914+
restr_type=None,
915915
R_func=None,
916916
)
917917

mne/decoding/ged.py

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,26 @@
66
import scipy.linalg
77

88
from ..cov import Covariance, _smart_eigh, compute_whitener
9-
from ..defaults import _handle_default
10-
from ..rank import compute_rank
11-
from ..utils import _verbose_safe_false, logger
9+
from ..utils import logger
1210

1311

14-
def _handle_restr_map(C_ref, restr_map, info, rank):
12+
def _handle_restr_map(C_ref, restr_type, info, rank):
1513
"""Get restricting map to C_ref rank-dimensional principal subspace.
1614
1715
Returns matrix of shape (rank, n_chs) used to restrict or
1816
restrict+rescale (whiten) covariances matrices.
1917
"""
20-
if C_ref is None or restr_map is None:
18+
if C_ref is None or restr_type is None:
2119
return None
22-
if restr_map == "whitening":
20+
if restr_type == "whitening":
2321
projs = info["projs"]
2422
C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], projs, 0)
25-
restr_map = compute_whitener(C_ref_cov, info, rank=rank, pca=True)
26-
elif restr_map == "ssd":
23+
restr_map = compute_whitener(C_ref_cov, info, rank=rank, pca=True)[0]
24+
elif restr_type == "ssd":
2725
restr_map = _get_ssd_whitener(C_ref, rank)
28-
elif restr_map == "restricting":
26+
elif restr_type == "restricting":
2927
restr_map = _get_restricting_map(C_ref, info, rank)
30-
elif isinstance(restr_map, callable):
28+
elif isinstance(restr_type, callable):
3129
pass
3230
else:
3331
raise ValueError(
@@ -147,6 +145,15 @@ def _ajd_pham(X, eps=1e-6, max_iter=15):
147145
return V, D
148146

149147

148+
def _is_all_pos_def(covs):
149+
for cov in covs:
150+
try:
151+
_ = scipy.linalg.cholesky(cov)
152+
except np.linalg.LinAlgError:
153+
return False
154+
return True
155+
156+
150157
def _smart_ajd(covs, restr_map=None, weights=None):
151158
"""Perform smart approximate joint diagonalization.
152159
@@ -157,6 +164,12 @@ def _smart_ajd(covs, restr_map=None, weights=None):
157164
The matrix of generalized eigenvectors is of shape (n_chs, r).
158165
"""
159166
if restr_map is None:
167+
is_all_pos_def = _is_all_pos_def(covs)
168+
if not is_all_pos_def:
169+
raise ValueError(
170+
"If C_ref is not provided by covariance estimator, "
171+
"all the covs should be positive definite"
172+
)
160173
evecs, D = _ajd_pham(covs)
161174
return evecs
162175

@@ -191,42 +204,6 @@ def _normalize_eigenvectors(evecs, covs, sample_weights):
191204
return evecs
192205

193206

194-
def _get_ssd_rank(S, R, info, rank):
195-
# find ranks of covariance matrices
196-
rank_signal = list(
197-
compute_rank(
198-
Covariance(
199-
S,
200-
info.ch_names,
201-
list(),
202-
list(),
203-
0,
204-
verbose=_verbose_safe_false(),
205-
),
206-
rank,
207-
_handle_default("scalings_cov_rank", None),
208-
info,
209-
).values()
210-
)[0]
211-
rank_noise = list(
212-
compute_rank(
213-
Covariance(
214-
R,
215-
info.ch_names,
216-
list(),
217-
list(),
218-
0,
219-
verbose=_verbose_safe_false(),
220-
),
221-
rank,
222-
_handle_default("scalings_cov_rank", None),
223-
info,
224-
).values()
225-
)[0]
226-
rank = np.min([rank_signal, rank_noise]) # should be identical
227-
return rank
228-
229-
230207
def _get_ssd_whitener(S, rank):
231208
"""Perform dimensionality reduction on the covariance matrices."""
232209
n_channels = S.shape[0]

mne/decoding/ssd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(
137137
_ssd_mod,
138138
mod_params,
139139
dec_type="single",
140-
restr_map="ssd",
140+
restr_type="ssd",
141141
R_func=None,
142142
)
143143

0 commit comments

Comments
 (0)