Skip to content

Commit 893b784

Browse files
authored
Fix bug with ica.plot_properties (#13885)
1 parent b5cfff0 commit 893b784

7 files changed

Lines changed: 147 additions & 129 deletions

File tree

doc/changes/dev/13885.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with :meth:`mne.preprocessing.ICA.plot_properties` when using ``reject`` in :meth:`mne.preprocessing.ICA.fit`, by `Eric Larson`_.

examples/preprocessing/find_ref_artifacts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
on the reference channels are removed.
2525
2626
This technique is fully described and validated in :footcite:`HannaEtAl2020`
27-
2827
"""
2928
# Authors: Jeff Hanna <jeff.hanna@gmail.com>
3029
#
@@ -78,6 +77,7 @@
7877
ica_kwargs = dict(
7978
method="picard",
8079
fit_params=dict(tol=1e-4), # use a high tol here for speed
80+
random_state=99,
8181
)
8282
all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True)
8383
ica_tog = ICA(n_components=60, max_iter="auto", allow_ref_meg=True, **ica_kwargs)

examples/preprocessing/muscle_ica.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@
4848

4949
# %%
5050
# By inspection, let's select out the muscle-artifact components based on
51-
# :footcite:`DharmapraniEtAl2016` manually.
52-
#
53-
# The criteria are:
51+
# :footcite:`DharmapraniEtAl2016` manually. The criteria are:
5452
#
5553
# - Positive slope of log-log power spectrum between 7 and 75 Hz
5654
# (here just flat because it's not in log-log)

mne/_fiff/pick.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ def _picks_to_idx(
12381238
extra_repr = ", treated as range({n_chan})"
12391239
else:
12401240
picks = none # let _picks_str_to_idx handle it
1241-
extra_repr = f'None, treated as "{none}"'
1241+
extra_repr = f', treated as "{none}"'
12421242

12431243
#
12441244
# slice

mne/viz/ica.py

Lines changed: 96 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .._fiff.proj import _has_eeg_average_ref_proj
1616
from ..defaults import DEFAULTS, _handle_default
1717
from ..utils import (
18-
_reject_data_segments,
1918
_validate_type,
2019
fill_doc,
2120
verbose,
@@ -202,13 +201,10 @@ def _create_properties_layout(figsize=None, fig=None):
202201
def _plot_ica_properties(
203202
pick,
204203
ica,
205-
inst,
206204
psds_mean,
207205
freqs,
208-
n_trials,
209-
epoch_var,
210206
plot_lowpass_edge,
211-
epochs_src,
207+
this_epochs_src,
212208
set_title_and_labels,
213209
plot_std,
214210
psd_ylabel,
@@ -219,7 +215,7 @@ def _plot_ica_properties(
219215
fig,
220216
axes,
221217
kind,
222-
dropped_indices,
218+
bad_indices,
223219
):
224220
"""Plot ICA properties (helper)."""
225221
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
@@ -237,23 +233,15 @@ def _plot_ica_properties(
237233
)
238234

239235
# image and erp
240-
# we create a new epoch with dropped rows
241-
epoch_data = epochs_src.get_data(copy=False)
242-
epoch_data = np.insert(
243-
arr=epoch_data,
244-
obj=(dropped_indices - np.arange(len(dropped_indices))).astype(int),
245-
values=0.0,
246-
axis=0,
247-
)
248-
from ..epochs import EpochsArray
249-
250-
epochs_src = EpochsArray(
251-
epoch_data, epochs_src.info, tmin=epochs_src.tmin, verbose=0
252-
)
253-
236+
n_trials = len(this_epochs_src)
237+
epoch_var = np.var(this_epochs_src.get_data(), axis=-1)
238+
assert epoch_var.shape[1] == 1 # single channel
239+
epoch_var = epoch_var[:, 0]
240+
assert epoch_var.shape == (len(this_epochs_src),)
241+
this_epochs_src._data[bad_indices] = 0
254242
plot_epochs_image(
255-
epochs_src,
256-
picks=pick,
243+
this_epochs_src,
244+
picks=[0],
257245
axes=[image_ax, erp_ax],
258246
combine=None,
259247
colorbar=False,
@@ -273,44 +261,41 @@ def _plot_ica_properties(
273261
)
274262
if plot_lowpass_edge:
275263
spec_ax.axvline(
276-
inst.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2
264+
this_epochs_src.info["lowpass"], lw=2, linestyle="--", color="k", alpha=0.2
277265
)
278266

279267
# epoch variance
268+
good_indices = np.setdiff1d(np.arange(n_trials), bad_indices)
280269
var_ax_divider = make_axes_locatable(var_ax)
281-
hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%")
282-
var_ax.scatter(
283-
range(len(epoch_var)), epoch_var, alpha=0.5, facecolor=[0, 0, 0], lw=0
284-
)
270+
hist_ax = var_ax_divider.append_axes("right", size="33%", pad="2.5%", sharey=var_ax)
271+
facecolor = np.zeros((len(epoch_var), 3))
272+
alpha = np.full(len(epoch_var), 0.5)
285273
# rejected epochs in red
274+
facecolor[bad_indices] = [1, 0, 0]
275+
alpha[bad_indices] = 0.75
286276
var_ax.scatter(
287-
dropped_indices,
288-
epoch_var[dropped_indices],
289-
alpha=1.0,
290-
facecolor=[1, 0, 0],
291-
lw=0,
277+
np.arange(n_trials), epoch_var, alpha=alpha, facecolor=facecolor, lw=0
292278
)
293279
# compute percentage of dropped epochs
294-
var_percent = float(len(dropped_indices)) / float(len(epoch_var)) * 100.0
280+
var_percent = 100 * len(bad_indices) / n_trials
295281

296282
# histogram & histogram
283+
epoch_var_good = epoch_var[good_indices]
297284
_, counts, _ = hist_ax.hist(
298-
epoch_var, orientation="horizontal", color="k", alpha=0.5
285+
epoch_var_good, orientation="horizontal", color="k", alpha=0.5
299286
)
300287

301288
# kde
302-
ymin, ymax = hist_ax.get_ylim()
303289
try:
304-
kde = gaussian_kde(epoch_var)
290+
kde = gaussian_kde(epoch_var_good)
305291
except np.linalg.LinAlgError:
306292
pass # singular: happens when there is nothing plotted
307293
else:
308-
x = np.linspace(ymin, ymax, 50)
294+
x = np.linspace(epoch_var_good.min(), epoch_var_good.max(), 50)
309295
kde_ = kde(x)
310296
kde_ /= kde_.max() or 1.0
311297
kde_ *= hist_ax.get_xlim()[-1] * 0.9
312298
hist_ax.plot(kde_, x, color="k")
313-
hist_ax.set_ylim(ymin, ymax)
314299

315300
# aesthetics
316301
# ----------
@@ -319,16 +304,16 @@ def _plot_ica_properties(
319304
# erp
320305
set_title_and_labels(erp_ax, [], "Time (s)", "AU")
321306
erp_ax.spines["right"].set_color("k")
322-
erp_ax.set_xlim(epochs_src.times[[0, -1]])
307+
erp_ax.set_xlim(this_epochs_src.times[[0, -1]])
323308
# remove half of yticks if more than 5
324309
yt = erp_ax.get_yticks()
325310
if len(yt) > 5:
326-
erp_ax.yaxis.set_ticks(yt[::2])
311+
erp_ax.set_yticks(yt[::2])
327312

328313
# remove xticks - erp plot shows xticks for both image and erp plot
329-
image_ax.xaxis.set_ticks([])
314+
image_ax.set_xticks([])
330315
yt = image_ax.get_yticks()
331-
image_ax.yaxis.set_ticks(yt[1:])
316+
image_ax.set_yticks(yt[1:])
332317
image_ax.set_ylim([-0.5, n_trials + 0.5])
333318

334319
def _set_scale(ax, scale):
@@ -342,10 +327,6 @@ def _set_scale(ax, scale):
342327
set_title_and_labels(spec_ax, "Spectrum", "Frequency (Hz)", psd_ylabel)
343328
spec_ax.yaxis.labelpad = 0
344329
spec_ax.set_xlim(freqs[[0, -1]])
345-
ylim = spec_ax.get_ylim()
346-
air = np.diff(ylim)[0] * 0.1
347-
spec_ax.set_ylim(ylim[0] - air, ylim[1] + air)
348-
image_ax.axhline(0, color="k", linewidth=0.5)
349330
if log_scale:
350331
_set_scale(spec_ax, "log")
351332

@@ -603,24 +584,24 @@ def _fast_plot_ica_properties(
603584
# calculations
604585
# ------------
605586
if isinstance(precomputed_data, tuple):
606-
kind, dropped_indices, epochs_src, data = precomputed_data
587+
kind, bad_indices, epochs_src = precomputed_data
607588
else:
608-
kind, dropped_indices, epochs_src, data = _prepare_data_ica_properties(
589+
kind, bad_indices, epochs_src = _prepare_data_ica_properties(
609590
inst, ica, reject_by_annotation, reject
610591
)
611-
del reject
612-
ica_data = np.swapaxes(data[:, picks, :], 0, 1)
613-
dropped_src = ica_data
592+
del reject, inst
593+
epochs_src_picked = epochs_src.pick(picks)
594+
del epochs_src
595+
good_indices = np.setdiff1d(np.arange(len(epochs_src_picked)), bad_indices)
614596

615597
# spectrum
616-
Nyquist = inst.info["sfreq"] / 2.0
617-
lp = inst.info["lowpass"]
598+
Nyquist = epochs_src_picked.info["sfreq"] / 2.0
599+
lp = epochs_src_picked.info["lowpass"]
618600
if "fmax" not in psd_args:
619601
psd_args["fmax"] = min(lp * 1.25, Nyquist)
620602
plot_lowpass_edge = lp < Nyquist and (psd_args["fmax"] > lp)
621-
spectrum = epochs_src.compute_psd(picks=picks, **psd_args)
622-
# we've already restricted picks ↑↑↑↑↑↑↑↑↑↑↑
623-
# in the spectrum object, so here we do picks=all ↓↓↓↓↓↓↓↓↓↓↓
603+
# we've already restricted picks in epochs_src_picked, so here we do picks=all
604+
spectrum = epochs_src_picked[good_indices].compute_psd(picks="all", **psd_args)
624605
psds, freqs = spectrum.get_data(return_freqs=True, picks="all", exclude=[])
625606
# we also pass exclude=[] so that when this is called by right-clicking in
626607
# a plot_sources() window on an ICA component name that has been marked as
@@ -654,30 +635,14 @@ def set_title_and_labels(ax, title, xlab, ylab):
654635
if idx > 0:
655636
fig, axes = _create_properties_layout(figsize=figsize)
656637

657-
# we reconstruct an epoch_variance with 0 where indexes where dropped
658-
epoch_var = np.var(ica_data[idx], axis=1)
659-
drop_var = np.var(dropped_src[idx], axis=1)
660-
drop_indices_corrected = (
661-
dropped_indices - np.arange(len(dropped_indices))
662-
).astype(int)
663-
epoch_var = np.insert(
664-
arr=epoch_var,
665-
obj=drop_indices_corrected,
666-
values=drop_var[dropped_indices],
667-
axis=0,
668-
)
669-
670638
# the actual plot
671639
fig = _plot_ica_properties(
672640
pick,
673641
ica,
674-
inst,
675642
psds_mean,
676643
freqs,
677-
ica_data.shape[1],
678-
epoch_var,
679644
plot_lowpass_edge,
680-
epochs_src,
645+
epochs_src_picked.copy().pick(picks=[idx]),
681646
set_title_and_labels,
682647
plot_std,
683648
psd_ylabel,
@@ -688,7 +653,7 @@ def set_title_and_labels(ax, title, xlab, ylab):
688653
fig,
689654
axes,
690655
kind,
691-
dropped_indices,
656+
bad_indices,
692657
)
693658
all_fig.append(fig)
694659

@@ -721,65 +686,76 @@ def _prepare_data_ica_properties(inst, ica, reject_by_annotation=True, reject="a
721686
data : array of shape (n_epochs, n_ica_sources, n_times)
722687
A view on epochs ICA sources data.
723688
"""
724-
from ..epochs import BaseEpochs
689+
from ..epochs import BaseEpochs, Epochs, make_fixed_length_events
725690
from ..io import BaseRaw, RawArray
726691

727692
_validate_type(inst, (BaseRaw, BaseEpochs), "inst", "Raw or Epochs")
693+
bad_indices = []
728694
if isinstance(inst, BaseRaw):
729695
# when auto, delegate reject to the ica
730-
from ..epochs import make_fixed_length_epochs
731696

732697
if reject == "auto":
733698
reject = ica.reject_
734-
drop_inds = None
735-
dropped_indices = []
736-
if reject is None:
737-
inst_current = inst
738-
else:
739-
data = inst.get_data()
740-
data, drop_inds = _reject_data_segments(
741-
data, reject, flat=None, decim=None, info=inst.info, tstep=2.0
742-
)
743-
inst_current = RawArray(data, inst.info)
744-
# break up continuous signal into segments; suppress "All epochs were
745-
# dropped!" because we handle that case gracefully below
746-
with warnings.catch_warnings():
747-
warnings.filterwarnings(
748-
"ignore", "All epochs were dropped!", RuntimeWarning
749-
)
750-
epochs_src = make_fixed_length_epochs(
751-
ica.get_sources(inst_current),
752-
duration=2,
753-
preload=True,
754-
reject_by_annotation=reject_by_annotation,
755-
proj=False,
756-
verbose=False,
757-
)
758-
# if all epochs were dropped by annotations, stitch the good segments
759-
# together so that the plot can still be generated
760-
if reject_by_annotation and len(epochs_src) == 0:
761-
good_data = inst_current.get_data(reject_by_annotation="omit")
699+
# First we try making epochs in the normal way and see if we have enough
700+
events = make_fixed_length_events(inst, duration=2)
701+
kwargs = dict(
702+
tmin=0,
703+
tmax=2 - 1.0 / inst.info["sfreq"],
704+
baseline=None,
705+
verbose="error",
706+
proj=False,
707+
)
708+
epochs = Epochs(
709+
inst,
710+
events,
711+
reject=reject,
712+
reject_by_annotation=reject_by_annotation,
713+
preload=False,
714+
**kwargs,
715+
).drop_bad(verbose="error")
716+
# If all epochs were dropped, stitch the good segments according to
717+
# reject_by_annotation back together and get sources for those, subject to
718+
# the reject param
719+
if reject_by_annotation and len(epochs) == 0:
720+
good_data = inst.get_data(reject_by_annotation="omit")
721+
inst_stitched = RawArray(good_data, inst.info.copy(), verbose="error")
722+
events_stitched = make_fixed_length_events(inst_stitched, duration=2)
723+
epochs_stitched = Epochs(
724+
inst_stitched,
725+
events_stitched,
726+
reject=reject,
727+
reject_by_annotation=False,
728+
preload=False,
729+
**kwargs,
730+
).drop_bad(verbose="error")
731+
got_samps = len(epochs_stitched) * len(epochs_stitched.times)
762732
min_samples = int(2 * inst.info["sfreq"])
763-
if good_data.shape[1] >= min_samples:
764-
inst_good = RawArray(good_data, inst_current.info.copy(), verbose=False)
765-
epochs_src = make_fixed_length_epochs(
766-
ica.get_sources(inst_good),
767-
duration=2,
768-
preload=True,
769-
reject_by_annotation=False,
770-
proj=False,
771-
verbose=False,
772-
)
773-
# getting dropped epochs indexes
774-
if drop_inds is not None:
775-
dropped_indices = [(d[0] // len(epochs_src.times)) + 1 for d in drop_inds]
733+
if got_samps >= min_samples:
734+
inst = inst_stitched
735+
events = events_stitched
736+
epochs = epochs_stitched
737+
epochs_src = Epochs(
738+
ica.get_sources(inst),
739+
events,
740+
# We have already rejected by annotation and reject above, but we don't
741+
# here so we can keep data for bad epochs around
742+
reject=None,
743+
reject_by_annotation=False,
744+
preload=True,
745+
**kwargs,
746+
)
747+
bad_indices = np.where([len(log) for log in epochs.drop_log])[0]
776748
kind = "Segment"
749+
assert len(epochs_src) == len(epochs) + len(bad_indices)
750+
if len(epochs_src) == len(bad_indices):
751+
raise RuntimeError(
752+
f"No clean 2-second segments found out of {len(events)} using "
753+
f"{reject=} and {reject_by_annotation=}."
754+
)
777755
else:
778-
drop_inds = None
779756
epochs_src = ica.get_sources(inst)
780-
dropped_indices = []
781757
kind = "Epochs"
782-
return kind, dropped_indices, epochs_src, epochs_src.get_data(copy=False)
758+
return kind, bad_indices, epochs_src
783759

784760

785761
def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=None):

0 commit comments

Comments
 (0)