Skip to content

Commit a682408

Browse files
authored
Merge pull request matplotlib#31737 from scottshambaugh/fix-31726-3d-scale-invalid
BUG: Don't try to show / autoscale 3D data that lies outside the axis valid scale range
2 parents 91e66b2 + 9f4a87a commit a682408

6 files changed

Lines changed: 145 additions & 33 deletions

File tree

lib/matplotlib/axis.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,15 @@ def limit_range_for_scale(self, vmin, vmax):
871871
"""
872872
return self._scale.limit_range_for_scale(vmin, vmax, self.get_minpos())
873873

874+
def _nan_out_of_scale_range(self, data):
875+
"""
876+
Return *data* with values that are out of range for this axis's scale
877+
replaced by NaN. E.g. ``<=0`` on a log axis.
878+
"""
879+
data = np.asanyarray(data, dtype=float)
880+
valid = self._scale.val_in_range(data)
881+
return data if np.all(valid) else np.where(valid, data, np.nan)
882+
874883
def _get_autoscale_on(self):
875884
"""Return whether this Axis is autoscaled."""
876885
return self._autoscale_on

lib/matplotlib/scale.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
""" # noqa: E501
3131

3232
import inspect
33-
import math
3433
import textwrap
3534
from functools import wraps
3635

@@ -119,17 +118,20 @@ def val_in_range(self, val):
119118
"""
120119
Return whether the value(s) are within the valid range for this scale.
121120
122-
This method is a generic implementation. Subclasses may implement more
123-
efficient solutions for their domain.
124-
"""
125-
try:
126-
if not math.isfinite(val):
127-
return False
121+
Accepts a scalar or array-like ``val``. For a scalar, returns a
122+
Python ``bool``. For an array, returns a bool ndarray of the same
123+
shape. This is a generic implementation, and subclasses may implement
124+
more efficient solutions for their domain.
125+
"""
126+
arr = np.asarray(val)
127+
with np.errstate(invalid='ignore'):
128+
try:
129+
vmin, vmax = self.limit_range_for_scale(arr, arr, minpos=1e-300)
130+
except (TypeError, ValueError):
131+
result = np.zeros(arr.shape, dtype=bool)
128132
else:
129-
vmin, vmax = self.limit_range_for_scale(val, val, minpos=1e-300)
130-
return vmin == val and vmax == val
131-
except (TypeError, ValueError):
132-
return False
133+
result = np.isfinite(arr) & (vmin == arr) & (vmax == arr)
134+
return bool(result) if arr.ndim == 0 else result
133135

134136

135137
def _make_axis_parameter_optional(init_func):
@@ -219,11 +221,13 @@ def get_transform(self):
219221

220222
def val_in_range(self, val):
221223
"""
222-
Return whether the value is within the valid range for this scale.
224+
Return whether the value(s) are within the valid range for this scale.
223225
224226
This is True for all values, except +-inf and NaN.
225227
"""
226-
return math.isfinite(val)
228+
arr = np.asarray(val)
229+
result = np.isfinite(arr)
230+
return bool(result) if arr.ndim == 0 else result
227231

228232

229233
class FuncTransform(Transform):
@@ -431,11 +435,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
431435

432436
def val_in_range(self, val):
433437
"""
434-
Return whether the value is within the valid range for this scale.
438+
Return whether the value(s) are within the valid range for this scale.
435439
436440
This is True for value(s) > 0 except +inf and NaN.
437441
"""
438-
return math.isfinite(val) and val > 0
442+
arr = np.asarray(val)
443+
with np.errstate(invalid='ignore'):
444+
result = np.isfinite(arr) & (arr > 0)
445+
return bool(result) if arr.ndim == 0 else result
439446

440447

441448
class FuncScaleLog(LogScale):
@@ -625,11 +632,13 @@ def get_transform(self):
625632

626633
def val_in_range(self, val):
627634
"""
628-
Return whether the value is within the valid range for this scale.
635+
Return whether the value(s) are within the valid range for this scale.
629636
630637
This is True for all values, except +-inf and NaN.
631638
"""
632-
return math.isfinite(val)
639+
arr = np.asarray(val)
640+
result = np.isfinite(arr)
641+
return bool(result) if arr.ndim == 0 else result
633642

634643

635644
class AsinhTransform(Transform):
@@ -759,11 +768,13 @@ def set_default_locators_and_formatters(self, axis):
759768

760769
def val_in_range(self, val):
761770
"""
762-
Return whether the value is within the valid range for this scale.
771+
Return whether the value(s) are within the valid range for this scale.
763772
764773
This is True for all values, except +-inf and NaN.
765774
"""
766-
return math.isfinite(val)
775+
arr = np.asarray(val)
776+
result = np.isfinite(arr)
777+
return bool(result) if arr.ndim == 0 else result
767778

768779

769780
class LogitTransform(Transform):
@@ -880,11 +891,14 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
880891

881892
def val_in_range(self, val):
882893
"""
883-
Return whether the value is within the valid range for this scale.
894+
Return whether the value(s) are within the valid range for this scale.
884895
885896
This is True for value(s) which are between 0 and 1 (excluded).
886897
"""
887-
return 0 < val < 1
898+
arr = np.asarray(val)
899+
with np.errstate(invalid='ignore'):
900+
result = (0 < arr) & (arr < 1)
901+
return bool(result) if arr.ndim == 0 else result
888902

889903

890904
_scale_mapping = {

lib/matplotlib/tests/test_scale.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,23 @@ def test_val_in_range_base_fallback():
477477
assert s.val_in_range(np.nan) is False
478478
assert s.val_in_range(np.inf) is False
479479
assert s.val_in_range(-np.inf) is False
480+
481+
482+
def test_val_in_range_array():
483+
# Vectorized: scalar in -> scalar bool, array in -> bool array.
484+
arr = np.array([0.5, -1.0, 0.0, np.nan, np.inf, 0.25])
485+
cases = {
486+
'linear': [True, True, True, False, False, True],
487+
'log': [True, False, False, False, False, True],
488+
'symlog': [True, True, True, False, False, True],
489+
'asinh': [True, True, True, False, False, True],
490+
'logit': [True, False, False, False, False, True],
491+
}
492+
for name, expected in cases.items():
493+
s = mscale._scale_mapping[name](axis=None)
494+
np.testing.assert_array_equal(s.val_in_range(arr), expected)
495+
496+
# 2D shape is preserved.
497+
out = mscale._scale_mapping['log'](axis=None).val_in_range(
498+
np.array([[1.0, -1.0], [0.5, np.nan]]))
499+
np.testing.assert_array_equal(out, [[True, False], [True, False]])

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,30 @@ def _viewlim_mask(xs, ys, zs, axes):
9797
return mask
9898

9999

100+
def _scale_invalid_mask(xs, ys, zs, axes):
101+
"""
102+
Return the mask of points whose coordinates are invalid for the axis
103+
scale they live on (e.g. <=0 on a log axis).
104+
105+
Parameters
106+
----------
107+
xs, ys, zs : array-like
108+
The points to check, in data coordinates.
109+
axes : Axes3D
110+
The axes whose scales are queried.
111+
112+
Returns
113+
-------
114+
mask : np.ndarray
115+
Boolean array, ``True`` where any of x/y/z is out of its scale's
116+
valid domain.
117+
"""
118+
return np.logical_or.reduce((
119+
np.logical_not(axes.xaxis._scale.val_in_range(xs)),
120+
np.logical_not(axes.yaxis._scale.val_in_range(ys)),
121+
np.logical_not(axes.zaxis._scale.val_in_range(zs))))
122+
123+
100124
class Text3D(mtext.Text):
101125
"""
102126
Text object with 3D position and direction.
@@ -191,8 +215,10 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False):
191215

