Skip to content

Commit df5bbaf

Browse files
committed
refactor: clarify 1D triangulation and interpolation
1 parent 85a9744 commit df5bbaf

2 files changed

Lines changed: 57 additions & 47 deletions

File tree

adaptive/learner/learnerND.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -493,21 +493,39 @@ def load_dataframe( # type: ignore[override]
493493
def bounds_are_done(self):
494494
return all(p in self.data for p in self._bounds_points)
495495

496+
def _sorted_line_data(self):
497+
coordinates = self.points[:, 0]
498+
sorted_indices = np.argsort(coordinates)
499+
return coordinates[sorted_indices], self.values[sorted_indices]
500+
496501
def _ip(self):
497-
"""A `scipy.interpolate.LinearNDInterpolator` instance
498-
containing the learner's data."""
502+
"""A SciPy interpolator containing the learner's data."""
499503
# XXX: take our own triangulation into account when generating the _ip
500-
if self.ndim == 1:
501-
points = self.points.ravel()
502-
sorted_idx = np.argsort(points)
503-
return interpolate.interp1d(
504-
points[sorted_idx],
505-
self.values[sorted_idx],
506-
axis=0,
507-
bounds_error=False,
508-
fill_value=np.nan,
509-
)
510-
return interpolate.LinearNDInterpolator(self.points, self.values)
504+
if self.ndim != 1:
505+
return interpolate.LinearNDInterpolator(self.points, self.values)
506+
507+
coordinates, values = self._sorted_line_data()
508+
return interpolate.interp1d(
509+
coordinates,
510+
values,
511+
axis=0,
512+
bounds_error=False,
513+
fill_value=np.nan,
514+
)
515+
516+
def _plot_1d(self, n=None):
517+
hv = ensure_holoviews()
518+
if len(self.data) < 2:
519+
return hv.Path([]) * hv.Scatter([]).opts(size=5)
520+
521+
(x_bounds,) = self._bbox
522+
n = n or 201
523+
xs = np.linspace(*x_bounds, n)
524+
ys = self._ip()(xs)
525+
scatter_points = [
526+
(point[0], value) for point, value in sorted(self.data.items())
527+
]
528+
return hv.Path((xs, ys)) * hv.Scatter(scatter_points).opts(size=5)
511529

512530
@property
513531
def tri(self):
@@ -891,34 +909,23 @@ def remove_unfinished(self):
891909
##########################
892910

893911
def plot(self, n=None, tri_alpha=0):
894-
"""Plot the function we want to learn, only works in 2D.
912+
"""Plot the function we want to learn in 1D or 2D.
895913
896914
Parameters
897915
----------
898916
n : int
899-
the number of boxes in the interpolation grid along each axis
917+
The number of interpolation points per axis.
900918
tri_alpha : float (0 to 1)
901919
Opacity of triangulation lines
902920
"""
903-
hv = ensure_holoviews()
904921
if self.vdim > 1:
905922
raise NotImplementedError(
906923
"holoviews currently does not support", "3D surface plots in bokeh."
907924
)
908925
if self.ndim == 1:
909-
if len(self.data) >= 2:
910-
(x,) = self._bbox
911-
n = n or 201
912-
xs = np.linspace(x[0], x[1], n)
913-
ys = self._ip()(xs)
914-
path = hv.Path((xs, ys))
915-
points = [(x[0], y) for x, y in sorted(self.data.items())]
916-
scatter = hv.Scatter(points)
917-
else:
918-
path = hv.Path([])
919-
scatter = hv.Scatter([])
920-
return path * scatter.opts(size=5)
926+
return self._plot_1d(n)
921927

928+
hv = ensure_holoviews()
922929
if self.ndim != 2:
923930
raise NotImplementedError(
924931
"Only 1D and 2D plots are implemented: You can "

adaptive/learner/triangulation.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ def is_iterable_and_sized(obj):
231231
return isinstance(obj, Iterable) and isinstance(obj, Sized)
232232

233233

234+
def _flat_simplices(coords):
235+
"""Yield the intervals of a 1D triangulation in sorted coordinate order."""
236+
sorted_indices = sorted(range(len(coords)), key=coords.__getitem__)
237+
for left, right in zip(sorted_indices, sorted_indices[1:]):
238+
if coords[left] != coords[right]:
239+
yield left, right
240+
241+
234242
def simplex_volume_in_embedding(vertices) -> float:
235243
"""Calculate the volume of a simplex in a higher dimensional embedding.
236244
That is: dim > len(vertices) - 1. For example if you would like to know the
@@ -257,12 +265,13 @@ def simplex_volume_in_embedding(vertices) -> float:
257265
# Modified from https://codereview.stackexchange.com/questions/77593/calculating-the-volume-of-a-tetrahedron
258266

259267
vertices = asarray(vertices, dtype=float)
260-
if len(vertices) == 2:
261-
# 1-simplex (line segment): volume is the Euclidean distance
268+
num_vertices = len(vertices)
269+
if num_vertices == 2:
270+
# A 1-simplex is just a line segment.
262271
return float(norm(vertices[1] - vertices[0]))
263272

264-
dim = len(vertices[0])
265-
if dim == 2:
273+
embedding_dim = len(vertices[0])
274+
if embedding_dim == 2:
266275
# Heron's formula
267276
a, b, c = scipy.spatial.distance.pdist(vertices, metric="euclidean")
268277
s = 0.5 * (a + b + c)
@@ -272,13 +281,12 @@ def simplex_volume_in_embedding(vertices) -> float:
272281
sq_dists = scipy.spatial.distance.pdist(vertices, metric="sqeuclidean")
273282

274283
# Add border while compressed
275-
num_verts = scipy.spatial.distance.num_obs_y(sq_dists)
276-
bordered = concatenate((ones(num_verts), sq_dists))
284+
bordered = concatenate((ones(num_vertices), sq_dists))
277285

278286
# Make matrix and find volume
279287
sq_dists_mat = scipy.spatial.distance.squareform(bordered)
280288

281-
coeff = -((-2) ** (num_verts - 1)) * factorial(num_verts - 1) ** 2
289+
coeff = -((-2) ** (num_vertices - 1)) * factorial(num_vertices - 1) ** 2
282290
vol_square = fast_det(sq_dists_mat) / coeff
283291

284292
if vol_square < 0:
@@ -346,19 +354,14 @@ def __init__(self, coords):
346354
self.vertex_to_simplices = [set() for _ in coords]
347355

348356
if dim == 1:
349-
# For 1D, sort points and create intervals as simplices,
350-
# skipping adjacent duplicates to avoid degenerate zero-volume simplices
351-
sorted_indices = sorted(range(len(coords)), key=lambda i: coords[i])
352-
for i in range(len(sorted_indices) - 1):
353-
if coords[sorted_indices[i]] == coords[sorted_indices[i + 1]]:
354-
continue
355-
self.add_simplex((sorted_indices[i], sorted_indices[i + 1]))
357+
simplices = _flat_simplices(coords)
356358
else:
357-
# find a Delaunay triangulation to start with, then we will throw it
358-
# away and continue with our own algorithm
359-
initial_tri = scipy.spatial.Delaunay(coords)
360-
for simplex in initial_tri.simplices:
361-
self.add_simplex(simplex)
359+
# Find a Delaunay triangulation to start with, then we will throw it
360+
# away and continue with our own algorithm.
361+
simplices = scipy.spatial.Delaunay(coords).simplices
362+
363+
for simplex in simplices:
364+
self.add_simplex(simplex)
362365

363366
def delete_simplex(self, simplex):
364367
simplex = tuple(sorted(simplex))

0 commit comments

Comments
 (0)