1515from .._fiff .proj import _has_eeg_average_ref_proj
1616from ..defaults import DEFAULTS , _handle_default
1717from ..utils import (
18- _reject_data_segments ,
1918 _validate_type ,
2019 fill_doc ,
2120 verbose ,
@@ -202,13 +201,10 @@ def _create_properties_layout(figsize=None, fig=None):
202201def _plot_ica_properties (
203202 pick ,
204203 ica ,
205- inst ,
206204 psds_mean ,
207205 freqs ,
208- n_trials ,
209- epoch_var ,
210206 plot_lowpass_edge ,
211- epochs_src ,
207+ this_epochs_src ,
212208 set_title_and_labels ,
213209 plot_std ,
214210 psd_ylabel ,
@@ -219,7 +215,7 @@ def _plot_ica_properties(
219215 fig ,
220216 axes ,
221217 kind ,
222- dropped_indices ,
218+ bad_indices ,
223219):
224220 """Plot ICA properties (helper)."""
225221 from mpl_toolkits .axes_grid1 .axes_divider import make_axes_locatable
@@ -237,21 +233,15 @@ def _plot_ica_properties(
237233 )
238234
239235 # image and erp
240- # we create a new epoch with dropped rows
241- src_data = epochs_src .get_data (copy = False )
242- n = len (src_data ) + len (dropped_indices )
243- epoch_data = np .zeros ((n ,) + (src_data .shape [1 :]), dtype = src_data .dtype )
244- use_idx = np .setdiff1d (np .arange (n ), dropped_indices )
245- epoch_data [use_idx ] = src_data
246- from ..epochs import EpochsArray
247-
248- epochs_src = EpochsArray (
249- epoch_data , epochs_src .info , tmin = epochs_src .tmin , verbose = 0
250- )
251-
236+ n_trials = len (this_epochs_src )
237+ epoch_var = np .var (this_epochs_src .get_data (), axis = - 1 )
238+ assert epoch_var .shape [1 ] == 1 # single channel
239+ epoch_var = epoch_var [:, 0 ]
240+ assert epoch_var .shape == (len (this_epochs_src ),)
241+ this_epochs_src ._data [bad_indices ] = 0
252242 plot_epochs_image (
253- epochs_src ,
254- picks = pick ,
243+ this_epochs_src ,
244+ picks = [ 0 ] ,
255245 axes = [image_ax , erp_ax ],
256246 combine = None ,
257247 colorbar = False ,
@@ -270,47 +260,40 @@ def _plot_ica_properties(
270260 alpha = 0.2 ,
271261 )
272262 if plot_lowpass_edge :
273- spec_ax .axvline (
274- inst .info ["lowpass" ], lw = 2 , linestyle = "--" , color = "k" , alpha = 0.2
275- )
263+ spec_ax .axvline (ica .info ["lowpass" ], lw = 2 , linestyle = "--" , color = "k" , alpha = 0.2 )
276264
277265 # epoch variance
266+ good_indices = np .setdiff1d (np .arange (n_trials ), bad_indices )
278267 var_ax_divider = make_axes_locatable (var_ax )
279- hist_ax = var_ax_divider .append_axes ("right" , size = "33%" , pad = "2.5%" )
280- var_ax .scatter (
281- range (len (epoch_var )), epoch_var , alpha = 0.5 , facecolor = [0 , 0 , 0 ], lw = 0
282- )
268+ hist_ax = var_ax_divider .append_axes ("right" , size = "33%" , pad = "2.5%" , sharey = var_ax )
269+ facecolor = np .zeros ((len (epoch_var ), 3 ))
270+ alpha = np .full (len (epoch_var ), 0.5 )
283271 # rejected epochs in red
284- # TODO: This can't be right as the variance is computed on the good/remaining
285- # epochs, so these are by necessity zero
272+ facecolor [ bad_indices ] = [ 1 , 0 , 0 ]
273+ alpha [ bad_indices ] = 0.75
286274 var_ax .scatter (
287- dropped_indices ,
288- np .zeros (len (dropped_indices )),
289- alpha = 1.0 ,
290- facecolor = [1 , 0 , 0 ],
291- lw = 0 ,
275+ np .arange (n_trials ), epoch_var , alpha = alpha , facecolor = facecolor , lw = 0
292276 )
293277 # compute percentage of dropped epochs
294- var_percent = float ( len ( dropped_indices )) / float ( len (epoch_var )) * 100.0
278+ var_percent = 100 * len (bad_indices ) / n_trials
295279
296280 # histogram & histogram
281+ epoch_var_good = epoch_var [good_indices ]
297282 _ , counts , _ = hist_ax .hist (
298- epoch_var , orientation = "horizontal" , color = "k" , alpha = 0.5
283+ epoch_var_good , orientation = "horizontal" , color = "k" , alpha = 0.5
299284 )
300285
301286 # kde
302- ymin , ymax = hist_ax .get_ylim ()
303287 try :
304- kde = gaussian_kde (epoch_var )
288+ kde = gaussian_kde (epoch_var_good )
305289 except np .linalg .LinAlgError :
306290 pass # singular: happens when there is nothing plotted
307291 else :
308- x = np .linspace (ymin , ymax , 50 )
292+ x = np .linspace (epoch_var_good . min (), epoch_var_good . max () , 50 )
309293 kde_ = kde (x )
310294 kde_ /= kde_ .max () or 1.0
311295 kde_ *= hist_ax .get_xlim ()[- 1 ] * 0.9
312296 hist_ax .plot (kde_ , x , color = "k" )
313- hist_ax .set_ylim (ymin , ymax )
314297
315298 # aesthetics
316299 # ----------
@@ -319,7 +302,7 @@ def _plot_ica_properties(
319302 # erp
320303 set_title_and_labels (erp_ax , [], "Time (s)" , "AU" )
321304 erp_ax .spines ["right" ].set_color ("k" )
322- erp_ax .set_xlim (epochs_src .times [[0 , - 1 ]])
305+ erp_ax .set_xlim (this_epochs_src .times [[0 , - 1 ]])
323306 # remove half of yticks if more than 5
324307 yt = erp_ax .get_yticks ()
325308 if len (yt ) > 5 :
@@ -603,23 +586,26 @@ def _fast_plot_ica_properties(
603586 # calculations
604587 # ------------
605588 if isinstance (precomputed_data , tuple ):
606- kind , dropped_indices , epochs_src , data = precomputed_data
589+ kind , bad_indices , epochs_src = precomputed_data
607590 else :
608- kind , dropped_indices , epochs_src , data = _prepare_data_ica_properties (
591+ kind , bad_indices , epochs_src = _prepare_data_ica_properties (
609592 inst , ica , reject_by_annotation , reject
610593 )
611- del reject
612- ica_data = np .swapaxes (data [:, picks , :], 0 , 1 )
594+ del reject , inst
595+ if len (epochs_src ) == 0 :
596+ return [fig ]
597+ epochs_src_picked = epochs_src .pick (picks )
598+ del epochs_src
599+ good_indices = np .setdiff1d (np .arange (len (epochs_src_picked )), bad_indices )
613600
614601 # spectrum
615- Nyquist = inst .info ["sfreq" ] / 2.0
616- lp = inst .info ["lowpass" ]
602+ Nyquist = ica .info ["sfreq" ] / 2.0
603+ lp = ica .info ["lowpass" ]
617604 if "fmax" not in psd_args :
618605 psd_args ["fmax" ] = min (lp * 1.25 , Nyquist )
619606 plot_lowpass_edge = lp < Nyquist and (psd_args ["fmax" ] > lp )
620- spectrum = epochs_src .compute_psd (picks = picks , ** psd_args )
621- # we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑
622- # in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓
607+ # we've already restricted picks in epochs_src_picked, so here we do picks=all
608+ spectrum = epochs_src_picked [good_indices ].compute_psd (picks = "all" , ** psd_args )
623609 psds , freqs = spectrum .get_data (return_freqs = True , picks = "all" , exclude = [])
624610 # we also pass exclude=[] so that when this is called by right-clicking in
625611 # a plot_sources() window on an ICA component name that has been marked as
@@ -653,20 +639,14 @@ def set_title_and_labels(ax, title, xlab, ylab):
653639 if idx > 0 :
654640 fig , axes = _create_properties_layout (figsize = figsize )
655641
656- # we reconstruct an epoch_variance with 0 where indexes where dropped
657- epoch_var = np .var (ica_data [idx ], axis = 1 )
658-
659642 # the actual plot
660643 fig = _plot_ica_properties (
661644 pick ,
662645 ica ,
663- inst ,
664646 psds_mean ,
665647 freqs ,
666- ica_data .shape [1 ],
667- epoch_var ,
668648 plot_lowpass_edge ,
669- epochs_src ,
649+ epochs_src_picked . copy (). pick ( picks = [ idx ]) ,
670650 set_title_and_labels ,
671651 plot_std ,
672652 psd_ylabel ,
@@ -677,7 +657,7 @@ def set_title_and_labels(ax, title, xlab, ylab):
677657 fig ,
678658 axes ,
679659 kind ,
680- dropped_indices ,
660+ bad_indices ,
681661 )
682662 all_fig .append (fig )
683663
@@ -710,65 +690,76 @@ def _prepare_data_ica_properties(inst, ica, reject_by_annotation=True, reject="a
710690 data : array of shape (n_epochs, n_ica_sources, n_times)
711691 A view on epochs ICA sources data.
712692 """
713- from ..epochs import BaseEpochs
693+ from ..epochs import BaseEpochs , Epochs , make_fixed_length_events
714694 from ..io import BaseRaw , RawArray
715695
716696 _validate_type (inst , (BaseRaw , BaseEpochs ), "inst" , "Raw or Epochs" )
697+ bad_indices = []
717698 if isinstance (inst , BaseRaw ):
718699 # when auto, delegate reject to the ica
719- from ..epochs import make_fixed_length_epochs
720700
721701 if reject == "auto" :
722702 reject = ica .reject_
723- drop_inds = None
724- dropped_indices = []
725- if reject is None :
726- inst_current = inst
727- else :
728- data = inst .get_data ()
729- data , drop_inds = _reject_data_segments (
730- data , reject , flat = None , decim = None , info = inst .info , tstep = 2.0
731- )
732- inst_current = RawArray (data , inst .info )
733- # break up continuous signal into segments; suppress "All epochs were
734- # dropped!" because we handle that case gracefully below
735- with warnings .catch_warnings ():
736- warnings .filterwarnings (
737- "ignore" , "All epochs were dropped!" , RuntimeWarning
738- )
739- epochs_src = make_fixed_length_epochs (
740- ica .get_sources (inst_current ),
741- duration = 2 ,
742- preload = True ,
743- reject_by_annotation = reject_by_annotation ,
744- proj = False ,
745- verbose = False ,
746- )
703+ events = make_fixed_length_events (inst , duration = 2 )
704+ kwargs = dict (
705+ tmin = 0 ,
706+ tmax = 2 - 1.0 / inst .info ["sfreq" ],
707+ baseline = None ,
708+ verbose = "error" ,
709+ preload = True ,
710+ proj = False ,
711+ )
712+ epochs = Epochs (
713+ inst ,
714+ events ,
715+ reject = reject ,
716+ reject_by_annotation = reject_by_annotation ,
717+ ** kwargs ,
718+ )
747719 # if all epochs were dropped by annotations, stitch the good segments
748720 # together so that the plot can still be generated
749- if reject_by_annotation and len (epochs_src ) == 0 :
750- good_data = inst_current .get_data (reject_by_annotation = "omit" )
721+ epochs_src = None
722+ if reject_by_annotation and len (epochs ) == 0 :
723+ epochs_good = Epochs (
724+ inst ,
725+ events ,
726+ reject = reject ,
727+ reject_by_annotation = False ,
728+ ** kwargs ,
729+ )
730+ got_samps = len (epochs_good ) * len (epochs .times )
751731 min_samples = int (2 * inst .info ["sfreq" ])
752- if good_data .shape [1 ] >= min_samples :
753- inst_good = RawArray (good_data , inst_current .info .copy (), verbose = False )
754- epochs_src = make_fixed_length_epochs (
732+ if got_samps >= min_samples :
733+ good_data = np .reshape (
734+ np .transpose (epochs_good .get_data (), [1 , 0 , 2 ]),
735+ (len (epochs .ch_names ), got_samps ),
736+ )
737+ inst_good = RawArray (good_data , inst .info .copy (), verbose = "error" )
738+ epochs_src = Epochs (
755739 ica .get_sources (inst_good ),
756- duration = 2 ,
757- preload = True ,
740+ events ,
741+ reject = reject ,
758742 reject_by_annotation = False ,
759- proj = False ,
760- verbose = False ,
743+ ** kwargs ,
761744 )
762- # getting dropped epochs indexes
763- if drop_inds is not None :
764- dropped_indices = [(d [0 ] // len (epochs_src .times )) for d in drop_inds ]
745+ if epochs_src is None :
746+ ica .get_sources (inst )
747+ epochs_src = Epochs (
748+ ica .get_sources (inst ),
749+ events ,
750+ # We have already rejected by annotation and reject above, but we don't
751+ # here so we can keep data for bad epochs around
752+ reject = None ,
753+ reject_by_annotation = False ,
754+ ** kwargs ,
755+ )
756+ bad_indices = np .where ([len (log ) for log in epochs .drop_log ])[0 ]
757+ assert len (epochs_src ) == len (epochs ) + len (bad_indices )
765758 kind = "Segment"
766759 else :
767- drop_inds = None
768760 epochs_src = ica .get_sources (inst )
769- dropped_indices = []
770761 kind = "Epochs"
771- return kind , dropped_indices , epochs_src , epochs_src . get_data ( copy = False )
762+ return kind , bad_indices , epochs_src
772763
773764
774765def _plot_ica_sources_evoked (evoked , picks , exclude , title , show , ica , labels = None ):
0 commit comments