88import numpy as np
99from scipy import linalg
1010
11+ from .._fiff .meas_info import Info
1112from .._fiff .pick import _pick_data_channels , pick_info
1213from ..cov import Covariance , _regularized_covariance
1314from ..decoding ._covs_ged import _xdawn_estimate
@@ -246,6 +247,30 @@ class _XdawnTransformer(_GEDTransformer):
246247 Parameters to pass to :func:`mne.compute_covariance`.
247248
248249 .. versionadded:: 0.16
250+ restr_type : "restricting" | "whitening" | None
251+ Restricting transformation for covariance matrices before performing
252+ generalized eigendecomposition.
253+ If "restricting" only restriction to the principal subspace of signal_cov
254+ will be performed.
255+ If "whitening", covariance matrices will be additionally rescaled according
256+ to the whitening for the signal_cov.
257+ If None, no restriction will be applied. Defaults to None.
258+
259+ .. versionadded:: 1.10
260+ info : mne.Info | None
261+ The mne.Info object with information about the sensors and methods of
262+ measurement used for covariance estimation and generalized
263+ eigendecomposition.
264+ If None, one channel type and no projections will be assumed and if
265+ rank is dict, it will be sum of ranks per channel type.
266+ Defaults to None.
267+
268+ .. versionadded:: 1.10
269+ %(rank)s
270+ Defaults to "full".
271+
272+ .. versionadded:: 1.10
273+
249274
250275 Attributes
251276 ----------
@@ -257,21 +282,39 @@ class _XdawnTransformer(_GEDTransformer):
257282 The Xdawn patterns used to restore the signals for each event type.
258283 """
259284
260- def __init__ (self , n_components = 2 , reg = None , signal_cov = None , method_params = None ):
285+ def __init__ (
286+ self ,
287+ n_components = 2 ,
288+ reg = None ,
289+ signal_cov = None ,
290+ method_params = None ,
291+ restr_type = None ,
292+ info = None ,
293+ rank = "full" ,
294+ ):
261295 """Init."""
262296 self .n_components = n_components
263297 self .signal_cov = signal_cov
264298 self .reg = reg
265299 self .method_params = method_params
300+ self .restr_type = restr_type
301+ self .info = info
302+ self .rank = rank
266303
267304 cov_callable = partial (
268- _xdawn_estimate , reg = reg , cov_method_params = method_params , R = signal_cov
305+ _xdawn_estimate ,
306+ reg = reg ,
307+ cov_method_params = method_params ,
308+ R = signal_cov ,
309+ info = info ,
310+ rank = rank ,
269311 )
270312 super ().__init__ (
271313 n_components = n_components ,
272314 cov_callable = cov_callable ,
273315 mod_ged_callable = _xdawn_mod ,
274316 dec_type = "multi" ,
317+ restr_type = restr_type ,
275318 )
276319
277320 def _validate_params (self , X ):
@@ -288,7 +331,13 @@ def _validate_params(self, X):
288331 raise ValueError (
289332 "signal_cov data should be of shape (n_channels, n_channels)"
290333 )
291- _validate_type (self .method_params , (Mapping , None ))
334+ _validate_type (self .method_params , (Mapping , None ), "method_params" )
335+ _check_option (
336+ "restr_type" ,
337+ self .restr_type ,
338+ ("restricting" , "whitening" , None ),
339+ )
340+ _validate_type (self .info , (Info , None ), "info" )
292341
293342 def fit (self , X , y = None ):
294343 """Fit Xdawn spatial filters.
0 commit comments