|
| 1 | +"""Select the triangulation backend used by the learners. |
| 2 | +
|
| 3 | +If the optional Rust-accelerated `adaptive-triangulation |
| 4 | +<https://github.com/python-adaptive/adaptive-triangulation>`_ package is |
| 5 | +installed (``pip install "adaptive[rust]"``), it is used automatically as a |
| 6 | +drop-in replacement for the pure-Python implementation in |
| 7 | +`adaptive.learner.triangulation`, which makes `~adaptive.LearnerND` |
| 8 | +significantly faster. |
| 9 | +
|
| 10 | +The selection can be overridden with the ``ADAPTIVE_TRIANGULATION_BACKEND`` |
| 11 | +environment variable: |
| 12 | +
|
| 13 | +- ``auto`` (default): use the Rust backend if available, else pure Python |
| 14 | +- ``python``: always use the pure-Python implementation |
| 15 | +- ``rust``: require the Rust backend, raising `ImportError` if it is missing |
| 16 | +
|
| 17 | +The active backend is exposed as the string ``TRIANGULATION_BACKEND`` |
| 18 | +(``"python"`` or ``"rust"``). |
| 19 | +
|
| 20 | +Note that the pure-Python implementation in `adaptive.learner.triangulation` |
| 21 | +is always importable under its own name, regardless of the selected backend, |
| 22 | +so pickles that reference it keep working. |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import os |
| 28 | + |
| 29 | +# Minimal version that is a complete drop-in for the learners |
| 30 | +# (incl. ``get_opposing_vertices`` and pickle/deepcopy support). |
| 31 | +_MIN_RUST_VERSION = (0, 2, 1) |
| 32 | + |
| 33 | + |
| 34 | +def _rust_version() -> tuple[int, ...] | None: |
| 35 | + """Return the installed ``adaptive_triangulation`` version, or None.""" |
| 36 | + try: |
| 37 | + import adaptive_triangulation |
| 38 | + except ImportError: |
| 39 | + return None |
| 40 | + version = adaptive_triangulation.__version__ |
| 41 | + return tuple(int(part) for part in version.split(".")[:3] if part.isdigit()) |
| 42 | + |
| 43 | + |
| 44 | +def _import_rust_triangulation(): |
| 45 | + """Import the Rust `Triangulation`, raising a helpful `ImportError`.""" |
| 46 | + version = _rust_version() |
| 47 | + if version is None: |
| 48 | + raise ImportError( |
| 49 | + "The 'rust' triangulation backend was requested but the " |
| 50 | + "'adaptive-triangulation' package is not installed. " |
| 51 | + 'Install it with: pip install "adaptive[rust]"' |
| 52 | + ) |
| 53 | + if version < _MIN_RUST_VERSION: |
| 54 | + raise ImportError( |
| 55 | + "The 'rust' triangulation backend requires " |
| 56 | + f"adaptive-triangulation >= {'.'.join(map(str, _MIN_RUST_VERSION))}, " |
| 57 | + f"found {'.'.join(map(str, version))}. Upgrade it with: " |
| 58 | + 'pip install -U "adaptive[rust]"' |
| 59 | + ) |
| 60 | + from adaptive_triangulation import Triangulation |
| 61 | + |
| 62 | + return Triangulation |
| 63 | + |
| 64 | + |
| 65 | +def _select_backend() -> str: |
| 66 | + backend = os.environ.get("ADAPTIVE_TRIANGULATION_BACKEND", "auto").lower() |
| 67 | + if backend not in ("auto", "python", "rust"): |
| 68 | + raise ValueError( |
| 69 | + f"ADAPTIVE_TRIANGULATION_BACKEND={backend!r} is invalid, " |
| 70 | + "use 'auto', 'python', or 'rust'." |
| 71 | + ) |
| 72 | + if backend == "auto": |
| 73 | + version = _rust_version() |
| 74 | + return ( |
| 75 | + "rust" if version is not None and version >= _MIN_RUST_VERSION else "python" |
| 76 | + ) |
| 77 | + if backend == "rust": |
| 78 | + _import_rust_triangulation() # raise with guidance if unusable |
| 79 | + return backend |
| 80 | + |
| 81 | + |
| 82 | +def resolve_triangulation_class(backend="auto"): |
| 83 | + """Return the `Triangulation` class to use for *backend*. |
| 84 | +
|
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + backend : str or type |
| 88 | + ``"auto"`` (the module-level default backend, which prefers the Rust |
| 89 | + implementation when available), ``"python"``, ``"rust"``, or a |
| 90 | + `Triangulation`-compatible class. |
| 91 | + """ |
| 92 | + if isinstance(backend, type): |
| 93 | + return backend |
| 94 | + if backend == "auto": |
| 95 | + return Triangulation |
| 96 | + if backend == "python": |
| 97 | + from adaptive.learner.triangulation import Triangulation as tri_class |
| 98 | + |
| 99 | + return tri_class |
| 100 | + if backend == "rust": |
| 101 | + return _import_rust_triangulation() |
| 102 | + raise ValueError( |
| 103 | + f"Invalid triangulation backend {backend!r}, use 'auto', 'python', " |
| 104 | + "'rust', or a Triangulation-compatible class." |
| 105 | + ) |
| 106 | + |
| 107 | + |
| 108 | +TRIANGULATION_BACKEND: str = _select_backend() |
| 109 | + |
| 110 | +if TRIANGULATION_BACKEND == "rust": |
| 111 | + from adaptive_triangulation import ( |
| 112 | + Triangulation, |
| 113 | + circumsphere, |
| 114 | + fast_2d_circumcircle, |
| 115 | + fast_2d_point_in_simplex, |
| 116 | + fast_3d_circumcircle, |
| 117 | + fast_norm, |
| 118 | + orientation, |
| 119 | + point_in_simplex, |
| 120 | + simplex_volume_in_embedding, |
| 121 | + ) |
| 122 | +else: |
| 123 | + from adaptive.learner.triangulation import ( |
| 124 | + Triangulation, |
| 125 | + circumsphere, |
| 126 | + fast_2d_circumcircle, |
| 127 | + fast_2d_point_in_simplex, |
| 128 | + fast_3d_circumcircle, |
| 129 | + fast_norm, |
| 130 | + orientation, |
| 131 | + point_in_simplex, |
| 132 | + simplex_volume_in_embedding, |
| 133 | + ) |
| 134 | + |
| 135 | +__all__ = [ |
| 136 | + "TRIANGULATION_BACKEND", |
| 137 | + "Triangulation", |
| 138 | + "resolve_triangulation_class", |
| 139 | + "circumsphere", |
| 140 | + "fast_2d_circumcircle", |
| 141 | + "fast_2d_point_in_simplex", |
| 142 | + "fast_3d_circumcircle", |
| 143 | + "fast_norm", |
| 144 | + "orientation", |
| 145 | + "point_in_simplex", |
| 146 | + "simplex_volume_in_embedding", |
| 147 | +] |
0 commit comments