Skip to content

Commit d6b9aba

Browse files
committed
new module numba
1 parent d70246f commit d6b9aba

2 files changed

Lines changed: 247 additions & 0 deletions

File tree

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
r"""Numba utilities, mainly used to deal with :ref:`numba-threading-layer` of :doc:`numba <numba:index>`.
3+
4+
``numba.config.THREADING_LAYER`` : env variable :envvar:`NUMBA_THREADING_LAYER`
5+
This can be set to a :class:`ThreadingLayer` or :class:`TheadingCategory`.
6+
``numba.config.THREADING_LAYER_PRIORITY`` : env variable :envvar:`NUMBA_THREADING_LAYER_PRIORITY`
7+
This can be set to a list of :class:`ThreadingLayer`\ s.
8+
9+
``fast-array-utils`` provides the following utilities:
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import sys
15+
import warnings
16+
from functools import cache, update_wrapper, wraps
17+
from types import FunctionType
18+
from typing import TYPE_CHECKING, Literal, cast, overload
19+
20+
21+
if TYPE_CHECKING:
22+
from collections.abc import Callable, Iterable
23+
24+
25+
__all__ = ["TheadingCategory", "ThreadingLayer", "njit", "threading_layer"]
26+
27+
28+
type TheadingCategory = Literal["default", "safe", "threadsafe", "forksafe"]
29+
"""Identifier for a threading layer category."""
30+
type ThreadingLayer = Literal["tbb", "omp", "workqueue"]
31+
"""Identifier for a concrete threading layer."""
32+
33+
34+
LAYERS: dict[TheadingCategory, set[ThreadingLayer]] = {
35+
"default": {"tbb", "omp", "workqueue"},
36+
"safe": {"tbb"},
37+
"threadsafe": {"tbb", "omp"},
38+
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
39+
}
40+
41+
42+
def threading_layer(layer_or_category: ThreadingLayer | TheadingCategory | None = None, /, priority: Iterable[ThreadingLayer] | None = None) -> ThreadingLayer:
43+
"""Get numba’s configured threading layer as specified in :ref:`numba-threading-layer`.
44+
45+
``layer_or_category`` defaults ``numba.config.THREADING_LAYER`` and ``priority`` to ``numba.config.THREADING_LAYER_PRIORITY``.
46+
"""
47+
import numba
48+
49+
if layer_or_category is None:
50+
layer_or_category = numba.config.THREADING_LAYER
51+
if priority is None:
52+
priority = numba.config.THREADING_LAYER_PRIORITY
53+
54+
return _threading_layer(layer_or_category, tuple(priority))
55+
56+
57+
@cache
58+
def _threading_layer(layer_or_category: ThreadingLayer | TheadingCategory, /, priority: Iterable[ThreadingLayer]) -> ThreadingLayer:
59+
import importlib
60+
61+
if (available := LAYERS.get(layer_or_category)) is None: # type: ignore[arg-type] # pragma: no cover
62+
return cast("ThreadingLayer", layer_or_category) # given by direct name
63+
64+
# given by layer type (safe, …)
65+
for layer in priority:
66+
if layer not in available: # pragma: no cover
67+
continue
68+
if layer != "workqueue":
69+
try: # `importlib.util.find_spec` doesn’t work here
70+
importlib.import_module(f"numba.np.ufunc.{layer}pool")
71+
except ImportError:
72+
continue
73+
# the layer has been found
74+
return layer
75+
msg = f"No threading layer matching {layer_or_category!r} ({available=}, {priority=})" # pragma: no cover
76+
raise ValueError(msg) # pragma: no cover
77+
78+
79+
def _is_in_unsafe_thread_pool() -> bool:
80+
import threading
81+
82+
current_thread = threading.current_thread()
83+
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
84+
return current_thread.name.startswith("ThreadPoolExecutor") and threading_layer() not in LAYERS["threadsafe"]
85+
86+
87+
@overload
88+
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
89+
@overload
90+
def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
91+
def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
92+
"""Jit-compile a function using numba.
93+
94+
On call, this function dispatches to a parallel or serial numba function,
95+
depending on the current threading environment.
96+
"""
97+
# See https://github.com/numbagg/numbagg/pull/201/files#r1409374809
98+
99+
def decorator(f: Callable[P, R], /) -> Callable[P, R]:
100+
import numba
101+
102+
from ._parallel_runtime import _needs_parallel_runtime_probe, _parallel_numba_runtime_is_safe
103+
104+
assert isinstance(f, FunctionType)
105+
106+
# use distinct names so numba doesn’t reuse the wrong version’s cache
107+
fns: dict[bool, Callable[P, R]] = {
108+
parallel: numba.njit(_copy_function(f, __qualname__=f"{f.__qualname__}-{'parallel' if parallel else 'serial'}"), cache=True, parallel=parallel)
109+
for parallel in (True, False)
110+
}
111+
112+
@wraps(f)
113+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
114+
if _is_in_unsafe_thread_pool(): # pragma: no cover
115+
msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`."
116+
warnings.warn(msg, UserWarning, stacklevel=2)
117+
return fns[False](*args, **kwargs)
118+
if _needs_parallel_runtime_probe() and not _parallel_numba_runtime_is_safe():
119+
msg = (
120+
f"Detected an unsupported numba parallel runtime. Running {f.__name__} in serial mode as a workaround. "
121+
"Set `NUMBA_THREADING_LAYER=workqueue` or install `tbb` to avoid this fallback."
122+
)
123+
warnings.warn(msg, UserWarning, stacklevel=2)
124+
return fns[False](*args, **kwargs)
125+
return fns[True](*args, **kwargs)
126+
127+
return wrapper
128+
129+
return decorator if fn is None else decorator(fn)
130+
131+
132+
def _copy_function[F: FunctionType](f: F, **overrides: object) -> F:
133+
new = FunctionType(code=f.__code__, globals=f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
134+
new.__kwdefaults__ = f.__kwdefaults__
135+
new = cast("F", update_wrapper(new, f))
136+
for key, value in overrides.items():
137+
setattr(new, key, value)
138+
return new
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
3+
from __future__ import annotations
4+
5+
import os
6+
import platform
7+
import subprocess
8+
import sys
9+
from functools import cache
10+
from typing import TYPE_CHECKING
11+
12+
import numba
13+
14+
from . import LAYERS
15+
16+
17+
if TYPE_CHECKING:
18+
from . import TheadingCategory, ThreadingLayer
19+
20+
21+
type _ParallelRuntimeProbeKey = tuple[str, ThreadingLayer | TheadingCategory, tuple[ThreadingLayer, ...], tuple[str, ...]]
22+
23+
24+
_PARALLEL_RUNTIME_PROBE_SENTINEL = "FAST_ARRAY_UTILS_NUMBA_PROBE_OK"
25+
_PARALLEL_RUNTIME_PROBE_TIMEOUT = 20
26+
_PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST = ("torch",)
27+
28+
29+
def _is_apple_silicon() -> bool:
30+
return sys.platform == "darwin" and platform.machine() == "arm64"
31+
32+
33+
def _could_select_omp_from_threading_config_without_probing() -> bool:
34+
match numba.config.THREADING_LAYER:
35+
case "omp":
36+
return True
37+
case "tbb" | "workqueue":
38+
return False
39+
case "default" | "safe" | "threadsafe" | "forksafe" as category:
40+
return "omp" in LAYERS[category]
41+
42+
43+
def _needs_parallel_runtime_probe() -> bool:
44+
if not _is_apple_silicon() or "torch" not in sys.modules:
45+
return False
46+
if numba.config.THREADING_LAYER in {"workqueue", "tbb"}:
47+
return False
48+
return _could_select_omp_from_threading_config_without_probing()
49+
50+
51+
def _loaded_relevant_parallel_runtime_probe_modules() -> tuple[str, ...]:
52+
return tuple(module for module in _PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST if module in sys.modules)
53+
54+
55+
def _parallel_runtime_probe_code(modules: tuple[str, ...]) -> str:
56+
lines = [*(f"import {module}" for module in modules), "import numba", "import numpy as np", ""]
57+
lines.extend([
58+
"@numba.njit(parallel=True, cache=False)",
59+
"def _probe(values):",
60+
" total = 0.0",
61+
" for i in numba.prange(values.shape[0]):",
62+
" total += values[i]",
63+
" return total",
64+
"",
65+
"values = np.arange(32, dtype=np.float64)",
66+
"assert _probe(values) == np.sum(values)",
67+
f"print({_PARALLEL_RUNTIME_PROBE_SENTINEL!r})",
68+
"",
69+
])
70+
return "\n".join(lines)
71+
72+
73+
def _parallel_runtime_probe_key() -> _ParallelRuntimeProbeKey:
74+
return (
75+
sys.executable,
76+
numba.config.THREADING_LAYER,
77+
tuple(numba.config.THREADING_LAYER_PRIORITY),
78+
_loaded_relevant_parallel_runtime_probe_modules(),
79+
)
80+
81+
82+
def _build_parallel_runtime_probe_env(key: _ParallelRuntimeProbeKey | None = None) -> dict[str, str]:
83+
_, layer_or_category, priority, _ = _parallel_runtime_probe_key() if key is None else key
84+
env = dict(os.environ)
85+
env["NUMBA_THREADING_LAYER"] = layer_or_category
86+
env["NUMBA_THREADING_LAYER_PRIORITY"] = " ".join(priority)
87+
return env
88+
89+
90+
@cache
91+
def _parallel_numba_runtime_is_safe_cached(key: _ParallelRuntimeProbeKey) -> bool:
92+
try:
93+
# The probe command is built from `sys.executable` plus a generated script
94+
# that only imports modules from a fixed whitelist.
95+
result = subprocess.run( # noqa: S603
96+
[key[0], "-c", _parallel_runtime_probe_code(key[3])],
97+
capture_output=True,
98+
check=False,
99+
env=_build_parallel_runtime_probe_env(key),
100+
text=True,
101+
timeout=_PARALLEL_RUNTIME_PROBE_TIMEOUT,
102+
)
103+
except Exception: # noqa: BLE001
104+
return False
105+
return result.returncode == 0 and _PARALLEL_RUNTIME_PROBE_SENTINEL in result.stdout
106+
107+
108+
def _parallel_numba_runtime_is_safe() -> bool:
109+
return _parallel_numba_runtime_is_safe_cached(_parallel_runtime_probe_key())

0 commit comments

Comments
 (0)