Skip to content

Commit 567a1b4

Browse files
authored
Add optional Rust-accelerated triangulation backend for LearnerND (#493)
* Add optional Rust-accelerated triangulation backend for LearnerND If the optional adaptive-triangulation package (>= 0.2.1) is installed, its Rust Triangulation implementation is now used automatically as a drop-in replacement for the pure-Python one, making LearnerND significantly faster. - New adaptive/learner/triangulation_backend.py selects the backend at import time; ADAPTIVE_TRIANGULATION_BACKEND=auto|python|rust overrides the automatic selection globally. - LearnerND gains a triangulation_backend keyword argument accepting "auto" (default), "python", "rust", or a Triangulation-compatible class for per-learner control. - The pure-Python adaptive.learner.triangulation names are never shadowed, so existing pickles that reference them keep loading. - Fix _pop_highest_existing_simplex evaluating `None in ...simplices` for stale queue entries (Python sets return False; the Rust proxy raises TypeError). - Add the `rust` optional-dependency extra, tests, and docs. * chore(docs): update TOC --------- Co-authored-by: basnijholt <basnijholt@users.noreply.github.com>
1 parent 1f60c43 commit 567a1b4

9 files changed

Lines changed: 329 additions & 8 deletions

File tree

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ To see Adaptive in action, try the [example notebook on Binder](https://mybinder
3636
- [:floppy_disk: Exporting Data](#floppy_disk-exporting-data)
3737
- [:test_tube: Implemented Algorithms](#test_tube-implemented-algorithms)
3838
- [:package: Installation](#package-installation)
39+
- [Faster triangulation (optional)](#faster-triangulation-optional)
3940
- [:wrench: Development](#wrench-development)
4041
- [:books: Citing](#books-citing)
4142
- [:page_facing_up: Draft Paper](#page_facing_up-draft-paper)
@@ -151,6 +152,17 @@ pip install "adaptive[notebook]"
151152

152153
The `[notebook]` above will also install the optional dependencies for running `adaptive` inside a Jupyter notebook.
153154

155+
### Faster triangulation (optional)
156+
157+
Installing the optional [adaptive-triangulation](https://github.com/python-adaptive/adaptive-triangulation) package makes `adaptive.LearnerND` significantly faster by replacing the pure-Python triangulation with a Rust implementation:
158+
159+
```bash
160+
pip install "adaptive[rust]"
161+
```
162+
163+
No code changes are needed — the Rust backend is detected and used automatically.
164+
To control the selection, pass `LearnerND(..., triangulation_backend="python" | "rust" | "auto")` per learner, or set the environment variable `ADAPTIVE_TRIANGULATION_BACKEND=python` to force the pure-Python implementation globally (or to `rust` to make a missing Rust backend an error instead of a silent fallback).
165+
154166
To use Adaptive in Jupyterlab, you need to install the following labextensions.
155167

156168
```bash

adaptive/learner/learner1D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1616
from adaptive.learner.learnerND import volume
17-
from adaptive.learner.triangulation import simplex_volume_in_embedding
17+
from adaptive.learner.triangulation_backend import simplex_volume_in_embedding
1818
from adaptive.notebook_integration import ensure_holoviews
1919
from adaptive.types import Float, Int, Real
2020
from adaptive.utils import (

adaptive/learner/learner2D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.interpolate import CloughTocher2DInterpolator, LinearNDInterpolator
1414

1515
from adaptive.learner.base_learner import BaseLearner
16-
from adaptive.learner.triangulation import simplex_volume_in_embedding
16+
from adaptive.learner.triangulation_backend import simplex_volume_in_embedding
1717
from adaptive.notebook_integration import ensure_holoviews
1818
from adaptive.types import Bool, Float, Real
1919
from adaptive.utils import (

adaptive/learner/learnerND.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from sortedcontainers import SortedKeyList
1515

1616
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
17-
from adaptive.learner.triangulation import (
17+
from adaptive.learner.triangulation import fast_det
18+
from adaptive.learner.triangulation_backend import (
1819
Triangulation,
1920
circumsphere,
20-
fast_det,
2121
point_in_simplex,
22+
resolve_triangulation_class,
2223
simplex_volume_in_embedding,
2324
)
2425
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
@@ -274,6 +275,15 @@ class LearnerND(BaseLearner):
274275
If not provided, then a default is used, which uses
275276
the deviation from a linear estimate, as well as
276277
triangle area, to determine the loss.
278+
anisotropic : bool, optional
279+
If True, the triangulation is stretched along the local gradient
280+
when choosing new points. Only works with scalar output.
281+
triangulation_backend : str or type, optional
282+
Which triangulation implementation to use: ``"auto"`` (default,
283+
prefers the optional Rust-accelerated `adaptive-triangulation
284+
<https://github.com/python-adaptive/adaptive-triangulation>`_
285+
package when it is installed), ``"python"``, ``"rust"``, or a
286+
`~adaptive.learner.triangulation.Triangulation`-compatible class.
277287
278288
279289
Attributes
@@ -308,7 +318,20 @@ class LearnerND(BaseLearner):
308318
children based on volume.
309319
"""
310320

311-
def __init__(self, func, bounds, loss_per_simplex=None, *, anisotropic=False):
321+
# Class-level fallback so that learners restored from pickles made
322+
# before `triangulation_backend` existed keep working.
323+
_triangulation_class = Triangulation
324+
325+
def __init__(
326+
self,
327+
func,
328+
bounds,
329+
loss_per_simplex=None,
330+
*,
331+
anisotropic=False,
332+
triangulation_backend="auto",
333+
):
334+
self._triangulation_class = resolve_triangulation_class(triangulation_backend)
312335
self._vdim = None
313336
self.loss_per_simplex = loss_per_simplex or default_loss
314337

@@ -385,6 +408,7 @@ def new(self) -> LearnerND:
385408
self.bounds,
386409
self.loss_per_simplex,
387410
anisotropic=self.anisotropic,
411+
triangulation_backend=self._triangulation_class,
388412
)
389413

390414
@property
@@ -513,7 +537,7 @@ def tri(self):
513537
return self._tri
514538

515539
try:
516-
self._tri = Triangulation(self.points)
540+
self._tri = self._triangulation_class(self.points)
517541
except ValueError:
518542
# A ValueError is raised if we do not have enough points or
519543
# the provided points are coplanar, so we need more points to
@@ -649,7 +673,7 @@ def _try_adding_pending_point_to_simplex(self, point, simplex):
649673

650674
if simplex not in self._subtriangulations:
651675
vertices = self.tri.get_vertices(simplex)
652-
self._subtriangulations[simplex] = Triangulation(vertices)
676+
self._subtriangulations[simplex] = self._triangulation_class(vertices)
653677

654678
self._pending_to_simplex[point] = simplex
655679
return self._subtriangulations[simplex].add_point(point)
@@ -713,7 +737,8 @@ def _pop_highest_existing_simplex(self):
713737
):
714738
return abs(loss), simplex, subsimplex
715739
if (
716-
simplex in self._subtriangulations
740+
subsimplex is not None
741+
and simplex in self._subtriangulations
717742
and simplex in self.tri.simplices
718743
and subsimplex in self._subtriangulations[simplex].simplices
719744
):
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Select the triangulation backend used by the learners.
2+
3+
If the optional Rust-accelerated `adaptive-triangulation
4+
<https://github.com/python-adaptive/adaptive-triangulation>`_ package is
5+
installed (``pip install "adaptive[rust]"``), it is used automatically as a
6+
drop-in replacement for the pure-Python implementation in
7+
`adaptive.learner.triangulation`, which makes `~adaptive.LearnerND`
8+
significantly faster.
9+
10+
The selection can be overridden with the ``ADAPTIVE_TRIANGULATION_BACKEND``
11+
environment variable:
12+
13+
- ``auto`` (default): use the Rust backend if available, else pure Python
14+
- ``python``: always use the pure-Python implementation
15+
- ``rust``: require the Rust backend, raising `ImportError` if it is missing
16+
17+
The active backend is exposed as the string ``TRIANGULATION_BACKEND``
18+
(``"python"`` or ``"rust"``).
19+
20+
Note that the pure-Python implementation in `adaptive.learner.triangulation`
21+
is always importable under its own name, regardless of the selected backend,
22+
so pickles that reference it keep working.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import os
28+
29+
# Minimal version that is a complete drop-in for the learners
30+
# (incl. ``get_opposing_vertices`` and pickle/deepcopy support).
31+
_MIN_RUST_VERSION = (0, 2, 1)
32+
33+
34+
def _rust_version() -> tuple[int, ...] | None:
35+
"""Return the installed ``adaptive_triangulation`` version, or None."""
36+
try:
37+
import adaptive_triangulation
38+
except ImportError:
39+
return None
40+
version = adaptive_triangulation.__version__
41+
return tuple(int(part) for part in version.split(".")[:3] if part.isdigit())
42+
43+
44+
def _import_rust_triangulation():
45+
"""Import the Rust `Triangulation`, raising a helpful `ImportError`."""
46+
version = _rust_version()
47+
if version is None:
48+
raise ImportError(
49+
"The 'rust' triangulation backend was requested but the "
50+
"'adaptive-triangulation' package is not installed. "
51+
'Install it with: pip install "adaptive[rust]"'
52+
)
53+
if version < _MIN_RUST_VERSION:
54+
raise ImportError(
55+
"The 'rust' triangulation backend requires "
56+
f"adaptive-triangulation >= {'.'.join(map(str, _MIN_RUST_VERSION))}, "
57+
f"found {'.'.join(map(str, version))}. Upgrade it with: "
58+
'pip install -U "adaptive[rust]"'
59+
)
60+
from adaptive_triangulation import Triangulation
61+
62+
return Triangulation
63+
64+
65+
def _select_backend() -> str:
66+
backend = os.environ.get("ADAPTIVE_TRIANGULATION_BACKEND", "auto").lower()
67+
if backend not in ("auto", "python", "rust"):
68+
raise ValueError(
69+
f"ADAPTIVE_TRIANGULATION_BACKEND={backend!r} is invalid, "
70+
"use 'auto', 'python', or 'rust'."
71+
)
72+
if backend == "auto":
73+
version = _rust_version()
74+
return (
75+
"rust" if version is not None and version >= _MIN_RUST_VERSION else "python"
76+
)
77+
if backend == "rust":
78+
_import_rust_triangulation() # raise with guidance if unusable
79+
return backend
80+
81+
82+
def resolve_triangulation_class(backend="auto"):
83+
"""Return the `Triangulation` class to use for *backend*.
84+
85+
Parameters
86+
----------
87+
backend : str or type
88+
``"auto"`` (the module-level default backend, which prefers the Rust
89+
implementation when available), ``"python"``, ``"rust"``, or a
90+
`Triangulation`-compatible class.
91+
"""
92+
if isinstance(backend, type):
93+
return backend
94+
if backend == "auto":
95+
return Triangulation
96+
if backend == "python":
97+
from adaptive.learner.triangulation import Triangulation as tri_class
98+
99+
return tri_class
100+
if backend == "rust":
101+
return _import_rust_triangulation()
102+
raise ValueError(
103+
f"Invalid triangulation backend {backend!r}, use 'auto', 'python', "
104+
"'rust', or a Triangulation-compatible class."
105+
)
106+
107+
108+
TRIANGULATION_BACKEND: str = _select_backend()
109+
110+
if TRIANGULATION_BACKEND == "rust":
111+
from adaptive_triangulation import (
112+
Triangulation,
113+
circumsphere,
114+
fast_2d_circumcircle,
115+
fast_2d_point_in_simplex,
116+
fast_3d_circumcircle,
117+
fast_norm,
118+
orientation,
119+
point_in_simplex,
120+
simplex_volume_in_embedding,
121+
)
122+
else:
123+
from adaptive.learner.triangulation import (
124+
Triangulation,
125+
circumsphere,
126+
fast_2d_circumcircle,
127+
fast_2d_point_in_simplex,
128+
fast_3d_circumcircle,
129+
fast_norm,
130+
orientation,
131+
point_in_simplex,
132+
simplex_volume_in_embedding,
133+
)
134+
135+
__all__ = [
136+
"TRIANGULATION_BACKEND",
137+
"Triangulation",
138+
"resolve_triangulation_class",
139+
"circumsphere",
140+
"fast_2d_circumcircle",
141+
"fast_2d_point_in_simplex",
142+
"fast_3d_circumcircle",
143+
"fast_norm",
144+
"orientation",
145+
"point_in_simplex",
146+
"simplex_volume_in_embedding",
147+
]

0 commit comments

Comments
 (0)