@@ -46,8 +46,10 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
4646 C_ref, info, rank can be None, while kwargs can be empty dict.
4747 mod_ged_callable : callable | None
4848 Function used to modify (e.g. sort or normalize) generalized
49- eigenvalues and eigenvectors. If None, evals and evecs will be ordered according
50- to :func:`~scipy.linalg.eigh` default. Defaults to None
49+ eigenvalues and eigenvectors. It should accept as arguments evals, evecs
50+ and also covs and optional kwargs returned by cov_callable. It should return
51+ only sorted and/or modified evals and evecs. If None, evals and evecs will be
52+ ordered according to :func:`~scipy.linalg.eigh` default. Defaults to None
5153 dec_type : "single" | "multi"
5254 When "single" and cov_callable returns > 2 covariances,
5355 approximate joint diagonalization based on Pham's algorithm
@@ -93,15 +95,13 @@ def __init__(
9395 cov_callable ,
9496 * ,
9597 mod_ged_callable = None ,
96- mod_params = None ,
9798 dec_type = "single" ,
9899 restr_type = None ,
99100 R_func = None ,
100101 ):
101102 self .n_components = n_components
102103 self .cov_callable = cov_callable
103104 self .mod_ged_callable = mod_ged_callable
104- self .mod_params = mod_params
105105 self .dec_type = dec_type
106106 self .restr_type = restr_type
107107 self .R_func = R_func
@@ -119,23 +119,24 @@ def fit(self, X, y=None):
119119 covs = np .stack (covs )
120120 self ._validate_covariances (covs )
121121 self ._validate_covariances ([C_ref ])
122- mod_params = self .mod_params if self .mod_params is not None else dict ()
123122 mod_ged_callable = (
124123 self .mod_ged_callable if self .mod_ged_callable is not None else _no_op_mod
125124 )
126125 if self .dec_type == "single" :
127126 if len (covs ) > 2 :
128- sample_weights = kwargs ["sample_weights" ]
127+ weights = (
128+ kwargs ["sample_weights" ] if "sample_weights" in kwargs else None
129+ )
129130 restr_mat = _handle_restr_mat (C_ref , self .restr_type , info , rank )
130- evecs = _smart_ajd (covs , restr_mat , weights = sample_weights )
131+ evecs = _smart_ajd (covs , restr_mat , weights = weights )
131132 evals = None
132133 else :
133134 S = covs [0 ]
134135 R = covs [1 ]
135136 restr_mat = _handle_restr_mat (C_ref , self .restr_type , info , rank )
136137 evals , evecs = _smart_ged (S , R , restr_mat , R_func = self .R_func )
137138
138- evals , evecs = mod_ged_callable (evals , evecs , covs , ** mod_params , ** kwargs )
139+ evals , evecs = mod_ged_callable (evals , evecs , covs , ** kwargs )
139140 self .evals_ = evals
140141 self .filters_ = evecs .T
141142 if self .restr_type == "ssd" :
@@ -153,9 +154,7 @@ def fit(self, X, y=None):
153154
154155 evals , evecs = _smart_ged (S , R , restr_mat , R_func = self .R_func )
155156
156- evals , evecs = mod_ged_callable (
157- evals , evecs , covs , ** mod_params , ** kwargs
158- )
157+ evals , evecs = mod_ged_callable (evals , evecs , covs , ** kwargs )
159158 all_evals .append (evals )
160159 all_evecs .append (evecs .T )
161160 all_patterns .append (np .linalg .pinv (evecs ))
0 commit comments