Skip to content

Commit 029691b

Browse files
committed
add _validate_params for _XdawnTransformer
1 parent 99d297e commit 029691b

1 file changed

Lines changed: 20 additions & 2 deletions

File tree

mne/preprocessing/xdawn.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# License: BSD-3-Clause
33
# Copyright the MNE-Python contributors.
44

5+
from collections.abc import Mapping
6+
57
import numpy as np
68
from scipy import linalg
79

@@ -13,7 +15,7 @@
1315
from ..epochs import BaseEpochs
1416
from ..evoked import Evoked, EvokedArray
1517
from ..io import BaseRaw
16-
from ..utils import _check_option, logger, pinv
18+
from ..utils import _check_option, _validate_type, logger, pinv
1719

1820

1921
def _construct_signal_from_epochs(epochs, events, sfreq, tmin):
@@ -275,6 +277,22 @@ def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None
275277
R_func=None,
276278
)
277279

280+
def _validate_params(self, X):
281+
_validate_type(self.n_components, int, "n_components")
282+
283+
# reg is validated in _regularized_covariance
284+
285+
if self.signal_cov is not None:
286+
if isinstance(self.signal_cov, Covariance):
287+
self.signal_cov = self.signal_cov.data
288+
elif not isinstance(self.signal_cov, np.ndarray):
289+
raise ValueError("signal_cov should be mne.Covariance or np.ndarray")
290+
if not np.array_equal(self.signal_cov.shape, np.tile(X.shape[1], 2)):
291+
raise ValueError(
292+
"signal_cov data should be of shape (n_channels, n_channels)"
293+
)
294+
_validate_type(self.method_params, (Mapping, None))
295+
278296
def fit(self, X, y=None):
279297
"""Fit Xdawn spatial filters.
280298
@@ -291,7 +309,7 @@ def fit(self, X, y=None):
291309
The Xdawn instance.
292310
"""
293311
X, y = self._check_Xy(X, y)
294-
312+
self._validate_params(X)
295313
# Main function
296314
self.classes_ = np.unique(y)
297315
self.filters_, self.patterns_, _ = _fit_xdawn(

0 commit comments

Comments
 (0)