diff --git a/ezyrb/__init__.py b/ezyrb/__init__.py index f57bb34c..affe49d6 100644 --- a/ezyrb/__init__.py +++ b/ezyrb/__init__.py @@ -16,6 +16,7 @@ "ReducedOrderModel", "PODAE", "RegularGrid", + "Nearest", "MultiReducedOrderModel", "SklearnApproximation", "SklearnReduction", diff --git a/ezyrb/approximation/__init__.py b/ezyrb/approximation/__init__.py index 69f06a84..4b388f12 100644 --- a/ezyrb/approximation/__init__.py +++ b/ezyrb/approximation/__init__.py @@ -9,6 +9,7 @@ "KNeighborsRegressor", "RadiusNeighborsRegressor", "SklearnApproximation", + "Nearest", ] from .approximation import Approximation @@ -19,3 +20,4 @@ from .kneighbors_regressor import KNeighborsRegressor from .radius_neighbors_regressor import RadiusNeighborsRegressor from .sklearn_approximation import SklearnApproximation +from .nearest import Nearest diff --git a/ezyrb/approximation/nearest.py b/ezyrb/approximation/nearest.py new file mode 100644 index 00000000..5eb7242d --- /dev/null +++ b/ezyrb/approximation/nearest.py @@ -0,0 +1,41 @@ +"""Wrapper for Nearest Neighbor Interpolator.""" +import logging +import numpy as np +from scipy.interpolate import NearestNDInterpolator as NearestND +from scipy.interpolate import interp1d +from .approximation import Approximation + +logger = logging.getLogger(__name__) + +class Nearest(Approximation): + """ + Nearest Neighbors interpolator for univariate and multivariate approximation. + + :param kwargs: arguments passed to the internal instance of + scipy.interpolate.NearestNDInterpolator or scipy.interpolate.interp1d. + """ + def __init__(self, rescale=True, **kwargs): + logger.debug("Initializing Nearest with rescale=%s, kwargs: %s", rescale, kwargs) + super().__init__() + self.rescale = rescale + self.kwargs = kwargs + self.interpolator = None + + def fit(self, points, values): + as_np_array = np.array(points) + + if as_np_array.ndim == 1 or (as_np_array.ndim == 2 and as_np_array.shape[1] == 1): + logger.debug("Using 1D nearest interpolation") + self.interpolator = interp1d( + np.squeeze(as_np_array), values, kind='nearest', + axis=0, bounds_error=False, fill_value="extrapolate" + ) + else: + logger.debug("Using ND nearest interpolation with rescale=%s", self.rescale) + # Pass the rescale flag specifically here + self.interpolator = NearestND(as_np_array, values, rescale=self.rescale, **self.kwargs) + + logger.info("Nearest fitted successfully") + + def predict(self, new_point): + return self.interpolator(new_point).squeeze() \ No newline at end of file diff --git a/tests/test_nearest.py b/tests/test_nearest.py new file mode 100644 index 00000000..c4f4b863 --- /dev/null +++ b/tests/test_nearest.py @@ -0,0 +1,58 @@ +import numpy as np +from unittest import TestCase +from ezyrb import Nearest + + +class TestNearest(TestCase): + def test_params(self): + reg = Nearest(rescale=False) + assert reg.rescale == False + + def test_default_params(self): + reg = Nearest() + assert reg.interpolator is None + + def test_predict1d(self): + reg = Nearest() + reg.fit([[1], [6], [8]], [[1, 0], [20, 5], [8, 6]]) + result = reg.predict([[1], [8], [6]]) + assert (result[0] == [1, 0]).all() + assert (result[1] == [8, 6]).all() + assert (result[2] == [20, 5]).all() + + def test_predict_multivariate(self): + reg = Nearest() + points = [[0, 0], [1, 1]] + values = [[10, 10], [20, 20]] + reg.fit(points, values) + result = reg.predict([[0.1, 0.1]]) + assert (result == [10, 10]).all() + + def test_wrong_input_shape(self): + with self.assertRaises(Exception): + reg = Nearest() + reg.fit([[1, 2], [6], [8, 9]], [[1, 0], [20, 5], [8, 6]]) + + def test_wrong_sample_count(self): + with self.assertRaises(Exception): + reg = Nearest() + reg.fit([[1, 2], [4, 5], [8, 9]], [[10, 10], [20, 20]]) + + def test_batch_multivariate(self): + """Test batch prediction with 2D input.""" + reg = Nearest() + points = [[0, 0], [1, 1], [2, 2]] + values = [[10, 10], [20, 20], [30, 30]] + reg.fit(points, values) + result = reg.predict([[0.1, 0.1], [0.9, 0.9]]) + assert result.shape == (2, 2) + assert (result[0] == [10, 10]).all() + assert (result[1] == [20, 20]).all() + + def test_scalar_output(self): + """Test with scalar (1D) output values.""" + reg = Nearest() + reg.fit([[1], [6], [8]], [10, 20, 30]) + result = reg.predict([[1], [6]]) + assert result.shape == (2,) + assert (result == [10, 20]).all() \ No newline at end of file diff --git a/tests/test_parallel/test_nearest.py b/tests/test_parallel/test_nearest.py new file mode 100644 index 00000000..4b95d2a7 --- /dev/null +++ b/tests/test_parallel/test_nearest.py @@ -0,0 +1,7 @@ +import pytest + +import ezyrb +from ezyrb.parallel import ReducedOrderModel as ParallelROM +ezyrb.ReducedOrderModel = ParallelROM + +from tests.test_nearest import *