Skip to content

Commit e91911a

Browse files
committed
fix: use parallelized numba functions if possible
1 parent cdaa7c9 commit e91911a

6 files changed

Lines changed: 118 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,6 @@ doctest_subpackage_requires = [
177177
"src/fast_array_utils/_plugins/dask.py = dask",
178178
"src/fast_array_utils/_plugins/numba_sparse.py = numba;scipy",
179179
]
180-
filterwarnings = [
181-
"error",
182-
# codspeed seems to break this dtype added by h5py
183-
"ignore:.*numpy[.]longdouble:UserWarning",
184-
"ignore:FNV hashing is not implemented in Numba:UserWarning",
185-
]
186180
markers = [
187181
"benchmark: marks tests as benchmark (to run with `--codspeed`)",
188182
]

src/fast_array_utils/conv/scipy/_to_dense.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numba
77

8+
from fast_array_utils.utils import njit
9+
810

911
if TYPE_CHECKING:
1012
from typing import Any
@@ -18,14 +20,14 @@
1820
__all__ = ["_to_dense_csc_numba", "_to_dense_csr_numba"]
1921

2022

21-
@numba.njit(cache=True)
23+
@njit
2224
def _to_dense_csc_numba(x: CSBase, out: NDArray[np.number[Any]]) -> None:
2325
for c in numba.prange(out.shape[1]):
2426
for i in range(x.indptr[c], x.indptr[c + 1]):
2527
out[x.indices[i], c] = x.data[i]
2628

2729

28-
@numba.njit(cache=True)
30+
@njit
2931
def _to_dense_csr_numba(x: CSBase, out: NDArray[np.number[Any]]) -> None:
3032
for r in numba.prange(out.shape[0]):
3133
for i in range(x.indptr[r], x.indptr[r + 1]):

src/fast_array_utils/stats/_is_constant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numba
88
import numpy as np
99

10+
from fast_array_utils.utils import njit
11+
1012
from .. import types
1113

1214

@@ -64,7 +66,7 @@ def _is_constant_cs(a: types.CSBase, /, *, axis: Literal[0, 1] | None = None) ->
6466
return _is_constant_cs_major(a, shape)
6567

6668

67-
@numba.njit(cache=True)
69+
@njit
6870
def _is_constant_cs_major(a: types.CSBase, shape: tuple[int, int]) -> NDArray[np.bool]:
6971
n = len(a.indptr) - 1
7072
result = np.ones(n, dtype=np.bool)

src/fast_array_utils/stats/_mean_var.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numba
77
import numpy as np
88

9+
from fast_array_utils.utils import njit
10+
911
from .. import types
1012
from ._power import power
1113

@@ -79,7 +81,7 @@ def _sparse_mean_var(mtx: types.CSBase, /, *, axis: Literal[0, 1]) -> tuple[NDAr
7981
)
8082

8183

82-
@numba.njit(cache=True)
84+
@njit
8385
def sparse_mean_var_minor_axis(
8486
x: types.CSBase,
8587
*,
@@ -109,7 +111,7 @@ def sparse_mean_var_minor_axis(
109111
return means, variances
110112

111113

112-
@numba.njit(cache=True)
114+
@njit
113115
def sparse_mean_var_major_axis(
114116
x: types.CSBase,
115117
*,

src/fast_array_utils/utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
import warnings
5+
from functools import cache, wraps
6+
from typing import TYPE_CHECKING, Literal, cast, overload
7+
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
12+
type LayerType = Literal["default", "safe", "threadsafe", "forksafe"]
13+
type Layer = Literal["tbb", "omp", "workqueue"]
14+
15+
16+
LAYERS: dict[LayerType, set[Layer]] = {
17+
"default": {"tbb", "omp", "workqueue"},
18+
"safe": {"tbb"},
19+
"threadsafe": {"tbb", "omp"},
20+
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
21+
}
22+
23+
24+
@cache
25+
def _numba_threading_layer() -> Layer:
26+
"""Get numba’s threading layer.
27+
28+
This function implements the algorithm as described in
29+
<https://numba.readthedocs.io/en/stable/user/threading-layer.html>
30+
"""
31+
import importlib
32+
33+
import numba
34+
35+
if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None:
36+
# given by direct name
37+
return numba.config.THREADING_LAYER
38+
39+
# given by layer type (safe, …)
40+
for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY):
41+
if layer not in available:
42+
continue
43+
if layer != "workqueue":
44+
try: # `importlib.util.find_spec` doesn’t work here
45+
importlib.import_module(f"numba.np.ufunc.{layer}pool")
46+
except ImportError:
47+
continue
48+
# the layer has been found
49+
return layer
50+
msg = f"No loadable threading layer: {numba.config.THREADING_LAYER=} ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
51+
raise ValueError(msg)
52+
53+
54+
def _is_in_unsafe_thread_pool() -> bool:
55+
import threading
56+
57+
current_thread = threading.current_thread()
58+
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
59+
return current_thread.name.startswith("ThreadPoolExecutor") and _numba_threading_layer() not in LAYERS["threadsafe"]
60+
61+
62+
@overload
63+
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
64+
@overload
65+
def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
66+
def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
67+
"""Jit-compile a function using numba.
68+
69+
On call, this function dispatches to a parallel or sequential numba function,
70+
depending on if it has been called from a thread pool.
71+
72+
See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809>
73+
"""
74+
75+
def decorator(f: Callable[P, R], /) -> Callable[P, R]:
76+
import numba
77+
78+
fns: dict[bool, Callable[P, R]] = {parallel: numba.njit(f, cache=True, parallel=parallel) for parallel in (True, False)}
79+
80+
@wraps(f)
81+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
82+
parallel = not _is_in_unsafe_thread_pool()
83+
if not parallel:
84+
msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`."
85+
warnings.warn(msg, UserWarning, stacklevel=2)
86+
return fns[parallel](*args, **kwargs)
87+
88+
return wrapper
89+
90+
return decorator if fn is None else decorator(fn)

src/testing/fast_array_utils/pytest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,23 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator:
8989
SUPPORTED_TYPE_PARAMS = [pytest.param(t, id=str(t), marks=_skip_if_unimportable(t)) for t in SUPPORTED_TYPES]
9090

9191

92+
@pytest.fixture(autouse=True)
93+
def dask_single_threaded() -> Generator[None]:
94+
"""Switch to a single-threaded scheduler for tests since numba crashes otherwise."""
95+
if not find_spec("dask"):
96+
yield
97+
return
98+
99+
import dask.config
100+
101+
prev_scheduler = dask.config.get("scheduler", "threads")
102+
dask.config.set(scheduler="single-threaded")
103+
try:
104+
yield
105+
finally:
106+
dask.config.set(scheduler=prev_scheduler)
107+
108+
92109
@pytest.fixture(scope="session", params=SUPPORTED_TYPE_PARAMS)
93110
def array_type(request: pytest.FixtureRequest) -> ArrayType:
94111
"""Fixture for a supported :class:`~testing.fast_array_utils.ArrayType`.

0 commit comments

Comments
 (0)