Skip to content

Commit 5f7a231

Browse files
Lutz Grossclaude
andcommitted
Add InterpolationTable class with order-0 C++ workers (issues #41, #55, #121)
- Add interpolateFromTable{1,2,3}DOrder0 C++ workers (piecewise-constant / nearest-neighbour) with OpenMP and MPI support, mirroring the existing order-1 workers in Data.cpp / Data.h - Expose via _interpolateTable{1,2,3}dOrder0 Boost.Python bindings - New escriptcore/py_src/interpolation.py: InterpolationTable class - Supports order 0 and 1, 1-D/2-D/3-D coordinate grids - x can have shape (), (1,), (2,), or (3,); result is always scalar - Table indexing: table[ix, iy, iz] (x-axis first) - Export InterpolationTable from esys.escript - Deprecated interpolateTable() now delegates to InterpolationTable Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 078f532 commit 5f7a231

7 files changed

Lines changed: 738 additions & 91 deletions

File tree

escript/py_src/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from esys.escriptcore.escriptcpp import *
7474
from esys.escriptcore.start import HAVE_SYMBOLS
7575
from esys.escriptcore.util import *
76+
from esys.escriptcore.interpolation import InterpolationTable
7677
from esys.escriptcore.nonlinearPDE import NonlinearPDE
7778
from esys.escriptcore.datamanager import DataManager
7879
if HAVE_SYMBOLS:
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
##############################################################################
2+
#
3+
# Copyright (c) 2003-2026 by the esys.escript Group
4+
# https://github.com/LutzGross/esys-escript.github.io
5+
#
6+
# Primary Business: Queensland, Australia
7+
# Licensed under the Apache License, version 2.0
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# See CREDITS file for contributors and development history
11+
#
12+
##############################################################################
13+
14+
"""
15+
Provides :class:`InterpolationTable`, a class-based interface for
16+
interpolating scalar values from a regular-grid lookup table onto a mesh.
17+
Supports order-0 (piecewise constant) and order-1 (linear) interpolation
18+
in 1D, 2D, and 3D.
19+
"""
20+
21+
import numpy as np
22+
from . import escriptcpp as escore
23+
24+
25+
class InterpolationTable:
26+
"""
27+
Interpolates scalar values from a regular-grid lookup table onto a mesh.
28+
29+
The coordinate data *x* passed to :meth:`__call__` determines the
30+
interpolation dimension:
31+
32+
======================== =========== =======================
33+
``x.getShape()`` Lookup dim Required table rank
34+
======================== =========== =======================
35+
``()`` 1-D 1 — shape ``(nx,)``
36+
``(1,)`` 1-D 1 — shape ``(nx,)``
37+
``(2,)`` 2-D 2 — shape ``(nx, ny)``
38+
``(3,)`` 3-D 3 — shape ``(nx, ny, nz)``
39+
======================== =========== =======================
40+
41+
The returned `Data` object is always **scalar** (shape ``()``).
42+
43+
**Table indexing convention**: ``table[ix, iy, iz]`` where *ix* is the
44+
index along the x-axis (first coordinate), *iy* along the y-axis
45+
(second), and *iz* along the z-axis (third).
46+
47+
Example usage::
48+
49+
import numpy as np
50+
from esys.finley import Rectangle
51+
from esys.escript import Function
52+
from esys.escriptcore.interpolation import InterpolationTable
53+
54+
dom = Rectangle(20, 20)
55+
x = Function(dom).getX() # shape (2,)
56+
57+
# 2-D scalar table, order-1 (linear) interpolation
58+
t = np.random.rand(5, 5)
59+
interp = InterpolationTable(t, origin=(0., 0.), step=(0.25, 0.25))
60+
result = interp(x) # scalar Data on Function(dom)
61+
62+
# 2-D scalar table, order-0 (piecewise constant) interpolation
63+
interp0 = InterpolationTable(t, origin=(0., 0.), step=(0.25, 0.25), order=0)
64+
result0 = interp0(x)
65+
66+
:param table: lookup table as a numpy array of rank 1, 2, or 3
67+
:type table: ``numpy.ndarray``
68+
:param origin: coordinate(s) of the first table entry; a single float
69+
for 1-D or a tuple for 2-D / 3-D
70+
:type origin: ``float`` or ``tuple`` of ``float``
71+
:param step: cell size(s); all values must be strictly positive
72+
:type step: ``float`` or ``tuple`` of ``float``
73+
:param order: interpolation order — 0 for piecewise constant (nearest
74+
neighbour), 1 for linear (default)
75+
:type order: ``int``
76+
:param undef: upper threshold; result values above this trigger a
77+
``RuntimeError``
78+
:type undef: ``float``
79+
:param check_boundaries: if ``True``, a ``RuntimeError`` is raised when
80+
a coordinate lies outside the table extent; otherwise the nearest
81+
boundary value is used
82+
:type check_boundaries: ``bool``
83+
"""
84+
85+
def __init__(self, table, origin, step, order=1, undef=1.e50,
86+
check_boundaries=False):
87+
if not isinstance(table, np.ndarray):
88+
table = np.array(table, dtype=float)
89+
if np.isscalar(origin):
90+
origin = (origin,)
91+
if np.isscalar(step):
92+
step = (step,)
93+
ndim = len(origin)
94+
if len(step) != ndim:
95+
raise ValueError("origin and step must have the same length")
96+
if ndim < 1 or ndim > 3:
97+
raise ValueError("ndim (length of origin) must be 1, 2, or 3")
98+
if table.ndim != ndim:
99+
raise ValueError(
100+
"table rank {} does not match coordinate dimension {} "
101+
"(set by length of origin)".format(table.ndim, ndim))
102+
if any(s <= 0 for s in step):
103+
raise ValueError("All step sizes must be strictly positive")
104+
if order not in (0, 1):
105+
raise ValueError("order must be 0 or 1")
106+
self._table = np.ascontiguousarray(table, dtype=float)
107+
self._origin = tuple(float(v) for v in origin)
108+
self._step = tuple(float(v) for v in step)
109+
self._ndim = ndim
110+
self._order = order
111+
self._undef = float(undef)
112+
self._check_boundaries = check_boundaries
113+
114+
# ------------------------------------------------------------------
115+
# Internal C++ dispatch
116+
# ------------------------------------------------------------------
117+
118+
def _cpp_1d(self, x0):
119+
if self._order == 0:
120+
return x0._interpolateTable1dOrder0(
121+
self._table, self._origin[0], self._step[0],
122+
self._undef, self._check_boundaries)
123+
else:
124+
return x0.interpolateTable(
125+
self._table, self._origin[0], self._step[0],
126+
self._undef, self._check_boundaries)
127+
128+
def _cpp_2d(self, x0, x1):
129+
if self._order == 0:
130+
return x0._interpolateTable2dOrder0(
131+
self._table, self._origin[0], self._step[0],
132+
x1, self._origin[1], self._step[1],
133+
self._undef, self._check_boundaries)
134+
else:
135+
return x0.interpolateTable(
136+
self._table, self._origin[0], self._step[0],
137+
x1, self._origin[1], self._step[1],
138+
self._undef, self._check_boundaries)
139+
140+
def _cpp_3d(self, x0, x1, x2):
141+
if self._order == 0:
142+
return x0._interpolateTable3dOrder0(
143+
self._table, self._origin[0], self._step[0],
144+
x1, self._origin[1], self._step[1],
145+
x2, self._origin[2], self._step[2],
146+
self._undef, self._check_boundaries)
147+
else:
148+
return x0._interpolateTable3d(
149+
self._table, self._origin[0], self._step[0],
150+
x1, self._origin[1], self._step[1],
151+
x2, self._origin[2], self._step[2],
152+
self._undef, self._check_boundaries)
153+
154+
# ------------------------------------------------------------------
155+
# Public interface
156+
# ------------------------------------------------------------------
157+
158+
def __call__(self, x):
159+
"""
160+
Return scalar interpolated values at the mesh points described by *x*.
161+
162+
:param x: coordinate data with shape ``()``, ``(1,)``, ``(2,)``, or
163+
``(3,)``; the shape must be consistent with the table rank
164+
:type x: `Data`
165+
:return: scalar interpolated field
166+
:rtype: `Data` with shape ``()``
167+
:raises TypeError: if *x* is not a `Data` object
168+
:raises ValueError: if *x* has an unsupported shape or is inconsistent
169+
with the table rank
170+
"""
171+
if not isinstance(x, escore.Data):
172+
raise TypeError("x must be a Data object")
173+
174+
sh = x.getShape()
175+
176+
if sh == ():
177+
if self._ndim != 1:
178+
raise ValueError(
179+
"scalar x requires a rank-1 table (ndim=1), "
180+
"got ndim={}".format(self._ndim))
181+
return self._cpp_1d(x)
182+
183+
if sh == (1,):
184+
if self._ndim != 1:
185+
raise ValueError(
186+
"x with shape (1,) requires a rank-1 table (ndim=1), "
187+
"got ndim={}".format(self._ndim))
188+
return self._cpp_1d(x[0])
189+
190+
if sh == (2,):
191+
if self._ndim != 2:
192+
raise ValueError(
193+
"x with shape (2,) requires a rank-2 table (ndim=2), "
194+
"got ndim={}".format(self._ndim))
195+
return self._cpp_2d(x[0], x[1])
196+
197+
if sh == (3,):
198+
if self._ndim != 3:
199+
raise ValueError(
200+
"x with shape (3,) requires a rank-3 table (ndim=3), "
201+
"got ndim={}".format(self._ndim))
202+
return self._cpp_3d(x[0], x[1], x[2])
203+
204+
raise ValueError(
205+
"x must have shape (), (1,), (2,), or (3,); got {}".format(sh))
206+
207+
def interpolate(self, x):
208+
"""Alias for :meth:`__call__`."""
209+
return self(x)
210+
211+
# ------------------------------------------------------------------
212+
# Read-only properties
213+
# ------------------------------------------------------------------
214+
215+
@property
216+
def order(self):
217+
"""Interpolation order (0 or 1)."""
218+
return self._order
219+
220+
@property
221+
def ndim(self):
222+
"""Number of coordinate dimensions (1, 2, or 3)."""
223+
return self._ndim
224+
225+
@property
226+
def origin(self):
227+
"""Tuple of starting coordinates."""
228+
return self._origin
229+
230+
@property
231+
def step(self):
232+
"""Tuple of cell sizes."""
233+
return self._step

escriptcore/py_src/util.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from . import escriptcpp as escore
5050
from .escriptcpp import C_GeneralTensorProduct, Data
51+
from .interpolation import InterpolationTable
5152
from .escriptcpp import getVersion, getMPIRankWorld, getMPIWorldMax, hasFeature
5253
#from .escriptcpp import printParallelThreadCounts
5354
from .escriptcpp import listEscriptParams
@@ -128,60 +129,30 @@ def interpolateTable(tab, dat, start, step, undef=1.e50, check_boundaries=False)
128129
"""
129130
Interpolates values from a lookup table.
130131
132+
.. deprecated::
133+
Use :class:`~esys.escriptcore.interpolation.InterpolationTable` instead.
134+
This function delegates to ``InterpolationTable`` with ``order=1``.
135+
131136
:param tab: the lookup table array
132137
:type tab: ``numpy.ndarray``
133-
:param dat: input data to interpolate
134-
:type dat: `Data` or ``numpy.ndarray``
138+
:param dat: coordinate data — scalar `Data` for 1-D or vector `Data` for 2-D/3-D
139+
:type dat: `Data`
135140
:param start: starting coordinate(s) for the table
136141
:type start: ``float`` or ``tuple`` of ``float``
137142
:param step: step size(s) for the table
138143
:type step: ``float`` or ``tuple`` of ``float``
139-
:param undef: value to use for undefined/out-of-bounds results
144+
:param undef: upper bound on interpolated values
140145
:type undef: ``float``
141-
:param check_boundaries: if ``True``, check that input is within table bounds
146+
:param check_boundaries: if ``True``, raise an error for out-of-bounds coordinates
142147
:type check_boundaries: ``bool``
143148
:return: interpolated values
144-
:rtype: `Data` or ``numpy.ndarray``
145-
146-
.. deprecated::
147-
This function is deprecated and is known to contain bugs.
148-
"""
149-
print("WARNING: This function is deprecated and is known to contain bugs.")
150-
try:
151-
dim=len(start)
152-
except TypeError:
153-
start=(start,)
154-
dim=1
155-
try:
156-
slen=len(step)
157-
except TypeError:
158-
step=(step,)
159-
slen=1
160-
if dim<1 or dim>3:
161-
raise ValueError("Length of start list must be between 1 and 3.")
162-
if dim!=slen:
163-
raise ValueError("Length of start and step must be the same.")
164-
dshape=dat.getShape()
165-
if len(dshape)==0:
166-
datdim=0
167-
firstdim=dat
168-
else:
169-
datdim=dshape[0]
170-
firstdim=dat[0]
171-
#So now we know firstdim is a scalar
172-
if (dim==1 and datdim>1) or (dim>1 and datdim!=dim):
173-
print((dim, datdim))
174-
raise ValueError("The dimension of dat must be equal to the length of start.")
175-
if dim==3:
176-
d1=dat[1]
177-
d2=dat[2]
178-
return firstdim._interpolateTable3d(tab, start[0], step[0], d1, start[1], step[1], d2, start[2], step[2], undef, check_boundaries)
179-
if dim==2:
180-
d1=dat[1]
181-
return firstdim.interpolateTable(tab, start[0], step[0], d1, start[1], step[1], undef, check_boundaries)
182-
# return d1.interpolateTable(tab, start[1], step[1], firstdim, start[0], step[0], undef, check_boundaries)
183-
else:
184-
return firstdim.interpolateTable(tab, start[0], step[0], undef, check_boundaries)
149+
:rtype: `Data`
150+
"""
151+
warnings.warn(
152+
"interpolateTable is deprecated; use InterpolationTable instead.",
153+
DeprecationWarning, stacklevel=2)
154+
return InterpolationTable(tab, start, step, order=1, undef=undef,
155+
check_boundaries=check_boundaries)(dat)
185156

186157

187158
def saveDataCSV(filename, append=False, refid=False, sep=", ", csep="_", **data):

0 commit comments

Comments
 (0)