Skip to content

Commit 075bbe6

Browse files
author
Jhonatan Ramos Felix
committed
added numba runtime tests
1 parent 9edf6cf commit 075bbe6

1 file changed

Lines changed: 258 additions & 0 deletions

File tree

tests/test_numba.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
import importlib
5+
import subprocess
6+
import warnings
7+
from typing import Any
8+
9+
import numpy as np
10+
import pytest
11+
12+
13+
pytest.importorskip("numba")
14+
15+
import numba
16+
17+
from fast_array_utils import numba as fa_numba
18+
19+
20+
def _return_true() -> bool:
21+
return True
22+
23+
24+
def _sum_prange(values: np.ndarray[Any, Any]) -> float:
25+
total = 0.0
26+
for i in numba.prange(values.shape[0]):
27+
total += values[i]
28+
return total
29+
30+
31+
@pytest.fixture(autouse=True)
32+
def clear_probe_cache() -> None:
33+
fa_numba._parallel_numba_runtime_is_safe_cached.cache_clear()
34+
35+
36+
def _set_runtime(
37+
monkeypatch: pytest.MonkeyPatch,
38+
*,
39+
platform_name: str = "darwin",
40+
machine: str = "arm64",
41+
loaded: tuple[str, ...] = ("torch",),
42+
layer: fa_numba.ThreadingLayer | fa_numba.TheadingCategory = "default",
43+
priority: tuple[fa_numba.ThreadingLayer, ...] = ("tbb", "omp", "workqueue"),
44+
layers: dict[fa_numba.TheadingCategory, set[fa_numba.ThreadingLayer]] | None = None,
45+
) -> None:
46+
monkeypatch.setattr(fa_numba.sys, "platform", platform_name)
47+
monkeypatch.setattr(fa_numba.platform, "machine", lambda: machine)
48+
for module in ("torch", "sklearn", "scanpy"):
49+
monkeypatch.delitem(fa_numba.sys.modules, module, raising=False)
50+
for module in loaded:
51+
monkeypatch.setitem(fa_numba.sys.modules, module, object())
52+
monkeypatch.setattr(numba.config, "THREADING_LAYER", layer)
53+
monkeypatch.setattr(numba.config, "THREADING_LAYER_PRIORITY", list(priority))
54+
if layers is not None:
55+
monkeypatch.setattr(fa_numba, "LAYERS", layers)
56+
57+
58+
def _install_fake_njit(monkeypatch: pytest.MonkeyPatch, calls: list[bool]) -> None:
59+
def fake_njit(fn: object, /, *, cache: bool, parallel: bool) -> Any:
60+
assert cache is True
61+
62+
def compiled(*args: object, **kwargs: object) -> bool:
63+
calls.append(parallel)
64+
return parallel
65+
66+
return compiled
67+
68+
monkeypatch.setattr(numba, "njit", fake_njit)
69+
70+
71+
@pytest.mark.parametrize(
72+
("platform_name", "machine", "loaded", "layer", "priority", "layers", "expected"),
73+
[
74+
pytest.param("darwin", "arm64", ("torch",), "default", ("tbb", "omp", "workqueue"), None, True, id="default"),
75+
pytest.param("darwin", "arm64", ("torch",), "omp", ("tbb", "omp", "workqueue"), None, True, id="omp"),
76+
pytest.param("darwin", "arm64", ("torch",), "threadsafe", ("omp", "tbb"), None, True, id="threadsafe"),
77+
pytest.param("darwin", "arm64", ("torch",), "workqueue", ("tbb", "omp", "workqueue"), None, False, id="workqueue"),
78+
pytest.param("darwin", "arm64", ("torch",), "safe", ("tbb",), None, False, id="safe"),
79+
pytest.param(
80+
"darwin",
81+
"arm64",
82+
("torch",),
83+
"forksafe",
84+
("tbb", "omp", "workqueue"),
85+
{**fa_numba.LAYERS, "forksafe": {"tbb", "omp", "workqueue"}},
86+
True,
87+
id="forksafe",
88+
),
89+
pytest.param("darwin", "arm64", (), "default", ("tbb", "omp", "workqueue"), None, False, id="no-torch"),
90+
pytest.param("darwin", "x86_64", ("torch",), "default", ("tbb", "omp", "workqueue"), None, False, id="not-arm"),
91+
pytest.param("linux", "arm64", ("torch",), "default", ("tbb", "omp", "workqueue"), None, False, id="not-darwin"),
92+
],
93+
)
94+
def test_probe_needed(
95+
monkeypatch: pytest.MonkeyPatch,
96+
platform_name: str,
97+
machine: str,
98+
loaded: tuple[str, ...],
99+
layer: fa_numba.ThreadingLayer | fa_numba.TheadingCategory,
100+
priority: tuple[fa_numba.ThreadingLayer, ...],
101+
layers: dict[fa_numba.TheadingCategory, set[fa_numba.ThreadingLayer]] | None,
102+
expected: bool,
103+
) -> None:
104+
_set_runtime(monkeypatch, platform_name=platform_name, machine=machine, loaded=loaded, layer=layer, priority=priority, layers=layers)
105+
assert fa_numba._needs_parallel_runtime_probe() is expected
106+
107+
108+
def test_probe_check_is_lazy(monkeypatch: pytest.MonkeyPatch) -> None:
109+
_set_runtime(monkeypatch)
110+
monkeypatch.setattr(fa_numba, "threading_layer", lambda: pytest.fail("threading_layer() should not be called"))
111+
112+
original_import_module = importlib.import_module
113+
114+
def import_module(name: str, package: str | None = None) -> object:
115+
if name.startswith("numba.np.ufunc.") and name.endswith("pool"):
116+
pytest.fail(f"backend pool module {name!r} should not be imported")
117+
return original_import_module(name, package)
118+
119+
monkeypatch.setattr(importlib, "import_module", import_module)
120+
121+
assert fa_numba._needs_parallel_runtime_probe() is True
122+
123+
124+
def test_probe_uses_torch_context(monkeypatch: pytest.MonkeyPatch) -> None:
125+
_set_runtime(monkeypatch, loaded=(), layer="threadsafe", priority=("omp", "tbb"))
126+
127+
first = fa_numba._parallel_runtime_probe_key()
128+
monkeypatch.setitem(fa_numba.sys.modules, "sklearn", object())
129+
monkeypatch.setitem(fa_numba.sys.modules, "scanpy", object())
130+
second = fa_numba._parallel_runtime_probe_key()
131+
monkeypatch.setitem(fa_numba.sys.modules, "torch", object())
132+
third = fa_numba._parallel_runtime_probe_key()
133+
134+
assert fa_numba._loaded_relevant_parallel_runtime_probe_modules() == ("torch",)
135+
assert first == second
136+
assert first[3] == ()
137+
assert third[3] == ("torch",)
138+
139+
env = fa_numba._build_parallel_runtime_probe_env(third)
140+
code = fa_numba._parallel_runtime_probe_code(third[3])
141+
142+
assert env["NUMBA_THREADING_LAYER"] == "threadsafe"
143+
assert env["NUMBA_THREADING_LAYER_PRIORITY"] == "omp tbb"
144+
assert "import torch" in code
145+
assert "import sklearn" not in code
146+
assert "import scanpy" not in code
147+
148+
149+
def test_probe_result(monkeypatch: pytest.MonkeyPatch) -> None:
150+
_set_runtime(monkeypatch, loaded=("torch",))
151+
calls: list[tuple[list[str], dict[str, object]]] = []
152+
153+
def run(cmd: list[str], /, **kwargs: object) -> subprocess.CompletedProcess[str]:
154+
calls.append((cmd, kwargs))
155+
return subprocess.CompletedProcess(cmd, 0, stdout=f"{fa_numba._PARALLEL_RUNTIME_PROBE_SENTINEL}\n", stderr="")
156+
157+
monkeypatch.setattr(fa_numba.subprocess, "run", run)
158+
159+
assert fa_numba._parallel_numba_runtime_is_safe() is True
160+
assert fa_numba._parallel_numba_runtime_is_safe() is True
161+
assert calls == [
162+
(
163+
[fa_numba.sys.executable, "-c", fa_numba._parallel_runtime_probe_code(("torch",))],
164+
{
165+
"capture_output": True,
166+
"check": False,
167+
"env": fa_numba._build_parallel_runtime_probe_env(),
168+
"text": True,
169+
"timeout": fa_numba._PARALLEL_RUNTIME_PROBE_TIMEOUT,
170+
},
171+
)
172+
]
173+
174+
175+
@pytest.mark.parametrize(
176+
("result", "error"),
177+
[
178+
pytest.param(subprocess.CompletedProcess(["python"], 1, stdout="", stderr="boom"), None, id="nonzero"),
179+
pytest.param(subprocess.CompletedProcess(["python"], 0, stdout="", stderr=""), None, id="missing-sentinel"),
180+
pytest.param(None, subprocess.TimeoutExpired(["python"], timeout=1), id="timeout"),
181+
pytest.param(None, RuntimeError("boom"), id="exception"),
182+
],
183+
)
184+
def test_probe_failure(
185+
monkeypatch: pytest.MonkeyPatch,
186+
result: subprocess.CompletedProcess[str] | None,
187+
error: BaseException | None,
188+
) -> None:
189+
_set_runtime(monkeypatch)
190+
191+
def run(cmd: list[str], /, **kwargs: object) -> subprocess.CompletedProcess[str]:
192+
if error is not None:
193+
raise error
194+
assert result is not None
195+
return result
196+
197+
monkeypatch.setattr(fa_numba.subprocess, "run", run)
198+
199+
assert fa_numba._parallel_numba_runtime_is_safe() is False
200+
201+
202+
@pytest.mark.parametrize(
203+
("unsafe_pool", "needs_probe", "probe_safe", "expected", "warning"),
204+
[
205+
pytest.param(True, None, None, False, "unsupported threading environment", id="thread-pool"),
206+
pytest.param(False, True, False, False, "unsupported numba parallel runtime", id="probe-fails"),
207+
pytest.param(False, True, True, True, None, id="probe-passes"),
208+
pytest.param(False, False, None, True, None, id="no-probe"),
209+
],
210+
)
211+
def test_njit_chooses_version(
212+
monkeypatch: pytest.MonkeyPatch,
213+
unsafe_pool: bool,
214+
needs_probe: bool | None,
215+
probe_safe: bool | None,
216+
expected: bool,
217+
warning: str | None,
218+
) -> None:
219+
calls: list[bool] = []
220+
_install_fake_njit(monkeypatch, calls)
221+
222+
monkeypatch.setattr(fa_numba, "_is_in_unsafe_thread_pool", lambda: unsafe_pool)
223+
if needs_probe is None:
224+
monkeypatch.setattr(fa_numba, "_needs_parallel_runtime_probe", lambda: pytest.fail("probe should not be consulted"))
225+
else:
226+
monkeypatch.setattr(fa_numba, "_needs_parallel_runtime_probe", lambda: needs_probe)
227+
if probe_safe is None:
228+
monkeypatch.setattr(fa_numba, "_parallel_numba_runtime_is_safe", lambda: pytest.fail("probe should not run"))
229+
else:
230+
monkeypatch.setattr(fa_numba, "_parallel_numba_runtime_is_safe", lambda: probe_safe)
231+
232+
wrapped = fa_numba.njit(_return_true)
233+
234+
if warning is None:
235+
with warnings.catch_warnings(record=True) as caught:
236+
warnings.simplefilter("always")
237+
assert wrapped() is expected
238+
assert not caught
239+
else:
240+
with pytest.warns(UserWarning, match=warning):
241+
assert wrapped() is expected
242+
assert calls == [expected]
243+
244+
245+
def test_serial_fallback() -> None:
246+
values = np.arange(10, dtype=np.float64)
247+
wrapped = fa_numba.njit(_sum_prange)
248+
249+
with (
250+
pytest.MonkeyPatch().context() as monkeypatch,
251+
pytest.warns(UserWarning, match="unsupported numba parallel runtime"),
252+
):
253+
monkeypatch.setattr(fa_numba, "_is_in_unsafe_thread_pool", lambda: False)
254+
monkeypatch.setattr(fa_numba, "_needs_parallel_runtime_probe", lambda: True)
255+
monkeypatch.setattr(fa_numba, "_parallel_numba_runtime_is_safe", lambda: False)
256+
result = wrapped(values)
257+
258+
assert result == pytest.approx(np.sum(values))

0 commit comments

Comments
 (0)