Skip to content

Commit 629d396

Browse files
Ignore data on 3D plots that are outside the valid axis scale range
1 parent 3af60ad commit 629d396

3 files changed

Lines changed: 95 additions & 12 deletions

File tree

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,30 @@ def _viewlim_mask(xs, ys, zs, axes):
9797
return mask
9898

9999

100+
def _scale_invalid_mask(xs, ys, zs, axes):
101+
"""
102+
Return the mask of points whose coordinates are invalid for the axis
103+
scale they live on (e.g. <=0 on a log axis).
104+
105+
Parameters
106+
----------
107+
xs, ys, zs : array-like
108+
The points to check, in data coordinates.
109+
axes : Axes3D
110+
The axes whose scales are queried.
111+
112+
Returns
113+
-------
114+
mask : np.ndarray
115+
Boolean array, ``True`` where any of x/y/z is out of its scale's
116+
valid domain.
117+
"""
118+
return np.logical_or.reduce((
119+
np.logical_not(axes.xaxis._scale.val_in_range(xs)),
120+
np.logical_not(axes.yaxis._scale.val_in_range(ys)),
121+
np.logical_not(axes.zaxis._scale.val_in_range(zs))))
122+
123+
100124
class Text3D(mtext.Text):
101125
"""
102126
Text object with 3D position and direction.
@@ -191,8 +215,10 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False):
191215

192216
@artist.allow_rasterization
193217
def draw(self, renderer):
218+
mask = _scale_invalid_mask(self._x, self._y, self._z, self.axes)
194219
if self._axlim_clip:
195-
mask = _viewlim_mask(self._x, self._y, self._z, self.axes)
220+
mask = mask | _viewlim_mask(self._x, self._y, self._z, self.axes)
221+
if np.any(mask):
196222
pos3d = np.ma.array([self._x, self._y, self._z],
197223
mask=mask, dtype=float).filled(np.nan)
198224
else:
@@ -328,9 +354,12 @@ def get_data_3d(self):
328354

329355
@artist.allow_rasterization
330356
def draw(self, renderer):
357+
scale_mask = _scale_invalid_mask(*self._verts3d, self.axes)
331358
if self._axlim_clip:
359+
scale_mask = scale_mask | _viewlim_mask(*self._verts3d, self.axes)
360+
if np.any(scale_mask):
332361
mask = np.broadcast_to(
333-
_viewlim_mask(*self._verts3d, self.axes),
362+
scale_mask,
334363
(len(self._verts3d), *self._verts3d[0].shape)
335364
)
336365
xs3d, ys3d, zs3d = np.ma.array(self._verts3d,
@@ -424,10 +453,13 @@ class Collection3D(Collection):
424453
def do_3d_projection(self):
425454
"""Project the points according to renderer matrix."""
426455
vs_list = [vs for vs, _ in self._3dverts_codes]
456+
masks = [_scale_invalid_mask(*vs.T, self.axes) for vs in vs_list]
427457
if self._axlim_clip:
428-
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
429-
_viewlim_mask(*vs.T, self.axes), vs.shape))
430-
for vs in vs_list]
458+
masks = [m | _viewlim_mask(*vs.T, self.axes)
459+
for m, vs in zip(masks, vs_list)]
460+
vs_list = [np.ma.array(vs, mask=np.broadcast_to(m, vs.shape))
461+
if np.any(m) else vs
462+
for vs, m in zip(vs_list, masks)]
431463
xyzs_list = [proj3d._scale_proj_transform(
432464
vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list]
433465
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
@@ -508,6 +540,14 @@ def do_3d_projection(self):
508540
if np.ma.isMA(segments):
509541
mask = segments.mask
510542

543+
scale_mask = _scale_invalid_mask(segments[..., 0],
544+
segments[..., 1],
545+
segments[..., 2],
546+
self.axes)
547+
if np.any(scale_mask):
548+
mask = mask | np.broadcast_to(scale_mask[..., np.newaxis],
549+
(*scale_mask.shape, 3))
550+
511551
if self._axlim_clip:
512552
viewlim_mask = _viewlim_mask(segments[..., 0],
513553
segments[..., 1],
@@ -597,12 +637,15 @@ def get_path(self):
597637

598638
def do_3d_projection(self):
599639
s = self._segment3d
640+
xs0, ys0, zs0 = zip(*s)
641+
mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes)
600642
if self._axlim_clip:
601-
mask = _viewlim_mask(*zip(*s), self.axes)
643+
mask = mask | _viewlim_mask(xs0, ys0, zs0, self.axes)
644+
if np.any(mask):
602645
xs, ys, zs = np.ma.array(zip(*s),
603646
dtype=float, mask=mask).filled(np.nan)
604647
else:
605-
xs, ys, zs = zip(*s)
648+
xs, ys, zs = xs0, ys0, zs0
606649
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
607650
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]))
608651
return min(vzs)
@@ -657,12 +700,15 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False):
657700

658701
def do_3d_projection(self):
659702
s = self._segment3d
703+
xs0, ys0, zs0 = zip(*s)
704+
mask = _scale_invalid_mask(xs0, ys0, zs0, self.axes)
660705
if self._axlim_clip:
661-
mask = _viewlim_mask(*zip(*s), self.axes)
706+
mask = mask | _viewlim_mask(xs0, ys0, zs0, self.axes)
707+
if np.any(mask):
662708
xs, ys, zs = np.ma.array(zip(*s),
663709
dtype=float, mask=mask).filled(np.nan)
664710
else:
665-
xs, ys, zs = zip(*s)
711+
xs, ys, zs = xs0, ys0, zs0
666712
vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes)
667713
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d)
668714
return min(vzs)
@@ -801,8 +847,10 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False):
801847
self.stale = True
802848

803849
def do_3d_projection(self):
850+
mask = _scale_invalid_mask(*self._offsets3d, self.axes)
804851
if self._axlim_clip:
805-
mask = _viewlim_mask(*self._offsets3d, self.axes)
852+
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
853+
if np.any(mask):
806854
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
807855
else:
808856
xs, ys, zs = self._offsets3d
@@ -1023,8 +1071,10 @@ def do_3d_projection(self):
10231071
for xyz in self._offsets3d:
10241072
if np.ma.isMA(xyz):
10251073
mask = mask | xyz.mask
1074+
mask = mask | _scale_invalid_mask(*self._offsets3d, self.axes)
10261075
if self._axlim_clip:
10271076
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
1077+
if np.any(mask):
10281078
mask = np.broadcast_to(mask,
10291079
(len(self._offsets3d), *self._offsets3d[0].shape))
10301080
xyzs = np.ma.array(self._offsets3d, mask=mask)
@@ -1362,9 +1412,11 @@ def do_3d_projection(self):
13621412
if self._edge_is_mapped:
13631413
self._edgecolor3d = self._edgecolors
13641414

1365-
needs_masking = np.any(self._invalid_vertices)
13661415
num_faces = len(self._faces)
1367-
mask = self._invalid_vertices
1416+
mask = self._invalid_vertices | _scale_invalid_mask(
1417+
self._faces[..., 0], self._faces[..., 1],
1418+
self._faces[..., 2], self.axes)
1419+
needs_masking = np.any(mask)
13681420

13691421
# Some faces might contain masked vertices, so we want to ignore any
13701422
# errors that those might cause

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@
5858
from . import axis3d
5959

6060

61+
def _mask_scale_invalid(data, axis):
62+
"""
63+
Return ``data`` with values invalid for ``axis``'s scale (e.g. ``<=0`` on a
64+
log axis) replaced by NaN, so they don't pollute data limits.
65+
"""
66+
if data is None:
67+
return data
68+
data = np.asanyarray(data, dtype=float)
69+
valid = axis._scale.val_in_range(data)
70+
if np.all(valid):
71+
return data
72+
return np.where(valid, data, np.nan)
73+
74+
6175
@_docstring.interpd
6276
@_api.define_aliases({
6377
"xlim": ["xlim3d"], "ylim": ["ylim3d"], "zlim": ["zlim3d"]})
@@ -640,13 +654,16 @@ def autoscale(self, enable=True, axis='both', tight=None):
640654
def auto_scale_xyz(self, X, Y, Z=None, had_data=None):
641655
# This updates the bounding boxes as to keep a record as to what the
642656
# minimum sized rectangular volume holds the data.
657+
X = _mask_scale_invalid(X, self.xaxis)
658+
Y = _mask_scale_invalid(Y, self.yaxis)
643659
if np.shape(X) == np.shape(Y):
644660
self.xy_dataLim.update_from_data_xy(
645661
np.column_stack([np.ravel(X), np.ravel(Y)]), not had_data)
646662
else:
647663
self.xy_dataLim.update_from_data_x(X, not had_data)
648664
self.xy_dataLim.update_from_data_y(Y, not had_data)
649665
if Z is not None:
666+
Z = _mask_scale_invalid(Z, self.zaxis)
650667
self.zz_dataLim.update_from_data_x(Z, not had_data)
651668
# Let autoscale_view figure out how to use this data.
652669
self.autoscale_view()

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3189,3 +3189,17 @@ def test_scale3d_calc_coord():
31893189
# Pane coordinate should match axis limit (y-pane at max)
31903190
assert pane_idx == 1
31913191
assert point[pane_idx] == pytest.approx(ax.get_ylim()[1])
3192+
3193+
3194+
def test_plot_surface_log_scale_invalid_values():
3195+
"""Ensure non-positive Z values on a log z-axis does not corrupt zlim."""
3196+
fig = plt.figure()
3197+
ax = fig.add_subplot(projection='3d')
3198+
ax.set_zscale('log')
3199+
X, Y = np.meshgrid(np.linspace(1, 3, 4), np.linspace(1, 3, 4))
3200+
Z = X * Y - 4 # half the entries are <= 0, invalid for a log scale
3201+
ax.plot_surface(X, Y, Z)
3202+
fig.canvas.draw()
3203+
3204+
zmin, zmax = ax.get_zlim()
3205+
assert 1e-3 < zmin < zmax < 1e3, f"zlim corrupted: {(zmin, zmax)}"

0 commit comments

Comments
 (0)