192216
@artist.allow_rasterization
193217
def draw(self, renderer):
218+
mask = _scale_invalid_mask(self._x, self._y, self._z, self.axes)
194219
if self._axlim_clip:
195-
mask = _viewlim_mask(self._x, self._y, self._z, self.axes)
220+
mask |= _viewlim_mask(self._x, self._y, self._z, self.axes)
221+
if np.any(mask):
196222
pos3d = np.ma.array([self._x, self._y, self._z],
197223
mask=mask, dtype=float).filled(np.nan)
198224
else:
@@ -328,9 +354,12 @@ def get_data_3d(self):
328354

329355
@artist.allow_rasterization
330356
def draw(self, renderer):
357+
scale_mask = _scale_invalid_mask(*self._verts3d, self.axes)
331358
if self._axlim_clip:
359+
scale_mask |= _viewlim_mask(*self._verts3d, self.axes)
360+
if np.any(scale_mask):
332361
mask = np.broadcast_to(
333-
_viewlim_mask(*self._verts3d, self.axes),
362+
scale_mask,
334363
(len(self._verts3d), *self._verts3d[0].shape)
335364
)
336365
xs3d, ys3d, zs3d = np.ma.array(self._verts3d,
@@ -424,10 +453,13 @@ class Collection3D(Collection):
424453
def do_3d_projection(self):
425454
"""Project the points according to renderer matrix."""
426455
vs_list = [vs for vs, _ in self._3dverts_codes]
456+
masks = [_scale_invalid_mask(*vs.T, self.axes) for vs in vs_list]
427457
if self._axlim_clip:
428-
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
429-
_viewlim_mask(*vs.T, self.axes), vs.shape))
430-
for vs in vs_list]
458+
masks = [m | _viewlim_mask(*vs.T, self.axes)
459+
for m, vs in zip(masks, vs_list)]
460+
vs_list = [np.ma.array(vs, mask=np.broadcast_to(m, vs.shape))
461+
if np.any(m) else vs
462+
for vs, m in zip(vs_list, masks)]
431463
xyzs_list = [proj3d._scale_proj_transform(
432464
vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list]
433465
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
@@ -520,6 +552,14 @@ def do_3d_projection(self):
520552
if np.ma.isMA(segments) and segments.mask is not np.ma.nomask:
521553
mask = segments.mask
522554

555+
scale_mask = _scale_invalid_mask(segments[..., 0],
556+
segments[..., 1],
557+
segments[..., 2],
558+
self.axes)
559+
if np.any(scale_mask):
560+
mask |= np.broadcast_to(scale_mask[..., np.newaxis],
561+
(*scale_mask.shape, 3))
562+
523563
if self._axlim_clip:
524564
viewlim_mask = _viewlim_mask(segments[..., 0],
525565
segments[..., 1],
@@ -612,12 +652,15 @@ def get_path(self):
612652

613653
def do_3d_projection(self):
614654
s = self._segment3d
655+
xs0, ys0, zs0 = zip(*s)
656+
mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes)
615657
if self._axlim_clip:
616-
mask = _viewlim_mask(*zip(*s), self.axes)
658+
mask |= _viewlim_mask(xs0, ys0, zs0, self.axes)
659+
if np.any(mask):
617660
xs, ys, zs = np.ma.array(zip(*s),
618661
dtype=float, mask=mask).filled(np.nan)
619662
else:
620-
xs, ys, zs = zip(*s)
663+
xs, ys, zs = xs0, ys0, zs0
621664
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
622665
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]))
623666
return min(vzs)
@@ -672,12 +715,15 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False):
672715

