Skip to content

Commit d70246f

Browse files
committed
removed monkeypatch manual context
1 parent 0984eba commit d70246f

1 file changed

Lines changed: 37 additions & 37 deletions

File tree

tests/test_numba.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numba
2525

2626
from fast_array_utils import numba as fa_numba
27+
from fast_array_utils.numba import _parallel_runtime as probe
2728

2829

2930
def _return_true() -> bool:
@@ -39,7 +40,7 @@ def _sum_prange(values: NDArray[np.float64]) -> float:
3940

4041
@pytest.fixture(autouse=True)
4142
def clear_probe_cache() -> None:
42-
fa_numba._parallel_numba_runtime_is_safe_cached.cache_clear()
43+
probe._parallel_numba_runtime_is_safe_cached.cache_clear()
4344
fa_numba._threading_layer.cache_clear()
4445

4546

@@ -112,16 +113,16 @@ def _set_runtime(
112113
priority: tuple[fa_numba.ThreadingLayer, ...] = ("tbb", "omp", "workqueue"),
113114
layers: dict[fa_numba.TheadingCategory, set[fa_numba.ThreadingLayer]] | None = None,
114115
) -> None:
115-
monkeypatch.setattr(fa_numba.sys, "platform", platform_name)
116-
monkeypatch.setattr(fa_numba.platform, "machine", lambda: machine)
116+
monkeypatch.setattr(probe.sys, "platform", platform_name)
117+
monkeypatch.setattr(probe.platform, "machine", lambda: machine)
117118
for module in ("torch", "sklearn", "scanpy"):
118-
monkeypatch.delitem(fa_numba.sys.modules, module, raising=False)
119+
monkeypatch.delitem(probe.sys.modules, module, raising=False)
119120
for module in loaded:
120-
monkeypatch.setitem(fa_numba.sys.modules, module, object())
121+
monkeypatch.setitem(probe.sys.modules, module, object())
121122
monkeypatch.setattr(numba.config, "THREADING_LAYER", layer)
122123
monkeypatch.setattr(numba.config, "THREADING_LAYER_PRIORITY", list(priority))
123124
if layers is not None:
124-
monkeypatch.setattr(fa_numba, "LAYERS", layers)
125+
monkeypatch.setattr(probe, "LAYERS", layers)
125126

126127

127128
def _install_fake_njit(monkeypatch: pytest.MonkeyPatch, calls: list[bool]) -> None:
@@ -172,7 +173,7 @@ def test_probe_needed(
172173
expected: bool,
173174
) -> None:
174175
_set_runtime(monkeypatch, platform_name=platform_name, machine=machine, loaded=loaded, layer=layer, priority=priority, layers=layers)
175-
assert fa_numba._needs_parallel_runtime_probe() is expected
176+
assert probe._needs_parallel_runtime_probe() is expected
176177

177178

