22# License: BSD-3-Clause
33# Copyright the MNE-Python contributors.
44
5+ from collections .abc import Mapping
6+
57import numpy as np
68from scipy import linalg
79
1315from ..epochs import BaseEpochs
1416from ..evoked import Evoked , EvokedArray
1517from ..io import BaseRaw
16- from ..utils import _check_option , logger , pinv
18+ from ..utils import _check_option , _validate_type , logger , pinv
1719
1820
1921def _construct_signal_from_epochs (epochs , events , sfreq , tmin ):
@@ -275,6 +277,22 @@ def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None
275277 R_func = None ,
276278 )
277279
280+ def _validate_params (self , X ):
281+ _validate_type (self .n_components , int , "n_components" )
282+
283+ # reg is validated in _regularized_covariance
284+
285+ if self .signal_cov is not None :
286+ if isinstance (self .signal_cov , Covariance ):
287+ self .signal_cov = self .signal_cov .data
288+ elif not isinstance (self .signal_cov , np .ndarray ):
289+ raise ValueError ("signal_cov should be mne.Covariance or np.ndarray" )
290+ if not np .array_equal (self .signal_cov .shape , np .tile (X .shape [1 ], 2 )):
291+ raise ValueError (
292+ "signal_cov data should be of shape (n_channels, n_channels)"
293+ )
294+ _validate_type (self .method_params , (Mapping , None ))
295+
278296 def fit (self , X , y = None ):
279297 """Fit Xdawn spatial filters.
280298
@@ -291,7 +309,7 @@ def fit(self, X, y=None):
291309 The Xdawn instance.
292310 """
293311 X , y = self ._check_Xy (X , y )
294-
312+ self . _validate_params ( X )
295313 # Main function
296314 self .classes_ = np .unique (y )
297315 self .filters_ , self .patterns_ , _ = _fit_xdawn (
0 commit comments