Skip to content

Commit 8b6fa06

Browse files
larsonersseth
authored andcommitted
BUG: Fix bug with plot_white (mne-tools#13595)
1 parent 90b5e94 commit 8b6fa06

9 files changed

Lines changed: 76 additions & 39 deletions

File tree

azure-pipelines.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ stages:
114114
- bash: |
115115
set -e
116116
python -m pip install --progress-bar off --upgrade pip
117-
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git" "git+https://github.com/python-quantities/python-quantities" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1,!=6.9.1" pandas neo pymatreader antio defusedxml curryreader pymef
117+
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1,!=6.9.1" pandas neo pymatreader antio defusedxml curryreader pymef
118118
python -m pip uninstall -yq mne
119119
python -m pip install --progress-bar off --upgrade -e . --group=test
120120
displayName: 'Install dependencies with pip'
@@ -173,7 +173,7 @@ stages:
173173
python -m pip install --progress-bar off --upgrade pip
174174
python -m pip install --progress-bar off --upgrade --pre --only-binary=\"numpy,scipy,matplotlib,vtk\" numpy scipy matplotlib vtk
175175
python -c "import vtk"
176-
python -m pip install --progress-bar off --upgrade -ve .[full] --group=test_extra "git+https://github.com/python-quantities/python-quantities"
176+
python -m pip install --progress-bar off --upgrade -ve .[full] --group=test_extra
177177
displayName: 'Install dependencies with pip'
178178
- bash: |
179179
set -e

doc/changes/dev/13595.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where :func:`mne.viz.plot_evoked_white` did not accept a single "meg" rank value like those returned from :func:`mne.compute_rank`, by `Eric Larson`_.

mne/minimum_norm/tests/test_inverse.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
EvokedArray,
2424
SourceEstimate,
2525
combine_evoked,
26+
compute_rank,
2627
compute_raw_covariance,
2728
convert_forward_solution,
2829
make_ad_hoc_cov,
@@ -993,21 +994,34 @@ def test_make_inverse_operator_diag(evoked, noise_cov, tmp_path, azure_windows):
993994

994995
def test_inverse_operator_noise_cov_rank(evoked, noise_cov):
995996
"""Test MNE inverse operator with a specified noise cov rank."""
996-
fwd_op = read_forward_solution_meg(fname_fwd, surf_ori=True)
997-
inv = make_inverse_operator(evoked.info, fwd_op, noise_cov, rank=dict(meg=64))
997+
fwd_op_meg = read_forward_solution_meg(fname_fwd, surf_ori=True)
998+
inv = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank=dict(meg=64))
998999
assert compute_rank_inverse(inv) == 64
999-
inv = make_inverse_operator(evoked.info, fwd_op, noise_cov, rank=dict(meg=64))
1000+
inv = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank=dict(meg=64))
10001001
assert compute_rank_inverse(inv) == 64
10011002

10021003
bad_cov = noise_cov.copy()
10031004
bad_cov["data"][0, 0] *= 1e12
10041005
with pytest.warns(RuntimeWarning, match="orders of magnitude"):
1005-
make_inverse_operator(evoked.info, fwd_op, bad_cov, rank=dict(meg=64))
1006+
make_inverse_operator(evoked.info, fwd_op_meg, bad_cov, rank=dict(meg=64))
10061007

1007-
fwd_op = read_forward_solution_eeg(fname_fwd, surf_ori=True)
1008-
inv = make_inverse_operator(evoked.info, fwd_op, noise_cov, rank=dict(eeg=20))
1008+
fwd_op_eeg = read_forward_solution_eeg(fname_fwd, surf_ori=True)
1009+
inv = make_inverse_operator(evoked.info, fwd_op_eeg, noise_cov, rank=dict(eeg=20))
10091010
assert compute_rank_inverse(inv) == 20
10101011

1012+
# with and without rank passed explicitly
1013+
inv_info = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank="info")
1014+
info_rank = 302
1015+
assert compute_rank_inverse(inv_info) == info_rank
1016+
rank = compute_rank(noise_cov, info=evoked.copy().pick("meg").info, rank="info")
1017+
assert "meg" in rank
1018+
assert sum(rank.values()) == info_rank
1019+
inv_rank = make_inverse_operator(evoked.info, fwd_op_meg, noise_cov, rank=rank)
1020+
assert compute_rank_inverse(inv_rank) == info_rank
1021+
evoked_info = apply_inverse(evoked, inv_info, lambda2, "MNE")
1022+
evoked_rank = apply_inverse(evoked, inv_rank, lambda2, "MNE")
1023+
assert_allclose(evoked_rank.data, evoked_info.data)
1024+
10111025

10121026
def test_inverse_operator_volume(evoked, tmp_path):
10131027
"""Test MNE inverse computation on volume source space."""

mne/tests/test_cov.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def test_compute_whitener(proj, pca):
9494
assert pca is False
9595
assert_allclose(round_trip, np.eye(n_channels), atol=0.05)
9696

