Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ezyrb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"ReducedOrderModel",
"PODAE",
"RegularGrid",
"Nearest",
"MultiReducedOrderModel",
"SklearnApproximation",
"SklearnReduction",
Expand Down
2 changes: 2 additions & 0 deletions ezyrb/approximation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"KNeighborsRegressor",
"RadiusNeighborsRegressor",
"SklearnApproximation",
"Nearest",
]

from .approximation import Approximation
Expand All @@ -19,3 +20,4 @@
from .kneighbors_regressor import KNeighborsRegressor
from .radius_neighbors_regressor import RadiusNeighborsRegressor
from .sklearn_approximation import SklearnApproximation
from .nearest import Nearest
41 changes: 41 additions & 0 deletions ezyrb/approximation/nearest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Wrapper for Nearest Neighbor Interpolator."""
import logging
import numpy as np
from scipy.interpolate import NearestNDInterpolator as NearestND
from scipy.interpolate import interp1d
from .approximation import Approximation

logger = logging.getLogger(__name__)

class Nearest(Approximation):
"""
Nearest Neighbors interpolator for univariate and multivariate approximation.

:param kwargs: arguments passed to the internal instance of
scipy.interpolate.NearestNDInterpolator or scipy.interpolate.interp1d.
"""
def __init__(self, rescale=True, **kwargs):
logger.debug("Initializing Nearest with rescale=%s, kwargs: %s", rescale, kwargs)
super().__init__()
self.rescale = rescale
self.kwargs = kwargs
self.interpolator = None

def fit(self, points, values):
as_np_array = np.array(points)

if as_np_array.ndim == 1 or (as_np_array.ndim == 2 and as_np_array.shape[1] == 1):
logger.debug("Using 1D nearest interpolation")
self.interpolator = interp1d(
np.squeeze(as_np_array), values, kind='nearest',
axis=0, bounds_error=False, fill_value="extrapolate"
)
else:
logger.debug("Using ND nearest interpolation with rescale=%s", self.rescale)
# Pass the rescale flag specifically here
self.interpolator = NearestND(as_np_array, values, rescale=self.rescale, **self.kwargs)

logger.info("Nearest fitted successfully")

def predict(self, new_point):
return self.interpolator(new_point).squeeze()
58 changes: 58 additions & 0 deletions tests/test_nearest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from unittest import TestCase
from ezyrb import Nearest


class TestNearest(TestCase):
def test_params(self):
reg = Nearest(rescale=False)
assert reg.rescale == False

def test_default_params(self):
reg = Nearest()
assert reg.interpolator is None

def test_predict1d(self):
reg = Nearest()
reg.fit([[1], [6], [8]], [[1, 0], [20, 5], [8, 6]])
result = reg.predict([[1], [8], [6]])
assert (result[0] == [1, 0]).all()
assert (result[1] == [8, 6]).all()
assert (result[2] == [20, 5]).all()

def test_predict_multivariate(self):
reg = Nearest()
points = [[0, 0], [1, 1]]
values = [[10, 10], [20, 20]]
reg.fit(points, values)
result = reg.predict([[0.1, 0.1]])
assert (result == [10, 10]).all()

def test_wrong_input_shape(self):
with self.assertRaises(Exception):
reg = Nearest()
reg.fit([[1, 2], [6], [8, 9]], [[1, 0], [20, 5], [8, 6]])

def test_wrong_sample_count(self):
with self.assertRaises(Exception):
reg = Nearest()
reg.fit([[1, 2], [4, 5], [8, 9]], [[10, 10], [20, 20]])

def test_batch_multivariate(self):
"""Test batch prediction with 2D input."""
reg = Nearest()
points = [[0, 0], [1, 1], [2, 2]]
values = [[10, 10], [20, 20], [30, 30]]
reg.fit(points, values)
result = reg.predict([[0.1, 0.1], [0.9, 0.9]])
assert result.shape == (2, 2)
assert (result[0] == [10, 10]).all()
assert (result[1] == [20, 20]).all()

def test_scalar_output(self):
"""Test with scalar (1D) output values."""
reg = Nearest()
reg.fit([[1], [6], [8]], [10, 20, 30])
result = reg.predict([[1], [6]])
assert result.shape == (2,)
assert (result == [10, 20]).all()
7 changes: 7 additions & 0 deletions tests/test_parallel/test_nearest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest

import ezyrb
from ezyrb.parallel import ReducedOrderModel as ParallelROM
ezyrb.ReducedOrderModel = ParallelROM

from tests.test_nearest import *
Loading