Skip to content

Commit 0c31c92

Browse files
committed
Add function to plot statistical clusters on brain surface
1 parent a07f970 commit 0c31c92

5 files changed

Lines changed: 241 additions & 5 deletions

File tree

doc/api/visualization.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Visualization
6868
plot_volume_source_estimates
6969
plot_vector_source_estimates
7070
plot_sparse_source_estimates
71+
plot_stat_cluster
7172
plot_tfr_topomap
7273
plot_topo_image_epochs
7374
plot_topomap

mne/viz/_3d.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
)
8484
from ._dipole import _check_concat_dipoles, _plot_dipole_3d, _plot_dipole_mri_outlines
8585
from .evoked_field import EvokedField
86+
from .ui_events import subscribe
8687
from .utils import (
8788
_check_time_unit,
8889
_get_cmap,
@@ -4301,3 +4302,158 @@ def _get_3d_option(key):
43014302
else:
43024303
opt = opt.lower() == "true"
43034304
return opt
4305+
4306+
4307+
def plot_stat_cluster(cluster, src, brain, time="max-extent", color="magenta", width=1):
4308+
"""Plot the spatial extent of a cluster on top of a brain.
4309+
4310+
Parameters
4311+
----------
4312+
cluster : tuple (time_idx, vertex_idx)
4313+
The cluster to plot.
4314+
src : SourceSpaces
4315+
The source space that was used for the inverse computation.
4316+
brain : Brain
4317+
The brain figure on which to plot the cluster.
4318+
time : float | "interactive" | "max-extent"
4319+
The time (in seconds) at which to plot the spatial extent of the cluster.
4320+
If set to ``"interactive"`` the time will follow the selected time in the brain
4321+
figure.
4322+
By default, ``"max-extent"``, the time of maximal spatial extent is chosen.
4323+
color : str
4324+
A maplotlib-style color specification indicating the color to use when plotting
4325+
the spatial extent of the cluster.
4326+
width : int
4327+
The width of the lines used to draw the outlines.
4328+
4329+
Returns
4330+
-------
4331+
brain : Brain
4332+
The brain figure, now with the cluster plotted on top of it.
4333+
"""
4334+
# Here due to circular import
4335+
from ..label import Label
4336+
4337+
# args check
4338+
if isinstance(cluster, tuple):
4339+
if len(cluster) != 2:
4340+
raise ValueError(
4341+
"A cluster is a tuple of two elements, a list time "
4342+
"indices and list of vertex indices"
4343+
)
4344+
else:
4345+
raise TypeError(f"Tuple expected, got {type(cluster)} instead.")
4346+
4347+
cluster_time_idx, cluster_vertex_index = cluster
4348+
4349+
# A cluster is defined both in space and time. If we want to plot the boundaries of
4350+
# the cluster in space, we must choose a specific time for which to show the
4351+
# boundaries (as they change over time).
4352+
if time == "max-extent":
4353+
time_idx, n_vertices = np.unique(cluster_time_idx, return_counts=True)
4354+
time_idx = time_idx[np.argmax(n_vertices)]
4355+
elif time == "interactive":
4356+
time_idx = brain._data["time_idx"]
4357+
elif isinstance(time, float):
4358+
time_idx = np.searchsorted(brain._times[:-1], time)
4359+
else:
4360+
raise ValueError(
4361+
"Time should be 'max-extent', 'interactive', or floating point"
4362+
f" value, got '{time}' instead."
4363+
)
4364+
4365+
# Select only the vertex indices at the chosen time
4366+
draw_vertex_index = [
4367+
v for v, t in zip(cluster_vertex_index, cluster_time_idx) if t == time_idx
4368+
]
4369+
4370+
# Let's create an anatomical label containing these vertex indices.
4371+
# Problem 1): a label must be defined for either the left or right hemisphere. It
4372+
# cannot span both hemispheres. So we must filter the vertices based on their
4373+
# hemisphere.
4374+
# Problem 2): we have vertex *indices* that need to be transformed into proper
4375+
# vertex numbers. Not every vertex in the original high-resolution brain mesh is a
4376+
# source point in the source estimate. Do draw nice smooth curves, we need to
4377+
# interpolate the vertex indices.
4378+
4379+
# Both problems can be solved by accessing the vertices defined in the source space
4380+
# object. The source space object is actually a list of two source spaces.
4381+
src_lh, src_rh = src
4382+
4383+
# Split the vertices based on the hemisphere in which they are located.
4384+
lh_verts, rh_verts = src_lh["vertno"], src_rh["vertno"]
4385+
n_lh_verts = len(lh_verts)
4386+
draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts]
4387+
draw_rh_verts = [
4388+
rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts
4389+
]
4390+
4391+
# Vertices in a label must be unique and in increasing order
4392+
draw_lh_verts = np.unique(draw_lh_verts)
4393+
draw_rh_verts = np.unique(draw_rh_verts)
4394+
4395+
# We are now ready to create the anatomical label objects
4396+
cluster_index = 0
4397+
for label in brain.labels["lh"] + brain.labels["rh"]:
4398+
if label.name.startswith("cluster-"):
4399+
try:
4400+
cluster_index = max(cluster_index, int(label.name.split("-", 1)[1]))
4401+
except ValueError:
4402+
pass
4403+
lh_label = Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}")
4404+
rh_label = Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}")
4405+
4406+
# Interpolate the vertices in each label to the full resolution mesh
4407+
if len(lh_label) > 0:
4408+
lh_label = lh_label.smooth(
4409+
smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir
4410+
)
4411+
brain.add_label(lh_label, borders=width, color=color)
4412+
if len(rh_label) > 0:
4413+
rh_label = rh_label.smooth(
4414+
smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir
4415+
)
4416+
brain.add_label(rh_label, borders=width, color=color)
4417+
4418+
def on_time_change(event):
4419+
time_idx = np.searchsorted(brain._times, event.time)
4420+
for hemi in brain._hemis:
4421+
mesh = brain._layered_meshes[hemi]
4422+
for i, label in enumerate(brain.labels[hemi]):
4423+
if label.name == f"cluster-{cluster_index}":
4424+
del brain.labels[hemi][i]
4425+
mesh.remove_overlay(label.name)
4426+
4427+
# Select only the vertex indices at the chosen time
4428+
draw_vertex_index = [
4429+
v for v, t in zip(cluster_vertex_index, cluster_time_idx) if t == time_idx
4430+
]
4431+
draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts]
4432+
draw_rh_verts = [
4433+
rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts
4434+
]
4435+
4436+
# Vertices in a label must be unique and in increasing order
4437+
draw_lh_verts = np.unique(draw_lh_verts)
4438+
draw_rh_verts = np.unique(draw_rh_verts)
4439+
lh_label = Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}")
4440+
rh_label = Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}")
4441+
if len(lh_label) > 0:
4442+
lh_label = lh_label.smooth(
4443+
smooth=3,
4444+
subject=brain._subject,
4445+
subjects_dir=brain._subjects_dir,
4446+
verbose=False,
4447+
)
4448+
brain.add_label(lh_label, borders=width, color=color)
4449+
if len(rh_label) > 0:
4450+
rh_label = rh_label.smooth(
4451+
smooth=3,
4452+
subject=brain._subject,
4453+
subjects_dir=brain._subjects_dir,
4454+
verbose=False,
4455+
)
4456+
brain.add_label(rh_label, borders=width, color=color)
4457+
4458+
if time == "interactive":
4459+
subscribe(brain, "time_change", on_time_change)

