Skip to content

Commit ac7bc72

Browse files
authored
Feat: Make matplotlib optional (#366)
1 parent 380d373 commit ac7bc72

19 files changed

Lines changed: 104 additions & 54 deletions

HISTORY.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,28 @@
22
History
33
=======
44

5+
9.0.1 (2025-04-16)
6+
------------------
7+
8+
Docs
9+
~~~~
10+
- Include single quotes in command: `pip install 'scikit-spatial[plotting]'`
11+
12+
13+
9.0.0 (2025-04-16)
14+
------------------
15+
16+
Breaking Changes
17+
~~~~~~~~~~~~~~~~
18+
- Make matplotlib an optional dependency. It can be installed by `pip install 'scikit-spatial[plotting]'`
19+
20+
521
8.1.0 (2024-12-23)
622
------------------
723

824
Features
925
~~~~~~~~
10-
- Add optiont to return error for line and plane of best fit.
26+
- Add option to return error for line and plane of best fit.
1127

1228

1329
8.0.0 (2024-10-03)

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ $ conda install scikit-spatial -c conda-forge
150150

151151
```
152152

153+
The `matplotlib` dependency is optional. To enable plotting, you can install scikit-spatial with the extra `plotting`.
154+
155+
```bash
156+
$ pip install 'scikit-spatial[plotting]'
157+
158+
```
159+
153160
# Example Usage
154161

155162
## Measurement

docs/source/plotting.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ Plotting
66

77
This library uses ``matplotlib`` to enable plotting of all of its spatial objects. Each object has a ``plot_2d`` and/or ``plot_3d`` method. For example, a :class:`Point` object can be plotted in 2D or 3D, while a :class:`Sphere` object can only be plotted in 3D.
88

9+
The ``matplotlib`` dependency is optional. To install it, you can install scikit-spatial with the extra `plotting`.
10+
11+
.. code-block:: console
12+
13+
$ pip install 'scikit-spatial[plotting]'
14+
15+
916
The ``plot_2d`` methods require an instance of :class:`~matplotlib.axes.Axes` as the first argument, while the ``plot_3d`` methods require an instance of :class:`~mpl_toolkits.mplot3d.axes3d.Axes3D`. This allows for placing multiple spatial objects on the same plot, which is useful for visualizing computations such as projection or intersection.
1017

1118
The methods also pass keyword arguments to ``matplotlib`` functions. For example, ``Point.plot_2d`` uses :meth:`~matplotlib.axes.Axes.scatter` under the hood, so any keyword arguments to :meth:`~matplotlib.axes.Axes.scatter` can also be input to the method. Some plotting methods have additional keyword arguments that are not passed to ``matplotlib``, such as ``Line.plot_2d``, which takes parameters ``t_1`` and ``t_2`` to determine the start and end points of the line.

docs/source/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Sphinx==5.3.0
22
numpydoc==1.5.0
3-
scikit-spatial==8.1.0
3+
scikit-spatial[plotting]==9.0.1
44
setuptools==70.0.0
55
sphinx-bootstrap-theme==0.8.1
66
sphinx-gallery==0.9.0

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "scikit-spatial"
7-
version = "8.1.0"
7+
version = "9.0.1"
88
description = "Spatial objects and computations based on NumPy arrays."
99

1010
license = { text = "BSD-3-Clause" }
@@ -31,13 +31,15 @@ classifiers = [
3131
requires-python = ">=3.8"
3232

3333
dependencies = [
34-
"matplotlib>=3",
3534
"numpy>1.24; python_version >= '3.12'",
3635
"numpy>=1; python_version < '3.12'",
3736
"scipy>1.11; python_version >= '3.12'",
3837
"scipy>=1; python_version < '3.12'",
3938
]
4039

40+
[project.optional-dependencies]
41+
plotting = ["matplotlib>=3"]
42+
4143
[project.urls]
4244
repository = "https://github.com/ajhynes7/scikit-spatial"
4345
documentation = "https://scikit-spatial.readthedocs.io"

src/skspatial/objects/_base_sphere.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Module for base class of Circle and Sphere."""
22

3+
from typing import cast
4+
35
import numpy as np
46

57
from skspatial._functions import _contains_point
@@ -85,4 +87,4 @@ def project_point(self, point: array_like) -> Point:
8587

8688
vector_to_point = Vector.from_points(self.point, point)
8789

88-
return self.point + self.radius * vector_to_point.unit()
90+
return cast(Point, self.point + self.radius * vector_to_point.unit())

src/skspatial/objects/circle.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import math
66
from typing import Tuple, cast
77

8-
import matplotlib.pyplot as plt
98
import numpy as np
10-
from matplotlib.axes import Axes
119

1210
from skspatial._functions import np_float
1311
from skspatial.objects._base_sphere import _BaseSphere
@@ -368,8 +366,8 @@ def intersect_line(self, line: Line) -> Tuple[Point, Point]:
368366
point_2 = point_1 + line.direction.unit()
369367

370368
# Translate the points on the line to mimic the circle being centered on the origin.
371-
point_translated_1: Point = point_1 - self.point
372-
point_translated_2: Point = point_2 - self.point
369+
point_translated_1 = point_1 - self.point
370+
point_translated_2 = point_2 - self.point
373371

374372
x_1, y_1 = point_translated_1
375373
x_2, y_2 = point_translated_2
@@ -465,7 +463,7 @@ def best_fit(cls, points: array_like) -> Circle:
465463

466464
return cls(center, radius)
467465

468-
def plot_2d(self, ax_2d: Axes, **kwargs) -> None:
466+
def plot_2d(self, ax_2d, **kwargs) -> None:
469467
"""
470468
Plot the circle in 2D.
471469
@@ -493,6 +491,8 @@ def plot_2d(self, ax_2d: Axes, **kwargs) -> None:
493491
>>> limits = plt.axis([-10, 10, -10, 10])
494492
495493
"""
494+
import matplotlib.pyplot as plt
495+
496496
circle = plt.Circle(tuple(self.point), self.radius, **kwargs)
497497

498498
ax_2d.add_artist(circle)

src/skspatial/objects/cylinder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import List, Optional, Tuple, cast
88

99
import numpy as np
10-
from mpl_toolkits.mplot3d import Axes3D
1110
from scipy.optimize import minimize
1211

1312
from skspatial._functions import _solve_quadratic
@@ -639,7 +638,7 @@ def _spherical_to_cartesian(spherical_coordinates: _SphericalCoordinates) -> Vec
639638

640639
return cls(point_a, vector_ab, radius)
641640

642-
def plot_3d(self, ax_3d: Axes3D, n_along_axis: int = 100, n_angles: int = 30, **kwargs) -> None:
641+
def plot_3d(self, ax_3d, n_along_axis: int = 100, n_angles: int = 30, **kwargs) -> None:
643642
"""
644643
Plot a 3D cylinder.
645644

src/skspatial/objects/line.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
from typing import Optional, cast
66

77
import numpy as np
8-
from matplotlib.axes import Axes
9-
from mpl_toolkits.mplot3d import Axes3D
108

119
from skspatial.objects._base_line_plane import _BaseLinePlane
1210
from skspatial.objects.point import Point
1311
from skspatial.objects.points import Points
1412
from skspatial.objects.vector import Vector
15-
from skspatial.plotting import _connect_points_2d, _connect_points_3d
1613
from skspatial.transformation import transform_coordinates
1714
from skspatial.typing import array_like
1815

@@ -265,7 +262,7 @@ def to_point(self, t: float = 1) -> Point:
265262
"""
266263
vector_along_line = t * self.direction
267264

268-
return self.point + vector_along_line
265+
return cast(Point, self.point + vector_along_line)
269266

270267
def project_point(self, point: array_like) -> Point:
271268
"""
@@ -787,7 +784,7 @@ def transform_points(self, points: array_like) -> np.ndarray:
787784

788785
return column.flatten()
789786

790-
def plot_2d(self, ax_2d: Axes, t_1: float = 0, t_2: float = 1, **kwargs) -> None:
787+
def plot_2d(self, ax_2d, t_1: float = 0, t_2: float = 1, **kwargs) -> None:
791788
"""
792789
Plot a 2D line.
793790
@@ -821,12 +818,14 @@ def plot_2d(self, ax_2d: Axes, t_1: float = 0, t_2: float = 1, **kwargs) -> None
821818
>>> grid = ax.grid()
822819
823820
"""
821+
from skspatial.plotting import _connect_points_2d
822+
824823
point_1 = self.to_point(t_1)
825824
point_2 = self.to_point(t_2)
826825

827826
_connect_points_2d(ax_2d, point_1, point_2, **kwargs)
828827

829-
def plot_3d(self, ax_3d: Axes3D, t_1: float = 0, t_2: float = 1, **kwargs) -> None:
828+
def plot_3d(self, ax_3d, t_1: float = 0, t_2: float = 1, **kwargs) -> None:
830829
"""
831830
Plot a 3D line.
832831
@@ -862,6 +861,8 @@ def plot_3d(self, ax_3d: Axes3D, t_1: float = 0, t_2: float = 1, **kwargs) -> No
862861
>>> line.point.plot_3d(ax, s=100)
863862
864863
"""
864+
from skspatial.plotting import _connect_points_3d
865+
865866
point_1 = self.to_point(t_1)
866867
point_2 = self.to_point(t_2)
867868

src/skspatial/objects/line_segment.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
import math
66

77
import numpy as np
8-
from matplotlib.axes import Axes
9-
from mpl_toolkits.mplot3d import Axes3D
108

119
from skspatial.objects._base_spatial import _BaseSpatial
1210
from skspatial.objects.line import Line
1311
from skspatial.objects.point import Point
1412
from skspatial.objects.vector import Vector
15-
from skspatial.plotting import _connect_points_2d, _connect_points_3d
1613
from skspatial.typing import array_like
1714

1815

@@ -157,7 +154,7 @@ def intersect_line_segment(self, other: LineSegment, **kwargs) -> Point:
157154

158155
return point_intersection
159156

160-
def plot_2d(self, ax_2d: Axes, **kwargs) -> None:
157+
def plot_2d(self, ax_2d, **kwargs) -> None:
161158
"""
162159
Plot a 2D line segment.
163160
@@ -190,9 +187,11 @@ def plot_2d(self, ax_2d: Axes, **kwargs) -> None:
190187
>>> grid = ax.grid()
191188
192189
"""
190+
from skspatial.plotting import _connect_points_2d
191+
193192
_connect_points_2d(ax_2d, self.point_a, self.point_b, **kwargs)
194193

195-
def plot_3d(self, ax_3d: Axes3D, **kwargs) -> None:
194+
def plot_3d(self, ax_3d, **kwargs) -> None:
196195
"""
197196
Plot a 3D line segment.
198197
@@ -226,4 +225,6 @@ def plot_3d(self, ax_3d: Axes3D, **kwargs) -> None:
226225
>>> segment.plot_3d(ax, c='k')
227226
228227
"""
228+
from skspatial.plotting import _connect_points_3d
229+
229230
_connect_points_3d(ax_3d, self.point_a, self.point_b, **kwargs)

0 commit comments

Comments
 (0)