1515from ..utils import _verbose_safe_false , logger
1616
1717
18- def _concat_cov (x_class , * , cov_kind , log_rank , reg , cov_method_params , rank , info ):
18+ def _concat_cov (x_class , * , cov_kind , log_rank , reg , cov_method_params , info , rank ):
1919 """Concatenate epochs before computing the covariance."""
2020 _ , n_channels , _ = x_class .shape
2121
@@ -34,7 +34,7 @@ def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, in
3434 return cov , n_channels # the weight here is just the number of channels
3535
3636
37- def _epoch_cov (x_class , * , cov_kind , log_rank , reg , cov_method_params , rank , info ):
37+ def _epoch_cov (x_class , * , cov_kind , log_rank , reg , cov_method_params , info , rank ):
3838 """Mean of per-epoch covariances."""
3939 name = reg if isinstance (reg , str ) else "empirical"
4040 name += " with shrinkage" if isinstance (reg , float ) else ""
@@ -62,22 +62,29 @@ def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, inf
6262 return cov , weight
6363
6464
65- def _csp_estimate (X , y , reg , cov_method_params , cov_est , rank , norm_trace ):
65+ def _handle_info_rank (X , info , rank ):
66+ if info is None :
67+ # use mag instead of eeg to avoid the cov EEG projection warning
68+ info = create_info (X .shape [1 ], 1000.0 , "mag" )
69+ if isinstance (rank , dict ):
70+ rank = dict (mag = sum (rank .values ()))
71+
72+ return info , rank
73+
74+
75+ def _csp_estimate (X , y , reg , cov_method_params , cov_est , info , rank , norm_trace ):
6676 _ , n_channels , _ = X .shape
6777 classes_ = np .unique (y )
6878 if cov_est == "concat" :
6979 cov_estimator = _concat_cov
7080 elif cov_est == "epoch" :
7181 cov_estimator = _epoch_cov
72- # Someday we could allow the user to pass this, then we wouldn't need to convert
73- # but in the meantime they can use a pipeline with a scaler
74- _info = create_info (n_channels , 1000.0 , "mag" )
75- if isinstance (rank , dict ):
76- _rank = {"mag" : sum (rank .values ())}
77- else :
78- _rank = _compute_rank_raw_array (
79- X .transpose (1 , 0 , 2 ).reshape (X .shape [1 ], - 1 ),
80- _info ,
82+
83+ info , rank = _handle_info_rank (X , info , rank )
84+ if not isinstance (rank , dict ):
85+ rank = _compute_rank_raw_array (
86+ np .hstack (X ),
87+ info ,
8188 rank = rank ,
8289 scalings = None ,
8390 log_ch_type = "data" ,
@@ -92,8 +99,8 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
9299 log_rank = ci == 0 ,
93100 reg = reg ,
94101 cov_method_params = cov_method_params ,
95- rank = _rank ,
96- info = _info ,
102+ info = info ,
103+ rank = rank ,
97104 )
98105
99106 if norm_trace :
@@ -105,7 +112,7 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
105112 covs = np .stack (covs )
106113 C_ref = covs .mean (0 )
107114
108- return covs , C_ref , _info , _rank , dict (sample_weights = np .array (sample_weights ))
115+ return covs , C_ref , info , rank , dict (sample_weights = np .array (sample_weights ))
109116
110117
111118def _xdawn_estimate (
@@ -118,6 +125,7 @@ def _xdawn_estimate(
118125 rank = "full" ,
119126):
120127 classes = np .unique (y )
128+ info , rank = _handle_info_rank (X , info , rank )
121129
122130 # Retrieve or compute whitening covariance
123131 if R is None :
@@ -143,7 +151,14 @@ def _xdawn_estimate(
143151
144152 covs .append (R )
145153 C_ref = R
146- rank = rank if isinstance (rank , dict ) else None
154+ if not isinstance (rank , dict ):
155+ rank = _compute_rank_raw_array (
156+ np .hstack (X ),
157+ info ,
158+ rank = rank ,
159+ scalings = None ,
160+ log_ch_type = "data" ,
161+ )
147162 return covs , C_ref , info , rank , dict ()
148163
149164
0 commit comments