1616from parcels ._core .utils .string import _assert_str_and_python_varname
1717from parcels ._core .uxgrid import UxGrid
1818from parcels ._core .xgrid import XGrid
19- from parcels ._python import assert_same_function_signature
2019from parcels ._reprs import field_repr , vectorfield_repr
2120from parcels ._typing import VectorType
22- from parcels .interpolators import (
23- ZeroInterpolator ,
24- ZeroInterpolator_Vector ,
25- )
21+ from parcels .interpolators ._base import ScalarInterpolator , VectorInterpolator
2622
2723if TYPE_CHECKING :
2824 from parcels ._core .model import Model
@@ -89,7 +85,7 @@ def __init__(
8985 self ,
9086 name : str ,
9187 model : Model ,
92- interp_method : Callable ,
88+ interp_method : ScalarInterpolator ,
9389 ):
9490 # TODO PR: Enable isinstance check once Model is moved to abc.Model
9591 # if not isinstance(model, "Model"):
@@ -103,7 +99,8 @@ def __init__(
10399 self .model = model
104100
105101 # Setting the interpolation method dynamically
106- assert_same_function_signature (interp_method , ref = ZeroInterpolator , context = "Interpolation" )
102+ if not isinstance (interp_method , ScalarInterpolator ):
103+ raise ValueError (f"interp_method must be a `ScalarInterpolator` object. Got { type (interp_method )= !r} " )
107104 self ._interp_method = interp_method
108105
109106 self .igrid = - 1 # Default the grid index to -1
@@ -128,8 +125,9 @@ def interp_method(self):
128125 return self ._interp_method
129126
130127 @interp_method .setter
131- def interp_method (self , method : Callable ):
132- assert_same_function_signature (method , ref = ZeroInterpolator , context = "Interpolation" )
128+ def interp_method (self , method : ScalarInterpolator ):
129+ if not isinstance (method , ScalarInterpolator ):
130+ raise ValueError (f"method must be a `ScalarInterpolator` object. Got { type (method )= !r} " )
133131 self ._interp_method = method
134132
135133 def _check_velocitysampling (self ):
@@ -175,7 +173,7 @@ def eval(self, time: datetime, z, y, x, particles=None):
175173
176174 particle_positions , grid_positions = _get_positions (self , time , z , y , x , particles , _ei )
177175
178- value = self ._interp_method (particle_positions , grid_positions , self )
176+ value = self ._interp_method . interp (particle_positions , grid_positions , self )
179177
180178 _update_particle_states_interp_value (particles , value )
181179
@@ -201,7 +199,7 @@ def __init__(
201199 U : Field , # noqa: N803
202200 V : Field , # noqa: N803
203201 W : Field | None = None , # noqa: N803
204- vector_interp_method : Callable | None = None ,
202+ vector_interp_method : VectorInterpolator | None = None ,
205203 ):
206204 if vector_interp_method is None :
207205 raise ValueError ("vector_interp_method must be provided for VectorField initialization." )
@@ -226,7 +224,11 @@ def __init__(
226224 else :
227225 self .vector_type = "2D"
228226
229- assert_same_function_signature (vector_interp_method , ref = ZeroInterpolator_Vector , context = "Interpolation" )
227+ if not isinstance (vector_interp_method , VectorInterpolator ):
228+ raise ValueError (
229+ f"vector_interp_method must be a `VectorInterpolator` object. Got { type (vector_interp_method )= !r} "
230+ )
231+
230232 self ._vector_interp_method = vector_interp_method
231233
232234 def __repr__ (self ):
@@ -238,7 +240,8 @@ def vector_interp_method(self):
238240
239241 @vector_interp_method .setter
240242 def vector_interp_method (self , method : Callable ):
241- assert_same_function_signature (method , ref = ZeroInterpolator_Vector , context = "Interpolation" )
243+ if not isinstance (method , VectorInterpolator ):
244+ raise ValueError (f"method must be a `VectorInterpolator` object. Got { type (method )= !r} " )
242245 self ._vector_interp_method = method
243246
244247 def eval (self , time : datetime , z , y , x , particles = None ):
@@ -277,7 +280,7 @@ def eval(self, time: datetime, z, y, x, particles=None):
277280
278281 particle_positions , grid_positions = _get_positions (self .U , time , z , y , x , particles , _ei )
279282
280- (u , v , w ) = self ._vector_interp_method (particle_positions , grid_positions , self )
283+ (u , v , w ) = self ._vector_interp_method . interp (particle_positions , grid_positions , self )
281284
282285 for vel in (u , v , w ):
283286 _update_particle_states_interp_value (particles , vel )
0 commit comments