@@ -3743,20 +3743,47 @@ def _gen_extract_label_time_course(
37433743 else :
37443744 # For other modes, initialize the label_tc array
37453745 label_tc = np .zeros ((n_labels ,) + stc .data .shape [1 :], dtype = stc .data .dtype )
3746+
3747+ pca_volume = mode == "pca_flip" and kind == "volume"
3748+ if pca_volume :
3749+ from sklearn .utils .extmath import randomized_svd
3750+
3751+ logger .debug ("First SVD for PCA volume on stc data" )
3752+ u_b , s_b , vh_b = randomized_svd (stc .data , max_channels )
37463753 for i , (vertidx , flip ) in enumerate (zip (label_vertidx , src_flip )):
37473754 if vertidx is not None :
3748- if isinstance (vertidx , sparse .csr_array ):
3749- assert mri_resolution
3750- assert vertidx .shape [1 ] == stc .data .shape [0 ]
3751- this_data = np .reshape (stc .data , (stc .data .shape [0 ], - 1 ))
3752- this_data = vertidx @ this_data
3753- this_data .shape = (this_data .shape [0 ],) + stc .data .shape [1 :]
3754- else :
3755- this_data = stc .data [vertidx ]
3756- if mode == "pca_flip" :
3757- label_tc [i ] = func (flip , this_data , max_channels )
3755+ if pca_volume :
3756+ # Use a trick for efficiency:
3757+ # stc = Ub Sb VhB
3758+ # full_data = vertidx @ stc
3759+ # = vertidx @ Ub @ Sb @ Vhb
3760+ # Consider U_f, s_f, Vh_f = SVD(vertidx @ Ub @ Sb)
3761+ # Then U,S,V = svd(full_data) is such that
3762+ # U_f = U, S = s_f and V = Vh_f @ Vhb
3763+ # This trick is more efficient, because:
3764+ # - We compute a first SVD once on stc, restricted to
3765+ # only first max_channels singular vals/vecs (quite fast)
3766+ # - We project vertidx to be from Nvertex x Nsources
3767+ # to Nvertex x rank.
3768+ # - We compute SVD on Nvertex x rank
3769+ # As rank << Nsources, we end up saving a lot of computations.
3770+ tmp_array = vertidx @ u_b @ np .diag (s_b )
3771+ U , S , v_tmp = np .linalg .svd (tmp_array , full_matrices = False )
3772+ V = v_tmp @ vh_b
3773+ label_tc [i ] = _compute_pca_quantities (U , S , V , flip )
37583774 else :
3759- label_tc [i ] = func (flip , this_data )
3775+ if isinstance (vertidx , sparse .csr_array ):
3776+ assert mri_resolution
3777+ assert vertidx .shape [1 ] == stc .data .shape [0 ]
3778+ this_data = np .reshape (stc .data , (stc .data .shape [0 ], - 1 ))
3779+ this_data = vertidx @ this_data
3780+ this_data .shape = (this_data .shape [0 ],) + stc .data .shape [1 :]
3781+ else :
3782+ this_data = stc .data [vertidx ]
3783+ if mode == "pca_flip" :
3784+ label_tc [i ] = func (flip , this_data , max_channels )
3785+ else :
3786+ label_tc [i ] = func (flip , this_data )
37603787 logger .debug (f"Done with label { i } " )
37613788 if mode is not None :
37623789 offset = nvert [:- n_mean ].sum () # effectively :2 or :0
0 commit comments