1515from .._fiff .proj import _has_eeg_average_ref_proj
1616from ..defaults import DEFAULTS , _handle_default
1717from ..utils import (
18- _reject_data_segments ,
1918 _validate_type ,
2019 fill_doc ,
2120 verbose ,
@@ -202,13 +201,10 @@ def _create_properties_layout(figsize=None, fig=None):
202201def _plot_ica_properties (
203202 pick ,
204203 ica ,
205- inst ,
206204 psds_mean ,
207205 freqs ,
208- n_trials ,
209- epoch_var ,
210206 plot_lowpass_edge ,
211- epochs_src ,
207+ this_epochs_src ,
212208 set_title_and_labels ,
213209 plot_std ,
214210 psd_ylabel ,
@@ -219,7 +215,7 @@ def _plot_ica_properties(
219215 fig ,
220216 axes ,
221217 kind ,
222- dropped_indices ,
218+ bad_indices ,
223219):
224220 """Plot ICA properties (helper)."""
225221 from mpl_toolkits .axes_grid1 .axes_divider import make_axes_locatable
@@ -237,23 +233,15 @@ def _plot_ica_properties(
237233 )
238234
239235 # image and erp
240- # we create a new epoch with dropped rows
241- epoch_data = epochs_src .get_data (copy = False )
242- epoch_data = np .insert (
243- arr = epoch_data ,
244- obj = (dropped_indices - np .arange (len (dropped_indices ))).astype (int ),
245- values = 0.0 ,
246- axis = 0 ,
247- )
248- from ..epochs import EpochsArray
249-
250- epochs_src = EpochsArray (
251- epoch_data , epochs_src .info , tmin = epochs_src .tmin , verbose = 0
252- )
253-
236+ n_trials = len (this_epochs_src )
237+ epoch_var = np .var (this_epochs_src .get_data (), axis = - 1 )
238+ assert epoch_var .shape [1 ] == 1 # single channel
239+ epoch_var = epoch_var [:, 0 ]
240+ assert epoch_var .shape == (len (this_epochs_src ),)
241+ this_epochs_src ._data [bad_indices ] = 0
254242 plot_epochs_image (
255- epochs_src ,
256- picks = pick ,
243+ this_epochs_src ,
244+ picks = [ 0 ] ,
257245 axes = [image_ax , erp_ax ],
258246 combine = None ,
259247 colorbar = False ,
@@ -273,44 +261,41 @@ def _plot_ica_properties(
273261 )
274262 if plot_lowpass_edge :
275263 spec_ax .axvline (
276- inst .info ["lowpass" ], lw = 2 , linestyle = "--" , color = "k" , alpha = 0.2
264+ this_epochs_src .info ["lowpass" ], lw = 2 , linestyle = "--" , color = "k" , alpha = 0.2
277265 )
278266
279267 # epoch variance
268+ good_indices = np .setdiff1d (np .arange (n_trials ), bad_indices )
280269 var_ax_divider = make_axes_locatable (var_ax )
281- hist_ax = var_ax_divider .append_axes ("right" , size = "33%" , pad = "2.5%" )
282- var_ax .scatter (
283- range (len (epoch_var )), epoch_var , alpha = 0.5 , facecolor = [0 , 0 , 0 ], lw = 0
284- )
270+ hist_ax = var_ax_divider .append_axes ("right" , size = "33%" , pad = "2.5%" , sharey = var_ax )
271+ facecolor = np .zeros ((len (epoch_var ), 3 ))
272+ alpha = np .full (len (epoch_var ), 0.5 )
285273 # rejected epochs in red
274+ facecolor [bad_indices ] = [1 , 0 , 0 ]
275+ alpha [bad_indices ] = 0.75
286276 var_ax .scatter (
287- dropped_indices ,
288- epoch_var [dropped_indices ],
289- alpha = 1.0 ,
290- facecolor = [1 , 0 , 0 ],
291- lw = 0 ,
277+ np .arange (n_trials ), epoch_var , alpha = alpha , facecolor = facecolor , lw = 0
292278 )
293279 # compute percentage of dropped epochs
294- var_percent = float ( len ( dropped_indices )) / float ( len (epoch_var )) * 100.0
280+ var_percent = 100 * len (bad_indices ) / n_trials
295281
296282 # histogram & histogram
283+ epoch_var_good = epoch_var [good_indices ]
297284 _ , counts , _ = hist_ax .hist (
298- epoch_var , orientation = "horizontal" , color = "k" , alpha = 0.5
285+ epoch_var_good , orientation = "horizontal" , color = "k" , alpha = 0.5
299286 )
300287
301288 # kde
302- ymin , ymax = hist_ax .get_ylim ()
303289 try :
304- kde = gaussian_kde (epoch_var )
290+ kde = gaussian_kde (epoch_var_good )
305291 except np .linalg .LinAlgError :
306292 pass # singular: happens when there is nothing plotted
307293 else :
308- x = np .linspace (ymin , ymax , 50 )
294+ x = np .linspace (epoch_var_good . min (), epoch_var_good . max () , 50 )
309295 kde_ = kde (x )
310296 kde_ /= kde_ .max () or 1.0
311297 kde_ *= hist_ax .get_xlim ()[- 1 ] * 0.9
312298 hist_ax .plot (kde_ , x , color = "k" )
313- hist_ax .set_ylim (ymin , ymax )
314299
315300 # aesthetics
316301 # ----------
@@ -319,16 +304,16 @@ def _plot_ica_properties(
319304 # erp
320305 set_title_and_labels (erp_ax , [], "Time (s)" , "AU" )
321306 erp_ax .spines ["right" ].set_color ("k" )
322- erp_ax .set_xlim (epochs_src .times [[0 , - 1 ]])
307+ erp_ax .set_xlim (this_epochs_src .times [[0 , - 1 ]])
323308 # remove half of yticks if more than 5
324309 yt = erp_ax .get_yticks ()
325310 if len (yt ) > 5 :
326- erp_ax .yaxis . set_ticks (yt [::2 ])
311+ erp_ax .set_yticks (yt [::2 ])
327312
328313 # remove xticks - erp plot shows xticks for both image and erp plot
329- image_ax .xaxis . set_ticks ([])
314+ image_ax .set_xticks ([])
330315 yt = image_ax .get_yticks ()
331- image_ax .yaxis . set_ticks (yt [1 :])
316+ image_ax .set_yticks (yt [1 :])
332317 image_ax .set_ylim ([- 0.5 , n_trials + 0.5 ])
333318
334319 def _set_scale (ax , scale ):
@@ -342,10 +327,6 @@ def _set_scale(ax, scale):
342327 set_title_and_labels (spec_ax , "Spectrum" , "Frequency (Hz)" , psd_ylabel )
343328 spec_ax .yaxis .labelpad = 0
344329 spec_ax .set_xlim (freqs [[0 , - 1 ]])
345- ylim = spec_ax .get_ylim ()
346- air = np .diff (ylim )[0 ] * 0.1
347- spec_ax .set_ylim (ylim [0 ] - air , ylim [1 ] + air )
348- image_ax .axhline (0 , color = "k" , linewidth = 0.5 )
349330 if log_scale :
350331 _set_scale (spec_ax , "log" )
351332
@@ -603,24 +584,24 @@ def _fast_plot_ica_properties(
603584 # calculations
604585 # ------------
605586 if isinstance (precomputed_data , tuple ):
606- kind , dropped_indices , epochs_src , data = precomputed_data
587+ kind , bad_indices , epochs_src = precomputed_data
607588 else :
608- kind , dropped_indices , epochs_src , data = _prepare_data_ica_properties (
589+ kind , bad_indices , epochs_src = _prepare_data_ica_properties (
609590 inst , ica , reject_by_annotation , reject
610591 )
611- del reject
612- ica_data = np .swapaxes (data [:, picks , :], 0 , 1 )
613- dropped_src = ica_data
592+ del reject , inst
593+ epochs_src_picked = epochs_src .pick (picks )
594+ del epochs_src
595+ good_indices = np .setdiff1d (np .arange (len (epochs_src_picked )), bad_indices )
614596
615597 # spectrum
616- Nyquist = inst .info ["sfreq" ] / 2.0
617- lp = inst .info ["lowpass" ]
598+ Nyquist = epochs_src_picked .info ["sfreq" ] / 2.0
599+ lp = epochs_src_picked .info ["lowpass" ]
618600 if "fmax" not in psd_args :
619601 psd_args ["fmax" ] = min (lp * 1.25 , Nyquist )
620602 plot_lowpass_edge = lp < Nyquist and (psd_args ["fmax" ] > lp )
621- spectrum = epochs_src .compute_psd (picks = picks , ** psd_args )
622- # we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑
623- # in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓
603+ # we've already restricted picks in epochs_src_picked, so here we do picks=all
604+ spectrum = epochs_src_picked [good_indices ].compute_psd (picks = "all" , ** psd_args )
624605 psds , freqs = spectrum .get_data (return_freqs = True , picks = "all" , exclude = [])
625606 # we also pass exclude=[] so that when this is called by right-clicking in
626607 # a plot_sources() window on an ICA component name that has been marked as
@@ -654,30 +635,14 @@ def set_title_and_labels(ax, title, xlab, ylab):
654635 if idx > 0 :
655636 fig , axes = _create_properties_layout (figsize = figsize )
656637
657- # we reconstruct an epoch_variance with 0 where indexes where dropped
658- epoch_var = np .var (ica_data [idx ], axis = 1 )
659- drop_var = np .var (dropped_src [idx ], axis = 1 )
660- drop_indices_corrected = (
661- dropped_indices - np .arange (len (dropped_indices ))
662- ).astype (int )
663- epoch_var = np .insert (
664- arr = epoch_var ,
665- obj = drop_indices_corrected ,
666- values = drop_var [dropped_indices ],
667- axis = 0 ,
668- )
669-
670638 # the actual plot
671639 fig = _plot_ica_properties (
672640 pick ,
673641 ica ,
674- inst ,
675642 psds_mean ,
676643 freqs ,
677- ica_data .shape [1 ],
678- epoch_var ,
679644 plot_lowpass_edge ,
680- epochs_src ,
645+ epochs_src_picked . copy (). pick ( picks = [ idx ]) ,
681646 set_title_and_labels ,
682647 plot_std ,
683648 psd_ylabel ,
@@ -688,7 +653,7 @@ def set_title_and_labels(ax, title, xlab, ylab):
688653 fig ,
689654 axes ,
690655 kind ,
691- dropped_indices ,
656+ bad_indices ,
692657 )
693658 all_fig .append (fig )
694659
@@ -721,65 +686,76 @@ def _prepare_data_ica_properties(inst, ica, reject_by_annotation=True, reject="a
721686 data : array of shape (n_epochs, n_ica_sources, n_times)
722687 A view on epochs ICA sources data.
723688 """
724- from ..epochs import BaseEpochs
689+ from ..epochs import BaseEpochs , Epochs , make_fixed_length_events
725690 from ..io import BaseRaw , RawArray
726691
727692 _validate_type (inst , (BaseRaw , BaseEpochs ), "inst" , "Raw or Epochs" )
693+ bad_indices = []
728694 if isinstance (inst , BaseRaw ):
729695 # when auto, delegate reject to the ica
730- from ..epochs import make_fixed_length_epochs
731696
732697 if reject == "auto" :
733698 reject = ica .reject_
734- drop_inds = None
735- dropped_indices = []
736- if reject is None :
737- inst_current = inst
738- else :
739- data = inst .get_data ()
740- data , drop_inds = _reject_data_segments (
741- data , reject , flat = None , decim = None , info = inst .info , tstep = 2.0
742- )
743- inst_current = RawArray (data , inst .info )
744- # break up continuous signal into segments; suppress "All epochs were
745- # dropped!" because we handle that case gracefully below
746- with warnings .catch_warnings ():
747- warnings .filterwarnings (
748- "ignore" , "All epochs were dropped!" , RuntimeWarning
749- )
750- epochs_src = make_fixed_length_epochs (
751- ica .get_sources (inst_current ),
752- duration = 2 ,
753- preload = True ,
754- reject_by_annotation = reject_by_annotation ,
755- proj = False ,
756- verbose = False ,
757- )
758- # if all epochs were dropped by annotations, stitch the good segments
759- # together so that the plot can still be generated
760- if reject_by_annotation and len (epochs_src ) == 0 :
761- good_data = inst_current .get_data (reject_by_annotation = "omit" )
699+ # First we try making epochs in the normal way and see if we have enough
700+ events = make_fixed_length_events (inst , duration = 2 )
701+ kwargs = dict (
702+ tmin = 0 ,
703+ tmax = 2 - 1.0 / inst .info ["sfreq" ],
704+ baseline = None ,
705+ verbose = "error" ,
706+ proj = False ,
707+ )
708+ epochs = Epochs (
709+ inst ,
710+ events ,
711+ reject = reject ,
712+ reject_by_annotation = reject_by_annotation ,
713+ preload = False ,
714+ ** kwargs ,
715+ ).drop_bad (verbose = "error" )
716+ # If all epochs were dropped, stitch the good segments according to
717+ # reject_by_annotation back together and get sources for those, subject to
718+ # the reject param
719+ if reject_by_annotation and len (epochs ) == 0 :
720+ good_data = inst .get_data (reject_by_annotation = "omit" )
721+ inst_stitched = RawArray (good_data , inst .info .copy (), verbose = "error" )
722+ events_stitched = make_fixed_length_events (inst_stitched , duration = 2 )
723+ epochs_stitched = Epochs (
724+ inst_stitched ,
725+ events_stitched ,
726+ reject = reject ,
727+ reject_by_annotation = False ,
728+ preload = False ,
729+ ** kwargs ,
730+ ).drop_bad (verbose = "error" )
731+ got_samps = len (epochs_stitched ) * len (epochs_stitched .times )
762732 min_samples = int (2 * inst .info ["sfreq" ])
763- if good_data .shape [1 ] >= min_samples :
764- inst_good = RawArray (good_data , inst_current .info .copy (), verbose = False )
765- epochs_src = make_fixed_length_epochs (
766- ica .get_sources (inst_good ),
767- duration = 2 ,
768- preload = True ,
769- reject_by_annotation = False ,
770- proj = False ,
771- verbose = False ,
772- )
773- # getting dropped epochs indexes
774- if drop_inds is not None :
775- dropped_indices = [(d [0 ] // len (epochs_src .times )) + 1 for d in drop_inds ]
733+ if got_samps >= min_samples :
734+ inst = inst_stitched
735+ events = events_stitched
736+ epochs = epochs_stitched
737+ epochs_src = Epochs (
738+ ica .get_sources (inst ),
739+ events ,
740+ # We have already rejected by annotation and reject above, but we don't
741+ # here so we can keep data for bad epochs around
742+ reject = None ,
743+ reject_by_annotation = False ,
744+ preload = True ,
745+ ** kwargs ,
746+ )
747+ bad_indices = np .where ([len (log ) for log in epochs .drop_log ])[0 ]
776748 kind = "Segment"
749+ assert len (epochs_src ) == len (epochs ) + len (bad_indices )
750+ if len (epochs_src ) == len (bad_indices ):
751+ raise RuntimeError (
752+ f"No clean 2-second segments found out of { len (events )} using "
753+ f"{ reject = } and { reject_by_annotation = } ."
754+ )
777755 else :
778- drop_inds = None
779756 epochs_src = ica .get_sources (inst )
780- dropped_indices = []
781757 kind = "Epochs"
782- return kind , dropped_indices , epochs_src , epochs_src . get_data ( copy = False )
758+ return kind , bad_indices , epochs_src
783759
784760
785761def _plot_ica_sources_evoked (evoked , picks , exclude , title , show , ica , labels = None ):
0 commit comments