Skip to content

Commit 420c0a7

Browse files
Fix non-square wireframe 3D plotting
1 parent f4cf125 commit 420c0a7

3 files changed

Lines changed: 20 additions & 4 deletions

File tree

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,15 +497,27 @@ def do_3d_projection(self):
497497
"""
498498
Project the points according to renderer matrix.
499499
"""
500-
segments = np.asanyarray(self._segments3d)
500+
segments = self._segments3d
501+
502+
# Handle ragged inputs, but prefer a faster path for same-length segments
503+
segment_lengths = [len(segment) for segment in segments]
504+
ragged = len(set(segment_lengths)) > 1
505+
if ragged:
506+
# Branch masked / non-masked for speed
507+
if any(np.ma.isMA(segment) for segment in segments):
508+
segments = np.ma.concatenate(segments)
509+
else:
510+
segments = np.concatenate(segments)
511+
else:
512+
segments = np.asanyarray(segments)
501513

502514
# Handle empty segments
503515
if segments.size == 0:
504516
LineCollection.set_segments(self, [])
505517
return np.nan
506518

507519
mask = False
508-
if np.ma.isMA(segments):
520+
if np.ma.isMA(segments) and segments.mask is not np.ma.nomask:
509521
mask = segments.mask
510522

511523
if self._axlim_clip:
@@ -519,9 +531,12 @@ def do_3d_projection(self):
519531
(*viewlim_mask.shape, 3))
520532
mask = mask | viewlim_mask
521533

522-
xyzs = np.ma.array(
523-
proj3d._scale_proj_transform_vectors(segments, self.axes), mask=mask)
534+
xyzs = proj3d._scale_proj_transform_vectors(segments, self.axes)
535+
if mask is not False:
536+
xyzs = np.ma.array(xyzs, mask=mask)
524537
segments_2d = xyzs[..., 0:2]
538+
if ragged:
539+
segments_2d = np.split(segments_2d, np.cumsum(segment_lengths[:-1]))
525540
LineCollection.set_segments(self, segments_2d)
526541

527542
# FIXME
-10.2 KB
Loading

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def test_wireframe3dasymmetric():
854854
fig = plt.figure()
855855
ax = fig.add_subplot(projection='3d')
856856
X, Y, Z = axes3d.get_test_data(0.05)
857+
X, Y, Z = X[:-1], Y[:-1], Z[:-1] # Drop a row so the grid is non-square
857858
ax.plot_wireframe(X, Y, Z, rcount=3, ccount=13)
858859

859860

0 commit comments

Comments
 (0)