Skip to content

Commit 9edf6cf

Browse files
author
Jhonatan Ramos Felix
committed
handle broken numba parallel runtime with torch on apple silicon
1 parent d5176d6 commit 9edf6cf

1 file changed

Lines changed: 121 additions & 4 deletions

File tree

src/fast_array_utils/numba.py

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
from __future__ import annotations
1313

14+
import os
15+
import platform
16+
import subprocess
1417
import sys
1518
import warnings
1619
from functools import cache, update_wrapper, wraps
@@ -29,6 +32,7 @@
2932
"""Identifier for a threading layer category."""
3033
type ThreadingLayer = Literal["tbb", "omp", "workqueue"]
3134
"""Identifier for a concrete threading layer."""
35+
type _ParallelRuntimeProbeKey = tuple[str, ThreadingLayer | TheadingCategory, tuple[ThreadingLayer, ...], tuple[str, ...]]
3236

3337

3438
LAYERS: dict[TheadingCategory, set[ThreadingLayer]] = {
@@ -39,6 +43,11 @@
3943
}
4044

4145

46+
_PARALLEL_RUNTIME_PROBE_SENTINEL = "FAST_ARRAY_UTILS_NUMBA_PROBE_OK"
47+
_PARALLEL_RUNTIME_PROBE_TIMEOUT = 10
48+
_PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST = ("torch",)
49+
50+
4251
def threading_layer(layer_or_category: ThreadingLayer | TheadingCategory | None = None, /, priority: Iterable[ThreadingLayer] | None = None) -> ThreadingLayer:
4352
"""Get numba’s configured threading layer as specified in :ref:`numba-threading-layer`.
4453
@@ -84,6 +93,107 @@ def _is_in_unsafe_thread_pool() -> bool:
8493
return current_thread.name.startswith("ThreadPoolExecutor") and threading_layer() not in LAYERS["threadsafe"]
8594

8695

96+
def _is_apple_silicon() -> bool:
97+
return sys.platform == "darwin" and platform.machine() == "arm64"
98+
99+
100+
def _is_torch_loaded() -> bool:
101+
return "torch" in sys.modules
102+
103+
104+
def _configured_threading_layer_or_category_without_probing() -> ThreadingLayer | TheadingCategory:
105+
import numba
106+
107+
# Avoid `threading_layer()` here: resolving backends may import pool modules.
108+
return cast("ThreadingLayer | TheadingCategory", numba.config.THREADING_LAYER)
109+
110+
111+
def _configured_explicit_threading_layer_without_probing() -> ThreadingLayer | None:
112+
layer_or_category = _configured_threading_layer_or_category_without_probing()
113+
return layer_or_category if layer_or_category in LAYERS["default"] else None
114+
115+
116+
def _is_explicit_safe_threading_layer() -> bool:
117+
return _configured_explicit_threading_layer_without_probing() in {"workqueue", "tbb"}
118+
119+
120+
def _could_select_omp_from_threading_config_without_probing() -> bool:
121+
if (layer := _configured_explicit_threading_layer_without_probing()) is not None:
122+
return layer == "omp"
123+
return "omp" in LAYERS[cast("TheadingCategory", _configured_threading_layer_or_category_without_probing())]
124+
125+
126+
def _needs_parallel_runtime_probe() -> bool:
127+
if not _is_apple_silicon() or not _is_torch_loaded():
128+
return False
129+
if _is_explicit_safe_threading_layer():
130+
return False
131+
return _could_select_omp_from_threading_config_without_probing()
132+
133+
134+
def _loaded_relevant_parallel_runtime_probe_modules() -> tuple[str, ...]:
135+
return tuple(module for module in _PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST if module in sys.modules)
136+
137+
138+
def _parallel_runtime_probe_code(modules: tuple[str, ...]) -> str:
139+
lines = [*(f"import {module}" for module in modules), "import numba", "import numpy as np", ""]
140+
lines.extend(
141+
[
142+
"@numba.njit(parallel=True, cache=False)",
143+
"def _probe(values):",
144+
" total = 0.0",
145+
" for i in numba.prange(values.shape[0]):",
146+
" total += values[i]",
147+
" return total",
148+
"",
149+
"values = np.arange(32, dtype=np.float64)",
150+
"assert _probe(values) == np.sum(values)",
151+
f"print({_PARALLEL_RUNTIME_PROBE_SENTINEL!r})",
152+
"",
153+
]
154+
)
155+
return "\n".join(lines)
156+
157+
158+
def _parallel_runtime_probe_key() -> _ParallelRuntimeProbeKey:
159+
import numba
160+
161+
return (
162+
sys.executable,
163+
_configured_threading_layer_or_category_without_probing(),
164+
tuple(cast("Iterable[ThreadingLayer]", numba.config.THREADING_LAYER_PRIORITY)),
165+
_loaded_relevant_parallel_runtime_probe_modules(),
166+
)
167+
168+
169+
def _build_parallel_runtime_probe_env(key: _ParallelRuntimeProbeKey | None = None) -> dict[str, str]:
170+
_, layer_or_category, priority, _ = _parallel_runtime_probe_key() if key is None else key
171+
env = dict(os.environ)
172+
env["NUMBA_THREADING_LAYER"] = layer_or_category
173+
env["NUMBA_THREADING_LAYER_PRIORITY"] = " ".join(priority)
174+
return env
175+
176+
177+
@cache
178+
def _parallel_numba_runtime_is_safe_cached(key: _ParallelRuntimeProbeKey) -> bool:
179+
try:
180+
result = subprocess.run(
181+
[key[0], "-c", _parallel_runtime_probe_code(key[3])],
182+
capture_output=True,
183+
check=False,
184+
env=_build_parallel_runtime_probe_env(key),
185+
text=True,
186+
timeout=_PARALLEL_RUNTIME_PROBE_TIMEOUT,
187+
)
188+
except Exception:
189+
return False
190+
return result.returncode == 0 and _PARALLEL_RUNTIME_PROBE_SENTINEL in result.stdout
191+
192+
193+
def _parallel_numba_runtime_is_safe() -> bool:
194+
return _parallel_numba_runtime_is_safe_cached(_parallel_runtime_probe_key())
195+
196+
87197
@overload
88198
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
89199
@overload
@@ -92,7 +202,7 @@ def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callab
92202
"""Jit-compile a function using numba.
93203
94204
On call, this function dispatches to a parallel or serial numba function,
95-
depending on if it has been called from a thread pool.
205+
depending on the current threading environment.
96206
"""
97207
# See https://github.com/numbagg/numbagg/pull/201/files#r1409374809
98208

@@ -109,11 +219,18 @@ def decorator(f: Callable[P, R], /) -> Callable[P, R]:
109219

110220
@wraps(f)
111221
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
112-
parallel = not _is_in_unsafe_thread_pool()
113-
if not parallel: # pragma: no cover
222+
if _is_in_unsafe_thread_pool(): # pragma: no cover
114223
msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`."
115224
warnings.warn(msg, UserWarning, stacklevel=2)
116-
return fns[parallel](*args, **kwargs)
225+
return fns[False](*args, **kwargs)
226+
if _needs_parallel_runtime_probe() and not _parallel_numba_runtime_is_safe():
227+
msg = (
228+
f"Detected an unsupported numba parallel runtime. Running {f.__name__} in serial mode as a workaround. "
229+
"Set `NUMBA_THREADING_LAYER=workqueue` or install `tbb` to avoid this fallback."
230+
)
231+
warnings.warn(msg, UserWarning, stacklevel=2)
232+
return fns[False](*args, **kwargs)
233+
return fns[True](*args, **kwargs)
117234

118235
return wrapper
119236

0 commit comments

Comments
 (0)