2424import numba
2525
2626from fast_array_utils import numba as fa_numba
27+ from fast_array_utils .numba import _parallel_runtime as probe
2728
2829
2930def _return_true () -> bool :
@@ -39,7 +40,7 @@ def _sum_prange(values: NDArray[np.float64]) -> float:
3940
4041@pytest .fixture (autouse = True )
4142def clear_probe_cache () -> None :
42- fa_numba ._parallel_numba_runtime_is_safe_cached .cache_clear ()
43+ probe ._parallel_numba_runtime_is_safe_cached .cache_clear ()
4344 fa_numba ._threading_layer .cache_clear ()
4445
4546
@@ -112,16 +113,16 @@ def _set_runtime(
112113 priority : tuple [fa_numba .ThreadingLayer , ...] = ("tbb" , "omp" , "workqueue" ),
113114 layers : dict [fa_numba .TheadingCategory , set [fa_numba .ThreadingLayer ]] | None = None ,
114115) -> None :
115- monkeypatch .setattr (fa_numba .sys , "platform" , platform_name )
116- monkeypatch .setattr (fa_numba .platform , "machine" , lambda : machine )
116+ monkeypatch .setattr (probe .sys , "platform" , platform_name )
117+ monkeypatch .setattr (probe .platform , "machine" , lambda : machine )
117118 for module in ("torch" , "sklearn" , "scanpy" ):
118- monkeypatch .delitem (fa_numba .sys .modules , module , raising = False )
119+ monkeypatch .delitem (probe .sys .modules , module , raising = False )
119120 for module in loaded :
120- monkeypatch .setitem (fa_numba .sys .modules , module , object ())
121+ monkeypatch .setitem (probe .sys .modules , module , object ())
121122 monkeypatch .setattr (numba .config , "THREADING_LAYER" , layer )
122123 monkeypatch .setattr (numba .config , "THREADING_LAYER_PRIORITY" , list (priority ))
123124 if layers is not None :
124- monkeypatch .setattr (fa_numba , "LAYERS" , layers )
125+ monkeypatch .setattr (probe , "LAYERS" , layers )
125126
126127
127128def _install_fake_njit (monkeypatch : pytest .MonkeyPatch , calls : list [bool ]) -> None :
@@ -172,7 +173,7 @@ def test_probe_needed(
172173 expected : bool ,
173174) -> None :
174175 _set_runtime (monkeypatch , platform_name = platform_name , machine = machine , loaded = loaded , layer = layer , priority = priority , layers = layers )
175- assert fa_numba ._needs_parallel_runtime_probe () is expected
176+ assert probe ._needs_parallel_runtime_probe () is expected
176177
177178
178179def test_probe_check_is_lazy (monkeypatch : pytest .MonkeyPatch ) -> None :
@@ -188,26 +189,26 @@ def import_module(name: str, package: str | None = None) -> object:
188189
189190 monkeypatch .setattr (importlib , "import_module" , import_module )
190191
191- assert fa_numba ._needs_parallel_runtime_probe () is True
192+ assert probe ._needs_parallel_runtime_probe () is True
192193
193194
194195def test_probe_uses_torch_context (monkeypatch : pytest .MonkeyPatch ) -> None :
195196 _set_runtime (monkeypatch , loaded = (), layer = "threadsafe" , priority = ("omp" , "tbb" ))
196197
197- first = fa_numba ._parallel_runtime_probe_key ()
198- monkeypatch .setitem (fa_numba .sys .modules , "sklearn" , object ())
199- monkeypatch .setitem (fa_numba .sys .modules , "scanpy" , object ())
200- second = fa_numba ._parallel_runtime_probe_key ()
201- monkeypatch .setitem (fa_numba .sys .modules , "torch" , object ())
202- third = fa_numba ._parallel_runtime_probe_key ()
198+ first = probe ._parallel_runtime_probe_key ()
199+ monkeypatch .setitem (probe .sys .modules , "sklearn" , object ())
200+ monkeypatch .setitem (probe .sys .modules , "scanpy" , object ())
201+ second = probe ._parallel_runtime_probe_key ()
202+ monkeypatch .setitem (probe .sys .modules , "torch" , object ())
203+ third = probe ._parallel_runtime_probe_key ()
203204
204- assert fa_numba ._loaded_relevant_parallel_runtime_probe_modules () == ("torch" ,)
205+ assert probe ._loaded_relevant_parallel_runtime_probe_modules () == ("torch" ,)
205206 assert first == second
206207 assert first [3 ] == ()
207208 assert third [3 ] == ("torch" ,)
208209
209- env = fa_numba ._build_parallel_runtime_probe_env (third )
210- code = fa_numba ._parallel_runtime_probe_code (third [3 ])
210+ env = probe ._build_parallel_runtime_probe_env (third )
211+ code = probe ._parallel_runtime_probe_code (third [3 ])
211212
212213 assert env ["NUMBA_THREADING_LAYER" ] == "threadsafe"
213214 assert env ["NUMBA_THREADING_LAYER_PRIORITY" ] == "omp tbb"
@@ -222,21 +223,21 @@ def test_probe_result(monkeypatch: pytest.MonkeyPatch) -> None:
222223
223224 def run (cmd : list [str ], / , ** kwargs : object ) -> subprocess .CompletedProcess [str ]:
224225 calls .append ((cmd , kwargs ))
225- return subprocess .CompletedProcess (cmd , 0 , stdout = f"{ fa_numba ._PARALLEL_RUNTIME_PROBE_SENTINEL } \n " , stderr = "" )
226+ return subprocess .CompletedProcess (cmd , 0 , stdout = f"{ probe ._PARALLEL_RUNTIME_PROBE_SENTINEL } \n " , stderr = "" )
226227
227- monkeypatch .setattr (fa_numba .subprocess , "run" , run )
228+ monkeypatch .setattr (probe .subprocess , "run" , run )
228229
229- assert fa_numba ._parallel_numba_runtime_is_safe () is True
230- assert fa_numba ._parallel_numba_runtime_is_safe () is True
230+ assert probe ._parallel_numba_runtime_is_safe () is True
231+ assert probe ._parallel_numba_runtime_is_safe () is True
231232 assert calls == [
232233 (
233- [fa_numba .sys .executable , "-c" , fa_numba ._parallel_runtime_probe_code (("torch" ,))],
234+ [probe .sys .executable , "-c" , probe ._parallel_runtime_probe_code (("torch" ,))],
234235 {
235236 "capture_output" : True ,
236237 "check" : False ,
237- "env" : fa_numba ._build_parallel_runtime_probe_env (),
238+ "env" : probe ._build_parallel_runtime_probe_env (),
238239 "text" : True ,
239- "timeout" : fa_numba ._PARALLEL_RUNTIME_PROBE_TIMEOUT ,
240+ "timeout" : probe ._PARALLEL_RUNTIME_PROBE_TIMEOUT ,
240241 },
241242 )
242243 ]
@@ -264,9 +265,9 @@ def run(_cmd: list[str], /, **_kwargs: object) -> subprocess.CompletedProcess[st
264265 assert result is not None
265266 return result
266267
267- monkeypatch .setattr (fa_numba .subprocess , "run" , run )
268+ monkeypatch .setattr (probe .subprocess , "run" , run )
268269
269- assert fa_numba ._parallel_numba_runtime_is_safe () is False
270+ assert probe ._parallel_numba_runtime_is_safe () is False
270271
271272
272273@pytest .mark .parametrize (
@@ -292,13 +293,13 @@ def test_njit_chooses_version(
292293
293294 monkeypatch .setattr (fa_numba , "_is_in_unsafe_thread_pool" , lambda : unsafe_pool )
294295 if needs_probe is None :
295- monkeypatch .setattr (fa_numba , "_needs_parallel_runtime_probe" , lambda : pytest .fail ("probe should not be consulted" ))
296+ monkeypatch .setattr (probe , "_needs_parallel_runtime_probe" , lambda : pytest .fail ("probe should not be consulted" ))
296297 else :
297- monkeypatch .setattr (fa_numba , "_needs_parallel_runtime_probe" , lambda : needs_probe )
298+ monkeypatch .setattr (probe , "_needs_parallel_runtime_probe" , lambda : needs_probe )
298299 if probe_safe is None :
299- monkeypatch .setattr (fa_numba , "_parallel_numba_runtime_is_safe" , lambda : pytest .fail ("probe should not run" ))
300+ monkeypatch .setattr (probe , "_parallel_numba_runtime_is_safe" , lambda : pytest .fail ("probe should not run" ))
300301 else :
301- monkeypatch .setattr (fa_numba , "_parallel_numba_runtime_is_safe" , lambda : probe_safe )
302+ monkeypatch .setattr (probe , "_parallel_numba_runtime_is_safe" , lambda : probe_safe )
302303
303304 wrapped = fa_numba .njit (_return_true )
304305
@@ -313,15 +314,14 @@ def test_njit_chooses_version(
313314 assert calls == [expected ]
314315
315316
316- def test_serial_fallback () -> None :
317+ def test_serial_fallback (monkeypatch : pytest . MonkeyPatch ) -> None :
317318 values = np .arange (10 , dtype = np .float64 )
319+ monkeypatch .setattr (fa_numba , "_is_in_unsafe_thread_pool" , lambda : False )
320+ monkeypatch .setattr (probe , "_needs_parallel_runtime_probe" , lambda : True )
321+ monkeypatch .setattr (probe , "_parallel_numba_runtime_is_safe" , lambda : False )
318322 wrapped = fa_numba .njit (_sum_prange )
319323
320- with pytest .MonkeyPatch ().context () as monkeypatch :
321- monkeypatch .setattr (fa_numba , "_is_in_unsafe_thread_pool" , lambda : False )
322- monkeypatch .setattr (fa_numba , "_needs_parallel_runtime_probe" , lambda : True )
323- monkeypatch .setattr (fa_numba , "_parallel_numba_runtime_is_safe" , lambda : False )
324- with pytest .warns (UserWarning , match = "unsupported numba parallel runtime" ):
325- result = wrapped (values )
324+ with pytest .warns (UserWarning , match = "unsupported numba parallel runtime" ):
325+ result = wrapped (values )
326326
327327 assert result == pytest .approx (np .sum (values ))
0 commit comments