673716
def do_3d_projection(self):
674717
s = self._segment3d
718+
xs0, ys0, zs0 = zip(*s)
719+
mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes)
675720
if self._axlim_clip:
676-
mask = _viewlim_mask(*zip(*s), self.axes)
721+
mask |= _viewlim_mask(xs0, ys0, zs0, self.axes)
722+
if np.any(mask):
677723
xs, ys, zs = np.ma.array(zip(*s),
678724
dtype=float, mask=mask).filled(np.nan)
679725
else:
680-
xs, ys, zs = zip(*s)
726+
xs, ys, zs = xs0, ys0, zs0
681727
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
682728
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d)
683729
return min(vzs)
@@ -816,8 +862,10 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False):
816862
self.stale = True
817863

818864
def do_3d_projection(self):
865+
mask = _scale_invalid_mask(*self._offsets3d, self.axes)
819866
if self._axlim_clip:
820-
mask = _viewlim_mask(*self._offsets3d, self.axes)
867+
mask |= _viewlim_mask(*self._offsets3d, self.axes)
868+
if np.any(mask):
821869
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
822870
else:
823871
xs, ys, zs = self._offsets3d
@@ -1038,8 +1086,10 @@ def do_3d_projection(self):
10381086
for xyz in self._offsets3d:
10391087
if np.ma.isMA(xyz):
10401088
mask = mask | xyz.mask
1089+
mask = mask | _scale_invalid_mask(*self._offsets3d, self.axes)
10411090
if self._axlim_clip:
10421091
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
1092+
if np.any(mask):
10431093
mask = np.broadcast_to(mask,
10441094
(len(self._offsets3d), *self._offsets3d[0].shape))
10451095
xyzs = np.ma.array(self._offsets3d, mask=mask)
@@ -1377,9 +1427,11 @@ def do_3d_projection(self):
13771427
if self._edge_is_mapped:
13781428
self._edgecolor3d = self._edgecolors
13791429

