Skip to content

Commit 5aafb79

Browse files
committed
Add optional Rust-accelerated triangulation backend for LearnerND
If the optional adaptive-triangulation package (>= 0.2.1) is installed, its Rust Triangulation implementation is now used automatically as a drop-in replacement for the pure-Python one, making LearnerND significantly faster. - New adaptive/learner/triangulation_backend.py selects the backend at import time; ADAPTIVE_TRIANGULATION_BACKEND=auto|python|rust overrides the automatic selection globally. - LearnerND gains a triangulation_backend keyword argument accepting "auto" (default), "python", "rust", or a Triangulation-compatible class for per-learner control. - The pure-Python adaptive.learner.triangulation names are never shadowed, so existing pickles that reference them keep loading. - Fix _pop_highest_existing_simplex evaluating `None in ...simplices` for stale queue entries (Python sets return False; the Rust proxy raises TypeError). - Add the `rust` optional-dependency extra, tests, and docs.
1 parent 1f60c43 commit 5aafb79

9 files changed

Lines changed: 328 additions & 8 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,17 @@ pip install "adaptive[notebook]"
151151

152152
The `[notebook]` above will also install the optional dependencies for running `adaptive` inside a Jupyter notebook.
153153

154+
### Faster triangulation (optional)
155+
156+
Installing the optional [adaptive-triangulation](https://github.com/python-adaptive/adaptive-triangulation) package makes `adaptive.LearnerND` significantly faster by replacing the pure-Python triangulation with a Rust implementation:
157+
158+
```bash
159+
pip install "adaptive[rust]"
160+
```
161+
162+
No code changes are needed — the Rust backend is detected and used automatically.
163+
To control the selection, pass `LearnerND(..., triangulation_backend="python" | "rust" | "auto")` per learner, or set the environment variable `ADAPTIVE_TRIANGULATION_BACKEND=python` to force the pure-Python implementation globally (or to `rust` to make a missing Rust backend an error instead of a silent fallback).
164+
154165
To use Adaptive in Jupyterlab, you need to install the following labextensions.
155166

156167
```bash

adaptive/learner/learner1D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1616
from adaptive.learner.learnerND import volume
17-
from adaptive.learner.triangulation import simplex_volume_in_embedding
17+
from adaptive.learner.triangulation_backend import simplex_volume_in_embedding
1818
from adaptive.notebook_integration import ensure_holoviews
1919
from adaptive.types import Float, Int, Real
2020
from adaptive.utils import (

adaptive/learner/learner2D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.interpolate import CloughTocher2DInterpolator, LinearNDInterpolator
1414

1515
from adaptive.learner.base_learner import BaseLearner
16-
from adaptive.learner.triangulation import simplex_volume_in_embedding
16+
from adaptive.learner.triangulation_backend import simplex_volume_in_embedding
1717
from adaptive.notebook_integration import ensure_holoviews
1818
from adaptive.types import Bool, Float, Real
1919
from adaptive.utils import (

adaptive/learner/learnerND.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from sortedcontainers import SortedKeyList
1515

1616
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
17-
from adaptive.learner.triangulation import (
17+
from adaptive.learner.triangulation import fast_det
18+
from adaptive.learner.triangulation_backend import (
1819
Triangulation,
1920
circumsphere,
20-
fast_det,
2121
point_in_simplex,
22+
resolve_triangulation_class,
2223
simplex_volume_in_embedding,
2324
)
2425
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
@@ -274,6 +275,15 @@ class LearnerND(BaseLearner):
274275
If not provided, then a default is used, which uses
275276
the deviation from a linear estimate, as well as
276277
triangle area, to determine the loss.
278+
anisotropic : bool, optional
279+
If True, the triangulation is stretched along the local gradient
280+
when choosing new points. Only works with scalar output.
281+
triangulation_backend : str or type, optional
282+
Which triangulation implementation to use: ``"auto"`` (default,
283+
prefers the optional Rust-accelerated `adaptive-triangulation
284+
<https://github.com/python-adaptive/adaptive-triangulation>`_
285+
package when it is installed), ``"python"``, ``"rust"``, or a
286+
`~adaptive.learner.triangulation.Triangulation`-compatible class.
277287
278288
279289
Attributes
@@ -308,7 +318,20 @@ class LearnerND(BaseLearner):
308318
children based on volume.
309319
"""
310320

311-
def __init__(self, func, bounds, loss_per_simplex=None, *, anisotropic=False):
321+
# Class-level fallback so that learners restored from pickles made
322+
# before `triangulation_backend` existed keep working.
323+
_triangulation_class = Triangulation
324+
325+
def __init__(
326+
self,
327+
func,
328+
bounds,
329+
loss_per_simplex=None,
330+
*,
331+
anisotropic=False,
332+
triangulation_backend="auto",
333+
):
334+
self._triangulation_class = resolve_triangulation_class(triangulation_backend)
312335
self._vdim = None
313336
self.loss_per_simplex = loss_per_simplex or default_loss
314337

@@ -385,6 +408,7 @@ def new(self) -> LearnerND:
385408
self.bounds,
386409
self.loss_per_simplex,
387410
anisotropic=self.anisotropic,
411+
triangulation_backend=self._triangulation_class,
388412
)
389413

390414
@property
@@ -513,7 +537,7 @@ def tri(self):
513537
return self._tri
514538

515539
try:
516-
self._tri = Triangulation(self.points)
540+
self._tri = self._triangulation_class(self.points)
517541
except ValueError:
518542
# A ValueError is raised if we do not have enough points or
519543
# the provided points are coplanar, so we need more points to
@@ -649,7 +673,7 @@ def _try_adding_pending_point_to_simplex(self, point, simplex):
649673

650674
if simplex not in self._subtriangulations:
651675
vertices = self.tri.get_vertices(simplex)
652-
self._subtriangulations[simplex] = Triangulation(vertices)
676+
self._subtriangulations[simplex] = self._triangulation_class(vertices)
653677

654678
self._pending_to_simplex[point] = simplex
655679
return self._subtriangulations[simplex].add_point(point)
@@ -713,7 +737,8 @@ def _pop_highest_existing_simplex(self):
713737
):
714738
return abs(loss), simplex, subsimplex
715739
if (
716-
simplex in self._subtriangulations
740+
subsimplex is not None
741+
and simplex in self._subtriangulations
717742
and simplex in self.tri.simplices
718743
and subsimplex in self._subtriangulations[simplex].simplices
719744
):
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Select the triangulation backend used by the learners.
2+
3+
If the optional Rust-accelerated `adaptive-triangulation
4+
<https://github.com/python-adaptive/adaptive-triangulation>`_ package is
5+
installed (``pip install "adaptive[rust]"``), it is used automatically as a
6+
drop-in replacement for the pure-Python implementation in
7+
`adaptive.learner.triangulation`, which makes `~adaptive.LearnerND`
8+
significantly faster.
9+
10+
The selection can be overridden with the ``ADAPTIVE_TRIANGULATION_BACKEND``
11+
environment variable:
12+
13+
- ``auto`` (default): use the Rust backend if available, else pure Python
14+
- ``python``: always use the pure-Python implementation
15+
- ``rust``: require the Rust backend, raising `ImportError` if it is missing
16+
17+
The active backend is exposed as the string ``TRIANGULATION_BACKEND``
18+
(``"python"`` or ``"rust"``).
19+
20+
Note that the pure-Python implementation in `adaptive.learner.triangulation`
21+
is always importable under its own name, regardless of the selected backend,
22+
so pickles that reference it keep working.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import os
28+
29+
# Minimal version that is a complete drop-in for the learners
30+
# (incl. ``get_opposing_vertices`` and pickle/deepcopy support).
31+
_MIN_RUST_VERSION = (0, 2, 1)
32+
33+
34+
def _rust_version() -> tuple[int, ...] | None:
35+
"""Return the installed ``adaptive_triangulation`` version, or None."""
36+
try:
37+
import adaptive_triangulation
38+
except ImportError:
39+
return None
40+
version = adaptive_triangulation.__version__
41+
return tuple(int(part) for part in version.split(".")[:3] if part.isdigit())
42+
43+
44+
def _import_rust_triangulation():
45+
"""Import the Rust `Triangulation`, raising a helpful `ImportError`."""
46+
version = _rust_version()
47+
if version is None:
48+
raise ImportError(
49+
"The 'rust' triangulation backend was requested but the "
50+
"'adaptive-triangulation' package is not installed. "
51+
'Install it with: pip install "adaptive[rust]"'
52+
)
53+
if version < _MIN_RUST_VERSION:
54+
raise ImportError(
55+
"The 'rust' triangulation backend requires "
56+
f"adaptive-triangulation >= {'.'.join(map(str, _MIN_RUST_VERSION))}, "
57+
f"found {'.'.join(map(str, version))}. Upgrade it with: "
58+
'pip install -U "adaptive[rust]"'
59+
)
60+
from adaptive_triangulation import Triangulation
61+
62+
return Triangulation
63+
64+
65+
def _select_backend() -> str:
66+
backend = os.environ.get("ADAPTIVE_TRIANGULATION_BACKEND", "auto").lower()
67+
if backend not in ("auto", "python", "rust"):
68+
raise ValueError(
69+
f"ADAPTIVE_TRIANGULATION_BACKEND={backend!r} is invalid, "
70+
"use 'auto', 'python', or 'rust'."
71+
)
72+
if backend == "auto":
73+
version = _rust_version()
74+
return (
75+
"rust" if version is not None and version >= _MIN_RUST_VERSION else "python"
76+
)
77+
if backend == "rust":
78+
_import_rust_triangulation() # raise with guidance if unusable
79+
return backend
80+
81+
82+
def resolve_triangulation_class(backend="auto"):
83+
"""Return the `Triangulation` class to use for *backend*.
84+
85+
Parameters
86+
----------
87+
backend : str or type
88+
``"auto"`` (the module-level default backend, which prefers the Rust
89+
implementation when available), ``"python"``, ``"rust"``, or a
90+
`Triangulation`-compatible class.
91+
"""
92+
if isinstance(backend, type):
93+
return backend
94+
if backend == "auto":
95+
return Triangulation
96+
if backend == "python":
97+
from adaptive.learner.triangulation import Triangulation as tri_class
98+
99+
return tri_class
100+
if backend == "rust":
101+
return _import_rust_triangulation()
102+
raise ValueError(
103+
f"Invalid triangulation backend {backend!r}, use 'auto', 'python', "
104+
"'rust', or a Triangulation-compatible class."
105+
)
106+
107+
108+
TRIANGULATION_BACKEND: str = _select_backend()
109+
110+
if TRIANGULATION_BACKEND == "rust":
111+
from adaptive_triangulation import (
112+
Triangulation,
113+
circumsphere,
114+
fast_2d_circumcircle,
115+
fast_2d_point_in_simplex,
116+
fast_3d_circumcircle,
117+
fast_norm,
118+
orientation,
119+
point_in_simplex,
120+
simplex_volume_in_embedding,
121+
)
122+
else:
123+
from adaptive.learner.triangulation import (
124+
Triangulation,
125+
circumsphere,
126+
fast_2d_circumcircle,
127+
fast_2d_point_in_simplex,
128+
fast_3d_circumcircle,
129+
fast_norm,
130+
orientation,
131+
point_in_simplex,
132+
simplex_volume_in_embedding,
133+
)
134+
135+
__all__ = [
136+
"TRIANGULATION_BACKEND",
137+
"Triangulation",
138+
"resolve_triangulation_class",
139+
"circumsphere",
140+
"fast_2d_circumcircle",
141+
"fast_2d_point_in_simplex",
142+
"fast_3d_circumcircle",
143+
"fast_norm",
144+
"orientation",
145+
"point_in_simplex",
146+
"simplex_volume_in_embedding",
147+
]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Tests for the automatic triangulation backend selection."""
2+
3+
import os
4+
import subprocess
5+
import sys
6+
7+
import pytest
8+
9+
from adaptive.learner import triangulation as python_triangulation
10+
from adaptive.learner import triangulation_backend as backend
11+
from adaptive.learner.triangulation_backend import (
12+
_MIN_RUST_VERSION,
13+
_rust_version,
14+
)
15+
16+
17+
def _run(code, backend_env):
18+
env = {**os.environ, "ADAPTIVE_TRIANGULATION_BACKEND": backend_env}
19+
return subprocess.run(
20+
[sys.executable, "-c", code], env=env, capture_output=True, text=True
21+
)
22+
23+
24+
def rust_is_usable():
25+
version = _rust_version()
26+
return version is not None and version >= _MIN_RUST_VERSION
27+
28+
29+
def test_backend_matches_installation():
30+
if rust_is_usable():
31+
import adaptive_triangulation
32+
33+
assert backend.TRIANGULATION_BACKEND == "rust"
34+
assert backend.Triangulation is adaptive_triangulation.Triangulation
35+
else:
36+
assert backend.TRIANGULATION_BACKEND == "python"
37+
assert backend.Triangulation is python_triangulation.Triangulation
38+
39+
40+
def test_python_triangulation_is_never_shadowed():
41+
# Old pickles reference adaptive.learner.triangulation.Triangulation by
42+
# qualified name, so the pure-Python class must stay importable as itself.
43+
assert python_triangulation.Triangulation.__module__ == (
44+
"adaptive.learner.triangulation"
45+
)
46+
47+
48+
def test_force_python_backend():
49+
code = (
50+
"from adaptive.learner import triangulation, triangulation_backend;"
51+
"assert triangulation_backend.TRIANGULATION_BACKEND == 'python';"
52+
"assert triangulation_backend.Triangulation is triangulation.Triangulation"
53+
)
54+
result = _run(code, "python")
55+
assert result.returncode == 0, result.stderr
56+
57+
58+
def test_force_rust_backend():
59+
code = (
60+
"from adaptive.learner import triangulation_backend;"
61+
"assert triangulation_backend.TRIANGULATION_BACKEND == 'rust'"
62+
)
63+
result = _run(code, "rust")
64+
if rust_is_usable():
65+
assert result.returncode == 0, result.stderr
66+
else:
67+
# Forcing the Rust backend without (a recent enough version of)
68+
# adaptive-triangulation must raise a helpful ImportError.
69+
assert result.returncode != 0
70+
assert "ImportError" in result.stderr
71+
assert "adaptive-triangulation" in result.stderr
72+
73+
74+
def test_invalid_backend_raises():
75+
result = _run("import adaptive.learner.triangulation_backend", "bogus")
76+
assert result.returncode != 0
77+
assert "ValueError" in result.stderr
78+
79+
80+
def test_resolve_triangulation_class():
81+
resolve = backend.resolve_triangulation_class
82+
assert resolve("auto") is backend.Triangulation
83+
assert resolve("python") is python_triangulation.Triangulation
84+
if rust_is_usable():
85+
import adaptive_triangulation
86+
87+
assert resolve("rust") is adaptive_triangulation.Triangulation
88+
else:
89+
with pytest.raises(ImportError, match="adaptive-triangulation"):
90+
resolve("rust")
91+
92+
class MyTriangulation(python_triangulation.Triangulation):
93+
pass
94+
95+
assert resolve(MyTriangulation) is MyTriangulation
96+
with pytest.raises(ValueError, match="Invalid triangulation backend"):
97+
resolve("bogus")
98+
99+
100+
def test_learnernd_triangulation_backend_argument():
101+
from adaptive import LearnerND
102+
103+
learner = LearnerND(
104+
lambda xy: sum(xy) ** 2,
105+
bounds=[(-1, 1), (-1, 1)],
106+
triangulation_backend="python",
107+
)
108+
assert learner._triangulation_class is python_triangulation.Triangulation
109+
assert learner.new()._triangulation_class is python_triangulation.Triangulation
110+
111+
112+
@pytest.mark.skipif(not rust_is_usable(), reason="needs adaptive-triangulation")
113+
def test_learnernd_uses_rust_backend():
114+
import adaptive_triangulation
115+
116+
from adaptive import LearnerND
117+
118+
learner = LearnerND(lambda xy: sum(xy) ** 2, bounds=[(-1, 1), (-1, 1)])
119+
for _ in range(50):
120+
points, _ = learner.ask(1)
121+
for point in points:
122+
learner.tell(point, learner.function(point))
123+
assert isinstance(learner.tri, adaptive_triangulation.Triangulation)
124+
assert learner.npoints >= 50

0 commit comments

Comments
 (0)