|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import sys |
| 4 | +import warnings |
| 5 | +from functools import cache, wraps |
| 6 | +from typing import TYPE_CHECKING, Literal, cast, overload |
| 7 | + |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from collections.abc import Callable |
| 11 | + |
| 12 | +type LayerType = Literal["default", "safe", "threadsafe", "forksafe"] |
| 13 | +type Layer = Literal["tbb", "omp", "workqueue"] |
| 14 | + |
| 15 | + |
| 16 | +LAYERS: dict[LayerType, set[Layer]] = { |
| 17 | + "default": {"tbb", "omp", "workqueue"}, |
| 18 | + "safe": {"tbb"}, |
| 19 | + "threadsafe": {"tbb", "omp"}, |
| 20 | + "forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})}, |
| 21 | +} |
| 22 | + |
| 23 | + |
| 24 | +@cache |
| 25 | +def _numba_threading_layer() -> Layer: |
| 26 | + """Get numba’s threading layer. |
| 27 | +
|
| 28 | + This function implements the algorithm as described in |
| 29 | + <https://numba.readthedocs.io/en/stable/user/threading-layer.html> |
| 30 | + """ |
| 31 | + import importlib |
| 32 | + |
| 33 | + import numba |
| 34 | + |
| 35 | + if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None: |
| 36 | + # given by direct name |
| 37 | + return numba.config.THREADING_LAYER |
| 38 | + |
| 39 | + # given by layer type (safe, …) |
| 40 | + for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY): |
| 41 | + if layer not in available: |
| 42 | + continue |
| 43 | + if layer != "workqueue": |
| 44 | + try: # `importlib.util.find_spec` doesn’t work here |
| 45 | + importlib.import_module(f"numba.np.ufunc.{layer}pool") |
| 46 | + except ImportError: |
| 47 | + continue |
| 48 | + # the layer has been found |
| 49 | + return layer |
| 50 | + msg = f"No loadable threading layer: {numba.config.THREADING_LAYER=} ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" |
| 51 | + raise ValueError(msg) |
| 52 | + |
| 53 | + |
| 54 | +def _is_in_unsafe_thread_pool() -> bool: |
| 55 | + import threading |
| 56 | + |
| 57 | + current_thread = threading.current_thread() |
| 58 | + # ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1' |
| 59 | + return current_thread.name.startswith("ThreadPoolExecutor") and _numba_threading_layer() not in LAYERS["threadsafe"] |
| 60 | + |
| 61 | + |
| 62 | +@overload |
| 63 | +def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ... |
| 64 | +@overload |
| 65 | +def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ... |
| 66 | +def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: |
| 67 | + """Jit-compile a function using numba. |
| 68 | +
|
| 69 | + On call, this function dispatches to a parallel or sequential numba function, |
| 70 | + depending on if it has been called from a thread pool. |
| 71 | +
|
| 72 | + See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809> |
| 73 | + """ |
| 74 | + |
| 75 | + def decorator(f: Callable[P, R], /) -> Callable[P, R]: |
| 76 | + import numba |
| 77 | + |
| 78 | + fns: dict[bool, Callable[P, R]] = {parallel: numba.njit(f, cache=True, parallel=parallel) for parallel in (True, False)} |
| 79 | + |
| 80 | + @wraps(f) |
| 81 | + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
| 82 | + parallel = not _is_in_unsafe_thread_pool() |
| 83 | + if not parallel: |
| 84 | + msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`." |
| 85 | + warnings.warn(msg, UserWarning, stacklevel=2) |
| 86 | + return fns[parallel](*args, **kwargs) |
| 87 | + |
| 88 | + return wrapper |
| 89 | + |
| 90 | + return decorator if fn is None else decorator(fn) |
0 commit comments