|
| 1 | +# SPDX-License-Identifier: MPL-2.0 |
| 2 | +r"""Numba utilities, mainly used to deal with :ref:`numba-threading-layer` of :doc:`numba <numba:index>`. |
| 3 | +
|
| 4 | +``numba.config.THREADING_LAYER`` : env variable :envvar:`NUMBA_THREADING_LAYER` |
| 5 | + This can be set to a :class:`ThreadingLayer` or :class:`TheadingCategory`. |
| 6 | +``numba.config.THREADING_LAYER_PRIORITY`` : env variable :envvar:`NUMBA_THREADING_LAYER_PRIORITY` |
| 7 | + This can be set to a list of :class:`ThreadingLayer`\ s. |
| 8 | +
|
| 9 | +``fast-array-utils`` provides the following utilities: |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import sys |
| 15 | +import warnings |
| 16 | +from functools import cache, update_wrapper, wraps |
| 17 | +from types import FunctionType |
| 18 | +from typing import TYPE_CHECKING, Literal, cast, overload |
| 19 | + |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + from collections.abc import Callable, Iterable |
| 23 | + |
| 24 | + |
| 25 | +__all__ = ["TheadingCategory", "ThreadingLayer", "njit", "threading_layer"] |
| 26 | + |
| 27 | + |
| 28 | +type TheadingCategory = Literal["default", "safe", "threadsafe", "forksafe"] |
| 29 | +"""Identifier for a threading layer category.""" |
| 30 | +type ThreadingLayer = Literal["tbb", "omp", "workqueue"] |
| 31 | +"""Identifier for a concrete threading layer.""" |
| 32 | + |
| 33 | + |
| 34 | +LAYERS: dict[TheadingCategory, set[ThreadingLayer]] = { |
| 35 | + "default": {"tbb", "omp", "workqueue"}, |
| 36 | + "safe": {"tbb"}, |
| 37 | + "threadsafe": {"tbb", "omp"}, |
| 38 | + "forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})}, |
| 39 | +} |
| 40 | + |
| 41 | + |
| 42 | +def threading_layer(layer_or_category: ThreadingLayer | TheadingCategory | None = None, /, priority: Iterable[ThreadingLayer] | None = None) -> ThreadingLayer: |
| 43 | + """Get numba’s configured threading layer as specified in :ref:`numba-threading-layer`. |
| 44 | +
|
| 45 | + ``layer_or_category`` defaults ``numba.config.THREADING_LAYER`` and ``priority`` to ``numba.config.THREADING_LAYER_PRIORITY``. |
| 46 | + """ |
| 47 | + import numba |
| 48 | + |
| 49 | + if layer_or_category is None: |
| 50 | + layer_or_category = numba.config.THREADING_LAYER |
| 51 | + if priority is None: |
| 52 | + priority = numba.config.THREADING_LAYER_PRIORITY |
| 53 | + |
| 54 | + return _threading_layer(layer_or_category, tuple(priority)) |
| 55 | + |
| 56 | + |
| 57 | +@cache |
| 58 | +def _threading_layer(layer_or_category: ThreadingLayer | TheadingCategory, /, priority: Iterable[ThreadingLayer]) -> ThreadingLayer: |
| 59 | + import importlib |
| 60 | + |
| 61 | + if (available := LAYERS.get(layer_or_category)) is None: # type: ignore[arg-type] # pragma: no cover |
| 62 | + return cast("ThreadingLayer", layer_or_category) # given by direct name |
| 63 | + |
| 64 | + # given by layer type (safe, …) |
| 65 | + for layer in priority: |
| 66 | + if layer not in available: # pragma: no cover |
| 67 | + continue |
| 68 | + if layer != "workqueue": |
| 69 | + try: # `importlib.util.find_spec` doesn’t work here |
| 70 | + importlib.import_module(f"numba.np.ufunc.{layer}pool") |
| 71 | + except ImportError: |
| 72 | + continue |
| 73 | + # the layer has been found |
| 74 | + return layer |
| 75 | + msg = f"No threading layer matching {layer_or_category!r} ({available=}, {priority=})" # pragma: no cover |
| 76 | + raise ValueError(msg) # pragma: no cover |
| 77 | + |
| 78 | + |
| 79 | +def _is_in_unsafe_thread_pool() -> bool: |
| 80 | + import threading |
| 81 | + |
| 82 | + current_thread = threading.current_thread() |
| 83 | + # ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1' |
| 84 | + return current_thread.name.startswith("ThreadPoolExecutor") and threading_layer() not in LAYERS["threadsafe"] |
| 85 | + |
| 86 | + |
| 87 | +@overload |
| 88 | +def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ... |
| 89 | +@overload |
| 90 | +def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ... |
| 91 | +def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: |
| 92 | + """Jit-compile a function using numba. |
| 93 | +
|
| 94 | + On call, this function dispatches to a parallel or serial numba function, |
| 95 | + depending on if it has been called from a thread pool. |
| 96 | + """ |
| 97 | + # See https://github.com/numbagg/numbagg/pull/201/files#r1409374809 |
| 98 | + |
| 99 | + def decorator(f: Callable[P, R], /) -> Callable[P, R]: |
| 100 | + import numba |
| 101 | + |
| 102 | + assert isinstance(f, FunctionType) |
| 103 | + |
| 104 | + # use distinct names so numba doesn’t reuse the wrong version’s cache |
| 105 | + fns: dict[bool, Callable[P, R]] = { |
| 106 | + parallel: numba.njit(_copy_function(f, __qualname__=f"{f.__qualname__}-{'parallel' if parallel else 'serial'}"), cache=True, parallel=parallel) |
| 107 | + for parallel in (True, False) |
| 108 | + } |
| 109 | + |
| 110 | + @wraps(f) |
| 111 | + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
| 112 | + parallel = not _is_in_unsafe_thread_pool() |
| 113 | + if not parallel: # pragma: no cover |
| 114 | + msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`." |
| 115 | + warnings.warn(msg, UserWarning, stacklevel=2) |
| 116 | + return fns[parallel](*args, **kwargs) |
| 117 | + |
| 118 | + return wrapper |
| 119 | + |
| 120 | + return decorator if fn is None else decorator(fn) |
| 121 | + |
| 122 | + |
| 123 | +def _copy_function[F: FunctionType](f: F, **overrides: object) -> F: |
| 124 | + new = FunctionType(code=f.__code__, globals=f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) |
| 125 | + new.__kwdefaults__ = f.__kwdefaults__ |
| 126 | + new = cast("F", update_wrapper(new, f)) |
| 127 | + for key, value in overrides.items(): |
| 128 | + setattr(new, key, value) |
| 129 | + return new |
0 commit comments