1380-
needs_masking = np.any(self._invalid_vertices)
13811430
num_faces = len(self._faces)
1382-
mask = self._invalid_vertices
1431+
mask = self._invalid_vertices | _scale_invalid_mask(
1432+
self._faces[..., 0], self._faces[..., 1],
1433+
self._faces[..., 2], self.axes)
1434+
needs_masking = np.any(mask)
13831435

13841436
# Some faces might contain masked vertices, so we want to ignore any
13851437
# errors that those might cause

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,16 @@ def autoscale(self, enable=True, axis='both', tight=None):
640640
def auto_scale_xyz(self, X, Y, Z=None, had_data=None):
641641
# This updates the bounding boxes as to keep a record as to what the
642642
# minimum sized rectangular volume holds the data.
643+
X = self.xaxis._nan_out_of_scale_range(X)
644+
Y = self.yaxis._nan_out_of_scale_range(Y)
643645
if np.shape(X) == np.shape(Y):
644646
self.xy_dataLim.update_from_data_xy(
645647
np.column_stack([np.ravel(X), np.ravel(Y)]), not had_data)
646648
else:
647649
self.xy_dataLim.update_from_data_x(X, not had_data)
648650
self.xy_dataLim.update_from_data_y(Y, not had_data)
649651
if Z is not None:
652+
Z = self.zaxis._nan_out_of_scale_range(Z)
650653
self.zz_dataLim.update_from_data_x(Z, not had_data)
651654
# Let autoscale_view figure out how to use this data.
652655
self.autoscale_view()

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3188,3 +3188,17 @@ def test_scale3d_calc_coord():
31883188
# Pane coordinate should match axis limit (y-pane at max)
31893189
assert pane_idx == 1
31903190
assert point[pane_idx] == pytest.approx(ax.get_ylim()[1])
3191+
3192+
3193+
def test_plot_surface_log_scale_invalid_values():
3194+
"""Ensure non-positive Z values on a log z-axis does not corrupt zlim."""
3195+
fig = plt.figure()
3196+
ax = fig.add_subplot(projection='3d')
3197+
ax.set_zscale('log')
3198+
X, Y = np.meshgrid(np.linspace(1, 3, 4), np.linspace(1, 3, 4))
3199+
Z = X * Y - 4 # half the entries are <= 0, invalid for a log scale
3200+
ax.plot_surface(X, Y, Z)
3201+
fig.canvas.draw()
3202+
3203+
zmin, zmax = ax.get_zlim()
3204+
assert 1e-3 < zmin < zmax < 1e3, f"zlim corrupted: {(zmin, zmax)}"

0 commit comments

Comments
 (0)