178179
def test_probe_check_is_lazy(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -188,26 +189,26 @@ def import_module(name: str, package: str | None = None) -> object:
188189

189190
monkeypatch.setattr(importlib, "import_module", import_module)
190191

191-
assert fa_numba._needs_parallel_runtime_probe() is True
192+
assert probe._needs_parallel_runtime_probe() is True
192193

193194

194195
def test_probe_uses_torch_context(monkeypatch: pytest.MonkeyPatch) -> None:
195196
_set_runtime(monkeypatch, loaded=(), layer="threadsafe", priority=("omp", "tbb"))
196197

197-
first = fa_numba._parallel_runtime_probe_key()
198-
monkeypatch.setitem(fa_numba.sys.modules, "sklearn", object())
199-
monkeypatch.setitem(fa_numba.sys.modules, "scanpy", object())
200-
second = fa_numba._parallel_runtime_probe_key()
201-
monkeypatch.setitem(fa_numba.sys.modules, "torch", object())
202-
third = fa_numba._parallel_runtime_probe_key()
198+
first = probe._parallel_runtime_probe_key()
199+
monkeypatch.setitem(probe.sys.modules, "sklearn", object())
200+
monkeypatch.setitem(probe.sys.modules, "scanpy", object())
201+
second = probe._parallel_runtime_probe_key()
202+
monkeypatch.setitem(probe.sys.modules, "torch", object())
203+
third = probe._parallel_runtime_probe_key()
203204

204-
assert fa_numba._loaded_relevant_parallel_runtime_probe_modules() == ("torch",)
205+
assert probe._loaded_relevant_parallel_runtime_probe_modules() == ("torch",)
205206
assert first == second
206207
assert first[3] == ()
207208
assert third[3] == ("torch",)
208209

209-
env = fa_numba._build_parallel_runtime_probe_env(third)
210-
code = fa_numba._parallel_runtime_probe_code(third[3])
210+
env = probe._build_parallel_runtime_probe_env(third)
211+
code = probe._parallel_runtime_probe_code(third[3])
211212

212213
assert env["NUMBA_THREADING_LAYER"] == "threadsafe"
213214
assert env["NUMBA_THREADING_LAYER_PRIORITY"] == "omp tbb"
@@ -222,21 +223,21 @@ def test_probe_result(monkeypatch: pytest.MonkeyPatch) -> None:
222223

223224
def run(cmd: list[str], /, **kwargs: object) -> subprocess.CompletedProcess[str]:
224225
calls.append((cmd, kwargs))
225-
return subprocess.CompletedProcess(cmd, 0, stdout=f"{fa_numba._PARALLEL_RUNTIME_PROBE_SENTINEL}\n", stderr="")
226+
return subprocess.CompletedProcess(cmd, 0, stdout=f"{probe._PARALLEL_RUNTIME_PROBE_SENTINEL}\n", stderr="")
226227

227-
monkeypatch.setattr(fa_numba.subprocess, "run", run)
228+
monkeypatch.setattr(probe.subprocess, "run", run)
228229

229-
assert fa_numba._parallel_numba_runtime_is_safe() is True
230-
assert fa_numba._parallel_numba_runtime_is_safe() is True
230+
assert probe._parallel_numba_runtime_is_safe() is True
231+
assert probe._parallel_numba_runtime_is_safe() is True
231232
assert calls == [
232233
(
233-
[fa_numba.sys.executable, "-c", fa_numba._parallel_runtime_probe_code(("torch",))],
234+
[probe.sys.executable, "-c", probe._parallel_runtime_probe_code(("torch",))],
234235
{
235236
"capture_output": True,
236237
"check": False,
237-
"env": fa_numba._build_parallel_runtime_probe_env(),
238+
"env": probe._build_parallel_runtime_probe_env(),
238239
"text": True,
239-
"timeout": fa_numba._PARALLEL_RUNTIME_PROBE_TIMEOUT,
240+
"timeout": probe._PARALLEL_RUNTIME_PROBE_TIMEOUT,
240241
},
241242
)
242243
]
@@ -264,9 +265,9 @@ def run(_cmd: list[str], /, **_kwargs: object) -> subprocess.CompletedProcess[st
264265
assert result is not None
265266
return result
266267

267-
monkeypatch.setattr(fa_numba.subprocess, "run", run)
268+
monkeypatch.setattr(probe.subprocess, "run", run)
268269

269-
assert fa_numba._parallel_numba_runtime_is_safe() is False
270+
assert probe._parallel_numba_runtime_is_safe() is False
270271

271272

272273
@pytest.mark.parametrize(
@@ -292,13 +293,13 @@ def test_njit_chooses_version(
292293

293294
monkeypatch.setattr(fa_numba, "_is_in_unsafe_thread_pool", lambda: unsafe_pool)
294295
if needs_probe is None:
295-
monkeypatch.setattr(fa_numba, "_needs_parallel_runtime_probe", lambda: pytest.fail("probe should not be consulted"))
296+
monkeypatch.setattr(probe, "_needs_parallel_runtime_probe", lambda: pytest.fail("probe should not be consulted"))
296297
else:
297-
monkeypatch.setattr(fa_numba, "_needs_parallel_runtime_probe", lambda: needs_probe)
298+
monkeypatch.setattr(probe, "_needs_parallel_runtime_probe", lambda: needs_probe)
298299
if probe_safe is None:
299-
monkeypatch.setattr(fa_numba, "_parallel_numba_runtime_is_safe", lambda: pytest.fail("probe should not run"))
300+
monkeypatch.setattr(probe, "_parallel_numba_runtime_is_safe", lambda: pytest.fail("probe should not run"))
300301
else:
301-
monkeypatch.setattr(fa_numba, "_parallel_numba_runtime_is_safe", lambda: probe_safe)
302+
monkeypatch.setattr(probe, "_parallel_numba_runtime_is_safe", lambda: probe_safe)
302303

303304
wrapped = fa_numba.njit(_return_true)
304305

@@ -313,15 +314,14 @@ def test_njit_chooses_version(
313314
assert calls == [expected]
314315

315316

316-
def test_serial_fallback() -> None:
317+
def test_serial_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
317318
values = np.arange(10, dtype=np.float64)
319+
monkeypatch.setattr(fa_numba, "_is_in_unsafe_thread_pool", lambda: False)
320+
monkeypatch.setattr(probe, "_needs_parallel_runtime_probe", lambda: True)
321+
monkeypatch.setattr(probe, "_parallel_numba_runtime_is_safe", lambda: False)
318322
wrapped = fa_numba.njit(_sum_prange)
319323

320-
with pytest.MonkeyPatch().context() as monkeypatch:
321-
monkeypatch.setattr(fa_numba, "_is_in_unsafe_thread_pool", lambda: False)
322-
monkeypatch.setattr(fa_numba, "_needs_parallel_runtime_probe", lambda: True)
323-
monkeypatch.setattr(fa_numba, "_parallel_numba_runtime_is_safe", lambda: False)
324-
with pytest.warns(UserWarning, match="unsupported numba parallel runtime"):
325-
result = wrapped(values)
324+
with pytest.warns(UserWarning, match="unsupported numba parallel runtime"):
325+
result = wrapped(values)
326326

327327
assert result == pytest.approx(np.sum(values))

0 commit comments

Comments
 (0)