Skip to content

Commit 48c7783

Browse files
authored
Backport PR scverse#4015 on branch 1.12.x (fix: use fast-array-utils’ threadsafe njit) (scverse#4022)
1 parent 0a4a597 commit 48c7783

13 files changed

Lines changed: 27 additions & 117 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
with: { enable-cache: false }
2828
- id: get-envs
2929
run: |
30-
ENVS_JSON=$(NO_COLOR=1 uvx '--with=virtualenv<21' hatch env show --json | jq -c 'to_entries
30+
ENVS_JSON=$(NO_COLOR=1 uvx hatch env show --json | jq -c 'to_entries
3131
| map(
3232
select(.key | startswith("hatch-test"))
3333
| {
@@ -63,7 +63,7 @@ jobs:
6363
- name: Install dependencies
6464
run: |
6565
echo "::group::Install hatch"
66-
uv tool install hatch '--with=virtualenv<21'
66+
uv tool install hatch
6767
echo "::endgroup::"
6868
echo "::group::Create environment"
6969
hatch -v env create ${{ matrix.env.name }}

.readthedocs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ build:
1616
- asdf global uv latest
1717
pre_build:
1818
# run towncrier to preview the next version’s release notes
19-
- ( find docs/release-notes -regex '[^.]+[.][^.]+.md' | grep -q . ) && uvx "--with=virtualenv<21" hatch run towncrier build --keep || true
19+
- ( find docs/release-notes -regex '[^.]+[.][^.]+.md' | grep -q . ) && uvx hatch run towncrier build --keep || true
2020
build:
2121
html:
22-
- uvx "--with=virtualenv<21" hatch run docs:build
22+
- uvx hatch run docs:build
2323
- mv docs/_build $READTHEDOCS_OUTPUT

docs/release-notes/4015.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix crashes when running numba in dask by using the threadsafe {func}`fast_array_utils.numba.njit` {smaller}`P Angerer`

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ dynamic = [ "version" ]
5353
dependencies = [
5454
"anndata>=0.10.8",
5555
"certifi",
56-
"fast-array-utils[accel,sparse]>=1.2.1",
56+
"fast-array-utils[accel,sparse]>=1.4",
5757
"h5py>=3.11",
5858
"joblib",
5959
"legacy-api-wrap>=1.5", # for positional API deprecations
@@ -215,8 +215,8 @@ lint.pylint.max-positional-args = 5
215215

216216
[tool.ruff.lint.flake8-tidy-imports.banned-api]
217217
"legacy_api_wrap.legacy_api".msg = "Use scanpy._compat.old_positionals instead"
218-
"numba.jit".msg = "Use `scanpy._compat.njit` instead"
219-
"numba.njit".msg = "Use `scanpy._compat.njit` instead"
218+
"numba.jit".msg = "Use `fast_array_utils.numba.njit` instead"
219+
"numba.njit".msg = "Use `fast_array_utils.numba.njit` instead"
220220
"numpy.bool_".msg = "Use `np.bool` instead"
221221
"pandas.api.types.is_categorical_dtype".msg = "Use isinstance(s.dtype, CategoricalDtype) instead"
222222
"pandas.value_counts".msg = "Use pd.Series(a).value_counts() instead"
@@ -291,7 +291,7 @@ run.source_pkgs = [ "scanpy" ]
291291
paths.source = [ "src", "**/site-packages" ]
292292
report.exclude_also = [
293293
# https://github.com/numba/numba/issues/4268
294-
"@(numba\\.|nb\\.)?njit.*",
294+
"@([\\w.]+.)?njit.*",
295295
"@deprecated.*",
296296
"if __name__ == .__main__.:",
297297
"if TYPE_CHECKING:",

src/scanpy/_compat.py

Lines changed: 2 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22

33
import sys
44
import warnings
5-
from functools import cache, partial, wraps
5+
from functools import cache, partial
66
from importlib.util import find_spec
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Literal, cast, overload
8+
from typing import TYPE_CHECKING
99

1010
import legacy_api_wrap
1111
from packaging.version import Version
1212
from scipy import sparse
1313

1414
if TYPE_CHECKING:
15-
from collections.abc import Callable
1615
from importlib.metadata import PackageMetadata
1716

1817

@@ -22,10 +21,8 @@
2221
"CSRBase",
2322
"DaskArray",
2423
"SpBase",
25-
"_numba_threading_layer",
2624
"deprecated",
2725
"fullname",
28-
"njit",
2926
"old_positionals",
3027
"pkg_metadata",
3128
"pkg_version",
@@ -123,99 +120,3 @@ def warn(
123120
warnings.warn( # noqa: TID251
124121
message, category, source=source, skip_file_prefixes=skip_file_prefixes
125122
)
126-
127-
128-
@overload
129-
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
130-
@overload
131-
def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
132-
def njit[**P, R](
133-
fn: Callable[P, R] | None = None, /
134-
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
135-
"""Jit-compile a function using numba.
136-
137-
On call, this function dispatches to a parallel or sequential numba function,
138-
depending on if it has been called from a thread pool.
139-
140-
See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809>
141-
"""
142-
143-
def decorator(f: Callable[P, R], /) -> Callable[P, R]:
144-
import numba
145-
146-
fns: dict[bool, Callable[P, R]] = {
147-
parallel: numba.njit(f, cache=True, parallel=parallel) # noqa: TID251
148-
for parallel in (True, False)
149-
}
150-
151-
@wraps(f)
152-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
153-
parallel = not _is_in_unsafe_thread_pool()
154-
if not parallel:
155-
msg = (
156-
"Detected unsupported threading environment. "
157-
f"Trying to run {f.__name__} in serial mode. "
158-
"In case of problems, install `tbb`."
159-
)
160-
warn(msg, UserWarning)
161-
return fns[parallel](*args, **kwargs)
162-
163-
return wrapper
164-
165-
return decorator if fn is None else decorator(fn)
166-
167-
168-
type LayerType = Literal["default", "safe", "threadsafe", "forksafe"]
169-
type Layer = Literal["tbb", "omp", "workqueue"]
170-
171-
172-
LAYERS: dict[LayerType, set[Layer]] = {
173-
"default": {"tbb", "omp", "workqueue"},
174-
"safe": {"tbb"},
175-
"threadsafe": {"tbb", "omp"},
176-
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
177-
}
178-
179-
180-
def _is_in_unsafe_thread_pool() -> bool:
181-
import threading
182-
183-
current_thread = threading.current_thread()
184-
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
185-
return (
186-
current_thread.name.startswith("ThreadPoolExecutor")
187-
and _numba_threading_layer() not in LAYERS["threadsafe"]
188-
)
189-
190-
191-
@cache
192-
def _numba_threading_layer() -> Layer:
193-
"""Get numba’s threading layer.
194-
195-
This function implements the algorithm as described in
196-
<https://numba.readthedocs.io/en/stable/user/threading-layer.html>
197-
"""
198-
import importlib
199-
200-
import numba
201-
202-
if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None:
203-
# given by direct name
204-
return numba.config.THREADING_LAYER
205-
206-
# given by layer type (safe, …)
207-
for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY):
208-
if layer not in available:
209-
continue
210-
if layer != "workqueue":
211-
try: # `importlib.util.find_spec` doesn’t work here
212-
importlib.import_module(f"numba.np.ufunc.{layer}pool")
213-
except ImportError:
214-
continue
215-
# the layer has been found
216-
return layer
217-
msg = (
218-
f"No loadable threading layer: {numba.config.THREADING_LAYER=} "
219-
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
220-
)
221-
raise ValueError(msg)

src/scanpy/experimental/pp/_highly_variable_genes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import numpy as np
99
import pandas as pd
1010
from anndata import AnnData
11+
from fast_array_utils.numba import njit
1112
from fast_array_utils.stats import mean_var
1213

1314
from ... import logging as logg
14-
from ..._compat import CSBase, njit, warn
15+
from ..._compat import CSBase, warn
1516
from ..._settings import Verbosity, settings
1617
from ..._utils import _doc_params, check_nonnegative_integers, view_to_actual
1718
from ...experimental._docs import (

src/scanpy/metrics/_gearys_c.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import numba
99
import numpy as np
10+
from fast_array_utils.numba import njit
1011

11-
from .._compat import CSRBase, njit
12+
from .._compat import CSRBase
1213
from .._utils import _doc_params
1314
from ..get import _get_obs_rep
1415
from ..neighbors._doc import doc_neighbors_key

src/scanpy/metrics/_morans_i.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import numba
99
import numpy as np
10+
from fast_array_utils.numba import njit
1011

11-
from .._compat import CSRBase, njit
12+
from .._compat import CSRBase
1213
from .._utils import _doc_params
1314
from ..get import _get_obs_rep
1415
from ..neighbors._doc import doc_neighbors_key

src/scanpy/preprocessing/_normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import numba
77
import numpy as np
88
from fast_array_utils import stats
9+
from fast_array_utils.numba import njit
910

1011
from .. import logging as logg
11-
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, njit, old_positionals, warn
12+
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, old_positionals, warn
1213
from .._utils import axis_mul_or_truediv, dematrix, view_to_actual
1314
from ..get import _get_obs_rep, _set_obs_rep
1415

src/scanpy/preprocessing/_qc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import numpy as np
88
import pandas as pd
99
from fast_array_utils import stats
10+
from fast_array_utils.numba import njit
1011
from scipy import sparse
1112

1213
from scanpy.get import _get_obs_rep
1314
from scanpy.preprocessing._distributed import materialize_as_ndarray
1415

15-
from .._compat import CSBase, CSRBase, DaskArray, njit, warn
16+
from .._compat import CSBase, CSRBase, DaskArray, warn
1617
from .._utils import _doc_params, axis_nnz
1718
from ._docs import (
1819
doc_adata_basic,

0 commit comments

Comments
 (0)