Skip to content

Commit eb78548

Browse files
ENH: group triaxial OPM topomaps by orientation
1 parent d1e3cb2 commit eb78548

5 files changed

Lines changed: 250 additions & 99 deletions

File tree

mne/viz/evoked.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,7 @@ def plot_evoked_joint(
18681868
ts_args.get("time_unit", "s"), evoked.times
18691869
)
18701870
topomap_args = dict() if topomap_args is None else topomap_args.copy()
1871+
opm_group_factor = 1
18711872

18721873
got_axes = False
18731874
illegal_args = {"show", "times", "exclude"}
@@ -1954,9 +1955,32 @@ def plot_evoked_joint(
19541955
del times
19551956
_, times_ts = _check_time_unit(ts_args["time_unit"], times_sec)
19561957

1958+
if len(ch_types) == 1 and set(ch_types) == {"mag"}:
1959+
from .topomap import _prepare_topomap_plot
1960+
from .topomap import _opm_coils
1961+
1962+
(
1963+
_,
1964+
_,
1965+
merge_channels,
1966+
_,
1967+
_,
1968+
_,
1969+
_,
1970+
) = _prepare_topomap_plot(
1971+
evoked,
1972+
"mag",
1973+
sphere=topomap_args.get("sphere", None),
1974+
)
1975+
is_opm = any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"])
1976+
if is_opm and bool(merge_channels):
1977+
opm_group_factor = 2
1978+
19571979
# prepare axes for topomap
19581980
if not got_axes:
1959-
fig, ts_ax, map_ax = _prepare_joint_axes(len(times_sec), figsize=(8.0, 4.2))
1981+
fig, ts_ax, map_ax = _prepare_joint_axes(
1982+
len(times_sec) * opm_group_factor, figsize=(8.0, 4.2)
1983+
)
19601984
cbar_ax = None
19611985
else:
19621986
ts_ax = ts_args["axes"]
@@ -2044,22 +2068,42 @@ def plot_evoked_joint(
20442068

20452069
# connection lines
20462070
# draw the connection lines between time series and topoplots
2047-
for timepoint, map_ax_ in zip(times_ts, map_ax):
2048-
con = ConnectionPatch(
2049-
xyA=[timepoint, ts_ax.get_ylim()[1]],
2050-
xyB=[0.5, 0],
2051-
coordsA="data",
2052-
coordsB="axes fraction",
2053-
axesA=ts_ax,
2054-
axesB=map_ax_,
2055-
color="grey",
2056-
linestyle="-",
2057-
linewidth=1.5,
2058-
alpha=0.66,
2059-
zorder=1,
2060-
clip_on=False,
2061-
)
2062-
fig.add_artist(con)
2071+
if opm_group_factor == 1:
2072+
for timepoint, map_ax_ in zip(times_ts, map_ax):
2073+
con = ConnectionPatch(
2074+
xyA=[timepoint, ts_ax.get_ylim()[1]],
2075+
xyB=[0.5, 0],
2076+
coordsA="data",
2077+
coordsB="axes fraction",
2078+
axesA=ts_ax,
2079+
axesB=map_ax_,
2080+
color="grey",
2081+
linestyle="-",
2082+
linewidth=1.5,
2083+
alpha=0.66,
2084+
zorder=1,
2085+
clip_on=False,
2086+
)
2087+
ts_ax.add_artist(con)
2088+
else:
2089+
for time_idx, timepoint in enumerate(times_ts):
2090+
for group_idx in range(opm_group_factor):
2091+
map_ax_ = map_ax[time_idx + group_idx * len(times_ts)]
2092+
con = ConnectionPatch(
2093+
xyA=[timepoint, ts_ax.get_ylim()[1]],
2094+
xyB=[0.5, 0],
2095+
coordsA="data",
2096+
coordsB="axes fraction",
2097+
axesA=ts_ax,
2098+
axesB=map_ax_,
2099+
color="grey",
2100+
linestyle="-",
2101+
linewidth=1.0,
2102+
alpha=0.5,
2103+
zorder=1,
2104+
clip_on=False,
2105+
)
2106+
ts_ax.add_artist(con)
20632107

20642108
# mark times in time series plot
20652109
for timepoint in times_ts:

mne/viz/tests/test_ica.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,4 +585,7 @@ def test_plot_components_opm_triaxial(triaxial_raw):
585585
ica = ICA(max_iter=1, random_state=0, n_components=3)
586586
ica.fit(triaxial_raw, picks="mag", verbose="error")
587587
fig = ica.plot_components()
588-
assert len(fig.axes) == 3
588+
assert len(fig.axes) == 6
589+
titles = [ax.get_title() for ax in fig.axes]
590+
assert any("[radial]" in title for title in titles)
591+
assert any("[tangential]" in title for title in titles)

mne/viz/tests/test_topo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ def test_plot_joint_opm_triaxial(triaxial_evoked):
145145
ts_args=dict(time_unit="s"),
146146
topomap_args=dict(time_unit="s", contours=0, res=8, sensors=False),
147147
)
148-
assert len(fig.axes) >= 2
148+
assert len(fig.axes) >= 3
149+
titles = [ax.get_title() for ax in fig.axes]
150+
assert any("radial" in title for title in titles)
151+
assert any("tangential" in title for title in titles)
149152

150153

151154
def test_plot_topo():

mne/viz/tests/test_topomap.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,22 @@ def test_split_opm_overlaps(triaxial_evoked):
851851
assert tangential == ["OPM002", "OPM003", "OPM005", "OPM006"]
852852

853853

854+
def test_plot_evoked_topomap_opm_triaxial_groups(triaxial_evoked):
855+
"""Test grouped radial/tangential topomap rendering for triaxial OPM."""
856+
fig = triaxial_evoked.plot_topomap(
857+
times=[0.0],
858+
ch_type="mag",
859+
contours=0,
860+
res=8,
861+
sensors=False,
862+
show=False,
863+
)
864+
assert len(fig.axes) == 3
865+
titles = [ax.get_title() for ax in fig.axes]
866+
assert any("radial" in title for title in titles)
867+
assert any("tangential" in title for title in titles)
868+
869+
854870
def test_plot_topomap_nirs_overlap(fnirs_epochs):
855871
"""Test plotting nirs topomap with overlapping channels (gh-7414)."""
856872
fig = fnirs_epochs["A"].average(picks="hbo").plot_topomap()

0 commit comments

Comments
 (0)