@@ -17,8 +17,9 @@ def create_probe_polygons(
1717 contacts_colors : list | None = None ,
1818 contacts_values : np .ndarray | None = None ,
1919 cmap : str = "viridis" ,
20- contacts_kargs : dict = {},
20+ contact_kwargs : dict = {},
2121 probe_shape_kwargs : dict = {},
22+ contacts_kargs = None , # DEPRECATED
2223):
2324 """Create PolyCollection objects for a Probe.
2425
@@ -32,7 +33,7 @@ def create_probe_polygons(
3233 Values to color the contacts with
3334 cmap : str, default: "viridis"
3435 A colormap color
35- contacts_kargs : dict, default: {}
36+ contact_kwargs : dict, default: {}
3637 Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
3738 probe_shape_kwargs : dict, default: {}
3839 Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
@@ -44,6 +45,16 @@ def create_probe_polygons(
4445 poly_contour : PolyCollection | None
4546 The polygon collection for the probe shape
4647 """
48+ if contacts_kargs is not None :
49+ import warnings
50+
51+ warnings .warn (
52+ "contacts_kargs is deprecated and will be removed in 0.3.4. Please use `contacts_kwargs` instead." ,
53+ category = DeprecationWarning ,
54+ stacklevel = 2 ,
55+ )
56+ contact_kwargs = contacts_kargs
57+
4758 if probe .ndim == 2 :
4859 from matplotlib .collections import PolyCollection
4960
@@ -59,7 +70,7 @@ def create_probe_polygons(
5970 _probe_shape_kwargs .update (probe_shape_kwargs )
6071
6172 _contacts_kargs = dict (alpha = 0.7 , edgecolor = [0.3 , 0.3 , 0.3 ], lw = 0.5 )
62- _contacts_kargs .update (contacts_kargs )
73+ _contacts_kargs .update (contact_kwargs )
6374
6475 n = probe .get_contact_count ()
6576
@@ -93,7 +104,7 @@ def plot_probe(
93104 with_contact_id : bool = False ,
94105 with_device_index : bool = False ,
95106 text_on_contact : list | np .ndarray | None = None ,
96- contacts_values : np .ndarray | None = None ,
107+ contacts_values : list | np .ndarray | None = None ,
97108 cmap : str = "viridis" ,
98109 title : bool = True ,
99110 contacts_kargs : dict = {},
@@ -119,9 +130,9 @@ def plot_probe(
119130 If True, channel ids are displayed on top of the channels
120131 with_device_index : bool, default: False
121132 If True, device channel indices are displayed on top of the channels
122- text_on_contact: None | list | numpy.array , default: None
133+ text_on_contact: None | list | np.ndarray , default: None
123134 Addintional text to plot on each contact
124- contacts_values : np.array , default: None
135+ contacts_values : list | np.ndarray | None , default: None
125136 Values to color the contacts with
126137 cmap : a colormap color, default: "viridis"
127138 A colormap color
@@ -248,7 +259,7 @@ def on_press(event):
248259 return poly , poly_contour
249260
250261
251- def plot_probegroup (probegroup , same_axes : bool = True , ** kargs ):
262+ def plot_probegroup (probegroup , same_axes : bool = True , ** kwargs ):
252263 """Plot all probes from a ProbeGroup
253264 Can be in an existing set of axes or separate axes.
254265
@@ -258,19 +269,37 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
258269 The ProbeGroup to plot
259270 same_axes : bool, default: True
260271 If True, the probes are plotted on the same axis
261- kargs: dict
262- see docstring for plot_probe for possible kargs
272+ kwargs: dict
273+ Additional keyword arguments to pass to plot_probe.
274+ If same_axes is True, the same kwargs are passed to all probes.
275+ If same_axes is False, the kwargs are passed separately to each probe,
276+ if they have the same length as the total number of contacts in the ProbeGroup.
277+ For example, if contacts_colors is given and has the same length as the total
278+ number of contacts in the ProbeGroup, then the colors are split and passed
279+ separately to each probe.
280+
281+ Available kwargs:
282+
283+ - contacts_colors
284+ - with_contact_id
285+ - with_device_index
286+ - text_on_contact
287+ - contacts_values
288+ - cmap
289+ - title
290+ - contacts_kargs
291+ - probe_shape_kwargs
263292 """
264293
265294 import matplotlib .pyplot as plt
266295
267- figsize = kargs .pop ("figsize" , None )
296+ figsize = kwargs .pop ("figsize" , None )
268297
269298 n = len (probegroup .probes )
270299
271300 if same_axes :
272- if "ax" in kargs :
273- ax = kargs .pop ("ax" )
301+ if "ax" in kwargs :
302+ ax = kwargs .pop ("ax" )
274303 else :
275304 if probegroup .ndim == 2 :
276305 fig , ax = plt .subplots (figsize = figsize )
@@ -279,14 +308,16 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
279308 ax = fig .add_subplot (1 , 1 , 1 , projection = "3d" )
280309 axs = [ax ] * n
281310 else :
282- if "ax" in kargs :
311+ if "ax" in kwargs :
283312 raise ValueError ("when same_axes=False, an axes object cannot be passed into this function." )
284313 if probegroup .ndim == 2 :
285314 fig , axs = plt .subplots (ncols = n , nrows = 1 , figsize = figsize )
286315 if n == 1 :
287316 axs = [axs ]
288317 else :
289- raise NotImplementedError
318+ raise NotImplementedError (
319+ "same_axes=False is currently only implemented for 2D probes. For 3D probes, please set same_axes=True."
320+ )
290321
291322 if same_axes :
292323 # global lims
@@ -297,36 +328,34 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs):
297328 ylims = min (ylims [0 ], ylims2 [0 ]), max (ylims [1 ], ylims2 [1 ])
298329 if zlims is not None :
299330 zlims = min (zlims [0 ], zlims2 [0 ]), max (zlims [1 ], zlims2 [1 ])
300- kargs ["xlims" ] = xlims
301- kargs ["ylims" ] = ylims
302- kargs ["zlims" ] = zlims
331+ kwargs ["xlims" ] = xlims
332+ kwargs ["ylims" ] = ylims
333+ kwargs ["zlims" ] = zlims
303334 else :
304335 # will be auto for each probe in each axis
305- kargs ["xlims" ] = None
306- kargs ["ylims" ] = None
307- kargs ["zlims" ] = None
336+ kwargs ["xlims" ] = None
337+ kwargs ["ylims" ] = None
338+ kwargs ["zlims" ] = None
308339
309- kargs ["title" ] = False
310- for i , probe in enumerate (probegroup .probes ):
311- plot_probe (probe , ax = axs [i ], ** kargs )
312-
313-
314- def plot_probe_group (probegroup , same_axes : bool = True , ** kargs ):
315- """
316- This function is deprecated and will be removed in 0.2.23
317- Please use plot_probegroup instead"""
340+ kwargs ["title" ] = False
318341
319- from warnings import warn
342+ cum_contact_count = 0
343+ total_contacts = sum (p .get_contact_count () for p in probegroup .probes )
320344
321- warn (
322- "`plot_probe_group` is deprecated and will be removed in 2.23. Use plot_probegroup instead" ,
323- category = DeprecationWarning ,
324- stacklevel = 2 ,
325- )
345+ for i , probe in enumerate (probegroup .probes ):
346+ n = probe .get_contact_count ()
347+ kwargs_probe = kwargs .copy ()
348+ for key in ["contacts_colors" , "contacts_values" , "text_on_contact" ]:
349+ if kwargs .get (key ) is not None :
350+ val = np .array (kwargs [key ])
351+ if len (val ) == total_contacts :
352+ kwargs_probe [key ] = val [cum_contact_count : cum_contact_count + n ]
326353
327- plot_probegroup (probegroup , same_axes = same_axes , ** kargs )
354+ plot_probe (probe , ax = axs [i ], ** kwargs_probe )
355+ cum_contact_count += n
328356
329357
358+ ### MATPLOTLIB INTERACTION ###
330359def _on_press (probe , event ):
331360 ax = event .inaxes
332361 x , y = event .xdata , event .ydata
0 commit comments