mne/viz/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ __all__ = [
7272
"plot_source_estimates",
7373
"plot_source_spectrogram",
7474
"plot_sparse_source_estimates",
75+
"plot_stat_cluster",
7576
"plot_tfr_topomap",
7677
"plot_topo_image_epochs",
7778
"plot_topomap",
@@ -97,6 +98,7 @@ from ._3d import (
9798
plot_head_positions,
9899
plot_source_estimates,
99100
plot_sparse_source_estimates,
101+
plot_stat_cluster,
100102
plot_vector_source_estimates,
101103
plot_volume_source_estimates,
102104
set_3d_options,

mne/viz/tests/test_3d.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
plot_head_positions,
5050
plot_source_estimates,
5151
plot_sparse_source_estimates,
52+
plot_stat_cluster,
5253
snapshot_brain_montage,
5354
)
5455
from mne.viz._3d import _get_map_ticks, _linearize_map, _process_clim
@@ -1413,3 +1414,53 @@ def test_link_brains(renderer_interactive):
14131414
with pytest.raises(TypeError, match="type is Brain"):
14141415
link_brains("foo")
14151416
link_brains(brain, time=True, camera=True)
1417+
1418+
1419+
@testing.requires_testing_data
1420+
def test_plot_stat_cluster(renderer_interactive):
1421+
"""Test plotting clusters on brain in static and interactive mode."""
1422+
pytest.importorskip("nibabel")
1423+
sample_src = read_source_spaces(src_fname)
1424+
vertices = [s["vertno"] for s in sample_src]
1425+
n_time = 5
1426+
n_verts = sum(len(v) for v in vertices)
1427+
1428+
# simulate stc data
1429+
stc_data = np.zeros(n_verts * n_time)
1430+
stc_size = stc_data.size
1431+
stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = (
1432+
np.random.RandomState(0).rand(stc_data.size // 20)
1433+
)
1434+
stc_data.shape = (n_verts, n_time)
1435+
stc = SourceEstimate(stc_data, vertices, 1, 1)
1436+
1437+
# Simulate a cluster
1438+
cluster_time_idx = [1, 1, 2, 3]
1439+
cluster_vertex_idx = [0, 1, 2, 3]
1440+
cluster = (cluster_time_idx, cluster_vertex_idx)
1441+
1442+
brain = plot_source_estimates(
1443+
stc,
1444+
"sample",
1445+
background=(1, 1, 0),
1446+
subjects_dir=subjects_dir,
1447+
colorbar=True,
1448+
clim="auto",
1449+
)
1450+
# Test for incorrect argument in time
1451+
with pytest.raises(ValueError):
1452+
plot_stat_cluster(cluster, sample_src, brain, "foo")
1453+
1454+
# test for incorrect shape of cluster
1455+
with pytest.raises(ValueError):
1456+
plot_stat_cluster(([1]), sample_src, brain)
1457+
1458+
# test for incorrect data type of cluster
1459+
with pytest.raises(ValueError):
1460+
plot_stat_cluster([[1, 2, 3], [1, 2, 3]], sample_src, brain)
1461+
1462+
# All correct
1463+
plot_stat_cluster(cluster, sample_src, brain)
1464+
1465+
brain.close()
1466+
del brain

tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from mne.epochs import equalize_epoch_counts
3030
from mne.minimum_norm import apply_inverse, read_inverse_operator
3131
from mne.stats import spatio_temporal_cluster_1samp_test, summarize_clusters_stc
32+
from mne.viz import plot_stat_cluster
3233

3334
# %%
3435
# Set parameters
@@ -142,19 +143,18 @@
142143
# Read the source space we are morphing to
143144
src = mne.read_source_spaces(src_fname)
144145
fsave_vertices = [s["vertno"] for s in src]
145-
morph_mat = mne.compute_source_morph(
146+
morph = mne.compute_source_morph(
146147
src=inverse_operator["src"],
147148
subject_to="fsaverage",
148149
spacing=fsave_vertices,
149150
subjects_dir=subjects_dir,
150-
).morph_mat
151-
152-
n_vertices_fsave = morph_mat.shape[0]
151+
)
152+
n_vertices_fsave = morph.morph_mat.shape[0]
153153

154154
# We have to change the shape for the dot() to work properly
155155
X = X.reshape(n_vertices_sample, n_times * n_subjects * 2)
156156
print("Morphing data.")
157-
X = morph_mat.dot(X) # morph_mat is a sparse matrix
157+
X = morph.morph_mat.dot(X) # morph_mat is a sparse matrix
158158
X = X.reshape(n_vertices_fsave, n_times, n_subjects, 2)
159159

160160
# %%
@@ -264,3 +264,29 @@
264264

265265
# We could save this via the following:
266266
# brain.save_image('clusters.png')
267+
268+
# %%
269+
# Alternatively, you may wish to observe clusters are considered statistically
270+
# significant under the permutation distribution with resect all the source estimates.
271+
# This can easily be done by plotting the cluster boundary on top of the source
272+
# estimates using the code snippet below.
273+
# ----------------------------------------------------------
274+
275+
difference = morph.apply(condition1 - condition2)
276+
difference_plot = difference.plot(
277+
hemi="both",
278+
views="lateral",
279+
subjects_dir=subjects_dir,
280+
size=(800, 800),
281+
initial_time=0.1,
282+
)
283+
284+
# We are plotting only 1st clusters here for illustration purpose.
285+
plot_stat_cluster(
286+
good_clusters[0], src, difference_plot, time="max-extent", color="magenta", width=1
287+
)
288+
289+
# Plotting second cluster on the interactive mode for illustration purpose.
290+
plot_stat_cluster(
291+
good_clusters[1], src, difference_plot, time="interactive", color="magenta", width=1
292+
)

0 commit comments

Comments
 (0)