Skip to content

Commit 8888c13

Browse files
committed
Simplification of PCA flip
1 parent 69937d2 commit 8888c13

1 file changed

Lines changed: 14 additions & 32 deletions

File tree

mne/source_estimate.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3385,16 +3385,17 @@ def _compute_pca_quantities(U, s, V, flip):
33853385
return result
33863386

33873387

3388-
def _pca_flip(flip, data):
3388+
def _pca_flip(flip, data, max_rank):
33893389
result = None
33903390
if flip is None: # Case of volumetric data: flip is meaningless
33913391
flip = 1
33923392
if data.shape[0] < 2:
33933393
result = data.mean(axis=0) # Trivial accumulator
33943394
else:
3395-
U, s, V = np.linalg.svd(data, full_matrices=False)
3395+
from sklearn.utils.extmath import randomized_svd
3396+
3397+
U, s, V = randomized_svd(data, n_components=max_rank)
33963398
# determine sign-flip.
3397-
# if flip is a mere int, multiply U and sum
33983399
result = _compute_pca_quantities(U, s, V, flip)
33993400
return result
34003401

@@ -3742,38 +3743,19 @@ def _gen_extract_label_time_course(
37423743
else:
37433744
# For other modes, initialize the label_tc array
37443745
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
3745-
pca_volumetric = kind == "volume" and mode == "pca_flip"
3746-
if pca_volumetric:
3747-
# Precompute randomized SVD on data
3748-
# Components are restricted to max_channels, which is the highest possible
3749-
# rank and is much smaller than the number of sources
3750-
from sklearn.utils.extmath import randomized_svd
3751-
3752-
u_data, s_data, vh_data = randomized_svd(
3753-
stc.data, n_components=max_channels
3754-
)
37553746
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
37563747
if vertidx is not None:
3757-
if pca_volumetric:
3758-
# Compute SVD of vertices
3759-
# We will use it to compute vertidx @ data implicitly,
3760-
u_vert, s_vert, vh_Vert = np.linalg.svd(vertidx.todense())
3761-
center_prod = np.diag(s_vert) @ vh_Vert @ u_data @ np.diag(s_data)
3762-
u_s, s_s, vh_s = np.linalg.svd(center_prod)
3763-
U = u_vert @ u_s
3764-
s = s_s
3765-
V = vh_s @ vh_data
3766-
label_tc[i] = _compute_pca_quantities(U, s, V, flip)
3748+
if isinstance(vertidx, sparse.csr_array):
3749+
assert mri_resolution
3750+
assert vertidx.shape[1] == stc.data.shape[0]
3751+
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
3752+
this_data = vertidx @ this_data
3753+
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
3754+
else:
3755+
this_data = stc.data[vertidx]
3756+
if mode == "pca_flip":
3757+
label_tc[i] = func(flip, this_data, max_channels)
37673758
else:
3768-
if isinstance(vertidx, sparse.csr_array):
3769-
assert mri_resolution
3770-
assert vertidx.shape[1] == stc.data.shape[0]
3771-
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
3772-
3773-
this_data = vertidx @ this_data
3774-
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
3775-
else:
3776-
this_data = stc.data[vertidx]
37773759
label_tc[i] = func(flip, this_data)
37783760
if mode is not None:
37793761
offset = nvert[:-n_mean].sum() # effectively :2 or :0

0 commit comments

Comments
 (0)