1111
1212from __future__ import annotations
1313
14+ import os
15+ import platform
16+ import subprocess
1417import sys
1518import warnings
1619from functools import cache , update_wrapper , wraps
2932"""Identifier for a threading layer category."""
3033type ThreadingLayer = Literal ["tbb" , "omp" , "workqueue" ]
3134"""Identifier for a concrete threading layer."""
35+ type _ParallelRuntimeProbeKey = tuple [str , ThreadingLayer | TheadingCategory , tuple [ThreadingLayer , ...], tuple [str , ...]]
3236
3337
3438LAYERS : dict [TheadingCategory , set [ThreadingLayer ]] = {
3943}
4044
4145
46+ _PARALLEL_RUNTIME_PROBE_SENTINEL = "FAST_ARRAY_UTILS_NUMBA_PROBE_OK"
47+ _PARALLEL_RUNTIME_PROBE_TIMEOUT = 10
48+ _PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST = ("torch" ,)
49+
50+
4251def threading_layer (layer_or_category : ThreadingLayer | TheadingCategory | None = None , / , priority : Iterable [ThreadingLayer ] | None = None ) -> ThreadingLayer :
4352 """Get numba’s configured threading layer as specified in :ref:`numba-threading-layer`.
4453
@@ -84,6 +93,107 @@ def _is_in_unsafe_thread_pool() -> bool:
8493 return current_thread .name .startswith ("ThreadPoolExecutor" ) and threading_layer () not in LAYERS ["threadsafe" ]
8594
8695
96+ def _is_apple_silicon () -> bool :
97+ return sys .platform == "darwin" and platform .machine () == "arm64"
98+
99+
100+ def _is_torch_loaded () -> bool :
101+ return "torch" in sys .modules
102+
103+
104+ def _configured_threading_layer_or_category_without_probing () -> ThreadingLayer | TheadingCategory :
105+ import numba
106+
107+ # Avoid `threading_layer()` here: resolving backends may import pool modules.
108+ return cast ("ThreadingLayer | TheadingCategory" , numba .config .THREADING_LAYER )
109+
110+
111+ def _configured_explicit_threading_layer_without_probing () -> ThreadingLayer | None :
112+ layer_or_category = _configured_threading_layer_or_category_without_probing ()
113+ return layer_or_category if layer_or_category in LAYERS ["default" ] else None
114+
115+
116+ def _is_explicit_safe_threading_layer () -> bool :
117+ return _configured_explicit_threading_layer_without_probing () in {"workqueue" , "tbb" }
118+
119+
120+ def _could_select_omp_from_threading_config_without_probing () -> bool :
121+ if (layer := _configured_explicit_threading_layer_without_probing ()) is not None :
122+ return layer == "omp"
123+ return "omp" in LAYERS [cast ("TheadingCategory" , _configured_threading_layer_or_category_without_probing ())]
124+
125+
126+ def _needs_parallel_runtime_probe () -> bool :
127+ if not _is_apple_silicon () or not _is_torch_loaded ():
128+ return False
129+ if _is_explicit_safe_threading_layer ():
130+ return False
131+ return _could_select_omp_from_threading_config_without_probing ()
132+
133+
134+ def _loaded_relevant_parallel_runtime_probe_modules () -> tuple [str , ...]:
135+ return tuple (module for module in _PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST if module in sys .modules )
136+
137+
138+ def _parallel_runtime_probe_code (modules : tuple [str , ...]) -> str :
139+ lines = [* (f"import { module } " for module in modules ), "import numba" , "import numpy as np" , "" ]
140+ lines .extend (
141+ [
142+ "@numba.njit(parallel=True, cache=False)" ,
143+ "def _probe(values):" ,
144+ " total = 0.0" ,
145+ " for i in numba.prange(values.shape[0]):" ,
146+ " total += values[i]" ,
147+ " return total" ,
148+ "" ,
149+ "values = np.arange(32, dtype=np.float64)" ,
150+ "assert _probe(values) == np.sum(values)" ,
151+ f"print({ _PARALLEL_RUNTIME_PROBE_SENTINEL !r} )" ,
152+ "" ,
153+ ]
154+ )
155+ return "\n " .join (lines )
156+
157+
158+ def _parallel_runtime_probe_key () -> _ParallelRuntimeProbeKey :
159+ import numba
160+
161+ return (
162+ sys .executable ,
163+ _configured_threading_layer_or_category_without_probing (),
164+ tuple (cast ("Iterable[ThreadingLayer]" , numba .config .THREADING_LAYER_PRIORITY )),
165+ _loaded_relevant_parallel_runtime_probe_modules (),
166+ )
167+
168+
169+ def _build_parallel_runtime_probe_env (key : _ParallelRuntimeProbeKey | None = None ) -> dict [str , str ]:
170+ _ , layer_or_category , priority , _ = _parallel_runtime_probe_key () if key is None else key
171+ env = dict (os .environ )
172+ env ["NUMBA_THREADING_LAYER" ] = layer_or_category
173+ env ["NUMBA_THREADING_LAYER_PRIORITY" ] = " " .join (priority )
174+ return env
175+
176+
177+ @cache
178+ def _parallel_numba_runtime_is_safe_cached (key : _ParallelRuntimeProbeKey ) -> bool :
179+ try :
180+ result = subprocess .run (
181+ [key [0 ], "-c" , _parallel_runtime_probe_code (key [3 ])],
182+ capture_output = True ,
183+ check = False ,
184+ env = _build_parallel_runtime_probe_env (key ),
185+ text = True ,
186+ timeout = _PARALLEL_RUNTIME_PROBE_TIMEOUT ,
187+ )
188+ except Exception :
189+ return False
190+ return result .returncode == 0 and _PARALLEL_RUNTIME_PROBE_SENTINEL in result .stdout
191+
192+
193+ def _parallel_numba_runtime_is_safe () -> bool :
194+ return _parallel_numba_runtime_is_safe_cached (_parallel_runtime_probe_key ())
195+
196+
87197@overload
88198def njit [** P , R ](fn : Callable [P , R ], / ) -> Callable [P , R ]: ...
89199@overload
@@ -92,7 +202,7 @@ def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callab
92202 """Jit-compile a function using numba.
93203
94204 On call, this function dispatches to a parallel or serial numba function,
95- depending on if it has been called from a thread pool .
205+ depending on the current threading environment .
96206 """
97207 # See https://github.com/numbagg/numbagg/pull/201/files#r1409374809
98208
@@ -109,11 +219,18 @@ def decorator(f: Callable[P, R], /) -> Callable[P, R]:
109219
110220 @wraps (f )
111221 def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> R :
112- parallel = not _is_in_unsafe_thread_pool ()
113- if not parallel : # pragma: no cover
222+ if _is_in_unsafe_thread_pool (): # pragma: no cover
114223 msg = f"Detected unsupported threading environment. Trying to run { f .__name__ } in serial mode. In case of problems, install `tbb`."
115224 warnings .warn (msg , UserWarning , stacklevel = 2 )
116- return fns [parallel ](* args , ** kwargs )
225+ return fns [False ](* args , ** kwargs )
226+ if _needs_parallel_runtime_probe () and not _parallel_numba_runtime_is_safe ():
227+ msg = (
228+ f"Detected an unsupported numba parallel runtime. Running { f .__name__ } in serial mode as a workaround. "
229+ "Set `NUMBA_THREADING_LAYER=workqueue` or install `tbb` to avoid this fallback."
230+ )
231+ warnings .warn (msg , UserWarning , stacklevel = 2 )
232+ return fns [False ](* args , ** kwargs )
233+ return fns [True ](* args , ** kwargs )
117234
118235 return wrapper
119236
0 commit comments