88import numbers
99
1010import numpy as np
11+ import scipy .linalg
1112from sklearn import model_selection as models
1213from sklearn .base import ( # noqa: F401
1314 BaseEstimator ,
1415 MetaEstimatorMixin ,
15- TransformerMixin ,
1616 clone ,
1717 is_classifier ,
1818)
1919from sklearn .linear_model import LogisticRegression
2020from sklearn .metrics import check_scoring
2121from sklearn .model_selection import KFold , StratifiedKFold , check_cv
2222from sklearn .utils import check_array , check_X_y , indexable
23+ from sklearn .utils .validation import check_is_fitted
2324
2425from ..parallel import parallel_func
2526from ..utils import _pl , logger , pinv , verbose , warn
26- from .ged import _get_ssd_rank , _handle_restr_map , _smart_ajd , _smart_ged
27+ from .ged import _handle_restr_map , _smart_ajd , _smart_ged
2728from .transformer import MNETransformerMixin
2829
2930
@@ -55,7 +56,7 @@ class GEDTransformer(MNETransformerMixin, BaseEstimator):
5556 (except the last) returned by cov_callable is decomposed with the last
5657 covariance. In this case, number of covariances should be number of classes + 1.
5758 Defaults to "single".
58- restr_map : "restricting" | "whitening" | "ssd" | None
59+ restr_type : "restricting" | "whitening" | "ssd" | None
5960 Restricting transformation for covariance matrices before performing GED.
6061 If "restricting" only restriction to the principal subspace of the C_ref
6162 will be performed.
@@ -94,7 +95,7 @@ def __init__(
9495 mod_ged_callable ,
9596 mod_params ,
9697 dec_type = "single" ,
97- restr_map = None ,
98+ restr_type = None ,
9899 R_func = None ,
99100 ):
100101 self .n_filters = n_filters
@@ -103,27 +104,35 @@ def __init__(
103104 self .mod_ged_callable = mod_ged_callable
104105 self .mod_params = mod_params
105106 self .dec_type = dec_type
106- self .restr_map = restr_map
107+ self .restr_type = restr_type
107108 self .R_func = R_func
108109
109110 def fit (self , X , y = None ):
110111 """..."""
112+ X , y = self ._check_data (
113+ X ,
114+ y = y ,
115+ fit = True ,
116+ return_y = True ,
117+ atleast_3d = False if self .restr_type == "ssd" else True ,
118+ )
111119 covs , C_ref , info , rank , kwargs = self .cov_callable (X , y , ** self .cov_params )
120+ self ._validate_covariances (covs + [C_ref ])
112121 if self .dec_type == "single" :
113122 if len (covs ) > 2 :
123+ covs = np .array (covs )
114124 sample_weights = kwargs ["sample_weights" ]
115- restr_map = _handle_restr_map (C_ref , self .restr_map , info , rank )
125+ restr_map = _handle_restr_map (C_ref , self .restr_type , info , rank )
116126 evecs = _smart_ajd (covs , restr_map , weights = sample_weights )
117127 evals = None
118128 else :
119129 S = covs [0 ]
120130 R = covs [1 ]
121- if self .restr_map == "ssd" :
122- rank = _get_ssd_rank (S , R , info , rank )
131+ if self .restr_type == "ssd" :
123132 mult_order = "ssd"
124133 else :
125134 mult_order = None
126- restr_map = _handle_restr_map (C_ref , self .restr_map , info , rank )
135+ restr_map = _handle_restr_map (C_ref , self .restr_type , info , rank )
127136 evals , evecs = _smart_ged (
128137 S , R , restr_map , R_func = self .R_func , mult_order = mult_order
129138 )
@@ -133,19 +142,26 @@ def fit(self, X, y=None):
133142 )
134143 self .evals_ = evals
135144 self .filters_ = evecs .T
136- if self .restr_map == "ssd" :
145+ if self .restr_type == "ssd" :
137146 self .patterns_ = np .linalg .pinv (evecs )
138147 else :
139148 self .patterns_ = pinv (evecs )
140149
141150 elif self .dec_type == "multi" :
142151 self .classes_ = np .unique (y )
143152 R = covs [- 1 ]
144- restr_map = _handle_restr_map (C_ref , self .restr_map , info , rank )
153+ if self .restr_type == "ssd" :
154+ mult_order = "ssd"
155+ else :
156+ mult_order = None
157+ restr_map = _handle_restr_map (C_ref , self .restr_type , info , rank )
145158 all_evals , all_evecs , all_patterns = list (), list (), list ()
146159 for i in range (len (self .classes_ )):
147160 S = covs [i ]
148- evals , evecs = _smart_ged (S , R , restr_map , R_func = self .R_func )
161+
162+ evals , evecs = _smart_ged (
163+ S , R , restr_map , R_func = self .R_func , mult_order = mult_order
164+ )
149165
150166 evals , evecs = self .mod_ged_callable (
151167 evals , evecs , covs , ** self .mod_params , ** kwargs
@@ -161,9 +177,48 @@ def fit(self, X, y=None):
161177
162178 def transform (self , X ):
163179 """..."""
164- X = np .dot (self .filters_ , X )
180+ check_is_fitted (self , "filters_" )
181+ X = self ._check_data (X )
182+ if self .dec_type == "single" :
183+ pick_filters = self .filters_ [: self .n_filters ]
184+ elif self .dec_type == "multi" :
185+ pick_filters = np .concatenate (
186+ [
187+ self .filters_ [i , : self .n_filters ]
188+ for i in range (self .filters_ .shape [0 ])
189+ ],
190+ axis = 0 ,
191+ )
192+ X = np .asarray ([pick_filters @ epoch for epoch in X ])
165193 return X
166194
195+ def _validate_covariances (self , covs ):
196+ for cov in covs :
197+ if cov is None :
198+ continue
199+ is_sym = scipy .linalg .issymmetric (cov , rtol = 1e-10 , atol = 1e-11 )
200+ if not is_sym :
201+ raise ValueError (
202+ "One of covariances or C_ref is not symmetric, "
203+ "check your cov_callable"
204+ )
205+ if not np .all (np .linalg .eigvals (cov ) >= 0 ):
206+ ValueError (
207+ "One of covariances or C_ref has negative eigenvalues, "
208+ "check your cov_callable"
209+ )
210+
211+ def __sklearn_tags__ (self ):
212+ """Tag the transformer."""
213+ tags = super ().__sklearn_tags__ ()
214+ tags .estimator_type = "transformer"
215+ # Can be a transformer where S and R covs are not based on y classes.
216+ tags .target_tags .required = False
217+ tags .target_tags .one_d_labels = True
218+ tags .input_tags .two_d_array = True
219+ tags .input_tags .three_d_array = True
220+ return tags
221+
167222
168223class LinearModel (MetaEstimatorMixin , BaseEstimator ):
169224 """Compute and store patterns from linear models.
0 commit comments