Skip to content

Commit 0fd4be5

Browse files
committed
Found a trick to make everything much faster with only two svds
1 parent 775ec80 commit 0fd4be5

1 file changed

Lines changed: 38 additions & 11 deletions

File tree

mne/source_estimate.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3743,20 +3743,47 @@ def _gen_extract_label_time_course(
37433743
else:
37443744
# For other modes, initialize the label_tc array
37453745
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
3746+
3747+
pca_volume = mode == "pca_flip" and kind == "volume"
3748+
if pca_volume:
3749+
from sklearn.utils.extmath import randomized_svd
3750+
3751+
logger.debug("First SVD for PCA volume on stc data")
3752+
u_b, s_b, vh_b = randomized_svd(stc.data, max_channels)
37463753
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
37473754
if vertidx is not None:
3748-
if isinstance(vertidx, sparse.csr_array):
3749-
assert mri_resolution
3750-
assert vertidx.shape[1] == stc.data.shape[0]
3751-
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
3752-
this_data = vertidx @ this_data
3753-
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
3754-
else:
3755-
this_data = stc.data[vertidx]
3756-
if mode == "pca_flip":
3757-
label_tc[i] = func(flip, this_data, max_channels)
3755+
if pca_volume:
3756+
# Use a trick for efficiency:
3757+
# stc = Ub Sb VhB
3758+
# full_data = vertidx @ stc
3759+
# = vertidx @ Ub @ Sb @ Vhb
3760+
# Consider U_f, s_f, Vh_f = SVD(vertidx @ Ub @ Sb)
3761+
# Then U,S,V = svd(full_data) is such that
3762+
# U_f = U, S = s_f and V = Vh_f @ Vhb
3763+
# This trick is more efficient, because:
3764+
# - We compute a first SVD once on stc, restricted to
3765+
# only first max_channels singular vals/vecs (quite fast)
3766+
# - We project vertidx to be from Nvertex x Nsources
3767+
# to Nvertex x rank.
3768+
# - We compute SVD on Nvertex x rank
3769+
# As rank << Nsources, we end up saving a lot of computations.
3770+
tmp_array = vertidx @ u_b @ np.diag(s_b)
3771+
U, S, v_tmp = np.linalg.svd(tmp_array, full_matrices=False)
3772+
V = v_tmp @ vh_b
3773+
label_tc[i] = _compute_pca_quantities(U, S, V, flip)
37583774
else:
3759-
label_tc[i] = func(flip, this_data)
3775+
if isinstance(vertidx, sparse.csr_array):
3776+
assert mri_resolution
3777+
assert vertidx.shape[1] == stc.data.shape[0]
3778+
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
3779+
this_data = vertidx @ this_data
3780+
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
3781+
else:
3782+
this_data = stc.data[vertidx]
3783+
if mode == "pca_flip":
3784+
label_tc[i] = func(flip, this_data, max_channels)
3785+
else:
3786+
label_tc[i] = func(flip, this_data)
37603787
logger.debug(f"Done with label {i}")
37613788
if mode is not None:
37623789
offset = nvert[:-n_mean].sum() # effectively :2 or :0

0 commit comments

Comments
 (0)