Skip to content

Commit 3c7df08

Browse files
committed
replace mod_params with partial as well
1 parent 85fb50f commit 3c7df08

2 files changed

Lines changed: 12 additions & 13 deletions

File tree

mne/decoding/base.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator):
4646
C_ref, info, rank can be None, while kwargs can be empty dict.
4747
mod_ged_callable : callable | None
4848
Function used to modify (e.g. sort or normalize) generalized
49-
eigenvalues and eigenvectors. If None, evals and evecs will be ordered according
50-
to :func:`~scipy.linalg.eigh` default. Defaults to None
49+
eigenvalues and eigenvectors. It should accept as arguments evals, evecs
50+
and also covs and optional kwargs returned by cov_callable. It should return
51+
only sorted and/or modified evals and evecs. If None, evals and evecs will be
52+
ordered according to :func:`~scipy.linalg.eigh` default. Defaults to None
5153
dec_type : "single" | "multi"
5254
When "single" and cov_callable returns > 2 covariances,
5355
approximate joint diagonalization based on Pham's algorithm
@@ -93,15 +95,13 @@ def __init__(
9395
cov_callable,
9496
*,
9597
mod_ged_callable=None,
96-
mod_params=None,
9798
dec_type="single",
9899
restr_type=None,
99100
R_func=None,
100101
):
101102
self.n_components = n_components
102103
self.cov_callable = cov_callable
103104
self.mod_ged_callable = mod_ged_callable
104-
self.mod_params = mod_params
105105
self.dec_type = dec_type
106106
self.restr_type = restr_type
107107
self.R_func = R_func
@@ -119,23 +119,24 @@ def fit(self, X, y=None):
119119
covs = np.stack(covs)
120120
self._validate_covariances(covs)
121121
self._validate_covariances([C_ref])
122-
mod_params = self.mod_params if self.mod_params is not None else dict()
123122
mod_ged_callable = (
124123
self.mod_ged_callable if self.mod_ged_callable is not None else _no_op_mod
125124
)
126125
if self.dec_type == "single":
127126
if len(covs) > 2:
128-
sample_weights = kwargs["sample_weights"]
127+
weights = (
128+
kwargs["sample_weights"] if "sample_weights" in kwargs else None
129+
)
129130
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
130-
evecs = _smart_ajd(covs, restr_mat, weights=sample_weights)
131+
evecs = _smart_ajd(covs, restr_mat, weights=weights)
131132
evals = None
132133
else:
133134
S = covs[0]
134135
R = covs[1]
135136
restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank)
136137
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
137138

138-
evals, evecs = mod_ged_callable(evals, evecs, covs, **mod_params, **kwargs)
139+
evals, evecs = mod_ged_callable(evals, evecs, covs, **kwargs)
139140
self.evals_ = evals
140141
self.filters_ = evecs.T
141142
if self.restr_type == "ssd":
@@ -153,9 +154,7 @@ def fit(self, X, y=None):
153154

154155
evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func)
155156

156-
evals, evecs = mod_ged_callable(
157-
evals, evecs, covs, **mod_params, **kwargs
158-
)
157+
evals, evecs = mod_ged_callable(evals, evecs, covs, **kwargs)
159158
all_evals.append(evals)
160159
all_evecs.append(evecs.T)
161160
all_patterns.append(np.linalg.pinv(evecs))

mne/decoding/csp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ def __init__(
135135
rank=rank,
136136
norm_trace=norm_trace,
137137
)
138+
mod_ged_callable = partial(_csp_mod, evecs_order=component_order)
138139
super().__init__(
139140
n_components=n_components,
140141
cov_callable=cov_callable,
141-
mod_ged_callable=_csp_mod,
142-
mod_params=dict(evecs_order=component_order),
142+
mod_ged_callable=mod_ged_callable,
143143
restr_type="restricting",
144144
R_func=sum,
145145
)

0 commit comments

Comments
 (0)