11import warnings
22
3+ import matplotlib as mpl
34import matplotlib .pyplot as plt
5+ import matplotlib .transforms as mtransforms
46import numpy as np
57
68from mplotutils ._deprecate import _deprecate_positional_args
79
810
11+ def _deprecate_ax1_ax2 (ax , ax2 , ax1 ):
12+ if ax is None :
13+ if ax1 is None :
14+ raise TypeError ("colorbar() missing 1 required positional argument: 'ax'" )
15+
16+ if ax2 is None :
17+ ax = ax1
18+ warnings .warn (
19+ "`ax1` has been deprecated in favor of `ax`" ,
20+ FutureWarning ,
21+ stacklevel = 4 ,
22+ )
23+
24+ else :
25+ ax = [ax1 , ax2 ]
26+ warnings .warn (
27+ "`ax1` and `ax2` has been deprecated in favor of `ax`, i.e. pass ax=[ax1, ax2]" ,
28+ FutureWarning ,
29+ stacklevel = 4 ,
30+ )
31+
32+ else :
33+
34+ if ax1 is not None and ax2 is not None :
35+ raise TypeError ("Cannot pass `ax`, and `ax2`" )
36+
37+ if ax2 is not None and np .ndim (ax ) != 0 :
38+ raise TypeError ("Cannot pass ax2 in addition to a list of axes" )
39+
40+ if ax2 is not None :
41+ ax = [ax , ax2 ]
42+
43+ warnings .warn (
44+ "Passing axes individually has been deprecated in favor of passing them"
45+ " as list, i.e. pass ``ax=[ax1, ax2]``, or ``ax=axs``" ,
46+ FutureWarning ,
47+ stacklevel = 4 ,
48+ )
49+
50+ return ax
51+
52+
953@_deprecate_positional_args ("0.3" )
1054def colorbar (
1155 mappable ,
12- ax1 ,
56+ ax = None ,
1357 ax2 = None ,
1458 * ,
1559 orientation = "vertical" ,
@@ -20,32 +64,29 @@ def colorbar(
2064 shrink = None ,
2165 ** kwargs ,
2266):
23- """
24- automatically resize colorbars on draw
67+ """colorbar that adjusts to the axes height (and automatically resizes)
2568
2669 See below for Example
2770
2871 Parameters
2972 ----------
3073 mappable : handle
3174 The `matplotlib.cm.ScalarMappable` described by this colorbar.
32- ax1 : `matplotlib.axes.Axes`
33- The axes to adjust the colorbar to.
34- ax2 : `~matplotlib.axes.Axes`, default: None.
35- If the colorbar should span more than one axes.
75+ ax : `~matplotlib.axes.Axes` or iterable or `numpy.ndarray` of Axes
76+ One or more parent Axes of the colorbar.
3677 orientation : 'vertical' | 'horizontal'. Default: 'vertical'.
3778 Orientation of the colorbar.
3879 aspect : float, default 20.
39- The ratio of long to short dimensions of the colorbar Mutually exclusive with
80+ The ratio of long to short dimensions of the colorbar. Mutually exclusive with
4081 `size`.
4182 size : float, default: None
42- Width of the colorbar as fraction of the axes width (vertical) or
83+ Width of the colorbar as fraction of all parent axes width (vertical) or
4384 height (horizontal). Mutually exclusive with `aspect`.
4485 pad : float, default: None.
45- Distance of the colorbar to the axes in Figure coordinates .
46- Default: 0.05 (vertical) or 0.15 (horizontal).
86+ Distance between axes and colorbar. In fraction of parent axes .
87+ Default: 0.05 (vertical) or 0.15 (horizontal).
4788 shift : 'symmetric' or float in 0..1, default: 'symmetric'
48- Fraction of the total height that the colorbar is shifted upwards . See Note.
89+ Fraction of the total height that the colorbar is shifted up/ right . See Note.
4990 shrink : None or float in 0..1, default: None.
5091 Fraction of the total height that the colorbar is shrunk. See Note.
5192 **kwargs : keyword arguments
@@ -119,7 +160,7 @@ def colorbar(
119160 ax.set_global()
120161 h = ax.pcolormesh([[0, 1]])
121162
122- cbar = mpu.colorbar(h, axs[0], axs[1] )
163+ cbar = mpu.colorbar(h, axs)
123164
124165 # =========================
125166 # example with 3 axes & 2 colorbars
@@ -134,7 +175,7 @@ def colorbar(
134175 h1 = ax.pcolormesh([[0, 1]])
135176 h2 = ax.pcolormesh([[0, 1]], cmap='Blues')
136177
137- cbar = mpu.colorbar(h, axs[0], axs[1], size=0.05)
178+ cbar = mpu.colorbar(h, [ axs[0], axs[1] ], size=0.05)
138179 cbar = mpu.colorbar(h, axs[2], size=0.05)
139180
140181 plt.draw()
@@ -151,6 +192,9 @@ def colorbar(
151192 plt.colorbar
152193 """
153194
195+ ax = _deprecate_ax1_ax2 (ax , ax2 , kwargs .pop ("ax1" , None ))
196+ axs = np .asarray (ax ).flatten ()
197+
154198 if orientation not in ("vertical" , "horizontal" ):
155199 raise ValueError ("orientation must be 'vertical' or 'horizontal'" )
156200
@@ -159,17 +203,13 @@ def colorbar(
159203 msg = "'anchor' and 'panchor' keywords not supported, use 'shrink' and 'shift'"
160204 raise ValueError (msg )
161205
162- # ensure 'ax' does not end up in plt.colorbar(**kwargs)
163- if "ax" in k :
164- if ax2 is not None :
165- raise ValueError ("Cannot pass `ax`, and `ax2`" )
166- # assume it is ax2 (it can't be ax1)
167- ax2 = kwargs .pop ("ax" )
206+ if not all (isinstance (ax , mpl .axes .Axes ) for ax in axs ):
207+ raise TypeError ("ax must be of Type mpl.axes.Axes" )
168208
169- f = ax1 .get_figure ()
209+ f = axs [ 0 ] .get_figure ()
170210
171- if ax2 is not None and f != ax2 .get_figure ():
172- raise ValueError ( "'ax1' and 'ax2' must belong to the same figure" )
211+ if not all ( f == ax .get_figure () for ax in axs ):
212+ raise TypeError ( "All passed axes must belong to the same figure" )
173213
174214 cbax = _get_cbax (f )
175215
@@ -178,8 +218,8 @@ def colorbar(
178218 if orientation == "vertical" :
179219 func = _resize_colorbar_vert (
180220 cbax ,
181- ax1 ,
182- ax2 = ax2 ,
221+ f ,
222+ axs ,
183223 aspect = aspect ,
184224 size = size ,
185225 pad = pad ,
@@ -189,8 +229,8 @@ def colorbar(
189229 else :
190230 func = _resize_colorbar_horz (
191231 cbax ,
192- ax1 ,
193- ax2 = ax2 ,
232+ f ,
233+ axs ,
194234 aspect = aspect ,
195235 size = size ,
196236 pad = pad ,
@@ -219,8 +259,8 @@ def _get_cbax(f):
219259
220260def _resize_colorbar_vert (
221261 cbax ,
222- ax1 ,
223- ax2 = None ,
262+ f ,
263+ axs ,
224264 aspect = None ,
225265 size = None ,
226266 pad = None ,
@@ -263,40 +303,31 @@ def _resize_colorbar_vert(
263303
264304 size , aspect , pad = _parse_size_aspect_pad (size , aspect , pad , "vertical" )
265305
266- f = ax1 .get_figure ()
267-
268- # swap axes if ax1 is above ax2
269- if ax2 is not None :
270- posn1 = ax1 .get_position ()
271- posn2 = ax2 .get_position ()
272-
273- ax1 , ax2 = (ax1 , ax2 ) if posn1 .y0 < posn2 .y0 else (ax2 , ax1 )
274-
275306 if aspect is not None :
276307 anchor = (0 , 0.5 )
277308 cbax .set_anchor (anchor )
278309 cbax .set_box_aspect (aspect )
279310
280311 # inner function is called by event handler
281312 def inner (event = None ):
282- pos1 = ax1 .get_position ()
313+
314+ # from mpl.colorbar (but not using ax.get_position(original=True).frozen())
315+ parents_bbox = mtransforms .Bbox .union ([ax .get_position () for ax in axs ])
283316
284317 # determine total height of all axes
285- if ax2 is None :
286- full_height = pos1 .height
287- else :
288- pos2 = ax2 .get_position ()
289- full_height = pos2 .y0 - pos1 .y0 + pos2 .height
318+ full_height = parents_bbox .height
290319
291- pad_scaled = pad * pos1 .width
320+ pad_scaled = pad * parents_bbox .width
292321
293322 # calculate position of cbax
294- left = pos1 .x0 + pos1 .width + pad_scaled
295- bottom = pos1 .y0 + shift * full_height
323+ left = parents_bbox .x1 + pad_scaled
324+
325+ bottom = parents_bbox .y0 + shift * full_height
326+
296327 height = (1 - shrink ) * full_height
297328
298329 if aspect is None :
299- size_scaled = size * pos1 .width
330+ size_scaled = size * parents_bbox .width
300331 width = size_scaled
301332 else :
302333 figure_aspect = np .divide (* f .get_size_inches ())
@@ -314,8 +345,8 @@ def inner(event=None):
314345
315346def _resize_colorbar_horz (
316347 cbax ,
317- ax1 ,
318- ax2 = None ,
348+ f ,
349+ axs ,
319350 aspect = None ,
320351 size = None ,
321352 pad = None ,
@@ -357,43 +388,32 @@ def _resize_colorbar_horz(
357388
358389 size , aspect , pad = _parse_size_aspect_pad (size , aspect , pad , "horizontal" )
359390
360- f = ax1 .get_figure ()
361-
362- if ax2 is not None :
363- posn1 = ax1 .get_position ()
364- posn2 = ax2 .get_position ()
365-
366- # swap axes if ax1 is right of ax2
367- ax1 , ax2 = (ax1 , ax2 ) if posn1 .x0 < posn2 .x0 else (ax2 , ax1 )
368-
369391 if aspect is not None :
370392 aspect = 1 / aspect
371393 anchor = (0.5 , 1.0 )
372394 cbax .set_anchor (anchor )
373395 cbax .set_box_aspect (aspect )
374396
375397 def inner (event = None ):
376- posn1 = ax1 .get_position ()
377398
378- if ax2 is None :
379- full_width = posn1 .width
380- else :
381- posn2 = ax2 .get_position ()
382- full_width = posn2 .x0 - posn1 .x0 + posn2 .width
399+ # from mpl.colorbar (but not using ax.get_position(original=True).frozen())
400+ parents_bbox = mtransforms .Bbox .union ([ax .get_position () for ax in axs ])
401+
402+ full_width = parents_bbox .width
383403
384- pad_scaled = pad * posn1 .height
404+ pad_scaled = pad * parents_bbox .height
385405
386- width = full_width - shrink * full_width
406+ width = ( 1 - shrink ) * full_width
387407
388408 if aspect is None :
389- size_scaled = size * posn1 .height
409+ size_scaled = size * parents_bbox .height
390410 height = size_scaled
391411 else :
392412 figure_aspect = np .divide (* f .get_size_inches ())
393413 height = width * (aspect * figure_aspect )
394414
395- left = posn1 .x0 + shift * full_width
396- bottom = posn1 .y0 - (pad_scaled + height )
415+ left = parents_bbox .x0 + shift * full_width
416+ bottom = parents_bbox .y0 - (pad_scaled + height )
397417
398418 pos = [left , bottom , width , height ]
399419
0 commit comments