@@ -46,7 +46,7 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
4646 mod_ged_callable : callable
4747 Function used to modify (e.g. sort or normalize) generalized
4848 eigenvalues and eigenvectors.
49- mod_params : dict
49+ mod_params : dict | None
5050 Parameters passed to mod_ged_callable.
5151 dec_type : "single" | "multi"
5252 When "single" and cov_callable returns > 2 covariances,
@@ -93,7 +93,8 @@ def __init__(
9393 cov_callable ,
9494 cov_params ,
9595 mod_ged_callable ,
96- mod_params ,
96+ * ,
97+ mod_params = None ,
9798 dec_type = "single" ,
9899 restr_type = None ,
99100 R_func = None ,
@@ -120,6 +121,7 @@ def fit(self, X, y=None):
120121 covs = np .stack (covs )
121122 self ._validate_covariances (covs )
122123 self ._validate_covariances ([C_ref ])
124+ mod_params = self .mod_params if self .mod_params is not None else dict ()
123125 if self .dec_type == "single" :
124126 if len (covs ) > 2 :
125127 sample_weights = kwargs ["sample_weights" ]
@@ -133,7 +135,7 @@ def fit(self, X, y=None):
133135 evals , evecs = _smart_ged (S , R , restr_mat , R_func = self .R_func )
134136
135137 evals , evecs = self .mod_ged_callable (
136- evals , evecs , covs , ** self . mod_params , ** kwargs
138+ evals , evecs , covs , ** mod_params , ** kwargs
137139 )
138140 self .evals_ = evals
139141 self .filters_ = evecs .T
@@ -153,7 +155,7 @@ def fit(self, X, y=None):
153155 evals , evecs = _smart_ged (S , R , restr_mat , R_func = self .R_func )
154156
155157 evals , evecs = self .mod_ged_callable (
156- evals , evecs , covs , ** self . mod_params , ** kwargs
158+ evals , evecs , covs , ** mod_params , ** kwargs
157159 )
158160 all_evals .append (evals )
159161 all_evecs .append (evecs .T )
0 commit comments