Skip to content

Commit b4241d4

Browse files
fix: use parallelized numba functions if possible (#155)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent e50c44f commit b4241d4

11 files changed

Lines changed: 174 additions & 15 deletions

File tree

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
zarr=("https://zarr.readthedocs.io/en/stable/", None),
7272
)
7373
nitpick_ignore = [
74+
("py:class", "P"),
75+
("py:class", "R"),
7476
("py:class", "Arr"),
7577
("py:class", "Array"),
7678
("py:class", "ToDType"),

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
fast-array-utils <self>
88
conv
99
stats
10+
numba
1011
typing
1112
testing
1213

docs/numba.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``fast_array_utils.numba``
2+
==========================
3+
4+
.. automodule:: fast_array_utils.numba
5+
:members:

pyproject.toml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ mypy_path = [ "$MYPY_CONFIG_FILE_DIR/typings", "$MYPY_CONFIG_FILE_DIR/src" ]
148148
stubPath = "./typings"
149149
reportPrivateUsage = false
150150

151+
[tool.ty]
152+
environment.extra-paths = [ "./typings" ]
153+
151154
[tool.pytest]
152155
strict = true
153156
addopts = [
@@ -164,12 +167,6 @@ doctest_subpackage_requires = [
164167
"src/fast_array_utils/_plugins/dask.py = dask",
165168
"src/fast_array_utils/_plugins/numba_sparse.py = numba;scipy",
166169
]
167-
filterwarnings = [
168-
"error",
169-
# codspeed seems to break this dtype added by h5py
170-
"ignore:.*numpy[.]longdouble:UserWarning",
171-
"ignore:FNV hashing is not implemented in Numba:UserWarning",
172-
]
173170
markers = [
174171
"benchmark: marks tests as benchmark (to run with `--codspeed`)",
175172
]

src/fast_array_utils/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88
This submodule requires :doc:`numba <numba:index>` to be installed
99
and contains statistics utilities.
1010
11+
:mod:`fast_array_utils.numba`
12+
This submodule contains numba utilities.
13+
1114
:mod:`fast_array_utils.typing` and :mod:`fast_array_utils.types`
1215
These submodules contain types for annotations and checks, respectively.
1316
Stubs for these types are available even if the respective packages are not installed.
1417
"""
1518

1619
from __future__ import annotations
1720

18-
from . import _plugins, conv, stats, types
21+
from . import _plugins, conv, numba, stats, types
1922

2023

21-
__all__ = ["conv", "stats", "types"]
24+
__all__ = ["conv", "numba", "stats", "types"]
2225

2326
_plugins.patch_dask()
2427
_plugins.register_numba_sparse()

src/fast_array_utils/conv/scipy/_to_dense.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numba
77

8+
from ...numba import njit
9+
810

911
if TYPE_CHECKING:
1012
from typing import Any
@@ -18,14 +20,14 @@
1820
__all__ = ["_to_dense_csc_numba", "_to_dense_csr_numba"]
1921

2022

21-
@numba.njit(cache=True)
23+
@njit
2224
def _to_dense_csc_numba(x: CSBase, out: NDArray[np.number[Any]]) -> None:
2325
for c in numba.prange(out.shape[1]):
2426
for i in range(x.indptr[c], x.indptr[c + 1]):
2527
out[x.indices[i], c] = x.data[i]
2628

2729

28-
@numba.njit(cache=True)
30+
@njit
2931
def _to_dense_csr_numba(x: CSBase, out: NDArray[np.number[Any]]) -> None:
3032
for r in numba.prange(out.shape[0]):
3133
for i in range(x.indptr[r], x.indptr[r + 1]):

src/fast_array_utils/numba.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
r"""Numba utilities, mainly used to deal with :ref:`numba-threading-layer` of :doc:`numba <numba:index>`.
3+
4+
``numba.config.THREADING_LAYER`` : env variable :envvar:`NUMBA_THREADING_LAYER`
5+
This can be set to a :class:`ThreadingLayer` or :class:`TheadingCategory`.
6+
``numba.config.THREADING_LAYER_PRIORITY`` : env variable :envvar:`NUMBA_THREADING_LAYER_PRIORITY`
7+
This can be set to a list of :class:`ThreadingLayer`\ s.
8+
9+
``fast-array-utils`` provides the following utilities:
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import sys
15+
import warnings
16+
from functools import cache, update_wrapper, wraps
17+
from types import FunctionType
18+
from typing import TYPE_CHECKING, Literal, cast, overload
19+
20+
21+
if TYPE_CHECKING:
22+
from collections.abc import Callable, Iterable
23+
24+
25+
__all__ = ["TheadingCategory", "ThreadingLayer", "njit", "threading_layer"]
26+
27+
28+
type TheadingCategory = Literal["default", "safe", "threadsafe", "forksafe"]
29+
"""Identifier for a threading layer category."""
30+
type ThreadingLayer = Literal["tbb", "omp", "workqueue"]
31+
"""Identifier for a concrete threading layer."""
32+
33+
34+
LAYERS: dict[TheadingCategory, set[ThreadingLayer]] = {
35+
"default": {"tbb", "omp", "workqueue"},
36+
"safe": {"tbb"},
37+
"threadsafe": {"tbb", "omp"},
38+
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
39+
}
40+
41+
42+
def threading_layer(layer_or_category: ThreadingLayer | TheadingCategory | None = None, /, priority: Iterable[ThreadingLayer] | None = None) -> ThreadingLayer:
43+
"""Get numba’s configured threading layer as specified in :ref:`numba-threading-layer`.
44+
45+
``layer_or_category`` defaults ``numba.config.THREADING_LAYER`` and ``priority`` to ``numba.config.THREADING_LAYER_PRIORITY``.
46+
"""
47+
import numba
48+
49+
if layer_or_category is None:
50+
layer_or_category = numba.config.THREADING_LAYER
51+
if priority is None:
52+
priority = numba.config.THREADING_LAYER_PRIORITY
53+
54+
return _threading_layer(layer_or_category, tuple(priority))
55+
56+
57+
@cache
58+
def _threading_layer(layer_or_category: ThreadingLayer | TheadingCategory, /, priority: Iterable[ThreadingLayer]) -> ThreadingLayer:
59+
import importlib
60+
61+
if (available := LAYERS.get(layer_or_category)) is None: # type: ignore[arg-type] # pragma: no cover
62+
return cast("ThreadingLayer", layer_or_category) # given by direct name
63+
64+
# given by layer type (safe, …)
65+
for layer in priority:
66+
if layer not in available: # pragma: no cover
67+
continue
68+
if layer != "workqueue":
69+
try: # `importlib.util.find_spec` doesn’t work here
70+
importlib.import_module(f"numba.np.ufunc.{layer}pool")
71+
except ImportError:
72+
continue
73+
# the layer has been found
74+
return layer
75+
msg = f"No threading layer matching {layer_or_category!r} ({available=}, {priority=})" # pragma: no cover
76+
raise ValueError(msg) # pragma: no cover
77+
78+
79+
def _is_in_unsafe_thread_pool() -> bool:
80+
import threading
81+
82+
current_thread = threading.current_thread()
83+
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
84+
return current_thread.name.startswith("ThreadPoolExecutor") and threading_layer() not in LAYERS["threadsafe"]
85+
86+
87+
@overload
88+
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
89+
@overload
90+
def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
91+
def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
92+
"""Jit-compile a function using numba.
93+
94+
On call, this function dispatches to a parallel or serial numba function,
95+
depending on if it has been called from a thread pool.
96+
"""
97+
# See https://github.com/numbagg/numbagg/pull/201/files#r1409374809
98+
99+
def decorator(f: Callable[P, R], /) -> Callable[P, R]:
100+
import numba
101+
102+
assert isinstance(f, FunctionType)
103+
104+
# use distinct names so numba doesn’t reuse the wrong version’s cache
105+
fns: dict[bool, Callable[P, R]] = {
106+
parallel: numba.njit(_copy_function(f, __qualname__=f"{f.__qualname__}-{'parallel' if parallel else 'serial'}"), cache=True, parallel=parallel)
107+
for parallel in (True, False)
108+
}
109+
110+
@wraps(f)
111+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
112+
parallel = not _is_in_unsafe_thread_pool()
113+
if not parallel: # pragma: no cover
114+
msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`."
115+
warnings.warn(msg, UserWarning, stacklevel=2)
116+
return fns[parallel](*args, **kwargs)
117+
118+
return wrapper
119+
120+
return decorator if fn is None else decorator(fn)
121+
122+
123+
def _copy_function[F: FunctionType](f: F, **overrides: object) -> F:
124+
new = FunctionType(code=f.__code__, globals=f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
125+
new.__kwdefaults__ = f.__kwdefaults__
126+
new = cast("F", update_wrapper(new, f))
127+
for key, value in overrides.items():
128+
setattr(new, key, value)
129+
return new

src/fast_array_utils/stats/_is_constant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99

1010
from .. import types
11+
from ..numba import njit
1112

1213

1314
if TYPE_CHECKING:
@@ -64,7 +65,7 @@ def _is_constant_cs(a: types.CSBase, /, *, axis: Literal[0, 1] | None = None) ->
6465
return _is_constant_cs_major(a, shape)
6566

6667

67-
@numba.njit(cache=True)
68+
@njit
6869
def _is_constant_cs_major(a: types.CSBase, shape: tuple[int, int]) -> NDArray[np.bool]:
6970
n = len(a.indptr) - 1
7071
result = np.ones(n, dtype=np.bool)

src/fast_array_utils/stats/_mean_var.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
from .. import types
10+
from ..numba import njit
1011
from ._power import power
1112

1213

@@ -79,7 +80,7 @@ def _sparse_mean_var(mtx: types.CSBase, /, *, axis: Literal[0, 1]) -> tuple[NDAr
7980
)
8081

8182

82-
@numba.njit(cache=True)
83+
@njit
8384
def sparse_mean_var_minor_axis(
8485
x: types.CSBase,
8586
*,
@@ -109,7 +110,7 @@ def sparse_mean_var_minor_axis(
109110
return means, variances
110111

111112

112-
@numba.njit(cache=True)
113+
@njit
113114
def sparse_mean_var_major_axis(
114115
x: types.CSBase,
115116
*,

typings/numba/__init__.pyi

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,27 @@
22
from collections.abc import Callable, Iterable
33
from typing import Literal, SupportsIndex, overload
44

5-
from .core.types import *
5+
from .core import config as config
6+
from .core.types import Type
67

78
type __Signature = str | Type
89
type _Signature = str | Type | tuple[__Signature, ...]
910

1011
# https://numba.readthedocs.io/en/stable/reference/jit-compilation.html#numba.jit
1112
@overload
12-
def njit[F: Callable[..., object]](f: F) -> F: ...
13+
def njit[F: Callable[..., object]](
14+
f: F,
15+
*,
16+
nopython: bool = True,
17+
nogil: bool = False,
18+
cache: bool = False,
19+
forceobj: bool = False,
20+
parallel: bool = False,
21+
error_model: Literal["python", "numpy"] = "python",
22+
fastmath: bool = False,
23+
locals: dict[str, object] = {},
24+
boundscheck: bool = False,
25+
) -> F: ...
1326
@overload
1427
def njit[F: Callable[..., object]](
1528
signature: _Signature | list[_Signature] | None = None,

0 commit comments

Comments
 (0)