Skip to content

Commit c19bc18

Browse files
authored
Enhance plot_probegroup to handle varying contact values (#397)
1 parent 208a95b commit c19bc18

File tree

9 files changed

+78
-57
lines changed

9 files changed

+78
-57
lines changed

doc/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Plotting
6565

6666
.. autofunction:: plot_probe
6767

68-
.. autofunction:: plot_probe_group
68+
.. autofunction:: plot_probegroup
6969

7070
Library
7171
-------

doc/generate_format_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import matplotlib.pyplot as plt
77

88
from probeinterface import Probe, ProbeGroup, combine_probes, write_probeinterface
9-
from probeinterface.plotting import plot_probe, plot_probe_group
9+
from probeinterface.plotting import plot_probe, plot_probegroup
1010

1111
from probeinterface import generate_tetrode
1212

examples/ex_03_generate_probe_group.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import matplotlib.pyplot as plt
1414

1515
from probeinterface import Probe, ProbeGroup
16-
from probeinterface.plotting import plot_probe_group
16+
from probeinterface.plotting import plot_probegroup
1717
from probeinterface import generate_dummy_probe
1818

1919
##############################################################################
@@ -39,11 +39,11 @@
3939
##############################################################################
4040
#  We can now plot all probes in the same axis:
4141

42-
plot_probe_group(probegroup, same_axes=True)
42+
plot_probegroup(probegroup, same_axes=True)
4343

4444
##############################################################################
4545
#  or in separate axes:
4646

47-
plot_probe_group(probegroup, same_axes=False, with_contact_id=True)
47+
plot_probegroup(probegroup, same_axes=False, with_contact_id=True)
4848

4949
plt.show()

examples/ex_05_device_channel_indices.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import matplotlib.pyplot as plt
1818

1919
from probeinterface import Probe, ProbeGroup
20-
from probeinterface.plotting import plot_probe, plot_probe_group
20+
from probeinterface.plotting import plot_probe, plot_probegroup
2121
from probeinterface import generate_multi_columns_probe
2222

2323
##############################################################################
@@ -85,6 +85,6 @@
8585
# The indices of the probe group can also be plotted:
8686

8787
fig, ax = plt.subplots()
88-
plot_probe_group(probegroup, with_contact_id=True, same_axes=True, ax=ax)
88+
plot_probegroup(probegroup, with_contact_id=True, same_axes=True, ax=ax)
8989

9090
plt.show()

examples/ex_06_import_export_to_file.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import matplotlib.pyplot as plt
2525

2626
from probeinterface import Probe, ProbeGroup
27-
from probeinterface.plotting import plot_probe, plot_probe_group
27+
from probeinterface.plotting import plot_probe, plot_probegroup
2828
from probeinterface import generate_dummy_probe
2929
from probeinterface import write_probeinterface, read_probeinterface
3030
from probeinterface import write_prb, read_prb
@@ -48,7 +48,7 @@
4848
write_probeinterface('my_two_probe_setup.json', probegroup)
4949

5050
probegroup2 = read_probeinterface('my_two_probe_setup.json')
51-
plot_probe_group(probegroup2)
51+
plot_probegroup(probegroup2)
5252

5353
##############################################################################
5454
# The format looks like this:
@@ -98,6 +98,6 @@
9898
f.write(prb_two_tetrodes)
9999

100100
two_tetrode = read_prb('two_tetrodes.prb')
101-
plot_probe_group(two_tetrode, same_axes=False, with_contact_id=True)
101+
plot_probegroup(two_tetrode, same_axes=False, with_contact_id=True)
102102

103103
plt.show()

examples/ex_07_probe_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import matplotlib.pyplot as plt
1818

1919
from probeinterface import Probe, ProbeGroup
20-
from probeinterface.plotting import plot_probe, plot_probe_group
20+
from probeinterface.plotting import plot_probe, plot_probegroup
2121

2222
##############################################################################
2323
# Generate 4 tetrodes:
@@ -35,7 +35,7 @@
3535
df = probegroup.to_dataframe()
3636
df
3737

38-
plot_probe_group(probegroup, with_contact_id=True, same_axes=True)
38+
plot_probegroup(probegroup, with_contact_id=True, same_axes=True)
3939

4040
##############################################################################
4141
# Generate a linear probe:

examples/ex_08_more_plotting_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import matplotlib.pyplot as plt
1414

1515
from probeinterface import Probe, ProbeGroup
16-
from probeinterface.plotting import plot_probe, plot_probe_group
16+
from probeinterface.plotting import plot_probe, plot_probegroup
1717
from probeinterface import generate_multi_columns_probe, generate_linear_probe
1818

1919
##############################################################################

src/probeinterface/plotting.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ def create_probe_polygons(
1717
contacts_colors: list | None = None,
1818
contacts_values: np.ndarray | None = None,
1919
cmap: str = "viridis",
20-
contacts_kargs: dict = {},
20+
contact_kwargs: dict = {},
2121
probe_shape_kwargs: dict = {},
22+
contacts_kargs=None, # DEPRECATED
2223
):
2324
"""Create PolyCollection objects for a Probe.
2425
@@ -32,7 +33,7 @@ def create_probe_polygons(
3233
Values to color the contacts with
3334
cmap : str, default: "viridis"
3435
A colormap color
35-
contacts_kargs : dict, default: {}
36+
contact_kwargs : dict, default: {}
3637
Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
3738
probe_shape_kwargs : dict, default: {}
3839
Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
@@ -44,6 +45,16 @@ def create_probe_polygons(
4445
poly_contour : PolyCollection | None
4546
The polygon collection for the probe shape
4647
"""
48+
if contacts_kargs is not None:
49+
import warnings
50+
51+
warnings.warn(
52+
"contacts_kargs is deprecated and will be removed in 0.3.4. Please use `contacts_kwargs` instead.",
53+
category=DeprecationWarning,
54+
stacklevel=2,
55+
)
56+
contact_kwargs = contacts_kargs
57+
4758
if probe.ndim == 2:
4859
from matplotlib.collections import PolyCollection
4960

@@ -59,7 +70,7 @@ def create_probe_polygons(
5970
_probe_shape_kwargs.update(probe_shape_kwargs)
6071

6172
_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
62-
_contacts_kargs.update(contacts_kargs)
73+
_contacts_kargs.update(contact_kwargs)
6374

6475
n = probe.get_contact_count()
6576

@@ -93,7 +104,7 @@ def plot_probe(
93104
with_contact_id: bool = False,
94105
with_device_index: bool = False,
95106
text_on_contact: list | np.ndarray | None = None,
96-
contacts_values: np.ndarray | None = None,
107+
contacts_values: list | np.ndarray | None = None,
97108
cmap: str = "viridis",
98109
title: bool = True,
99110
contacts_kargs: dict = {},
@@ -119,9 +130,9 @@ def plot_probe(
119130
If True, channel ids are displayed on top of the channels
120131
with_device_index : bool, default: False
121132
If True, device channel indices are displayed on top of the channels
122-
text_on_contact: None | list | numpy.array, default: None
133+
text_on_contact: None | list | np.ndarray, default: None
123134
Addintional text to plot on each contact
124-
contacts_values : np.array, default: None
135+
contacts_values : list | np.ndarray | None, default: None
125136
Values to color the contacts with
126137
cmap : a colormap color, default: "viridis"
127138
A colormap color
@@ -248,7 +259,7 @@ def on_press(event):
248259
return poly, poly_contour
249260

250261

251-
def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
262+
def plot_probegroup(probegroup, same_axes: bool = True, **kwargs):
252263
"""Plot all probes from a ProbeGroup
253264
Can be in an existing set of axes or separate axes.
254265
@@ -258,19 +269,37 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
258269
The ProbeGroup to plot
259270
same_axes : bool, default: True
260271
If True, the probes are plotted on the same axis
261-
kargs: dict
262-
see docstring for plot_probe for possible kargs
272+
kwargs: dict
273+
Additional keyword arguments to pass to plot_probe.
274+
If same_axes is True, the same kwargs are passed to all probes.
275+
If same_axes is False, the kwargs are passed separately to each probe,
276+
if they have the same length as the total number of contacts in the ProbeGroup.
277+
For example, if contacts_colors is given and has the same length as the total
278+
number of contacts in the ProbeGroup, then the colors are split and passed
279+
separately to each probe.
280+
281+
Available kwargs:
282+
283+
- contacts_colors
284+
- with_contact_id
285+
- with_device_index
286+
- text_on_contact
287+
- contacts_values
288+
- cmap
289+
- title
290+
- contacts_kargs
291+
- probe_shape_kwargs
263292
"""
264293

265294
import matplotlib.pyplot as plt
266295

267-
figsize = kargs.pop("figsize", None)
296+
figsize = kwargs.pop("figsize", None)
268297

269298
n = len(probegroup.probes)
270299

271300
if same_axes:
272-
if "ax" in kargs:
273-
ax = kargs.pop("ax")
301+
if "ax" in kwargs:
302+
ax = kwargs.pop("ax")
274303
else:
275304
if probegroup.ndim == 2:
276305
fig, ax = plt.subplots(figsize=figsize)
@@ -279,14 +308,16 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
279308
ax = fig.add_subplot(1, 1, 1, projection="3d")
280309
axs = [ax] * n
281310
else:
282-
if "ax" in kargs:
311+
if "ax" in kwargs:
283312
raise ValueError("when same_axes=False, an axes object cannot be passed into this function.")
284313
if probegroup.ndim == 2:
285314
fig, axs = plt.subplots(ncols=n, nrows=1, figsize=figsize)
286315
if n == 1:
287316
axs = [axs]
288317
else:
289-
raise NotImplementedError
318+
raise NotImplementedError(
319+
"same_axes=False is currently only implemented for 2D probes. For 3D probes, please set same_axes=True."
320+
)
290321

291322
if same_axes:
292323
# global lims
@@ -297,36 +328,34 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
297328
ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1])
298329
if zlims is not None:
299330
zlims = min(zlims[0], zlims2[0]), max(zlims[1], zlims2[1])
300-
kargs["xlims"] = xlims
301-
kargs["ylims"] = ylims
302-
kargs["zlims"] = zlims
331+
kwargs["xlims"] = xlims
332+
kwargs["ylims"] = ylims
333+
kwargs["zlims"] = zlims
303334
else:
304335
# will be auto for each probe in each axis
305-
kargs["xlims"] = None
306-
kargs["ylims"] = None
307-
kargs["zlims"] = None
336+
kwargs["xlims"] = None
337+
kwargs["ylims"] = None
338+
kwargs["zlims"] = None
308339

309-
kargs["title"] = False
310-
for i, probe in enumerate(probegroup.probes):
311-
plot_probe(probe, ax=axs[i], **kargs)
312-
313-
314-
def plot_probe_group(probegroup, same_axes: bool = True, **kargs):
315-
"""
316-
This function is deprecated and will be removed in 0.2.23
317-
Please use plot_probegroup instead"""
340+
kwargs["title"] = False
318341

319-
from warnings import warn
342+
cum_contact_count = 0
343+
total_contacts = sum(p.get_contact_count() for p in probegroup.probes)
320344

321-
warn(
322-
"`plot_probe_group` is deprecated and will be removed in 2.23. Use plot_probegroup instead",
323-
category=DeprecationWarning,
324-
stacklevel=2,
325-
)
345+
for i, probe in enumerate(probegroup.probes):
346+
n = probe.get_contact_count()
347+
kwargs_probe = kwargs.copy()
348+
for key in ["contacts_colors", "contacts_values", "text_on_contact"]:
349+
if kwargs.get(key) is not None:
350+
val = np.array(kwargs[key])
351+
if len(val) == total_contacts:
352+
kwargs_probe[key] = val[cum_contact_count : cum_contact_count + n]
326353

327-
plot_probegroup(probegroup, same_axes=same_axes, **kargs)
354+
plot_probe(probe, ax=axs[i], **kwargs_probe)
355+
cum_contact_count += n
328356

329357

358+
### MATPLOTLIB INTERACTION ###
330359
def _on_press(probe, event):
331360
ax = event.inaxes
332361
x, y = event.xdata, event.ydata

tests/test_plotting.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
from probeinterface import generate_dummy_probe, generate_dummy_probe_group
33
from probeinterface.plotting import plot_probe, plot_probegroup
44

5-
# remove once plot_probe_group is removed
6-
from probeinterface.plotting import plot_probe_group
7-
85
import matplotlib.pyplot as plt
96
import numpy as np
107

@@ -38,10 +35,6 @@ def test_plot_probegroup():
3835
plot_probegroup(probegroup, same_axes=True, with_contact_id=True)
3936
plot_probegroup(probegroup, same_axes=False)
4037

41-
# remove when plot_probe_group has been removed
42-
with pytest.warns(DeprecationWarning):
43-
plot_probe_group(probegroup)
44-
4538
# 3d
4639
probegroup_3d = ProbeGroup()
4740
for probe in probegroup.probes:
@@ -74,6 +67,5 @@ def test_plot_probe_two_side():
7467

7568
if __name__ == "__main__":
7669
# test_plot_probe()
77-
# test_plot_probe_group()
7870
test_plot_probe_two_side()
7971
plt.show()

0 commit comments

Comments
 (0)