Skip to content

Commit 69937d2

Browse files
committed
PCA flip for volumetric is now using randomized SVD to manage to run the SVD at all
1 parent fd71779 commit 69937d2

1 file changed

Lines changed: 44 additions & 21 deletions

File tree

mne/source_estimate.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .cov import Covariance
2020
from .evoked import _get_peak
2121
from .filter import FilterMixin, _check_fun, resample
22-
from .fixes import _eye_array, _safe_svd
22+
from .fixes import _eye_array
2323
from .parallel import parallel_func
2424
from .source_space._source_space import (
2525
SourceSpaces,
@@ -3375,27 +3375,27 @@ def _get_ico_tris(grade, verbose=None, return_surf=False):
33753375
return ico
33763376

33773377

3378+
def _compute_pca_quantities(U, s, V, flip):
3379+
if isinstance(flip, int):
3380+
sign = np.sign((flip * U[:, 0]).sum())
3381+
else:
3382+
sign = np.sign(np.dot(U[:, 0], flip))
3383+
scale = np.linalg.norm(s) / np.sqrt(len(U))
3384+
result = sign * scale * V[0]
3385+
return result
3386+
3387+
33783388
def _pca_flip(flip, data):
33793389
result = None
33803390
if flip is None: # Case of volumetric data: flip is meaningless
33813391
flip = 1
33823392
if data.shape[0] < 2:
33833393
result = data.mean(axis=0) # Trivial accumulator
33843394
else:
3395+
U, s, V = np.linalg.svd(data, full_matrices=False)
33853396
# determine sign-flip.
33863397
# if flip is a mere int, multiply U and sum
3387-
if isinstance(flip, int):
3388-
# We assume here that flip is thus denoting a volumetric.
3389-
# It means LAPACK is likely to overflow on big matrices => We use numpy
3390-
U, s, V = np.linalg.svd(data, full_matrices=False)
3391-
3392-
sign = np.sign((flip * U[:, 0]).sum())
3393-
else:
3394-
U, s, V = _safe_svd(data, full_matrices=False)
3395-
sign = np.sign(np.dot(U[:, 0], flip))
3396-
# use average power in label for scaling
3397-
scale = np.linalg.norm(s) / np.sqrt(len(data))
3398-
result = sign * scale * V[0]
3398+
result = _compute_pca_quantities(U, s, V, flip)
33993399
return result
34003400

34013401

@@ -3678,6 +3678,7 @@ def _gen_extract_label_time_course(
36783678
allow_empty=False,
36793679
mri_resolution=True,
36803680
verbose=None,
3681+
max_channels=400,
36813682
):
36823683
# loop through source estimates and extract time series
36833684
if src is None and mode in ["mean", "max"]:
@@ -3741,17 +3742,39 @@ def _gen_extract_label_time_course(
37413742
else:
37423743
# For other modes, initialize the label_tc array
37433744
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
3745+
pca_volumetric = kind == "volume" and mode == "pca_flip"
3746+
if pca_volumetric:
3747+
# Precompute randomized SVD on data
3748+
# Components are restricted to max_channels, which is the highest possible
3749+
# rank and is much smaller than the number of sources
3750+
from sklearn.utils.extmath import randomized_svd
3751+
3752+
u_data, s_data, vh_data = randomized_svd(
3753+
stc.data, n_components=max_channels
3754+
)
37443755
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
37453756
if vertidx is not None:
3746-
if isinstance(vertidx, sparse.csr_array):
3747-
assert mri_resolution
3748-
assert vertidx.shape[1] == stc.data.shape[0]
3749-
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
3750-
this_data = vertidx @ this_data
3751-
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
3757+
if pca_volumetric:
3758+
# Compute SVD of vertices
3759+
# We will use it to compute vertidx @ data implicitly,
3760+
u_vert, s_vert, vh_Vert = np.linalg.svd(vertidx.todense())
3761+
center_prod = np.diag(s_vert) @ vh_Vert @ u_data @ np.diag(s_data)
3762+
u_s, s_s, vh_s = np.linalg.svd(center_prod)
3763+
U = u_vert @ u_s
3764+
s = s_s
3765+
V = vh_s @ vh_data
3766+
label_tc[i] = _compute_pca_quantities(U, s, V, flip)
37523767
else:
3753-
this_data = stc.data[vertidx]
3754-
label_tc[i] = func(flip, this_data)
3768+
if isinstance(vertidx, sparse.csr_array):
3769+
assert mri_resolution
3770+
assert vertidx.shape[1] == stc.data.shape[0]
3771+
this_data = np.reshape(stc.data, (stc.data.shape[0], -1))
3772+
3773+
this_data = vertidx @ this_data
3774+
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
3775+
else:
3776+
this_data = stc.data[vertidx]
3777+
label_tc[i] = func(flip, this_data)
37553778
if mode is not None:
37563779
offset = nvert[:-n_mean].sum() # effectively :2 or :0
37573780
for i, nv in enumerate(nvert[2:]):

0 commit comments

Comments
 (0)