97+
# with and without rank
98+
W_info, _ = compute_whitener(cov, raw.info, pca=pca, rank="info", verbose="error")
99+
assert_allclose(W_info, W)
100+
rank = compute_rank(raw, rank="info", proj=proj)
101+
assert W.shape == (n_reduced, n_channels)
102+
W_rank, _ = compute_whitener(cov, raw.info, pca=pca, rank=rank, verbose="error")
103+
assert_allclose(W_rank, W)
104+
97105
raw.info["bads"] = [raw.ch_names[0]]
98106
picks = pick_types(raw.info, meg=True, eeg=True, exclude=[])
99107
with pytest.warns(RuntimeWarning, match="Too few samples"):

mne/viz/evoked.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,13 +1587,9 @@ def plot_evoked_white(
15871587
evoked.del_proj(idx)
15881588

15891589
evoked.pick_types(ref_meg=False, exclude="bads", **_PICK_TYPES_DATA_DICT)
1590-
n_ch_used, rank_list, picks_list, has_sss = _triage_rank_sss(
1590+
n_ch_used, rank_list, picks_list, meg_combined = _triage_rank_sss(
15911591
evoked.info, noise_cov, rank, scalings=None
15921592
)
1593-
if has_sss:
1594-
logger.info(
1595-
"SSS has been applied to data. Showing mag and grad whitening jointly."
1596-
)
15971593

15981594
# get one whitened evoked per cov
15991595
evokeds_white = [
@@ -1663,8 +1659,8 @@ def whitened_gfp(x, rank=None):
16631659
# hacks to get it to plot all channels in the same axes, namely setting
16641660
# the channel unit (most important) and coil type (for consistency) of
16651661
# all MEG channels to be the same.
1666-
meg_idx = sss_title = None
1667-
if has_sss:
1662+
meg_idx = combined_title = None
1663+
if meg_combined:
16681664
titles_["meg"] = "MEG (combined)"
16691665
meg_idx = [
16701666
pi for pi, (ch_type, _) in enumerate(picks_list) if ch_type == "meg"
@@ -1675,7 +1671,7 @@ def whitened_gfp(x, rank=None):
16751671
use = evokeds_white[0].info["chs"][picks[0]][key]
16761672
for pick in picks:
16771673
evokeds_white[0].info["chs"][pick][key] = use
1678-
sss_title = f"{titles_['meg']} ({len(picks)} channel{_pl(picks)})"
1674+
combined_title = f"{titles_['meg']} ({len(picks)} channel{_pl(picks)})"
16791675
evokeds_white[0].plot(
16801676
unit=False,
16811677
axes=axes_evoked,
@@ -1684,8 +1680,8 @@ def whitened_gfp(x, rank=None):
16841680
time_unit=time_unit,
16851681
spatial_colors=spatial_colors,
16861682
)
1687-
if has_sss:
1688-
axes_evoked[meg_idx].set(title=sss_title)
1683+
if meg_combined:
1684+
axes_evoked[meg_idx].set(title=combined_title)
16891685

16901686
# Now plot the GFP for all covs if indicated.
16911687
for evoked_white, noise_cov, rank_, color in iter_gfp:

mne/viz/tests/test_evoked.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Epochs,
1919
compute_covariance,
2020
compute_proj_evoked,
21+
compute_rank,
2122
make_fixed_length_events,
2223
read_cov,
2324
read_events,
@@ -357,6 +358,21 @@ def test_plot_evoked_image():
357358
evoked.plot_image(clim=[-4, 4])
358359

359360

361+
def test_plot_white_rank():
362+
"""Test plot_white with a combined-MEG rank arg."""
363+
cov = read_cov(cov_fname)
364+
cov["method"] = "empirical"
365+
cov["projs"] = [] # avoid warnings
366+
evoked = _get_epochs().average()
367+
evoked.set_eeg_reference("average") # Avoid warnings
368+
rank = compute_rank(evoked, "info")
369+
assert "grad" not in rank
370+
assert "mag" not in rank
371+
assert "meg" in rank
372+
evoked.plot_white(cov)
373+
evoked.plot_white(cov, rank=rank)
374+
375+
360376
def test_plot_white():
361377
"""Test plot_white."""
362378
cov = read_cov(cov_fname)
@@ -373,9 +389,9 @@ def test_plot_white():
373389
evoked.plot_white(cov, rank={"grad": 8}, time_unit="s", axes=fig.axes[:4])
374390
with pytest.raises(ValueError, match=r"must have shape \(4,\), got \(2,"):
375391
evoked.plot_white(cov, axes=fig.axes[:2])
376-
with pytest.raises(ValueError, match="When not using SSS"):
392+
with pytest.raises(ValueError, match="exceeds the number"):
377393
evoked.plot_white(cov, rank={"meg": 306})
378-
evoked.plot_white([cov, cov], time_unit="s")
394+
evoked.plot_white([cov, cov], rank={"meg": 9}, time_unit="s")
379395
plt.close("all")
380396

381397
fig = plot_evoked_white(evoked, [cov, cov])

mne/viz/utils.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,13 +2045,13 @@ def _setup_plot_projector(info, noise_cov, proj=True, use_noise_cov=True, nave=1
20452045
def _check_sss(info):
20462046
"""Check SSS history in info."""
20472047
ch_used = [ch for ch in _DATA_CH_TYPES_SPLIT if _contains_ch_type(info, ch)]
2048-
has_meg = "mag" in ch_used and "grad" in ch_used
2049-
has_sss = (
2050-
has_meg
2048+
has_mag_and_grad = "mag" in ch_used and "grad" in ch_used
2049+
needs_meg_combined = (
2050+
has_mag_and_grad
20512051
and len(info["proc_history"]) > 0
20522052
and info["proc_history"][0].get("max_info") is not None
20532053
)
2054-
return ch_used, has_meg, has_sss
2054+
return ch_used, has_mag_and_grad, needs_meg_combined
20552055

20562056

20572057
def _triage_rank_sss(info, covs, rank=None, scalings=None):
@@ -2061,22 +2061,28 @@ def _triage_rank_sss(info, covs, rank=None, scalings=None):
20612061
# Only look at good channels
20622062
picks = _pick_data_channels(info, with_ref_meg=False, exclude="bads")
20632063
info = pick_info(info, picks)
2064-
ch_used, has_meg, has_sss = _check_sss(info)
2065-
if has_sss:
2064+
ch_used, has_mag_and_grad, needs_meg_combined = _check_sss(info)
2065+
if needs_meg_combined:
20662066
if "mag" in rank or "grad" in rank:
20672067
raise ValueError(
20682068
'When using SSS, pass "meg" to set the rank '
20692069
'(separate rank values for "mag" or "grad" are '
20702070
"meaningless)."
20712071
)
2072+
meg_combined = True
20722073
elif "meg" in rank:
2073-
raise ValueError(
2074-
"When not using SSS, pass separate rank values "
2075-
'for "mag" and "grad" (do not use "meg").'
2076-
)
2074+
if needs_meg_combined:
2075+
start = "SSS has been applied to data"
2076+
else:
2077+
start = "Got a single MEG rank value"
2078+
logger.info("%s. Showing mag and grad whitening jointly.", start)
2079+
meg_combined = True
2080+
else:
2081+
meg_combined = False
2082+
del needs_meg_combined
20772083

2078-
picks_list = _picks_by_type(info, meg_combined=has_sss)
2079-
if has_sss:
2084+
picks_list = _picks_by_type(info, meg_combined=meg_combined)
2085+
if meg_combined:
20802086
# reduce ch_used to combined mag grad
20812087
ch_used = list(zip(*picks_list))[0]
20822088
# order pick list by ch_used (required for compat with plot_evoked)
@@ -2087,7 +2093,7 @@ def _triage_rank_sss(info, covs, rank=None, scalings=None):
20872093

20882094
picks_list2 = [k for k in picks_list]
20892095
# add meg picks if needed.
2090-
if has_meg:
2096+
if has_mag_and_grad:
20912097
# append ("meg", picks_meg)
20922098
picks_list2 += _picks_by_type(info, meg_combined=True)
20932099

@@ -2120,7 +2126,7 @@ def _triage_rank_sss(info, covs, rank=None, scalings=None):
21202126
this_rank[ch_type] = rank[ch_type]
21212127

21222128
rank_list.append(this_rank)
2123-
return n_ch_used, rank_list, picks_list, has_sss
2129+
return n_ch_used, rank_list, picks_list, meg_combined
21242130

21252131

21262132
def _check_cov(noise_cov, info):

tools/azure_dependencies.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
55
STD_ARGS="--progress-bar off --upgrade "
66
python -m pip install $STD_ARGS pip setuptools wheel
77
if [ "${TEST_MODE}" == "pip" ]; then
8-
python -m pip install $STD_ARGS --only-binary="numba,llvmlite,numpy,scipy,vtk,dipy,openmeeg" -e .[full] --group=test git+https://github.com/python-quantities/python-quantities
8+
python -m pip install $STD_ARGS --only-binary="numba,llvmlite,numpy,scipy,vtk,dipy,openmeeg" -e .[full] --group=test
99
elif [ "${TEST_MODE}" == "pip-pre" ]; then
1010
${SCRIPT_DIR}/install_pre_requirements.sh
1111
python -m pip install $STD_ARGS --pre -e . --group=test_extra

tools/github_actions_dependencies.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ else
3636
EXTRAS=""
3737
fi
3838
echo ""
39-
# until quantities releases...
40-
if [[ "${MNE_CI_KIND}" != "old" ]]; then
41-
STD_ARGS="$STD_ARGS git+https://github.com/python-quantities/python-quantities"
42-
fi
4339

4440
echo "::group::Installing test dependencies using pip"
4541
set -x

0 commit comments

Comments
 (0)