@@ -3385,16 +3385,17 @@ def _compute_pca_quantities(U, s, V, flip):
33853385 return result
33863386
33873387
3388- def _pca_flip (flip , data ):
3388+ def _pca_flip (flip , data , max_rank ):
33893389 result = None
33903390 if flip is None : # Case of volumetric data: flip is meaningless
33913391 flip = 1
33923392 if data .shape [0 ] < 2 :
33933393 result = data .mean (axis = 0 ) # Trivial accumulator
33943394 else :
3395- U , s , V = np .linalg .svd (data , full_matrices = False )
3395+ from sklearn .utils .extmath import randomized_svd
3396+
3397+ U , s , V = randomized_svd (data , n_components = max_rank )
33963398 # determine sign-flip.
3397- # if flip is a mere int, multiply U and sum
33983399 result = _compute_pca_quantities (U , s , V , flip )
33993400 return result
34003401
@@ -3742,38 +3743,19 @@ def _gen_extract_label_time_course(
37423743 else :
37433744 # For other modes, initialize the label_tc array
37443745 label_tc = np .zeros ((n_labels ,) + stc .data .shape [1 :], dtype = stc .data .dtype )
3745- pca_volumetric = kind == "volume" and mode == "pca_flip"
3746- if pca_volumetric :
3747- # Precompute randomized SVD on data
3748- # Components are restricted to max_channels, which is the highest possible
3749- # rank and is much smaller than the number of sources
3750- from sklearn .utils .extmath import randomized_svd
3751-
3752- u_data , s_data , vh_data = randomized_svd (
3753- stc .data , n_components = max_channels
3754- )
37553746 for i , (vertidx , flip ) in enumerate (zip (label_vertidx , src_flip )):
37563747 if vertidx is not None :
3757- if pca_volumetric :
3758- # Compute SVD of vertices
3759- # We will use it to compute vertidx @ data implicitly,
3760- u_vert , s_vert , vh_Vert = np .linalg . svd ( vertidx . todense ( ))
3761- center_prod = np . diag ( s_vert ) @ vh_Vert @ u_data @ np . diag ( s_data )
3762- u_s , s_s , vh_s = np . linalg . svd ( center_prod )
3763- U = u_vert @ u_s
3764- s = s_s
3765- V = vh_s @ vh_data
3766- label_tc [i ] = _compute_pca_quantities ( U , s , V , flip )
3748+ if isinstance ( vertidx , sparse . csr_array ) :
3749+ assert mri_resolution
3750+ assert vertidx . shape [ 1 ] == stc . data . shape [ 0 ]
3751+ this_data = np .reshape ( stc . data , ( stc . data . shape [ 0 ], - 1 ))
3752+ this_data = vertidx @ this_data
3753+ this_data . shape = ( this_data . shape [ 0 ],) + stc . data . shape [ 1 :]
3754+ else :
3755+ this_data = stc . data [ vertidx ]
3756+ if mode == "pca_flip" :
3757+ label_tc [i ] = func ( flip , this_data , max_channels )
37673758 else :
3768- if isinstance (vertidx , sparse .csr_array ):
3769- assert mri_resolution
3770- assert vertidx .shape [1 ] == stc .data .shape [0 ]
3771- this_data = np .reshape (stc .data , (stc .data .shape [0 ], - 1 ))
3772-
3773- this_data = vertidx @ this_data
3774- this_data .shape = (this_data .shape [0 ],) + stc .data .shape [1 :]
3775- else :
3776- this_data = stc .data [vertidx ]
37773759 label_tc [i ] = func (flip , this_data )
37783760 if mode is not None :
37793761 offset = nvert [:- n_mean ].sum () # effectively :2 or :0
0 commit comments