Skip to content

Commit aa0b9e6

Browse files
kshitij-mathsndem0
authored andcommitted
feat: add clough tocher from sklearn
1 parent 6750140 commit aa0b9e6

2 files changed

Lines changed: 54 additions & 0 deletions

File tree

ezyrb/approximation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"KNeighborsRegressor",
1010
"RadiusNeighborsRegressor",
1111
"SklearnApproximation",
12+
"CloughTocher",
1213
]
1314

1415
from .approximation import Approximation
@@ -19,3 +20,4 @@
1920
from .kneighbors_regressor import KNeighborsRegressor
2021
from .radius_neighbors_regressor import RadiusNeighborsRegressor
2122
from .sklearn_approximation import SklearnApproximation
23+
from .clough_tocher import CloughTocher
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Wrapper for Clough-Tocher 2D Interpolator."""
2+
3+
import logging
4+
import numpy as np
5+
from scipy.interpolate import CloughTocher2DInterpolator as CT
6+
7+
from .approximation import Approximation
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class CloughTocher(Approximation):
13+
r"""
14+
:math:`C^1` smooth, piecewise cubic interpolator for 2D multivariate approximation.
15+
16+
Note: This interpolator only supports 2-dimensional parameter spaces
17+
(i.e., mapping :math:`\mathbb{R}^2 \to \mathbb{R}^m`).
18+
19+
:param kwargs: arguments passed to the internal instance of
20+
scipy.interpolate.CloughTocher2DInterpolator.
21+
"""
22+
23+
def __init__(self, **kwargs):
24+
logger.debug("Initializing CloughTocher with kwargs: %s", kwargs)
25+
super().__init__()
26+
self.kwargs = kwargs
27+
self.interpolator = None
28+
29+
def fit(self, points, values):
30+
"""
31+
Construct the interpolator given `points` and `values`.
32+
"""
33+
as_np_array = np.array(points)
34+
35+
# Mathematical constraint: CT only works in R^2
36+
if as_np_array.ndim != 2 or as_np_array.shape[1] != 2:
37+
logger.error(
38+
"CloughTocher requested for data with shape %s",
39+
as_np_array.shape,
40+
)
41+
raise ValueError(
42+
"CloughTocher interpolator only supports exactly 2D parameter spaces."
43+
)
44+
45+
self.interpolator = CT(as_np_array, values, **self.kwargs)
46+
logger.info("CloughTocher fitted successfully")
47+
48+
def predict(self, new_point):
49+
"""
50+
Evaluate interpolator at given `new_points`.
51+
"""
52+
return self.interpolator(new_point).squeeze()

0 commit comments

Comments
 (0)