|
2 | 2 |
|
3 | 3 | import sys |
4 | 4 | import warnings |
5 | | -from functools import cache, partial, wraps |
| 5 | +from functools import cache, partial |
6 | 6 | from importlib.util import find_spec |
7 | 7 | from pathlib import Path |
8 | | -from typing import TYPE_CHECKING, Literal, cast, overload |
| 8 | +from typing import TYPE_CHECKING |
9 | 9 |
|
10 | 10 | import legacy_api_wrap |
11 | 11 | from packaging.version import Version |
12 | 12 | from scipy import sparse |
13 | 13 |
|
14 | 14 | if TYPE_CHECKING: |
15 | | - from collections.abc import Callable |
16 | 15 | from importlib.metadata import PackageMetadata |
17 | 16 |
|
18 | 17 |
|
|
22 | 21 | "CSRBase", |
23 | 22 | "DaskArray", |
24 | 23 | "SpBase", |
25 | | - "_numba_threading_layer", |
26 | 24 | "deprecated", |
27 | 25 | "fullname", |
28 | | - "njit", |
29 | 26 | "old_positionals", |
30 | 27 | "pkg_metadata", |
31 | 28 | "pkg_version", |
@@ -123,99 +120,3 @@ def warn( |
123 | 120 | warnings.warn( # noqa: TID251 |
124 | 121 | message, category, source=source, skip_file_prefixes=skip_file_prefixes |
125 | 122 | ) |
126 | | - |
127 | | - |
128 | | -@overload |
129 | | -def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ... |
130 | | -@overload |
131 | | -def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ... |
132 | | -def njit[**P, R]( |
133 | | - fn: Callable[P, R] | None = None, / |
134 | | -) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: |
135 | | - """Jit-compile a function using numba. |
136 | | -
|
137 | | - On call, this function dispatches to a parallel or sequential numba function, |
138 | | - depending on if it has been called from a thread pool. |
139 | | -
|
140 | | - See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809> |
141 | | - """ |
142 | | - |
143 | | - def decorator(f: Callable[P, R], /) -> Callable[P, R]: |
144 | | - import numba |
145 | | - |
146 | | - fns: dict[bool, Callable[P, R]] = { |
147 | | - parallel: numba.njit(f, cache=True, parallel=parallel) # noqa: TID251 |
148 | | - for parallel in (True, False) |
149 | | - } |
150 | | - |
151 | | - @wraps(f) |
152 | | - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
153 | | - parallel = not _is_in_unsafe_thread_pool() |
154 | | - if not parallel: |
155 | | - msg = ( |
156 | | - "Detected unsupported threading environment. " |
157 | | - f"Trying to run {f.__name__} in serial mode. " |
158 | | - "In case of problems, install `tbb`." |
159 | | - ) |
160 | | - warn(msg, UserWarning) |
161 | | - return fns[parallel](*args, **kwargs) |
162 | | - |
163 | | - return wrapper |
164 | | - |
165 | | - return decorator if fn is None else decorator(fn) |
166 | | - |
167 | | - |
168 | | -type LayerType = Literal["default", "safe", "threadsafe", "forksafe"] |
169 | | -type Layer = Literal["tbb", "omp", "workqueue"] |
170 | | - |
171 | | - |
172 | | -LAYERS: dict[LayerType, set[Layer]] = { |
173 | | - "default": {"tbb", "omp", "workqueue"}, |
174 | | - "safe": {"tbb"}, |
175 | | - "threadsafe": {"tbb", "omp"}, |
176 | | - "forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})}, |
177 | | -} |
178 | | - |
179 | | - |
180 | | -def _is_in_unsafe_thread_pool() -> bool: |
181 | | - import threading |
182 | | - |
183 | | - current_thread = threading.current_thread() |
184 | | - # ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1' |
185 | | - return ( |
186 | | - current_thread.name.startswith("ThreadPoolExecutor") |
187 | | - and _numba_threading_layer() not in LAYERS["threadsafe"] |
188 | | - ) |
189 | | - |
190 | | - |
191 | | -@cache |
192 | | -def _numba_threading_layer() -> Layer: |
193 | | - """Get numba’s threading layer. |
194 | | -
|
195 | | - This function implements the algorithm as described in |
196 | | - <https://numba.readthedocs.io/en/stable/user/threading-layer.html> |
197 | | - """ |
198 | | - import importlib |
199 | | - |
200 | | - import numba |
201 | | - |
202 | | - if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None: |
203 | | - # given by direct name |
204 | | - return numba.config.THREADING_LAYER |
205 | | - |
206 | | - # given by layer type (safe, …) |
207 | | - for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY): |
208 | | - if layer not in available: |
209 | | - continue |
210 | | - if layer != "workqueue": |
211 | | - try: # `importlib.util.find_spec` doesn’t work here |
212 | | - importlib.import_module(f"numba.np.ufunc.{layer}pool") |
213 | | - except ImportError: |
214 | | - continue |
215 | | - # the layer has been found |
216 | | - return layer |
217 | | - msg = ( |
218 | | - f"No loadable threading layer: {numba.config.THREADING_LAYER=} " |
219 | | - f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" |
220 | | - ) |
221 | | - raise ValueError(msg) |
0 commit comments