Skip to content

Commit 783c815

Browse files
authored
Feat: Add errors for line and plane of best fit (#362)
1 parent dbaac9e commit 783c815

7 files changed

Lines changed: 114 additions & 18 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*.xml
55
.*/
66
.coverage
7+
.python-version
78
build/
89
dist/
910
docs/build/

.python-version

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/skspatial/objects/cylinder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,9 @@ def _best_fit(points_centered: Points, centroid: Point) -> Tuple[Vector, Point,
542542

543543
def _compute_initial_direction(points: Points) -> np.ndarray:
544544
"""Compute the initial direction as the best fit line."""
545-
initial_direction = Line.best_fit(points).vector.unit()
545+
line_best_fit = cast(Line, Line.best_fit(points))
546+
547+
initial_direction = line_best_fit.vector.unit()
546548
spherical_coordinates = _cartesian_to_spherical(*initial_direction)
547549

548550
return np.array([spherical_coordinates.theta, spherical_coordinates.phi])

src/skspatial/objects/line.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,23 +657,38 @@ def intersect_line(self, other: Line, check_coplanar: bool = True, **kwargs) ->
657657
return self.point + vector_a_scaled
658658

659659
@classmethod
660-
def best_fit(cls, points: array_like, tol: Optional[float] = None, **kwargs) -> Line:
660+
def best_fit(
661+
cls,
662+
points: array_like,
663+
tol: Optional[float] = None,
664+
return_error: bool = False,
665+
**kwargs,
666+
) -> Line | tuple[Line, float]:
661667
"""
662668
Return the line of best fit for a set of points.
663669
670+
Also optionally return a value representing the error of the fit.
671+
This is the sum of the squared singular values from SVD (excluding the first).
672+
673+
"The singular values reflect the amount of data variance captured by the bases.
674+
The first basis (the one with largest singular value) lies in the direction of the greatest data variance.
675+
The second basis captures the orthogonal direction with the second greatest variance, and so on." [1]_
676+
664677
Parameters
665678
----------
666679
points : array_like
667680
Input points.
668681
tol : float | None, optional
669682
Keyword passed to :meth:`Points.are_collinear` (default None).
683+
return_error : bool, optional
684+
If True, also return a value representing the error of the fit (default False).
670685
kwargs : dict, optional
671686
Additional keywords passed to :func:`numpy.linalg.svd`
672687
673688
Returns
674689
-------
675-
Line
676-
The line of best fit.
690+
Line | tuple[Line, float]
691+
The line of best fit, and optionally the error of the fit.
677692
678693
Raises
679694
------
@@ -697,6 +712,10 @@ def best_fit(cls, points: array_like, tol: Optional[float] = None, **kwargs) ->
697712
>>> line.direction.round(3)
698713
Vector([0.707, 0.707])
699714
715+
References
716+
----------
717+
.. [1] : "Singular Value Decomposition", Oracle, https://docs.oracle.com/en/database/oracle/machine-learning/oml4sql/23/dmcon/singular-value-decomposition.html#GUID-14AA4B45-3B36-4056-9B9A-BD9DC471F0AD
718+
700719
"""
701720
points_spatial = Points(points)
702721

@@ -705,10 +724,17 @@ def best_fit(cls, points: array_like, tol: Optional[float] = None, **kwargs) ->
705724

706725
points_centered, centroid = points_spatial.mean_center(return_centroid=True)
707726

708-
_, _, vh = np.linalg.svd(points_centered, **kwargs)
709-
direction = vh[0, :]
727+
_, S, Vh = np.linalg.svd(points_centered, **kwargs)
728+
729+
direction = Vh[0, :]
730+
line_best_fit = cls(centroid, direction)
731+
732+
if return_error:
733+
error = np.sum(S[1:] ** 2)
734+
735+
return line_best_fit, error
710736

711-
return cls(centroid, direction)
737+
return line_best_fit
712738

713739
def transform_points(self, points: array_like) -> np.ndarray:
714740
"""

src/skspatial/objects/plane.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -703,33 +703,44 @@ def intersect_plane(self, other: Plane, **kwargs) -> Line:
703703
return Line(point_line, direction_line)
704704

705705
@classmethod
706-
def best_fit(cls, points: array_like, tol: Optional[float] = None, **kwargs) -> Plane:
706+
def best_fit(
707+
cls,
708+
points: array_like,
709+
tol: Optional[float] = None,
710+
return_error: bool = False,
711+
**kwargs,
712+
) -> Plane | tuple[Plane, float]:
707713
"""
708714
Return the plane of best fit for a set of 3D points.
709715
716+
Also optionally return a value representing the error of the fit.
717+
This is the sum of the squared singular values from SVD (excluding the first two).
718+
719+
"The singular values reflect the amount of data variance captured by the bases.
720+
The first basis (the one with largest singular value) lies in the direction of the greatest data variance.
721+
The second basis captures the orthogonal direction with the second greatest variance, and so on." [1]_
722+
710723
Parameters
711724
----------
712725
points : array_like
713726
Input 3D points.
714727
tol : float | None, optional
715728
Keyword passed to :meth:`Points.are_collinear` (default None).
729+
return_error : bool, optional
730+
If True, also return a value representing the error of the fit (default False).
716731
kwargs : dict, optional
717732
Additional keywords passed to :func:`numpy.linalg.svd`
718733
719734
Returns
720735
-------
721-
Plane
722-
The plane of best fit.
736+
Plane | tuple[Plane, float]
737+
The plane of best fit, and optionally the error of the fit.
723738
724739
Raises
725740
------
726741
ValueError
727742
If the points are collinear or are not 3D.
728743
729-
References
730-
----------
731-
https://scicomp.stackexchange.com/a/6901
732-
733744
Examples
734745
--------
735746
>>> from skspatial.objects import Plane
@@ -755,6 +766,11 @@ def best_fit(cls, points: array_like, tol: Optional[float] = None, **kwargs) ->
755766
>>> Plane.best_fit(points, full_matrices=False)
756767
Plane(point=Point([0.5, 0.5, 0. ]), normal=Vector([0., 0., 1.]))
757768
769+
References
770+
----------
771+
.. [1] : "Singular Value Decomposition", Oracle, https://docs.oracle.com/en/database/oracle/machine-learning/oml4sql/23/dmcon/singular-value-decomposition.html#GUID-14AA4B45-3B36-4056-9B9A-BD9DC471F0AD
772+
.. [2] : https://scicomp.stackexchange.com/a/6901
773+
758774
"""
759775
points = Points(points)
760776

@@ -766,10 +782,16 @@ def best_fit(cls, points: array_like, tol: Optional[float] = None, **kwargs) ->
766782

767783
points_centered, centroid = points.mean_center(return_centroid=True)
768784

769-
u, _, _ = np.linalg.svd(points_centered.T, **kwargs)
770-
normal = Vector(u[:, 2])
785+
U, S, _ = np.linalg.svd(points_centered.T, **kwargs)
786+
normal = Vector(U[:, 2])
787+
788+
plane_fit = cls(centroid, normal)
789+
790+
if return_error:
791+
error_fit = np.sum(S[2:] ** 2)
792+
return plane_fit, error_fit
771793

772-
return cls(centroid, normal)
794+
return plane_fit
773795

774796
def to_mesh(
775797
self,

tests/unit/objects/test_line.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,26 @@ def test_best_fit(points, line_expected):
323323
assert line_fit.point.is_close(line_expected.point)
324324

325325

326+
@pytest.mark.parametrize(
327+
("points", "line_expected", "error_expected"),
328+
[
329+
([[0, 0], [1, 0]], Line([0.5, 0], [1, 0]), 0),
330+
([[1, 0], [0, 0]], Line([0.5, 0], [-1, 0]), 0),
331+
([[0, 0], [1, 1], [2, 2]], Line([1, 1], [1, 1]), 0),
332+
([[0, 0], [0, 1], [1, 0], [1, 1]], Line([0.5, 0.5], [1, 0]), 1),
333+
([[0, 0], [0, 2], [2, 0], [2, 2]], Line([1, 1], [1, 0]), 4),
334+
([[0, 0], [0, 3], [3, 0], [3, 3]], Line([1.5, 1.5], [1, 0]), 9),
335+
],
336+
)
337+
def test_best_fit_with_error(points, line_expected, error_expected):
338+
line_fit, error_fit = Line.best_fit(np.array(points), return_error=True)
339+
340+
assert line_fit.is_close(line_expected)
341+
assert line_fit.point.is_close(line_expected.point)
342+
343+
assert math.isclose(error_fit, error_expected, abs_tol=1e-9)
344+
345+
326346
@pytest.mark.parametrize(
327347
("points", "message_expected"),
328348
[

tests/unit/objects/test_plane.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,32 @@ def test_best_fit(points, plane_expected):
437437
assert plane_fit.point.is_close(plane_expected.point)
438438

439439

440+
@pytest.mark.parametrize(
441+
("points", "plane_expected", "error_expected"),
442+
[
443+
([[0, 0], [1, 1], [0, 2]], Plane([1 / 3, 1, 0], [0, 0, 1]), 0),
444+
(
445+
[[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]],
446+
Plane([0.25, 0.25, 0.25], [1, 1, 1]),
447+
0.25,
448+
),
449+
(
450+
[[0, 0, 0], [2, 0, 0], [0, 2, 0], [0, 0, 2]],
451+
Plane([0.5, 0.5, 0.5], [1, 1, 1]),
452+
1,
453+
),
454+
],
455+
)
456+
def test_best_fit_with_error(points, plane_expected, error_expected):
457+
points = Points(points).set_dimension(3)
458+
plane_fit, error_fit = Plane.best_fit(points, return_error=True)
459+
460+
assert plane_fit.is_close(plane_expected)
461+
assert plane_fit.point.is_close(plane_expected.point)
462+
463+
assert math.isclose(error_fit, error_expected, abs_tol=1e-9)
464+
465+
440466
@pytest.mark.parametrize(
441467
("points", "message_expected"),
442468
[

0 commit comments

Comments
 (0)