@@ -145,7 +145,7 @@ def test_sklearn_compliance(estimator, check):
145145 check (estimator )
146146
147147
148- def _get_X_y (event_id ):
148+ def _get_X_y (event_id , return_info = False ):
149149 raw = read_raw (raw_fname , preload = False )
150150 events = read_events (event_name )
151151 picks = pick_types (
@@ -166,6 +166,8 @@ def _get_X_y(event_id):
166166 )
167167 X = epochs .get_data (copy = False , units = dict (eeg = "uV" , grad = "fT/cm" , mag = "fT" ))
168168 y = epochs .events [:, - 1 ]
169+ if return_info :
170+ return X , y , epochs .info
169171 return X , y
170172
171173
@@ -386,3 +388,22 @@ def test__no_op_mod():
386388 assert evals is evals_no_op
387389 assert evecs is evecs_no_op
388390 assert sorter_no_op is None
391+
392+
393+ def test_get_spatial_filter ():
394+ """Test instantiation of spatial filter."""
395+ event_id = dict (aud_l = 1 , vis_l = 3 )
396+ X , y , info = _get_X_y (event_id , return_info = True )
397+
398+ ged = _GEDTransformer (
399+ n_components = 4 ,
400+ cov_callable = _mock_cov_callable ,
401+ mod_ged_callable = _mock_mod_ged_callable ,
402+ restr_type = "restricting" ,
403+ )
404+ ged .fit (X , y )
405+ sp_filter = ged .get_spatial_filter (info )
406+ assert sp_filter .patterns_method == "pinv"
407+ np .testing .assert_array_equal (sp_filter .filters , ged .filters_ )
408+ np .testing .assert_array_equal (sp_filter .patterns , ged .patterns_ )
409+ np .testing .assert_array_equal (sp_filter .evals , ged .evals_ )
0 commit comments