|
19 | 19 | from .cov import Covariance |
20 | 20 | from .evoked import _get_peak |
21 | 21 | from .filter import FilterMixin, _check_fun, resample |
22 | | -from .fixes import _eye_array, _safe_svd |
| 22 | +from .fixes import _eye_array |
23 | 23 | from .parallel import parallel_func |
24 | 24 | from .source_space._source_space import ( |
25 | 25 | SourceSpaces, |
@@ -3375,27 +3375,27 @@ def _get_ico_tris(grade, verbose=None, return_surf=False): |
3375 | 3375 | return ico |
3376 | 3376 |
|
3377 | 3377 |
|
| 3378 | +def _compute_pca_quantities(U, s, V, flip): |
| 3379 | + if isinstance(flip, int): |
| 3380 | + sign = np.sign((flip * U[:, 0]).sum()) |
| 3381 | + else: |
| 3382 | + sign = np.sign(np.dot(U[:, 0], flip)) |
| 3383 | + scale = np.linalg.norm(s) / np.sqrt(len(U)) |
| 3384 | + result = sign * scale * V[0] |
| 3385 | + return result |
| 3386 | + |
| 3387 | + |
3378 | 3388 | def _pca_flip(flip, data): |
3379 | 3389 | result = None |
3380 | 3390 | if flip is None: # Case of volumetric data: flip is meaningless |
3381 | 3391 | flip = 1 |
3382 | 3392 | if data.shape[0] < 2: |
3383 | 3393 | result = data.mean(axis=0) # Trivial accumulator |
3384 | 3394 | else: |
| 3395 | + U, s, V = np.linalg.svd(data, full_matrices=False) |
3385 | 3396 | # determine sign-flip. |
3386 | 3397 | # if flip is a mere int, multiply U and sum |
3387 | | - if isinstance(flip, int): |
3388 | | - # We assume here that flip is thus denoting a volumetric. |
3389 | | - # It means LAPACK is likely to overflow on big matrices => We use numpy |
3390 | | - U, s, V = np.linalg.svd(data, full_matrices=False) |
3391 | | - |
3392 | | - sign = np.sign((flip * U[:, 0]).sum()) |
3393 | | - else: |
3394 | | - U, s, V = _safe_svd(data, full_matrices=False) |
3395 | | - sign = np.sign(np.dot(U[:, 0], flip)) |
3396 | | - # use average power in label for scaling |
3397 | | - scale = np.linalg.norm(s) / np.sqrt(len(data)) |
3398 | | - result = sign * scale * V[0] |
| 3398 | + result = _compute_pca_quantities(U, s, V, flip) |
3399 | 3399 | return result |
3400 | 3400 |
|
3401 | 3401 |
|
@@ -3678,6 +3678,7 @@ def _gen_extract_label_time_course( |
3678 | 3678 | allow_empty=False, |
3679 | 3679 | mri_resolution=True, |
3680 | 3680 | verbose=None, |
| 3681 | + max_channels=400, |
3681 | 3682 | ): |
3682 | 3683 | # loop through source estimates and extract time series |
3683 | 3684 | if src is None and mode in ["mean", "max"]: |
@@ -3741,17 +3742,39 @@ def _gen_extract_label_time_course( |
3741 | 3742 | else: |
3742 | 3743 | # For other modes, initialize the label_tc array |
3743 | 3744 | label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype) |
| 3745 | + pca_volumetric = kind == "volume" and mode == "pca_flip" |
| 3746 | + if pca_volumetric: |
| 3747 | + # Precompute randomized SVD on data |
| 3748 | + # Components are restricted to max_channels, which is the highest possible |
| 3749 | + # rank and is much smaller than the number of sources |
| 3750 | + from sklearn.utils.extmath import randomized_svd |
| 3751 | + |
| 3752 | + u_data, s_data, vh_data = randomized_svd( |
| 3753 | + stc.data, n_components=max_channels |
| 3754 | + ) |
3744 | 3755 | for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)): |
3745 | 3756 | if vertidx is not None: |
3746 | | - if isinstance(vertidx, sparse.csr_array): |
3747 | | - assert mri_resolution |
3748 | | - assert vertidx.shape[1] == stc.data.shape[0] |
3749 | | - this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) |
3750 | | - this_data = vertidx @ this_data |
3751 | | - this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] |
| 3757 | + if pca_volumetric: |
| 3758 | + # Compute SVD of vertices |
| 3759 | + # We will use it to compute vertidx @ data implicitly, |
| 3760 | + u_vert, s_vert, vh_Vert = np.linalg.svd(vertidx.todense()) |
| 3761 | + center_prod = np.diag(s_vert) @ vh_Vert @ u_data @ np.diag(s_data) |
| 3762 | + u_s, s_s, vh_s = np.linalg.svd(center_prod) |
| 3763 | + U = u_vert @ u_s |
| 3764 | + s = s_s |
| 3765 | + V = vh_s @ vh_data |
| 3766 | + label_tc[i] = _compute_pca_quantities(U, s, V, flip) |
3752 | 3767 | else: |
3753 | | - this_data = stc.data[vertidx] |
3754 | | - label_tc[i] = func(flip, this_data) |
| 3768 | + if isinstance(vertidx, sparse.csr_array): |
| 3769 | + assert mri_resolution |
| 3770 | + assert vertidx.shape[1] == stc.data.shape[0] |
| 3771 | + this_data = np.reshape(stc.data, (stc.data.shape[0], -1)) |
| 3772 | + |
| 3773 | + this_data = vertidx @ this_data |
| 3774 | + this_data.shape = (this_data.shape[0],) + stc.data.shape[1:] |
| 3775 | + else: |
| 3776 | + this_data = stc.data[vertidx] |
| 3777 | + label_tc[i] = func(flip, this_data) |
3755 | 3778 | if mode is not None: |
3756 | 3779 | offset = nvert[:-n_mean].sum() # effectively :2 or :0 |
3757 | 3780 | for i, nv in enumerate(nvert[2:]): |
|
0 commit comments